Skip to main content
torch.js has not been released yet.
torch.js logotorch.js logotorch.js
PlaygroundContact
Login
Documentation
IntroductionType SafetyTensor ExpressionsTensor IndexingEinsumEinopsAutogradTraining a ModelProfiling & MemoryPyTorch MigrationBest PracticesRuntimesPerformancePyTorch CompatibilityBenchmarksDType Coverage
adaptive_avg_pool1dadaptive_avg_pool2dadaptive_avg_pool3dadaptive_max_pool1dadaptive_max_pool1d_with_indicesadaptive_max_pool2dadaptive_max_pool2d_with_indicesadaptive_max_pool3dadaptive_max_pool3d_with_indicesAdaptiveMaxPoolFunctionalOptionsaffine_gridAffineGridFunctionalOptionsalpha_dropoutAlphaDropoutFunctionalOptionsavg_pool1davg_pool2davg_pool3dAvgPool1dFunctionalOptionsAvgPool2dFunctionalOptionsAvgPool3dFunctionalOptionsbatch_normBatchNormFunctionalOptionsbinary_cross_entropybinary_cross_entropy_with_logitsBinaryCrossEntropyFunctionalOptionsBinaryCrossEntropyWithLogitsFunctionalOptionsCeluFunctionalOptionschannel_shuffleconv_transpose1dconv_transpose2dconv_transpose3dconv1dConv1dFunctionalOptionsconv2dConv2dFunctionalOptionsconv3dConv3dFunctionalOptionsConvTranspose1dFunctionalOptionsConvTranspose2dFunctionalOptionsConvTranspose3dFunctionalOptionscosine_embedding_losscosine_similarityCosineEmbeddingLossFunctionalOptionsCosineSimilarityFunctionalOptionscross_entropyCrossEntropyFunctionalOptionsctc_lossCTCLossOptionsdropoutdropout1ddropout2ddropout3dDropoutFunctionalOptionsEluFunctionalOptionsembeddingembedding_bagEmbeddingBagFunctionalOptionsEmbeddingFunctionalOptionsfeature_alpha_dropoutfoldFoldFunctionalOptionsfractional_max_pool2dfractional_max_pool2d_with_indicesfractional_max_pool3dfractional_max_pool3d_with_indicesFractionalMaxPoolFunctionalOptionsgaussian_nll_lossGluFunctionalOptionsgrid_sampleGridSampleFunctionalOptionsgroup_normgrouped_mmGroupedMMFunctionalOptionsGroupNormFunctionalOptionsHardshrinkFunctionalOptionsHardtanhFunctionalOptionshinge_embedding_lossHingeEmbeddingLossFunctionalOptionshuber_lossHuberLossFunctionalOptionsinstance_normInstanceNormFunctionalOptionsinterpolateInterpolateFunctionalOptionskl_divKlDivFunctionalOptionsKLDivOptionsl1_lossL1LossFunctionalOptionslayer_normLayerNormFunctionalOptionsLeakyReluFunctionalOptionslinearlocal_response_normLocalResponseNormFunctionalOptionslog_softmaxlp_pool1dlp_pool2dlp_pool3dLPPoolFunctionalOptionsmargin_ranking_lossMarginRankingLossFunctionalOptionsmax_pool1dmax_pool1d_with_indicesmax_pool2dmax_pool2d_with_indicesmax_pool3dmax_pool3d_with_indicesmax_unpool1dmax_unpool2dmax_unpool3dMaxPool1dFunctionalOptionsMaxPool2dFunctionalOptionsMaxPool3dFunctionalOptionsMaxUnpoolFunctionalOptionsmse_lossMseLossFunctionalOptionsmulti_head_attention_forwardmulti_margin_lossMultiHeadAttentionFunctionalOptionsmultilabel_margin_lossmultilabel_soft_margin_lossnll_lossNllLossFunctionalOptionsnormalizeNormalizeFunctionalOptionsone_hotpadPadFunctionalOptionspairwise_distancePairwiseDistanceFunctionalOptionspdistPdistFunctionalOptionspixel_shufflepixel_unshufflepoisson_nll_lossPoolWithIndicesResultReluFunctionalOptionsrms_normRmsNormFunctionalOptionsRreluFunctionalOptionsscaled_grouped_mmscaled_mmScaledDotProductAttentionFunctionalOptionsScaledGroupedMMFunctionalOptionsScaledMMFunctionalOptionssmooth_l1_lossSmoothL1LossFunctionalOptionssoft_margin_lossSoftMarginLossFunctionalOptionsSoftmaxOptionsSoftminFunctionalOptionsSoftplusFunctionalOptionsSoftshrinkFunctionalOptionstriplet_margin_losstriplet_margin_with_distance_lossTripletMarginLossFunctionalOptionsunfoldUnfoldFunctionalOptionsupsampleupsample_bilinearupsample_nearestUpsampleBilinearOptionsUpsampleNearestOptionsUpsampleOptions
ActivationOptionsAdaptiveAvgPool1dAdaptiveAvgPool2dAdaptiveAvgPool3dAdaptiveLogSoftmaxOptionsAdaptiveLogSoftmaxWithLossAdaptiveMaxPool1dAdaptiveMaxPool1dOptionsAdaptiveMaxPool2dAdaptiveMaxPool2dOptionsAdaptiveMaxPool3dAdaptiveMaxPool3dOptionsadd_moduleAlphaDropoutappendappendapplyAvgPool1dAvgPool1dOptionsAvgPool2dAvgPool2dOptionsAvgPool3dAvgPool3dOptionsBackwardHookBackwardPreHookBatchNorm1dBatchNorm2dBatchNorm3dBatchNormOptionsBCELossBCEWithLogitsLossBilinearBilinearOptionsBufferBufferOptionsBufferRegistrationHookbufferscallCELUCELUOptionsChannelShufflechildrenCircularPad1dCircularPad2dCircularPad3dclearConstantPad1dConstantPad2dConstantPad3dConv1dConv2dConv3dConvOptionsConvTranspose1dConvTranspose2dConvTranspose3dConvTransposeOptionsCosineEmbeddingLossCosineEmbeddingLossOptionsCosineSimilarityCosineSimilarityOptionscreatecreateCrossEntropyLossCTCLossdecodedecodedeleteDropoutDropout1dDropout2dDropout3dDropoutOptionsELUELUOptionsEmbeddingEmbeddingBagEmbeddingBagForwardOptionsEmbeddingBagFromPretrainedOptionsEmbeddingBagOptionsEmbeddingFromPretrainedOptionsEmbeddingOptionsencodeencodeentriesentriesevalextendFeatureAlphaDropoutFlattenFlattenOptionsFoldFoldOptionsforwardforwardforwardforwardforwardforwardforwardforwardforwardforwardforwardforwardforward_with_targetForwardHookForwardPreHookFractionalMaxPool2dFractionalMaxPool3dFractionalMaxPoolOptionsfrom_pretrainedfrom_pretrainedGaussianNLLLossGELUGELUOptionsgenerate_square_subsequent_maskgetgetgetgetgetget_bufferget_parameterget_submoduleGLUGLUOptionsGroupNormGroupNormOptionsGRUGRUCellHardshrinkHardshrinkOptionsHardsigmoidHardswishHardtanhHardtanhOptionshashasHingeEmbeddingLossHingeEmbeddingLossOptionsHuberLossHuberLossOptionsIdentityInstanceNorm1dInstanceNorm2dInstanceNorm3dInstanceNormOptionsis_uninitialized_bufferis_uninitialized_parameteriterator]iterator]iterator]iterator]keyskeysKLDivLossL1LossL1LossOptionsLayerNormLayerNormOptionsLazyBatchNorm1dLazyBatchNorm2dLazyBatchNorm3dLazyConv1dLazyConv2dLazyConv3dLazyConvOptionsLazyConvTranspose1dLazyConvTranspose2dLazyConvTranspose3dLazyConvTransposeOptionsLazyInstanceNorm1dLazyInstanceNorm2dLazyInstanceNorm3dLazyLinearLeakyReLULeakyReLUOptionsLinearLinearOptionsload_state_dictload_state_dictLocalResponseNormLocalResponseNormOptionslog_probLogSigmoidLogSoftmaxLogSoftmaxOptionsLPPool1dLPPool1dOptionsLPPool2dLPPool2dOptionsLPPool3dLPPool3dOptionsLSTMLSTMCellLSTMCellOptionsMarginRankingLossMarginRankingLossOptionsmaterializematerializematerialize_uninitializedmaterialize_uninitializedMaxPool1dMaxPool1dOptionsMaxPool2dMaxPool2dOptionsMaxPool3dMaxPool3dOptionsMaxUnpool1dMaxUnpool1dOptionsMaxUnpool2dMaxUnpool2dOptionsMaxUnpool3dMaxUnpool3dOptionsMishModuleModuleBuffersModuleChildrenModuleDictModuleDictOptionsModuleListModuleListOptionsModuleParametersModuleRegistrationHookmodulesMSELossMSELossOptionsmultihead_attnMultiheadAttentionMultiheadAttentionOptionsMultiheadAttnOptionsMultiLabelMarginLossMultiLabelMarginLossOptionsMultiLabelSoftMarginLossMultiMarginLossnamed_buffersnamed_childrennamed_modulesnamed_parametersNamedModulesOptionsNamedRecurseOptionsNLLLossnum_parametersNumParametersOptionsPairwiseDistancePairwiseDistanceOptionsParameterParameterDictParameterDictOptionsParameterListParameterListOptionsParameterOptionsParameterRegistrationHookparametersPixelShufflePixelUnshufflePoissonNLLLosspoppopPReLUPReLUOptionsRecurseOptionsReflectionPad1dReflectionPad2dReflectionPad3dregister_backward_hookregister_bufferregister_forward_hookregister_forward_pre_hookregister_full_backward_hookregister_full_backward_pre_hookregister_module_backward_hookregister_module_buffer_registration_hookregister_module_forward_hookregister_module_forward_pre_hookregister_module_full_backward_hookregister_module_full_backward_pre_hookregister_module_module_registration_hookregister_module_parameter_registration_hookregister_parameterReLUReLU6RemovableHandleremoveReplicationPad1dReplicationPad2dReplicationPad3dRMSNormRMSNormOptionsRNNRNNBaseRNNBaseOptionsRNNCellRNNCellOptionsRReLURReLUOptionsrunrunSELUSequentialsetsetsetSigmoidSiLUSmoothL1LossSmoothL1LossOptionsSoftMarginLossSoftMarginLossOptionsSoftmaxSoftmax2dSoftmaxOptionsSoftminSoftminOptionsSoftplusSoftplusOptionsSoftshrinkSoftshrinkOptionsSoftsignstate_dictstate_dictStateDictOptionsstepSyncBatchNormTanhTanhshrinkThresholdThresholdOptionstotototrainTrainOptionsTransformerTransformerDecoderTransformerDecoderDecodeOptionsTransformerDecoderLayerTransformerDecoderLayerDecodeOptionsTransformerDecoderLayerOptionsTransformerDecoderOptionsTransformerEncoderTransformerEncoderEncodeOptionsTransformerEncoderLayerTransformerEncoderLayerEncodeOptionsTransformerEncoderLayerOptionsTransformerEncoderOptionsTransformerOptionsTransformerRunOptionsTripletMarginLossTripletMarginWithDistanceLossUnflattenUnfoldUnfoldOptionsUninitializedBufferUninitializedOptionsUninitializedParameterupdateUpsampleUpsamplingBilinear2dUpsamplingNearest2dvaluesvalueszero_gradZeroPad1dZeroPad2dZeroPad3d
absacosacoshAdaptivePool1dShapeAdaptivePool2dShapeaddaddbmmAddbmmOptionsaddcdivAddcdivOptionsaddcmulAddcmulOptionsaddmmAddmmOptionsaddmvAddmvOptionsaddrAddrOptionsadjointallallcloseAllcloseOptionsAlphaBetaOptionsamaxaminaminmaxAminmaxOptionsangleanyapplyOutarangeare_deterministic_algorithms_enabledargmaxargminargsortargwhereas_stridedas_tensorasinasinhAssertNoShapeErrorAssertNotErrorAsStridedOptionsAtat_error_index_out_of_boundsatanatan2atanhatleast_1datleast_2datleast_3dAtShapeautocast_decrement_nestingautocast_increment_nestingautograd_gradient_mismatch_errorautograd_not_registered_errorAutogradConfigAutogradDeviceAutogradDTypeAutogradEntryAutogradHandleAutogradHandleImplAxesRecordBackwardFnbaddbmmBaddbmmOptionsbartlett_windowBaseKernelConfigbatch_dimensions_do_not_match_errorbernoulliBernoulliOptionsBinaryBackwardFnBinaryBroadcastResultBinaryDTypeBinaryKernelConfigCPUBinaryKernelCPUBinaryOpConfigBinaryOpNamesBinaryOpSchemaBinaryOptionsbincountBincountOptionsbitwise_andbitwise_left_shiftbitwise_notbitwise_orbitwise_right_shiftbitwise_xorblackman_windowblock_diagbmmBooleanDTypeRulebroadcast_error_incompatible_dimensionsbroadcast_shapesbroadcast_tensorsbroadcast_toBroadcastShapeBroadcastShapeRulebroadcastShapesbucketizeBucketizeOptionsBufferUsagebuildEinopsErrorbuildErrorMessagecanBroadcastTocartesian_prodcatCatOptionsCatShapeCauchyOptionscdistCdistOptionsceilceluCeluFunctionalOptionschain_matmulCheckShapeErrorCholeskyShapechunkchunk_error_dim_out_of_rangeChunkOptionsclampClampOptionsclear_autocast_cacheclearEinopsCacheclearEinsumCacheclonecolumn_stackcombinationsCombinationsOptionscompiled_with_cxx11_abicomplexconjconj_physicalcontiguousConv1dShapeConv2dShapeConv3dShapeConvTranspose2dShapecopysigncorrcoefcoscoshcount_nonzeroCountNonzeroOptionscovcoverage_reportcoverageReportCoverageReportCovOptionsCPUForwardFnCPUKernelConfigCPUKernelEntryCPUOnlyResultCPUTensorDatacreateCumExtremeResultcreateTorchCreationOpSchemaCumExtremeResultcummaxcummincumprodCumShapecumsumcumulative_trapezoidCumulativeOptionsCumulativeOptionsWithDimdeg2raddetachDeterministicOptionsDetShapeDevicedevice_error_requiresDeviceBufferDeviceCapabilitiesDeviceCheckedResultDeviceConfigDeviceContextDeviceEntryDeviceHandleDeviceInputDeviceOptionsDeviceRegistryDeviceTypediagdiag_embedDiagEmbedOptionsdiagflatDiagflatOptionsDiagFlatOptionsdiagonal_scatterDiagonalOptionsDiagonalScatterOptionsDiagOptionsDiagShapediffDiffOptionsdigammadimension_error_out_of_rangeDispatchConfigdistDistOptionsdivdotDotShapeRuleDoubleDoubleDimdropoutDropoutFunctionalOptionsdsplitdstackDTypedtype_already_registered_errordtype_components_mismatch_errordtype_not_found_errorDTypeComponentsDTypeConfigDTypeCoverageReportDTypeDisplayConfigDTypeEntryDTypeHandleDTypeHandleImplDTypeInfoDTypeRegistryDTypeRuleDTypeSerializationConfigDynamicShapeEigShapeeinops_error_ambiguous_decompositioneinops_error_anonymous_in_outputeinops_error_dimension_mismatcheinops_error_invalid_patterneinops_error_reduce_undefined_outputeinops_error_repeat_missing_sizeeinops_error_undefined_axiseinsumeinsum_error_dimension_mismatcheinsum_error_index_out_of_rangeeinsum_error_invalid_equationeinsum_error_invalid_sublist_elementeinsum_error_operand_count_mismatcheinsum_error_subscript_rank_mismatcheinsum_error_unknown_output_indexEinsumOptionsEinsumOutputShapeEllipsiseluelu_EluFunctionalOptionsembedding_bag_error_requires_2d_inputemptyempty_cacheempty_likeeqequalerferfcerfinvexpexp2expandexpand_asexpand_error_incompatibleExpandShapeexpm1ExponentialOptionseyeEyeOptionsfftFFTOptionsfindKernelWithPredicatefindSimilarPatternsflattenFlattenOptionsFlattenShapeflipflip_error_dim_out_of_rangefliplrFlipShapeflipudfloat_powerFloatDTypeRulefloorfloor_dividefmaxfminfmodformatEquationErrorformatShapefracfrexpfrombufferfullfull_likefunction_already_registered_errorFunctionConfigFunctionEntryFunctionHandlegathergather_error_dim_out_of_rangeGatherShapegcdgegeluGeometricOptionsget_autocast_cpu_dtypeget_autocast_gpu_dtypeget_autocast_ipu_dtypeget_autocast_xla_dtypeget_default_deviceget_default_dtypeget_deterministic_debug_modeget_device_configget_device_contextget_device_moduleget_dtype_infoget_file_pathget_float32_matmul_precisionget_num_interop_threadsget_num_threadsget_op_infoget_printoptionsget_real_dtypeget_rng_stategetAutogradgetDTypegetEinopsCacheSizegetEinsumCacheSizegetFunctiongetKernelgetMethodgetOpInfoGetOpKindGetOpSchemagetScalarKernelgluGluFunctionalOptionsGradContextGradFnGradientsForgtHalfHalfDimhamming_windowhann_windowhardshrinkhardsigmoidhardswishhardtanhhardtanh_HardtanhFunctionalOptionshas_autogradhas_devicehas_dtypehas_kernelhasAutogradhasDTypehasFunctionhasKernelhasMethodhasScalarKernelHasShapeErrorheavisidehistcHistcOptionshistogramHistogramOptionsHistogramResulthsplithstackhypoti0IdentityShapeifftimagindex_addindex_copyindex_fillindex_putindex_reduceindex_selectindex_select_error_dim_out_of_rangeIndexPutOptionsIndexSelectShapeIndexSpecIndicesOptionsIndicesSpecinitialize_deviceInputsForInsertDiminvalid_config_errorinverseInverseShapeirfftis_anomaly_check_nan_enabledis_anomaly_enabledis_autocast_cache_enabledis_autocast_cpu_enabledis_autocast_ipu_enabledis_autocast_xla_enabledis_complexis_complex_dtypeis_cpu_only_modeis_deterministic_algorithms_warn_only_enabledis_floating_pointis_floating_point_dtypeis_inference_mode_enabledis_nonzerois_tensoris_warn_always_enabledis_webgpu_availableIs2DIsAtLeast1DIsBinaryOpIsBinaryOpNameiscloseIscloseOptionsisfiniteisinisinfisnanisneginfisposinfisrealIsReductionOpIsReductionOpNameIsRegistryErrorIsShapeErroristftISTFTOptionsIsUnaryOpIsUnaryOpNameitem_error_not_scalarItemResultkaiser_windowKaiserWindowOptionskernel_not_registered_errorkernel_signature_mismatch_errorKernelConfigKernelConfigWebGPUKernelEntryKernelHandleKernelInfoKernelPredicateKernelRegistryKernelWebGPUkronkthvalueKthvalueOptionslcmldexpleleaky_reluleaky_relu_LeakyReluFunctionalOptionslerplevenshteinDistancelgammalinalg_error_not_square_matrixlinalg_error_requires_2dlinalg_error_requires_at_least_2dlinearlinspacelist_custom_deviceslist_custom_dtypeslist_deviceslist_dtypeslist_functionslist_kernelslist_methodslist_opslistCustomDTypeslistDTypeslistFunctionslistKernelsListKernelsOptionslistMethodslistOpsListOpsOptionsloglog_softmaxlog10log1plog2logaddexplogaddexp2logcumsumexplogical_andlogical_notlogical_orlogical_xorLogitOptionsLogNormalOptionsLogOptionslogsigmoidlogspacelogsumexpLogsumexpOptionsltLUShapeLuSolveOptionsmasked_fillmasked_selectmasked_select_asyncMaskSpecmatmulmatmul_error_inner_dimensions_do_not_matchMatmul2DShapeMatmulShapeMatmulShapeRuleMatrixTransposeShapemaxmaximummeanmedianmemory_statsmemory_summarymeshgridmethod_already_registered_errormethod_dtype_not_supported_errorMethodConfigMethodEntryMethodHandleminminimummishmmMMShapeRulemodemovedimmsortmulmultinomialmultinomial_asyncMultinomialAsyncOptionsMultinomialOptionsMultiplyBymvMVShapeRulenan_to_numnanmeannanmediannanquantileNanReductionOptionsnansumNanToNumOptionsnarrownarrow_copynarrow_error_length_exceeds_boundsnarrow_error_start_out_of_boundsNarrowShapeneneedsBroadcastnegNegativeDimnextafternonzeroNonzeroOptionsnormnormalNormalOptionsNormOptionsnumelonesones_likeop_kind_mismatch_errorop_not_found_errorOpCoverageEntryOpInfoOpKindOpNameOpSchemaOpSchemasouterOuterShapepackPackShapepermutepermute_error_dimension_count_mismatchPermuteShapepoissonpolarPool1dShapePool2dShapePool3dShapepositivepowpreluPrintOptionsprodprofiler_allow_cudagraph_cupti_lazy_reinit_cuda12promote_typesPromoteDTypeRulePutOptionsquantileQuantileOptionsrad2degrandrand_likerandintrandint_likeRandintLikeOptionsRandintOptionsrandnrandn_likeRandomLikeOptionsRandomOptionsrandpermRangeSpecRankravelrealrearrangeRearrangeOptionsRearrangeShapereciprocalreduceReduceOperationReduceOptionsReduceShapeReductionKernelConfigCPUReductionKernelCPUReductionOpNamesReductionOpSchemaReductionOptionsReductionShapeRuleregister_backwardregister_deviceregister_dtyperegister_forwardregister_functionregister_methodregister_scalar_forwardregisterAutogradRegisterBackwardOptionsregisterBinaryOpregisterDTypeRegisterDTypeOptionsRegisteredDTyperegisterFunctionRegisterFunctionOptionsregisterKernelRegisterKernelOptionsregisterMethodRegisterMethodOptionsregisterScalarKernelregisterUnaryOpregistration_failed_errorrelurelu_relu6ReluFunctionalOptionsremainderRemoveDimrepeatrepeat_interleaveRepeatInterleaveOptionsRepeatOptionsRepeatShapeReplaceDimrequireWebGPUreset_peak_memory_statsreshapeReshapeShaperesult_typerfftrollRollOptionsrot90Rot90Optionsroundrrelurrelu_RreluFunctionalOptionsrsqrtSafeExpandShapeSameDTypeRuleSameShapeRuleSaveForBackwardScalarCPUForwardFnScalarCPUKernelConfigScalarKernelEntryScalarKernelHandleScalarWebGPUKernelConfigScaleDimscatterscatter_addscatter_add_scatter_error_dim_out_of_rangescatter_reducescatter_reduce_ScatterReduceOptionsScatterShapesearchsortedSearchSortedOptionsselectselect_error_index_out_of_boundsselect_scatterSelectShapeseluset_default_deviceset_default_tensor_typeset_deterministic_debug_modeset_float32_matmul_precisionset_printoptionsset_warn_alwaysSetupContextFnShapeShapeCheckedResultShapedTensorShapeErrorMessageShapeOpSchemaShapeRulesigmoidsignsignbitsilusinsincsinhSizeOptionsslice_error_out_of_boundsslice_scatterSliceOptionsSliceScatterOptionsSliceShapeSliceSpecsoftmaxsoftmax_error_dim_out_of_rangeSoftmaxShapesoftminSoftminFunctionalOptionssoftplusSoftplusFunctionalOptionssoftshrinksoftsignsortSortOptionssplitsplit_error_dim_out_of_rangeSplitOptionssqrtsquaresqueezeSqueezeOptionsSqueezeShapestackStackOptionsStackShapestdstd_meanStdVarMeanOptionsStdVarOptionsstftSTFTOptionsStrideOptionssubSublistSublistElementSubscriptIndexsumSVDShapeswapaxessym_floatsym_intsym_notttaketake_along_dimTakeAlongDimOptionstantanhtanhshrinktensortensor_splitTensorCreatorTensorDatatensordotTensordotOptionsTensorLikeTensorMetaTensorOptionsTensorStoragethresholdthreshold_tileTileShapeToOptionstopkTopkOptionsTorchtraceTraceShapetransposetranspose_dims_error_out_of_rangetranspose_error_requires_2d_tensorTransposeDimsShapeTransposeDimsShapeCheckedTransposeShapetrapezoidTrapezoidOptionsTriangularOptionstriltril_indicesTriOptionsTripletriutriu_indicestrue_dividetruncTupleOfLengthTypedArrayTypedArrayForTypedStorageTypeOptionsUnaryBackwardFnUnaryDTypeUnaryKernelConfigCPUUnaryKernelCPUUnaryOpConfigUnaryOpFnUnaryOpNamesUnaryOpParamsUnaryOpSchemaUnaryOptionsunbindunbind_error_dim_out_of_rangeUnbindOptionsunflattenUniformOptionsuniqueunique_consecutiveUniqueConsecutiveOptionsUniqueOptionsunpackUnpackShapeunravel_indexunregister_deviceunsqueezeUnsqueezeOptionsUnsqueezeShapeuse_deterministic_algorithmsValidateBatchedSquareMatrixValidateChunkDimValidatedEinsumShapevalidateDeviceValidateDeviceValidatedRearrangeShapeValidatedReduceShapeValidatedRepeatShapevalidateDTypeValidateEinsumValidateOperandCountValidateRanksValidateScalarValidateSplitDimValidateSquareMatrixValidateUnbindDimValueOptionsvar_var_meanvdotviewview_as_complexview_as_realvmapvsplitvstackWebGPUKernelConfigWebGPUOnlyResultWebGPUTensorDatawhereWindowOptionsxlogyzeroszeros_like
torch.js· 2026
LegalTerms of UsePrivacy Policy
/
/
  1. docs
  2. torch.js
  3. torch
  4. nn
  5. functional
  6. grouped_mm

torch.nn.functional.grouped_mm

function grouped_mm(input_tensor_list: Tensor[], mat2_tensor_list: Tensor[], options?: GroupedMMFunctionalOptions): Tensor[]function grouped_mm(input: Tensor, mat2: Tensor, options?: GroupedMMFunctionalOptions): Tensor

Performs grouped (multi-headed) matrix multiplication with optional bias and dtype casting.

Computes multiple independent matrix multiplications on lists of matrices in parallel. Each input matrix is multiplied with corresponding weight matrix independently, useful for multi-head attention, ensemble operations, and grouped computations. Essential for:

  • Multi-head attention: Computing attention heads in parallel
  • Grouped linear layers: Processing multiple feature groups independently
  • Ensemble inference: Running multiple models/heads simultaneously
  • Mixed-precision inference: Different dtypes for different heads
  • Distributed computation: Processing groups separately then combining
  • Efficient batch operations: More flexible than standard batched matmul
  • Conditional computation: Different weights for different groups

Operation: For each i: output[i] = input[i] @ weight[i] + bias[i] (optional)

All operations are independent and can be parallelized. Output dtype can be optionally cast to a different type for memory efficiency.

  • List lengths must match: All three lists (input, weights, bias) must have same length
  • Inner dimensions must match: input[i].shape[-1] must equal weights[i].shape[0]
  • Optional bias: Biases can be null for some/all operations (sparse application)
  • Independent operations: Each matmul is independent and can be parallelized
  • Dtype flexibility: Can output different dtype than computation for efficiency
  • Gradient propagation: Gradients flow back to all inputs, weights, and biases
  • Multi-head friendly: Natural fit for multi-head attention architectures
  • List length mismatch: Will error if input and weight lists differ in length
  • Bias length mismatch: Bias list length must match input list if provided
  • Dimension mismatch: Inner dimensions of input/weight must match for matmul
  • dtype casting: Out_dtype casting happens after computation (precision loss possible)

Parameters

input_tensor_listTensor[]
List of n input tensors with shapes: - 1D: [k] → matmul with [k, m] → [m] - 2D: [p, k] → matmul with [k, m] → [p, m] - Higher: [*, k] → matmul with [k, m] → [*, m]
mat2_tensor_listTensor[]
List of n weight tensors with shapes: - [k, m] for corresponding input shape [*, k] → output [*, m] - Must be same length as input_tensor_list - Inner dimension k must match input's last dimension
optionsGroupedMMFunctionalOptionsoptional

Returns

Tensor[]– List of output tensors with shape [*, m] for corresponding inputs

Examples

// Multi-head attention: process heads independently
const head_size = 64;
const num_heads = 8;
const seq_len = 10;
const batch_size = 32;

// Create inputs and weights for each head
const queries = [];    // 8 heads, each [batch*seq, 64]
const key_weights = []; // 8 heads, each [64, 64]
for (let h = 0; h < num_heads; h++) {
  queries.push(torch.randn(batch_size * seq_len, head_size));
  key_weights.push(torch.randn(head_size, head_size));
}

const head_outputs = torch.nn.functional.grouped_mm(queries, key_weights);
// head_outputs[i]: [batch*seq, 64] for each head
// Grouped linear layer: different weights for different feature groups
const input = [
  torch.randn(100, 50),  // Group 1: 100 samples, 50 features
  torch.randn(100, 75),  // Group 2: 100 samples, 75 features
  torch.randn(100, 30)   // Group 3: 100 samples, 30 features
];

const weights = [
  torch.randn(50, 64),   // Project group 1 to 64 dims
  torch.randn(75, 128),  // Project group 2 to 128 dims
  torch.randn(30, 32)    // Project group 3 to 32 dims
];

const biases = [
  torch.randn(64),   // Bias for group 1
  torch.randn(128),  // Bias for group 2
  torch.randn(32)    // Bias for group 3
];

const outputs = torch.nn.functional.grouped_mm(input, weights, biases);
// outputs[i]: [100, output_dim_i] for each group
// Ensemble inference: different models/weights for same inputs
const input_batch = torch.randn(batch_size, feature_dim);
const model_weights = [
  torch.randn(feature_dim, output_dim),  // Model 1 weights
  torch.randn(feature_dim, output_dim),  // Model 2 weights
  torch.randn(feature_dim, output_dim)   // Model 3 weights
];

// Duplicate input for each model (or use broadcasting)
const inputs = [input_batch, input_batch, input_batch];

const predictions = torch.nn.functional.grouped_mm(inputs, model_weights);
// predictions: [model_output1, model_output2, model_output3]
const ensemble_output = torch.stack(predictions, 0).mean(0);
// Average ensemble predictions
// Mixed-precision: compute in float32, output int8
const inputs = [torch.randn(100, 64), torch.randn(100, 64)];
const weights = [torch.randn(64, 32), torch.randn(64, 32)];

const outputs = torch.nn.functional.grouped_mm(
  inputs,
  weights,
  null,      // No bias
  'int8'     // Cast output to int8 for storage efficiency
);
// outputs: [int8 tensors] - reduced memory footprint
// Grouped convolution: treat as grouped linear for channel operations
const group_size = 32;
const num_groups = 4;
const groups = [];
const group_weights = [];

for (let g = 0; g < num_groups; g++) {
  groups.push(torch.randn(batch_size, group_size));
  group_weights.push(torch.randn(group_size, output_per_group));
}

const group_outputs = torch.nn.functional.grouped_mm(groups, group_weights);
const output = torch.cat(group_outputs, 1);  // Concatenate groups
// Grouped operation is more efficient than single large matmul

See Also

  • [PyTorch torch._int_mm (internal) / torch.nn.functional.grouped_mm](https://pytorch.org/docs/stable/generated/torch._int_mm (internal) / torch.nn.functional.grouped_mm.html)
  • scaled_grouped_mm - Batched grouped matmul with scaling
  • scaled_mm - Scaled matrix multiplication with advanced options
  • cat - Concatenate group outputs
  • stack - Stack group outputs
  • Tensor.matmul - Single matrix multiplication
Previous
group_norm
Next
GroupedMMFunctionalOptions