Skip to main content
torch.js has not been released yet.
torch.js logotorch.js logotorch.js
PlaygroundContact
Login
Documentation
IntroductionType SafetyTensor ExpressionsTensor IndexingEinsumEinopsAutogradTraining a ModelProfiling & MemoryPyTorch MigrationBest PracticesRuntimesPerformancePyTorch CompatibilityBenchmarksDType Coverage
adaptive_avg_pool1dadaptive_avg_pool2dadaptive_avg_pool3dadaptive_max_pool1dadaptive_max_pool1d_with_indicesadaptive_max_pool2dadaptive_max_pool2d_with_indicesadaptive_max_pool3dadaptive_max_pool3d_with_indicesAdaptiveMaxPoolFunctionalOptionsaffine_gridAffineGridFunctionalOptionsalpha_dropoutAlphaDropoutFunctionalOptionsavg_pool1davg_pool2davg_pool3dAvgPool1dFunctionalOptionsAvgPool2dFunctionalOptionsAvgPool3dFunctionalOptionsbatch_normBatchNormFunctionalOptionsbinary_cross_entropybinary_cross_entropy_with_logitsBinaryCrossEntropyFunctionalOptionsBinaryCrossEntropyWithLogitsFunctionalOptionsCeluFunctionalOptionschannel_shuffleconv_transpose1dconv_transpose2dconv_transpose3dconv1dConv1dFunctionalOptionsconv2dConv2dFunctionalOptionsconv3dConv3dFunctionalOptionsConvTranspose1dFunctionalOptionsConvTranspose2dFunctionalOptionsConvTranspose3dFunctionalOptionscosine_embedding_losscosine_similarityCosineEmbeddingLossFunctionalOptionsCosineSimilarityFunctionalOptionscross_entropyCrossEntropyFunctionalOptionsctc_lossCTCLossOptionsdropoutdropout1ddropout2ddropout3dDropoutFunctionalOptionsEluFunctionalOptionsembeddingembedding_bagEmbeddingBagFunctionalOptionsEmbeddingFunctionalOptionsfeature_alpha_dropoutfoldFoldFunctionalOptionsfractional_max_pool2dfractional_max_pool2d_with_indicesfractional_max_pool3dfractional_max_pool3d_with_indicesFractionalMaxPoolFunctionalOptionsgaussian_nll_lossGluFunctionalOptionsgrid_sampleGridSampleFunctionalOptionsgroup_normgrouped_mmGroupedMMFunctionalOptionsGroupNormFunctionalOptionsHardshrinkFunctionalOptionsHardtanhFunctionalOptionshinge_embedding_lossHingeEmbeddingLossFunctionalOptionshuber_lossHuberLossFunctionalOptionsinstance_normInstanceNormFunctionalOptionsinterpolateInterpolateFunctionalOptionskl_divKlDivFunctionalOptionsKLDivOptionsl1_lossL1LossFunctionalOptionslayer_normLayerNormFunctionalOptionsLeakyReluFunctionalOptionslinearlocal_response_normLocalResponseNormFunctionalOptionslog_softmaxlp_pool1dlp_pool2dlp_pool3dLPPoolFunctionalOptionsmargin_ranking_lossMarginRankingLossFunctionalOptionsmax_pool1dmax_pool1d_with_indicesmax_pool2dmax_pool2d_with_indicesmax_pool3dmax_pool3d_with_indicesmax_unpool1dmax_unpool2dmax_unpool3dMaxPool1dFunctionalOptionsMaxPool2dFunctionalOptionsMaxPool3dFunctionalOptionsMaxUnpoolFunctionalOptionsmse_lossMseLossFunctionalOptionsmulti_head_attention_forwardmulti_margin_lossMultiHeadAttentionFunctionalOptionsmultilabel_margin_lossmultilabel_soft_margin_lossnll_lossNllLossFunctionalOptionsnormalizeNormalizeFunctionalOptionsone_hotpadPadFunctionalOptionspairwise_distancePairwiseDistanceFunctionalOptionspdistPdistFunctionalOptionspixel_shufflepixel_unshufflepoisson_nll_lossPoolWithIndicesResultReluFunctionalOptionsrms_normRmsNormFunctionalOptionsRreluFunctionalOptionsscaled_grouped_mmscaled_mmScaledDotProductAttentionFunctionalOptionsScaledGroupedMMFunctionalOptionsScaledMMFunctionalOptionssmooth_l1_lossSmoothL1LossFunctionalOptionssoft_margin_lossSoftMarginLossFunctionalOptionsSoftmaxOptionsSoftminFunctionalOptionsSoftplusFunctionalOptionsSoftshrinkFunctionalOptionstriplet_margin_losstriplet_margin_with_distance_lossTripletMarginLossFunctionalOptionsunfoldUnfoldFunctionalOptionsupsampleupsample_bilinearupsample_nearestUpsampleBilinearOptionsUpsampleNearestOptionsUpsampleOptions
ActivationOptionsAdaptiveAvgPool1dAdaptiveAvgPool2dAdaptiveAvgPool3dAdaptiveLogSoftmaxOptionsAdaptiveLogSoftmaxWithLossAdaptiveMaxPool1dAdaptiveMaxPool1dOptionsAdaptiveMaxPool2dAdaptiveMaxPool2dOptionsAdaptiveMaxPool3dAdaptiveMaxPool3dOptionsadd_moduleAlphaDropoutappendappendapplyAvgPool1dAvgPool1dOptionsAvgPool2dAvgPool2dOptionsAvgPool3dAvgPool3dOptionsBackwardHookBackwardPreHookBatchNorm1dBatchNorm2dBatchNorm3dBatchNormOptionsBCELossBCEWithLogitsLossBilinearBilinearOptionsBufferBufferOptionsBufferRegistrationHookbufferscallCELUCELUOptionsChannelShufflechildrenCircularPad1dCircularPad2dCircularPad3dclearConstantPad1dConstantPad2dConstantPad3dConv1dConv2dConv3dConvOptionsConvTranspose1dConvTranspose2dConvTranspose3dConvTransposeOptionsCosineEmbeddingLossCosineEmbeddingLossOptionsCosineSimilarityCosineSimilarityOptionscreatecreateCrossEntropyLossCTCLossdecodedecodedeleteDropoutDropout1dDropout2dDropout3dDropoutOptionsELUELUOptionsEmbeddingEmbeddingBagEmbeddingBagForwardOptionsEmbeddingBagFromPretrainedOptionsEmbeddingBagOptionsEmbeddingFromPretrainedOptionsEmbeddingOptionsencodeencodeentriesentriesevalextendFeatureAlphaDropoutFlattenFlattenOptionsFoldFoldOptionsforwardforwardforwardforwardforwardforwardforwardforwardforwardforwardforwardforwardforward_with_targetForwardHookForwardPreHookFractionalMaxPool2dFractionalMaxPool3dFractionalMaxPoolOptionsfrom_pretrainedfrom_pretrainedGaussianNLLLossGELUGELUOptionsgenerate_square_subsequent_maskgetgetgetgetgetget_bufferget_parameterget_submoduleGLUGLUOptionsGroupNormGroupNormOptionsGRUGRUCellHardshrinkHardshrinkOptionsHardsigmoidHardswishHardtanhHardtanhOptionshashasHingeEmbeddingLossHingeEmbeddingLossOptionsHuberLossHuberLossOptionsIdentityInstanceNorm1dInstanceNorm2dInstanceNorm3dInstanceNormOptionsis_uninitialized_bufferis_uninitialized_parameteriterator]iterator]iterator]iterator]keyskeysKLDivLossL1LossL1LossOptionsLayerNormLayerNormOptionsLazyBatchNorm1dLazyBatchNorm2dLazyBatchNorm3dLazyConv1dLazyConv2dLazyConv3dLazyConvOptionsLazyConvTranspose1dLazyConvTranspose2dLazyConvTranspose3dLazyConvTransposeOptionsLazyInstanceNorm1dLazyInstanceNorm2dLazyInstanceNorm3dLazyLinearLeakyReLULeakyReLUOptionsLinearLinearOptionsload_state_dictload_state_dictLocalResponseNormLocalResponseNormOptionslog_probLogSigmoidLogSoftmaxLogSoftmaxOptionsLPPool1dLPPool1dOptionsLPPool2dLPPool2dOptionsLPPool3dLPPool3dOptionsLSTMLSTMCellLSTMCellOptionsMarginRankingLossMarginRankingLossOptionsmaterializematerializematerialize_uninitializedmaterialize_uninitializedMaxPool1dMaxPool1dOptionsMaxPool2dMaxPool2dOptionsMaxPool3dMaxPool3dOptionsMaxUnpool1dMaxUnpool1dOptionsMaxUnpool2dMaxUnpool2dOptionsMaxUnpool3dMaxUnpool3dOptionsMishModuleModuleBuffersModuleChildrenModuleDictModuleDictOptionsModuleListModuleListOptionsModuleParametersModuleRegistrationHookmodulesMSELossMSELossOptionsmultihead_attnMultiheadAttentionMultiheadAttentionOptionsMultiheadAttnOptionsMultiLabelMarginLossMultiLabelMarginLossOptionsMultiLabelSoftMarginLossMultiMarginLossnamed_buffersnamed_childrennamed_modulesnamed_parametersNamedModulesOptionsNamedRecurseOptionsNLLLossnum_parametersNumParametersOptionsPairwiseDistancePairwiseDistanceOptionsParameterParameterDictParameterDictOptionsParameterListParameterListOptionsParameterOptionsParameterRegistrationHookparametersPixelShufflePixelUnshufflePoissonNLLLosspoppopPReLUPReLUOptionsRecurseOptionsReflectionPad1dReflectionPad2dReflectionPad3dregister_backward_hookregister_bufferregister_forward_hookregister_forward_pre_hookregister_full_backward_hookregister_full_backward_pre_hookregister_module_backward_hookregister_module_buffer_registration_hookregister_module_forward_hookregister_module_forward_pre_hookregister_module_full_backward_hookregister_module_full_backward_pre_hookregister_module_module_registration_hookregister_module_parameter_registration_hookregister_parameterReLUReLU6RemovableHandleremoveReplicationPad1dReplicationPad2dReplicationPad3dRMSNormRMSNormOptionsRNNRNNBaseRNNBaseOptionsRNNCellRNNCellOptionsRReLURReLUOptionsrunrunSELUSequentialsetsetsetSigmoidSiLUSmoothL1LossSmoothL1LossOptionsSoftMarginLossSoftMarginLossOptionsSoftmaxSoftmax2dSoftmaxOptionsSoftminSoftminOptionsSoftplusSoftplusOptionsSoftshrinkSoftshrinkOptionsSoftsignstate_dictstate_dictStateDictOptionsstepSyncBatchNormTanhTanhshrinkThresholdThresholdOptionstotototrainTrainOptionsTransformerTransformerDecoderTransformerDecoderDecodeOptionsTransformerDecoderLayerTransformerDecoderLayerDecodeOptionsTransformerDecoderLayerOptionsTransformerDecoderOptionsTransformerEncoderTransformerEncoderEncodeOptionsTransformerEncoderLayerTransformerEncoderLayerEncodeOptionsTransformerEncoderLayerOptionsTransformerEncoderOptionsTransformerOptionsTransformerRunOptionsTripletMarginLossTripletMarginWithDistanceLossUnflattenUnfoldUnfoldOptionsUninitializedBufferUninitializedOptionsUninitializedParameterupdateUpsampleUpsamplingBilinear2dUpsamplingNearest2dvaluesvalueszero_gradZeroPad1dZeroPad2dZeroPad3d
absacosacoshAdaptivePool1dShapeAdaptivePool2dShapeaddaddbmmAddbmmOptionsaddcdivAddcdivOptionsaddcmulAddcmulOptionsaddmmAddmmOptionsaddmvAddmvOptionsaddrAddrOptionsadjointallallcloseAllcloseOptionsAlphaBetaOptionsamaxaminaminmaxAminmaxOptionsangleanyapplyOutarangeare_deterministic_algorithms_enabledargmaxargminargsortargwhereas_stridedas_tensorasinasinhAssertNoShapeErrorAssertNotErrorAsStridedOptionsAtat_error_index_out_of_boundsatanatan2atanhatleast_1datleast_2datleast_3dAtShapeautocast_decrement_nestingautocast_increment_nestingautograd_gradient_mismatch_errorautograd_not_registered_errorAutogradConfigAutogradDeviceAutogradDTypeAutogradEntryAutogradHandleAutogradHandleImplAxesRecordBackwardFnbaddbmmBaddbmmOptionsbartlett_windowBaseKernelConfigbatch_dimensions_do_not_match_errorbernoulliBernoulliOptionsBinaryBackwardFnBinaryBroadcastResultBinaryDTypeBinaryKernelConfigCPUBinaryKernelCPUBinaryOpConfigBinaryOpNamesBinaryOpSchemaBinaryOptionsbincountBincountOptionsbitwise_andbitwise_left_shiftbitwise_notbitwise_orbitwise_right_shiftbitwise_xorblackman_windowblock_diagbmmBooleanDTypeRulebroadcast_error_incompatible_dimensionsbroadcast_shapesbroadcast_tensorsbroadcast_toBroadcastShapeBroadcastShapeRulebroadcastShapesbucketizeBucketizeOptionsBufferUsagebuildEinopsErrorbuildErrorMessagecanBroadcastTocartesian_prodcatCatOptionsCatShapeCauchyOptionscdistCdistOptionsceilceluCeluFunctionalOptionschain_matmulCheckShapeErrorCholeskyShapechunkchunk_error_dim_out_of_rangeChunkOptionsclampClampOptionsclear_autocast_cacheclearEinopsCacheclearEinsumCacheclonecolumn_stackcombinationsCombinationsOptionscompiled_with_cxx11_abicomplexconjconj_physicalcontiguousConv1dShapeConv2dShapeConv3dShapeConvTranspose2dShapecopysigncorrcoefcoscoshcount_nonzeroCountNonzeroOptionscovcoverage_reportcoverageReportCoverageReportCovOptionsCPUForwardFnCPUKernelConfigCPUKernelEntryCPUOnlyResultCPUTensorDatacreateCumExtremeResultcreateTorchCreationOpSchemaCumExtremeResultcummaxcummincumprodCumShapecumsumcumulative_trapezoidCumulativeOptionsCumulativeOptionsWithDimdeg2raddetachDeterministicOptionsDetShapeDevicedevice_error_requiresDeviceBufferDeviceCapabilitiesDeviceCheckedResultDeviceConfigDeviceContextDeviceEntryDeviceHandleDeviceInputDeviceOptionsDeviceRegistryDeviceTypediagdiag_embedDiagEmbedOptionsdiagflatDiagflatOptionsDiagFlatOptionsdiagonal_scatterDiagonalOptionsDiagonalScatterOptionsDiagOptionsDiagShapediffDiffOptionsdigammadimension_error_out_of_rangeDispatchConfigdistDistOptionsdivdotDotShapeRuleDoubleDoubleDimdropoutDropoutFunctionalOptionsdsplitdstackDTypedtype_already_registered_errordtype_components_mismatch_errordtype_not_found_errorDTypeComponentsDTypeConfigDTypeCoverageReportDTypeDisplayConfigDTypeEntryDTypeHandleDTypeHandleImplDTypeInfoDTypeRegistryDTypeRuleDTypeSerializationConfigDynamicShapeEigShapeeinops_error_ambiguous_decompositioneinops_error_anonymous_in_outputeinops_error_dimension_mismatcheinops_error_invalid_patterneinops_error_reduce_undefined_outputeinops_error_repeat_missing_sizeeinops_error_undefined_axiseinsumeinsum_error_dimension_mismatcheinsum_error_index_out_of_rangeeinsum_error_invalid_equationeinsum_error_invalid_sublist_elementeinsum_error_operand_count_mismatcheinsum_error_subscript_rank_mismatcheinsum_error_unknown_output_indexEinsumOptionsEinsumOutputShapeEllipsiseluelu_EluFunctionalOptionsembedding_bag_error_requires_2d_inputemptyempty_cacheempty_likeeqequalerferfcerfinvexpexp2expandexpand_asexpand_error_incompatibleExpandShapeexpm1ExponentialOptionseyeEyeOptionsfftFFTOptionsfindKernelWithPredicatefindSimilarPatternsflattenFlattenOptionsFlattenShapeflipflip_error_dim_out_of_rangefliplrFlipShapeflipudfloat_powerFloatDTypeRulefloorfloor_dividefmaxfminfmodformatEquationErrorformatShapefracfrexpfrombufferfullfull_likefunction_already_registered_errorFunctionConfigFunctionEntryFunctionHandlegathergather_error_dim_out_of_rangeGatherShapegcdgegeluGeometricOptionsget_autocast_cpu_dtypeget_autocast_gpu_dtypeget_autocast_ipu_dtypeget_autocast_xla_dtypeget_default_deviceget_default_dtypeget_deterministic_debug_modeget_device_configget_device_contextget_device_moduleget_dtype_infoget_file_pathget_float32_matmul_precisionget_num_interop_threadsget_num_threadsget_op_infoget_printoptionsget_real_dtypeget_rng_stategetAutogradgetDTypegetEinopsCacheSizegetEinsumCacheSizegetFunctiongetKernelgetMethodgetOpInfoGetOpKindGetOpSchemagetScalarKernelgluGluFunctionalOptionsGradContextGradFnGradientsForgtHalfHalfDimhamming_windowhann_windowhardshrinkhardsigmoidhardswishhardtanhhardtanh_HardtanhFunctionalOptionshas_autogradhas_devicehas_dtypehas_kernelhasAutogradhasDTypehasFunctionhasKernelhasMethodhasScalarKernelHasShapeErrorheavisidehistcHistcOptionshistogramHistogramOptionsHistogramResulthsplithstackhypoti0IdentityShapeifftimagindex_addindex_copyindex_fillindex_putindex_reduceindex_selectindex_select_error_dim_out_of_rangeIndexPutOptionsIndexSelectShapeIndexSpecIndicesOptionsIndicesSpecinitialize_deviceInputsForInsertDiminvalid_config_errorinverseInverseShapeirfftis_anomaly_check_nan_enabledis_anomaly_enabledis_autocast_cache_enabledis_autocast_cpu_enabledis_autocast_ipu_enabledis_autocast_xla_enabledis_complexis_complex_dtypeis_cpu_only_modeis_deterministic_algorithms_warn_only_enabledis_floating_pointis_floating_point_dtypeis_inference_mode_enabledis_nonzerois_tensoris_warn_always_enabledis_webgpu_availableIs2DIsAtLeast1DIsBinaryOpIsBinaryOpNameiscloseIscloseOptionsisfiniteisinisinfisnanisneginfisposinfisrealIsReductionOpIsReductionOpNameIsRegistryErrorIsShapeErroristftISTFTOptionsIsUnaryOpIsUnaryOpNameitem_error_not_scalarItemResultkaiser_windowKaiserWindowOptionskernel_not_registered_errorkernel_signature_mismatch_errorKernelConfigKernelConfigWebGPUKernelEntryKernelHandleKernelInfoKernelPredicateKernelRegistryKernelWebGPUkronkthvalueKthvalueOptionslcmldexpleleaky_reluleaky_relu_LeakyReluFunctionalOptionslerplevenshteinDistancelgammalinalg_error_not_square_matrixlinalg_error_requires_2dlinalg_error_requires_at_least_2dlinearlinspacelist_custom_deviceslist_custom_dtypeslist_deviceslist_dtypeslist_functionslist_kernelslist_methodslist_opslistCustomDTypeslistDTypeslistFunctionslistKernelsListKernelsOptionslistMethodslistOpsListOpsOptionsloglog_softmaxlog10log1plog2logaddexplogaddexp2logcumsumexplogical_andlogical_notlogical_orlogical_xorLogitOptionsLogNormalOptionsLogOptionslogsigmoidlogspacelogsumexpLogsumexpOptionsltLUShapeLuSolveOptionsmasked_fillmasked_selectmasked_select_asyncMaskSpecmatmulmatmul_error_inner_dimensions_do_not_matchMatmul2DShapeMatmulShapeMatmulShapeRuleMatrixTransposeShapemaxmaximummeanmedianmemory_statsmemory_summarymeshgridmethod_already_registered_errormethod_dtype_not_supported_errorMethodConfigMethodEntryMethodHandleminminimummishmmMMShapeRulemodemovedimmsortmulmultinomialmultinomial_asyncMultinomialAsyncOptionsMultinomialOptionsMultiplyBymvMVShapeRulenan_to_numnanmeannanmediannanquantileNanReductionOptionsnansumNanToNumOptionsnarrownarrow_copynarrow_error_length_exceeds_boundsnarrow_error_start_out_of_boundsNarrowShapeneneedsBroadcastnegNegativeDimnextafternonzeroNonzeroOptionsnormnormalNormalOptionsNormOptionsnumelonesones_likeop_kind_mismatch_errorop_not_found_errorOpCoverageEntryOpInfoOpKindOpNameOpSchemaOpSchemasouterOuterShapepackPackShapepermutepermute_error_dimension_count_mismatchPermuteShapepoissonpolarPool1dShapePool2dShapePool3dShapepositivepowpreluPrintOptionsprodprofiler_allow_cudagraph_cupti_lazy_reinit_cuda12promote_typesPromoteDTypeRulePutOptionsquantileQuantileOptionsrad2degrandrand_likerandintrandint_likeRandintLikeOptionsRandintOptionsrandnrandn_likeRandomLikeOptionsRandomOptionsrandpermRangeSpecRankravelrealrearrangeRearrangeOptionsRearrangeShapereciprocalreduceReduceOperationReduceOptionsReduceShapeReductionKernelConfigCPUReductionKernelCPUReductionOpNamesReductionOpSchemaReductionOptionsReductionShapeRuleregister_backwardregister_deviceregister_dtyperegister_forwardregister_functionregister_methodregister_scalar_forwardregisterAutogradRegisterBackwardOptionsregisterBinaryOpregisterDTypeRegisterDTypeOptionsRegisteredDTyperegisterFunctionRegisterFunctionOptionsregisterKernelRegisterKernelOptionsregisterMethodRegisterMethodOptionsregisterScalarKernelregisterUnaryOpregistration_failed_errorrelurelu_relu6ReluFunctionalOptionsremainderRemoveDimrepeatrepeat_interleaveRepeatInterleaveOptionsRepeatOptionsRepeatShapeReplaceDimrequireWebGPUreset_peak_memory_statsreshapeReshapeShaperesult_typerfftrollRollOptionsrot90Rot90Optionsroundrrelurrelu_RreluFunctionalOptionsrsqrtSafeExpandShapeSameDTypeRuleSameShapeRuleSaveForBackwardScalarCPUForwardFnScalarCPUKernelConfigScalarKernelEntryScalarKernelHandleScalarWebGPUKernelConfigScaleDimscatterscatter_addscatter_add_scatter_error_dim_out_of_rangescatter_reducescatter_reduce_ScatterReduceOptionsScatterShapesearchsortedSearchSortedOptionsselectselect_error_index_out_of_boundsselect_scatterSelectShapeseluset_default_deviceset_default_tensor_typeset_deterministic_debug_modeset_float32_matmul_precisionset_printoptionsset_warn_alwaysSetupContextFnShapeShapeCheckedResultShapedTensorShapeErrorMessageShapeOpSchemaShapeRulesigmoidsignsignbitsilusinsincsinhSizeOptionsslice_error_out_of_boundsslice_scatterSliceOptionsSliceScatterOptionsSliceShapeSliceSpecsoftmaxsoftmax_error_dim_out_of_rangeSoftmaxShapesoftminSoftminFunctionalOptionssoftplusSoftplusFunctionalOptionssoftshrinksoftsignsortSortOptionssplitsplit_error_dim_out_of_rangeSplitOptionssqrtsquaresqueezeSqueezeOptionsSqueezeShapestackStackOptionsStackShapestdstd_meanStdVarMeanOptionsStdVarOptionsstftSTFTOptionsStrideOptionssubSublistSublistElementSubscriptIndexsumSVDShapeswapaxessym_floatsym_intsym_notttaketake_along_dimTakeAlongDimOptionstantanhtanhshrinktensortensor_splitTensorCreatorTensorDatatensordotTensordotOptionsTensorLikeTensorMetaTensorOptionsTensorStoragethresholdthreshold_tileTileShapeToOptionstopkTopkOptionsTorchtraceTraceShapetransposetranspose_dims_error_out_of_rangetranspose_error_requires_2d_tensorTransposeDimsShapeTransposeDimsShapeCheckedTransposeShapetrapezoidTrapezoidOptionsTriangularOptionstriltril_indicesTriOptionsTripletriutriu_indicestrue_dividetruncTupleOfLengthTypedArrayTypedArrayForTypedStorageTypeOptionsUnaryBackwardFnUnaryDTypeUnaryKernelConfigCPUUnaryKernelCPUUnaryOpConfigUnaryOpFnUnaryOpNamesUnaryOpParamsUnaryOpSchemaUnaryOptionsunbindunbind_error_dim_out_of_rangeUnbindOptionsunflattenUniformOptionsuniqueunique_consecutiveUniqueConsecutiveOptionsUniqueOptionsunpackUnpackShapeunravel_indexunregister_deviceunsqueezeUnsqueezeOptionsUnsqueezeShapeuse_deterministic_algorithmsValidateBatchedSquareMatrixValidateChunkDimValidatedEinsumShapevalidateDeviceValidateDeviceValidatedRearrangeShapeValidatedReduceShapeValidatedRepeatShapevalidateDTypeValidateEinsumValidateOperandCountValidateRanksValidateScalarValidateSplitDimValidateSquareMatrixValidateUnbindDimValueOptionsvar_var_meanvdotviewview_as_complexview_as_realvmapvsplitvstackWebGPUKernelConfigWebGPUOnlyResultWebGPUTensorDatawhereWindowOptionsxlogyzeroszeros_like
torch.js· 2026
LegalTerms of UsePrivacy Policy
/
/
  1. docs
  2. torch.js
  3. torch
  4. nn
  5. functional
  6. gaussian_nll_loss

torch.nn.functional.gaussian_nll_loss

function gaussian_nll_loss(input: Tensor, target: Tensor, var_: Tensor, options?: { full?: boolean; eps?: number; reduction?: 'none' | 'mean' | 'sum'; }): Tensor

Gaussian (normal distribution) negative log likelihood loss for continuous predictions.

Measures negative log likelihood under Gaussian distribution assumption. Predicts both the mean and variance of target distribution, enabling uncertainty quantification and heteroscedastic regression. Computes -log P(target | predicted_mean, predicted_variance) to measure prediction quality. Essential for:

  • Heteroscedastic regression (predict mean and uncertainty together)
  • Uncertainty quantification in neural networks (confidence in predictions)
  • Aleatoric uncertainty (data noise) modeling vs epistemic (model) uncertainty
  • Generative models (VAE, diffusion models) that predict variance
  • Bayesian neural networks (variational inference with Gaussian posteriors)
  • Robust regression that adapts to varying noise levels across samples
  • Time-series forecasting with confidence intervals/prediction bands
  • Probabilistic regression (predict distribution, not just point estimate)

Gaussian likelihood interpretation: Assumes target follows N(μ, σ²) with predicted mean μ and variance σ². Negative log likelihood: -log P(y | μ, σ²) = 0.5*[log(σ²) + (y-μ)²/σ²] Combines two terms:

  1. log(σ²): rewards confident predictions (low variance)
  2. (y-μ)²/σ²: penalizes errors relative to variance (heteroscedastic loss)

Heteroscedasticity and adaptive weighting: Network learns to increase variance σ² for hard-to-predict samples. Reduces loss contribution from high-variance regions (naturally downweights outliers). Variance = 1 → standard MSE; Variance > 1 → downweight error; Variance < 1 → upweight. Enables automatic importance weighting without manual specification.

Aleatoric vs epistemic uncertainty: Aleatoric (data noise): captured by variance σ² per sample Epistemic (model): captured by ensemble/MC-dropout disagreement (separate) Together: total uncertainty = aleatoric + epistemic (for robust prediction)

Mathematical form: Loss = 0.5 * [log(variance) + (prediction - target)² / variance] When full=False: simplified form (sufficient for training) When full=True: adds constant 0.5 * log(2π) (exact Gaussian NLL)

L=0.5[log⁡(σ2)+(x−y)2σ2]Lfull=L+0.5log⁡(2π)σ2=max⁡(var,ϵ)(clamped for stability)\begin{aligned} L = 0.5 \left[ \log(\sigma^2) + \frac{(x - y)^2}{\sigma^2} \right] \\ L_{\text{full}} = L + 0.5 \log(2\pi) \\ \sigma^2 = \max(\text{var}, \epsilon) \quad \text{(clamped for stability)} \end{aligned}L=0.5[log(σ2)+σ2(x−y)2​]Lfull​=L+0.5log(2π)σ2=max(var,ϵ)(clamped for stability)​
  • Variance must be positive: σ² 0; use exp/softplus to ensure positivity
  • Heteroscedastic weighting: Higher variance → lower loss contribution (automatic importance weighting)
  • Aleatoric uncertainty: Variance represents data noise, sample-specific
  • Distributed output: Can distribute network output for mean and variance
  • Log-variance stability: Use log(σ²) output then exponentiate (more stable)
  • Full parameter: full=False is usually sufficient; full=True rarely changes optimization
  • No assumption on target: Target can be any continuous values; doesn't require normality
  • Gradient stability: Variance clamping (eps) prevents log(0) and division issues
  • Positive variance required: Will error if variance ≤ 0 (clamps to eps)
  • Output layer: Must ensure σ² 0 (use exp, softplus, relu+eps, squared output)
  • Variance exploitation: Unconstrained variance → model learns to predict huge variance
  • Variance regularization: May need to regularize variance (prevent collapse)
  • Distribution assumption: Assumes Gaussian; if non-Gaussian, use other losses
  • Sample efficiency: Fitting variance requires more data than fitting mean alone

Parameters

inputTensor
Predicted mean μ of Gaussian distribution Shape [...] (any dimensions), represents E[target] under predicted distribution Example: [batch, output_dim] from final regression layer
targetTensor
Target values (observations) from Gaussian distribution Shape must match input; continuous values (unbounded) Example: actual continuous targets, regression labels
var_Tensor
Predicted variance σ² of Gaussian distribution Shape must match input; values 0 (variance is positive) Network must ensure positivity (e.g., exp(logvar), softplus, relu) Example: variance predictions from separate output head
options{ full?: boolean; eps?: number; reduction?: 'none' | 'mean' | 'sum'; }optional
Optional configuration: - full: Include constant term in NLL (default: false) - true: exact NLL = 0.5*[log(var) + (pred-target)²/var + log(2π)] - false: simplified = 0.5*[log(var) + (pred-target)²/var] (usually sufficient) - eps: Numerical stability floor for variance (default: 1e-6) - Clamps variance to max(variance, eps) to prevent log(0) and division by 0 - reduction: How to aggregate losses (default: 'mean') - 'none': per-element losses [...] - 'mean': average loss - 'sum': sum losses

Returns

Tensor– Loss tensor (same shape as input/target if reduction='none', scalar otherwise)

Examples

// Heteroscedastic regression: predict both mean and variance
const batch_size = 32;

// Split network output into mean and log-variance
const output = model(input);  // [batch, 100]
const mu = output.slice([0], [50]);        // First 50: predicted mean
const logvar = output.slice([50], [100]);  // Last 50: predicted log-variance
const var = logvar.exp();                  // Exponentiate to ensure > 0

const target = torch.randn([batch_size, 50]);  // Ground truth

const loss = torch.nn.functional.gaussian_nll_loss(
  mu, target, var,
  { full: false }  // Simplified loss
);
// Network learns to predict both accurate mean and appropriate variance
// Uncertainty quantification: prediction with confidence intervals
const x = torch.randn([1, 10]);  // Single input

// Network with two output heads
const mean_head = torch.nn.Linear(64, 5);
const logvar_head = torch.nn.Linear(64, 5);  // Log-variance for numerical stability

const hidden = model_backbone(x);  // [1, 64]
const mu = mean_head(hidden);              // [1, 5]
const var = logvar_head(hidden).exp();     // [1, 5], exponentiate

const targets = torch.tensor([[1, 2, 3, 4, 5]]);

const nll = torch.nn.functional.gaussian_nll_loss(mu, targets, var);

// Prediction confidence: lower variance → more confident
// Can construct 95% CI: mu ± 1.96*sqrt(var)
const std = var.sqrt();
const ci_lower = mu.sub(std.mul(1.96));
const ci_upper = mu.add(std.mul(1.96));
// Aleatoric uncertainty in computer vision: image regression
const batch_images = torch.randn([8, 3, 64, 64]);

// Network predicts image depth map with per-pixel uncertainty
const predicted_depth = model(batch_images);     // [8, 1, 64, 64]
const predicted_logvar = logvar_model(batch_images);  // [8, 1, 64, 64]
const predicted_var = predicted_logvar.exp();    // Ensure positivity

const target_depth = torch.randn([8, 1, 64, 64]);  // Ground truth depth

const depth_loss = torch.nn.functional.gaussian_nll_loss(
  predicted_depth,
  target_depth,
  predicted_var
);
// Per-pixel uncertainty: high variance at occlusions/depth discontinuities
// Smooth variance as auxiliary task (prevents overfitting)
// Bayesian deep learning: ensemble-like uncertainty
const num_samples = 10;
const predictions: Tensor[] = [];
const uncertainties: Tensor[] = [];

for (let i = 0; i < num_samples; i++) {
  // Forward pass with dropout enabled (MC-dropout)
  const mu_i = model.forward_mean(x);      // With dropout
  const var_i = model.forward_var(x);
  predictions.push(mu_i);
  uncertainties.push(var_i);
}

// Aleatoric uncertainty: average variance (data noise)
const aleatoric = torch.stack(uncertainties).mean(0);

// Epistemic uncertainty: variance of predictions (model disagreement)
const epistemic = torch.stack(predictions).var(0);

// Total uncertainty = aleatoric + epistemic
const total_unc = aleatoric.add(epistemic);

// Use mean prediction with aleatoric uncertainty in NLL
const mean_pred = torch.stack(predictions).mean(0);
const nll = torch.nn.functional.gaussian_nll_loss(mean_pred, target, aleatoric);
// Robust regression: outliers automatically downweighted
const x = torch.randn([100, 5]);
const targets = torch.randn([100, 1]);

// Network learns to increase variance for outliers
const predictions = model(x);  // [100, 2]: [mean, logvar]
const mu = predictions.slice([null, 0]);
const var = predictions.slice([null, 1]).exp();  // Positive variance

const loss = torch.nn.functional.gaussian_nll_loss(mu, targets, var);

// Samples with large errors naturally get higher variance predictions
// Reduces influence of outliers (proportional to 1/variance weighting)
// Comparison: MSE treats all errors equally; NLL downweights outliers
// Exact vs simplified NLL (full parameter)
const mu = torch.randn([32, 10]);
const target = torch.randn([32, 10]);
const var = torch.ones([32, 10]);  // Constant variance = 1

// Simplified NLL (faster, usually sufficient)
const simplified = torch.nn.functional.gaussian_nll_loss(mu, target, var, {
  full: false
});
// = 0.5 * [log(1) + (mu-target)²/1] = 0.5 * (mu-target)²

// Exact NLL (includes constant term)
const exact = torch.nn.functional.gaussian_nll_loss(mu, target, var, {
  full: true
});
// = 0.5 * [0 + (mu-target)² + log(2π)] = 0.5*((mu-target)² + log(2π))
// Difference = 0.5 * log(2π) ≈ 0.919 (constant for all samples)

See Also

  • PyTorch torch.nn.functional.gaussian_nll_loss
  • poisson_nll_loss - For count data (Poisson distribution)
  • torch.nn.GaussianNLLLoss - Module wrapper
  • mse_loss - Simpler alternative (assumes fixed variance)
  • torch.distributions.Normal - For sampling/likelihood computation
Previous
FractionalMaxPoolFunctionalOptions
Next
GluFunctionalOptions