Skip to main content
torch.js has not been released yet.
torch.js logotorch.js logotorch.js
PlaygroundContact
Login
Documentation
IntroductionType SafetyTensor ExpressionsTensor IndexingEinsumEinopsAutogradTraining a ModelProfiling & MemoryPyTorch MigrationBest PracticesRuntimesPerformancePyTorch CompatibilityBenchmarksDType Coverage
ActivationOptionsAdaptiveAvgPool1dAdaptiveAvgPool2dAdaptiveAvgPool3dAdaptiveLogSoftmaxOptionsAdaptiveLogSoftmaxWithLossAdaptiveMaxPool1dAdaptiveMaxPool1dOptionsAdaptiveMaxPool2dAdaptiveMaxPool2dOptionsAdaptiveMaxPool3dAdaptiveMaxPool3dOptionsadd_moduleAlphaDropoutappendappendapplyAvgPool1dAvgPool1dOptionsAvgPool2dAvgPool2dOptionsAvgPool3dAvgPool3dOptionsBackwardHookBackwardPreHookBatchNorm1dBatchNorm2dBatchNorm3dBatchNormOptionsBCELossBCEWithLogitsLossBilinearBilinearOptionsBufferBufferOptionsBufferRegistrationHookbufferscallCELUCELUOptionsChannelShufflechildrenCircularPad1dCircularPad2dCircularPad3dclearConstantPad1dConstantPad2dConstantPad3dConv1dConv2dConv3dConvOptionsConvTranspose1dConvTranspose2dConvTranspose3dConvTransposeOptionsCosineEmbeddingLossCosineEmbeddingLossOptionsCosineSimilarityCosineSimilarityOptionscreatecreateCrossEntropyLossCTCLossdecodedecodedeleteDropoutDropout1dDropout2dDropout3dDropoutOptionsELUELUOptionsEmbeddingEmbeddingBagEmbeddingBagForwardOptionsEmbeddingBagFromPretrainedOptionsEmbeddingBagOptionsEmbeddingFromPretrainedOptionsEmbeddingOptionsencodeencodeentriesentriesevalextendFeatureAlphaDropoutFlattenFlattenOptionsFoldFoldOptionsforwardforwardforwardforwardforwardforwardforwardforwardforwardforwardforwardforwardforward_with_targetForwardHookForwardPreHookFractionalMaxPool2dFractionalMaxPool3dFractionalMaxPoolOptionsfrom_pretrainedfrom_pretrainedGaussianNLLLossGELUGELUOptionsgenerate_square_subsequent_maskgetgetgetgetgetget_bufferget_parameterget_submoduleGLUGLUOptionsGroupNormGroupNormOptionsGRUGRUCellHardshrinkHardshrinkOptionsHardsigmoidHardswishHardtanhHardtanhOptionshashasHingeEmbeddingLossHingeEmbeddingLossOptionsHuberLossHuberLossOptionsIdentityInstanceNorm1dInstanceNorm2dInstanceNorm3dInstanceNormOptionsis_uninitialized_bufferis_uninitialized_parameteriterator]iterator]iterator]iterator]keyskeysKLDivLossL1LossL1LossOptionsLayerNormLayerNormOptionsLazyBatchNorm1dLazyBatchNorm2dLazyBatchNorm3dLazyConv1dLazyConv2dLazyConv3dLazyConvOptionsLazyConvTranspose1dLazyConvTranspose2dLazyConvTranspose3dLazyConvTransposeOptionsLazyInstanceNorm1dLazyInstanceNorm2dLazyInstanceNorm3dLazyLinearLeakyReLULeakyReLUOptionsLinearLinearOptionsload_state_dictload_state_dictLocalResponseNormLocalResponseNormOptionslog_probLogSigmoidLogSoftmaxLogSoftmaxOptionsLPPool1dLPPool1dOptionsLPPool2dLPPool2dOptionsLPPool3dLPPool3dOptionsLSTMLSTMCellLSTMCellOptionsMarginRankingLossMarginRankingLossOptionsmaterializematerializematerialize_uninitializedmaterialize_uninitializedMaxPool1dMaxPool1dOptionsMaxPool2dMaxPool2dOptionsMaxPool3dMaxPool3dOptionsMaxUnpool1dMaxUnpool1dOptionsMaxUnpool2dMaxUnpool2dOptionsMaxUnpool3dMaxUnpool3dOptionsMishModuleModuleBuffersModuleChildrenModuleDictModuleDictOptionsModuleListModuleListOptionsModuleParametersModuleRegistrationHookmodulesMSELossMSELossOptionsmultihead_attnMultiheadAttentionMultiheadAttentionOptionsMultiheadAttnOptionsMultiLabelMarginLossMultiLabelMarginLossOptionsMultiLabelSoftMarginLossMultiMarginLossnamed_buffersnamed_childrennamed_modulesnamed_parametersNamedModulesOptionsNamedRecurseOptionsNLLLossnum_parametersNumParametersOptionsPairwiseDistancePairwiseDistanceOptionsParameterParameterDictParameterDictOptionsParameterListParameterListOptionsParameterOptionsParameterRegistrationHookparametersPixelShufflePixelUnshufflePoissonNLLLosspoppopPReLUPReLUOptionsRecurseOptionsReflectionPad1dReflectionPad2dReflectionPad3dregister_backward_hookregister_bufferregister_forward_hookregister_forward_pre_hookregister_full_backward_hookregister_full_backward_pre_hookregister_module_backward_hookregister_module_buffer_registration_hookregister_module_forward_hookregister_module_forward_pre_hookregister_module_full_backward_hookregister_module_full_backward_pre_hookregister_module_module_registration_hookregister_module_parameter_registration_hookregister_parameterReLUReLU6RemovableHandleremoveReplicationPad1dReplicationPad2dReplicationPad3dRMSNormRMSNormOptionsRNNRNNBaseRNNBaseOptionsRNNCellRNNCellOptionsRReLURReLUOptionsrunrunSELUSequentialsetsetsetSigmoidSiLUSmoothL1LossSmoothL1LossOptionsSoftMarginLossSoftMarginLossOptionsSoftmaxSoftmax2dSoftmaxOptionsSoftminSoftminOptionsSoftplusSoftplusOptionsSoftshrinkSoftshrinkOptionsSoftsignstate_dictstate_dictStateDictOptionsstepSyncBatchNormTanhTanhshrinkThresholdThresholdOptionstotototrainTrainOptionsTransformerTransformerDecoderTransformerDecoderDecodeOptionsTransformerDecoderLayerTransformerDecoderLayerDecodeOptionsTransformerDecoderLayerOptionsTransformerDecoderOptionsTransformerEncoderTransformerEncoderEncodeOptionsTransformerEncoderLayerTransformerEncoderLayerEncodeOptionsTransformerEncoderLayerOptionsTransformerEncoderOptionsTransformerOptionsTransformerRunOptionsTripletMarginLossTripletMarginWithDistanceLossUnflattenUnfoldUnfoldOptionsUninitializedBufferUninitializedOptionsUninitializedParameterupdateUpsampleUpsamplingBilinear2dUpsamplingNearest2dvaluesvalueszero_gradZeroPad1dZeroPad2dZeroPad3d
absacosacoshAdaptivePool1dShapeAdaptivePool2dShapeaddaddbmmAddbmmOptionsaddcdivAddcdivOptionsaddcmulAddcmulOptionsaddmmAddmmOptionsaddmvAddmvOptionsaddrAddrOptionsadjointallallcloseAllcloseOptionsAlphaBetaOptionsamaxaminaminmaxAminmaxOptionsangleanyapplyOutarangeare_deterministic_algorithms_enabledargmaxargminargsortargwhereas_stridedas_tensorasinasinhAssertNoShapeErrorAssertNotErrorAsStridedOptionsAtat_error_index_out_of_boundsatanatan2atanhatleast_1datleast_2datleast_3dAtShapeautocast_decrement_nestingautocast_increment_nestingautograd_gradient_mismatch_errorautograd_not_registered_errorAutogradConfigAutogradDeviceAutogradDTypeAutogradEntryAutogradHandleAutogradHandleImplAxesRecordBackwardFnbaddbmmBaddbmmOptionsbartlett_windowBaseKernelConfigbatch_dimensions_do_not_match_errorbernoulliBernoulliOptionsBinaryBackwardFnBinaryBroadcastResultBinaryDTypeBinaryKernelConfigCPUBinaryKernelCPUBinaryOpConfigBinaryOpNamesBinaryOpSchemaBinaryOptionsbincountBincountOptionsbitwise_andbitwise_left_shiftbitwise_notbitwise_orbitwise_right_shiftbitwise_xorblackman_windowblock_diagbmmBooleanDTypeRulebroadcast_error_incompatible_dimensionsbroadcast_shapesbroadcast_tensorsbroadcast_toBroadcastShapeBroadcastShapeRulebroadcastShapesbucketizeBucketizeOptionsBufferUsagebuildEinopsErrorbuildErrorMessagecanBroadcastTocartesian_prodcatCatOptionsCatShapeCauchyOptionscdistCdistOptionsceilceluCeluFunctionalOptionschain_matmulCheckShapeErrorCholeskyShapechunkchunk_error_dim_out_of_rangeChunkOptionsclampClampOptionsclear_autocast_cacheclearEinopsCacheclearEinsumCacheclonecolumn_stackcombinationsCombinationsOptionscompiled_with_cxx11_abicomplexconjconj_physicalcontiguousConv1dShapeConv2dShapeConv3dShapeConvTranspose2dShapecopysigncorrcoefcoscoshcount_nonzeroCountNonzeroOptionscovcoverage_reportcoverageReportCoverageReportCovOptionsCPUForwardFnCPUKernelConfigCPUKernelEntryCPUOnlyResultCPUTensorDatacreateCumExtremeResultcreateTorchCreationOpSchemaCumExtremeResultcummaxcummincumprodCumShapecumsumcumulative_trapezoidCumulativeOptionsCumulativeOptionsWithDimdeg2raddetachDeterministicOptionsDetShapeDevicedevice_error_requiresDeviceBufferDeviceCapabilitiesDeviceCheckedResultDeviceConfigDeviceContextDeviceEntryDeviceHandleDeviceInputDeviceOptionsDeviceRegistryDeviceTypediagdiag_embedDiagEmbedOptionsdiagflatDiagflatOptionsDiagFlatOptionsdiagonal_scatterDiagonalOptionsDiagonalScatterOptionsDiagOptionsDiagShapediffDiffOptionsdigammadimension_error_out_of_rangeDispatchConfigdistDistOptionsdivdotDotShapeRuleDoubleDoubleDimdropoutDropoutFunctionalOptionsdsplitdstackDTypedtype_already_registered_errordtype_components_mismatch_errordtype_not_found_errorDTypeComponentsDTypeConfigDTypeCoverageReportDTypeDisplayConfigDTypeEntryDTypeHandleDTypeHandleImplDTypeInfoDTypeRegistryDTypeRuleDTypeSerializationConfigDynamicShapeEigShapeeinops_error_ambiguous_decompositioneinops_error_anonymous_in_outputeinops_error_dimension_mismatcheinops_error_invalid_patterneinops_error_reduce_undefined_outputeinops_error_repeat_missing_sizeeinops_error_undefined_axiseinsumeinsum_error_dimension_mismatcheinsum_error_index_out_of_rangeeinsum_error_invalid_equationeinsum_error_invalid_sublist_elementeinsum_error_operand_count_mismatcheinsum_error_subscript_rank_mismatcheinsum_error_unknown_output_indexEinsumOptionsEinsumOutputShapeEllipsiseluelu_EluFunctionalOptionsembedding_bag_error_requires_2d_inputemptyempty_cacheempty_likeeqequalerferfcerfinvexpexp2expandexpand_asexpand_error_incompatibleExpandShapeexpm1ExponentialOptionseyeEyeOptionsfftFFTOptionsfindKernelWithPredicatefindSimilarPatternsflattenFlattenOptionsFlattenShapeflipflip_error_dim_out_of_rangefliplrFlipShapeflipudfloat_powerFloatDTypeRulefloorfloor_dividefmaxfminfmodformatEquationErrorformatShapefracfrexpfrombufferfullfull_likefunction_already_registered_errorFunctionConfigFunctionEntryFunctionHandlegathergather_error_dim_out_of_rangeGatherShapegcdgegeluGeometricOptionsget_autocast_cpu_dtypeget_autocast_gpu_dtypeget_autocast_ipu_dtypeget_autocast_xla_dtypeget_default_deviceget_default_dtypeget_deterministic_debug_modeget_device_configget_device_contextget_device_moduleget_dtype_infoget_file_pathget_float32_matmul_precisionget_num_interop_threadsget_num_threadsget_op_infoget_printoptionsget_real_dtypeget_rng_stategetAutogradgetDTypegetEinopsCacheSizegetEinsumCacheSizegetFunctiongetKernelgetMethodgetOpInfoGetOpKindGetOpSchemagetScalarKernelgluGluFunctionalOptionsGradContextGradFnGradientsForgtHalfHalfDimhamming_windowhann_windowhardshrinkhardsigmoidhardswishhardtanhhardtanh_HardtanhFunctionalOptionshas_autogradhas_devicehas_dtypehas_kernelhasAutogradhasDTypehasFunctionhasKernelhasMethodhasScalarKernelHasShapeErrorheavisidehistcHistcOptionshistogramHistogramOptionsHistogramResulthsplithstackhypoti0IdentityShapeifftimagindex_addindex_copyindex_fillindex_putindex_reduceindex_selectindex_select_error_dim_out_of_rangeIndexPutOptionsIndexSelectShapeIndexSpecIndicesOptionsIndicesSpecinitialize_deviceInputsForInsertDiminvalid_config_errorinverseInverseShapeirfftis_anomaly_check_nan_enabledis_anomaly_enabledis_autocast_cache_enabledis_autocast_cpu_enabledis_autocast_ipu_enabledis_autocast_xla_enabledis_complexis_complex_dtypeis_cpu_only_modeis_deterministic_algorithms_warn_only_enabledis_floating_pointis_floating_point_dtypeis_inference_mode_enabledis_nonzerois_tensoris_warn_always_enabledis_webgpu_availableIs2DIsAtLeast1DIsBinaryOpIsBinaryOpNameiscloseIscloseOptionsisfiniteisinisinfisnanisneginfisposinfisrealIsReductionOpIsReductionOpNameIsRegistryErrorIsShapeErroristftISTFTOptionsIsUnaryOpIsUnaryOpNameitem_error_not_scalarItemResultkaiser_windowKaiserWindowOptionskernel_not_registered_errorkernel_signature_mismatch_errorKernelConfigKernelConfigWebGPUKernelEntryKernelHandleKernelInfoKernelPredicateKernelRegistryKernelWebGPUkronkthvalueKthvalueOptionslcmldexpleleaky_reluleaky_relu_LeakyReluFunctionalOptionslerplevenshteinDistancelgammalinalg_error_not_square_matrixlinalg_error_requires_2dlinalg_error_requires_at_least_2dlinearlinspacelist_custom_deviceslist_custom_dtypeslist_deviceslist_dtypeslist_functionslist_kernelslist_methodslist_opslistCustomDTypeslistDTypeslistFunctionslistKernelsListKernelsOptionslistMethodslistOpsListOpsOptionsloglog_softmaxlog10log1plog2logaddexplogaddexp2logcumsumexplogical_andlogical_notlogical_orlogical_xorLogitOptionsLogNormalOptionsLogOptionslogsigmoidlogspacelogsumexpLogsumexpOptionsltLUShapeLuSolveOptionsmasked_fillmasked_selectmasked_select_asyncMaskSpecmatmulmatmul_error_inner_dimensions_do_not_matchMatmul2DShapeMatmulShapeMatmulShapeRuleMatrixTransposeShapemaxmaximummeanmedianmemory_statsmemory_summarymeshgridmethod_already_registered_errormethod_dtype_not_supported_errorMethodConfigMethodEntryMethodHandleminminimummishmmMMShapeRulemodemovedimmsortmulmultinomialmultinomial_asyncMultinomialAsyncOptionsMultinomialOptionsMultiplyBymvMVShapeRulenan_to_numnanmeannanmediannanquantileNanReductionOptionsnansumNanToNumOptionsnarrownarrow_copynarrow_error_length_exceeds_boundsnarrow_error_start_out_of_boundsNarrowShapeneneedsBroadcastnegNegativeDimnextafternonzeroNonzeroOptionsnormnormalNormalOptionsNormOptionsnumelonesones_likeop_kind_mismatch_errorop_not_found_errorOpCoverageEntryOpInfoOpKindOpNameOpSchemaOpSchemasouterOuterShapepackPackShapepermutepermute_error_dimension_count_mismatchPermuteShapepoissonpolarPool1dShapePool2dShapePool3dShapepositivepowpreluPrintOptionsprodprofiler_allow_cudagraph_cupti_lazy_reinit_cuda12promote_typesPromoteDTypeRulePutOptionsquantileQuantileOptionsrad2degrandrand_likerandintrandint_likeRandintLikeOptionsRandintOptionsrandnrandn_likeRandomLikeOptionsRandomOptionsrandpermRangeSpecRankravelrealrearrangeRearrangeOptionsRearrangeShapereciprocalreduceReduceOperationReduceOptionsReduceShapeReductionKernelConfigCPUReductionKernelCPUReductionOpNamesReductionOpSchemaReductionOptionsReductionShapeRuleregister_backwardregister_deviceregister_dtyperegister_forwardregister_functionregister_methodregister_scalar_forwardregisterAutogradRegisterBackwardOptionsregisterBinaryOpregisterDTypeRegisterDTypeOptionsRegisteredDTyperegisterFunctionRegisterFunctionOptionsregisterKernelRegisterKernelOptionsregisterMethodRegisterMethodOptionsregisterScalarKernelregisterUnaryOpregistration_failed_errorrelurelu_relu6ReluFunctionalOptionsremainderRemoveDimrepeatrepeat_interleaveRepeatInterleaveOptionsRepeatOptionsRepeatShapeReplaceDimrequireWebGPUreset_peak_memory_statsreshapeReshapeShaperesult_typerfftrollRollOptionsrot90Rot90Optionsroundrrelurrelu_RreluFunctionalOptionsrsqrtSafeExpandShapeSameDTypeRuleSameShapeRuleSaveForBackwardScalarCPUForwardFnScalarCPUKernelConfigScalarKernelEntryScalarKernelHandleScalarWebGPUKernelConfigScaleDimscatterscatter_addscatter_add_scatter_error_dim_out_of_rangescatter_reducescatter_reduce_ScatterReduceOptionsScatterShapesearchsortedSearchSortedOptionsselectselect_error_index_out_of_boundsselect_scatterSelectShapeseluset_default_deviceset_default_tensor_typeset_deterministic_debug_modeset_float32_matmul_precisionset_printoptionsset_warn_alwaysSetupContextFnShapeShapeCheckedResultShapedTensorShapeErrorMessageShapeOpSchemaShapeRulesigmoidsignsignbitsilusinsincsinhSizeOptionsslice_error_out_of_boundsslice_scatterSliceOptionsSliceScatterOptionsSliceShapeSliceSpecsoftmaxsoftmax_error_dim_out_of_rangeSoftmaxShapesoftminSoftminFunctionalOptionssoftplusSoftplusFunctionalOptionssoftshrinksoftsignsortSortOptionssplitsplit_error_dim_out_of_rangeSplitOptionssqrtsquaresqueezeSqueezeOptionsSqueezeShapestackStackOptionsStackShapestdstd_meanStdVarMeanOptionsStdVarOptionsstftSTFTOptionsStrideOptionssubSublistSublistElementSubscriptIndexsumSVDShapeswapaxessym_floatsym_intsym_notttaketake_along_dimTakeAlongDimOptionstantanhtanhshrinktensortensor_splitTensorCreatorTensorDatatensordotTensordotOptionsTensorLikeTensorMetaTensorOptionsTensorStoragethresholdthreshold_tileTileShapeToOptionstopkTopkOptionsTorchtraceTraceShapetransposetranspose_dims_error_out_of_rangetranspose_error_requires_2d_tensorTransposeDimsShapeTransposeDimsShapeCheckedTransposeShapetrapezoidTrapezoidOptionsTriangularOptionstriltril_indicesTriOptionsTripletriutriu_indicestrue_dividetruncTupleOfLengthTypedArrayTypedArrayForTypedStorageTypeOptionsUnaryBackwardFnUnaryDTypeUnaryKernelConfigCPUUnaryKernelCPUUnaryOpConfigUnaryOpFnUnaryOpNamesUnaryOpParamsUnaryOpSchemaUnaryOptionsunbindunbind_error_dim_out_of_rangeUnbindOptionsunflattenUniformOptionsuniqueunique_consecutiveUniqueConsecutiveOptionsUniqueOptionsunpackUnpackShapeunravel_indexunregister_deviceunsqueezeUnsqueezeOptionsUnsqueezeShapeuse_deterministic_algorithmsValidateBatchedSquareMatrixValidateChunkDimValidatedEinsumShapevalidateDeviceValidateDeviceValidatedRearrangeShapeValidatedReduceShapeValidatedRepeatShapevalidateDTypeValidateEinsumValidateOperandCountValidateRanksValidateScalarValidateSplitDimValidateSquareMatrixValidateUnbindDimValueOptionsvar_var_meanvdotviewview_as_complexview_as_realvmapvsplitvstackWebGPUKernelConfigWebGPUOnlyResultWebGPUTensorDatawhereWindowOptionsxlogyzeroszeros_like
torch.js· 2026
LegalTerms of UsePrivacy Policy
/
/
  1. docs
  2. torch.js
  3. torch
  4. nn
  5. MultiheadAttention

torch.nn.MultiheadAttention

class MultiheadAttention extends Module
new MultiheadAttention(options: MultiheadAttentionOptions)
readonlyembed_dim(number)
readonlynum_heads(number)
readonlyhead_dim(number)
readonlydropout(number)
readonlybatch_first(boolean)
readonlykdim(number)
readonlyvdim(number)
readonlyadd_zero_attn(boolean)
readonlyq_proj_weight(Parameter)
readonlyk_proj_weight(Parameter)
readonlyv_proj_weight(Parameter)
readonlyout_proj(Linear)
readonlyin_proj_bias(Parameter | null)
readonlybias_k(Parameter | null)
readonlybias_v(Parameter | null)

MultiheadAttention module for neural sequence processing.

Multi-head attention is the core mechanism that powers modern Transformers and enables models to selectively focus on different parts of the input sequence when producing each output element. Rather than using a single attention function, multi-head attention applies multiple "attention heads" in parallel, each learning to attend to different aspects of the input (e.g., one head might focus on syntactic structure while another focuses on semantic relationships). This allows the model to jointly attend to information from multiple representation subspaces simultaneously.

Multi-head attention is the fundamental building block of the Transformer architecture and has become the standard for sequence modeling tasks ranging from NLP (language models, machine translation) to vision (image classification with ViT) to multimodal applications. The parallel heads enable computational efficiency (each head processes a smaller embedding dimension) while maintaining expressive power through diversity of learned attention patterns.

Core idea: Instead of a single attention computation, split the embedding into multiple "heads", compute scaled dot-product attention for each head independently, then concatenate and project: MultiHead(Q, K, V) = Concat(head_1, ..., head_h) W^O where head_i = Attention(Q W^Q_i, K W^K_i, V W^V_i)

Each head learns to attend to different semantic or syntactic aspects of the input, enabling richer information flow and more expressive transformations.

Key mechanisms:

  • Scaled dot-product attention: Core similarity function: Attention(Q,K,V) = softmax(QK^T/√d_k)V
  • Parallel heads: num_heads independent attention computations, each with embed_dim/num_heads features
  • Query/Key/Value projections: Learned linear transformations before attention computation
  • Output projection: Learned linear transformation after concatenating all heads
  • Masking: Support for causal masks (autoregressive), padding masks, and custom attention masks

When to use MultiheadAttention:

  • Sequence modeling: RNNs/LSTMs replacements for any sequential data (text, time series, audio)
  • Transformer layers: Core component of encoder/decoder Transformer architectures
  • Self-attention: When query, key, and value are the same (attends to positions in same sequence)
  • Cross-attention: When query comes from one sequence and key/value from another (e.g., encoder-decoder)
  • Vision Transformers: Image patches processed as sequences with spatial attention
  • Multimodal: Attending across modalities (text attending to image regions, etc.)

Attention types:

  • Self-attention: query = key = value (looks at same sequence), most common
  • Cross-attention: query from one sequence, key/value from different sequence (encoder-decoder)
  • Causal attention: Only attends to current and past positions (autoregressive generation)
  • Masked attention: Key padding masks prevent attention to padding tokens

Algorithm:

  1. Project Q, K, V with separate learned weight matrices
  2. Split embeddings into num_heads parts (each head gets embed_dim / num_heads features)
  3. For each head: compute scaled dot-product attention with scaling factor 1/√(embed_dim/num_heads)
  4. Apply optional masking (causal, padding, or custom attention mask)
  5. Apply softmax to get attention probabilities
  6. Apply dropout for regularization (during training only)
  7. Multiply attention weights by values to get weighted context
  8. Concatenate outputs from all heads
  9. Apply final output projection
MultiHead(Q,K,V)=Concat(head1,…,headh)WOheadi=Attention(QWiQ,KWiK,VWiV)Attention(Q,K,V)=softmax(QKTdk)Vdk=embed_dimnum_heads\begin{aligned} \text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1, \ldots, \text{head}_h) W^O \\ \text{head}_i = \text{Attention}(Q W^Q_i, K W^K_i, V W^V_i) \\ \text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right) V \\ d_k = \frac{\text{embed\_dim}}{\text{num\_heads}} \end{aligned}MultiHead(Q,K,V)=Concat(head1​,…,headh​)WOheadi​=Attention(QWiQ​,KWiK​,VWiV​)Attention(Q,K,V)=softmax(dk​​QKT​)Vdk​=num_headsembed_dim​​
  • Standard Transformer building block: MultiheadAttention + FFN are the core components.
  • Embed_dim must be divisible by num_heads: embed_dim % num_heads must equal 0.
  • Parallel computation: All heads compute independently, enabling GPU parallelization.
  • Self-attention default: If key/value not specified, defaults to self-attention (query=key=value).
  • Cross-attention: Set different kdim/vdim to attend to different input spaces.
  • Causal masking: Essential for autoregressive models (language generation, next-token prediction).
  • Padding masks: Important for efficient batch processing with variable-length sequences.
  • Residual connections: Always use with skip connections: out = norm(x + attn(x)) in practice.
  • Layer normalization: Pre-norm (layer_norm - attn) is more stable than post-norm (attn - layer_norm).

Examples

// Transformer encoder block with self-attention
class TransformerEncoderBlock extends torch.nn.Module {
  private attn: torch.nn.MultiheadAttention;
  private norm1: torch.nn.LayerNorm;
  private norm2: torch.nn.LayerNorm;
  private fc: torch.nn.Sequential;  // Feed-forward network

  constructor(embed_dim: number, num_heads: number) {
    super();
    this.attn = new torch.nn.MultiheadAttention({
      embed_dim: embed_dim,
      num_heads: num_heads,
      dropout: 0.1
    });
    this.norm1 = new torch.nn.LayerNorm([embed_dim]);
    this.norm2 = new torch.nn.LayerNorm([embed_dim]);

    // Feed-forward network: linear -> ReLU -> linear
    this.fc = new torch.nn.Sequential(
      new torch.nn.Linear(embed_dim, 4 * embed_dim),
      new torch.nn.ReLU(),
      new torch.nn.Linear(4 * embed_dim, embed_dim)
    );
  }

  forward(x: torch.Tensor): torch.Tensor {
    // x: [seq_len, batch_size, embed_dim]

    // Self-attention with residual connection
    const attn_out = this.attn.forward(x);
    x = x.add(attn_out);  // Residual connection
    x = this.norm1.forward(x);  // Layer normalization

    // Feed-forward with residual connection
    const ff_out = this.fc.forward(x);
    x = x.add(ff_out);  // Residual connection
    x = this.norm2.forward(x);  // Layer normalization

    return x;
  }
}
// Cross-attention: decoder attends to encoder outputs
const batch_size = 32, seq_len = 50, tgt_len = 40, embed_dim = 512;

const attn = new torch.nn.MultiheadAttention({
  embed_dim: embed_dim,
  num_heads: 8,
  dropout: 0.1,
  kdim: embed_dim,  // Key dimension (from encoder)
  vdim: embed_dim   // Value dimension (from encoder)
});

// Encoder outputs: key and value come from encoder
const encoder_output = torch.randn([seq_len, batch_size, embed_dim]);

// Decoder hidden state: query comes from decoder
const decoder_hidden = torch.randn([tgt_len, batch_size, embed_dim]);

// Cross-attention: decoder queries attend to encoder outputs
const [attn_output] = attn.multihead_attn(
  decoder_hidden,  // query (from decoder)
  encoder_output,  // key (from encoder)
  encoder_output,  // value (from encoder)
  { need_weights: true }
);
// attn_output: [tgt_len, batch_size, embed_dim]
// Causal attention for autoregressive generation
const batch_size = 16, seq_len = 100, embed_dim = 768, num_heads = 12;

const attn = new torch.nn.MultiheadAttention({
  embed_dim: embed_dim,
  num_heads: num_heads,
  dropout: 0.1,
  batch_first: true  // Input shape: [batch, seq_len, embed_dim]
});

const x = torch.randn([batch_size, seq_len, embed_dim]);

// Causal mask: each position only attends to itself and previous positions
const [output, attn_weights] = attn.multihead_attn(x, x, x, {
  is_causal: true,    // Only look at past and current positions
  need_weights: true
});
// output: [batch_size, seq_len, embed_dim]
// Ensures autoregressive property: position i depends only on <i
// Padding mask for variable-length sequences
const batch_size = 8, max_seq_len = 50, embed_dim = 256;

const attn = new torch.nn.MultiheadAttention({
  embed_dim: embed_dim,
  num_heads: 8,
  batch_first: true
});

// Some sequences are shorter than max_seq_len, padded with zeros
const sequences = torch.randn([batch_size, max_seq_len, embed_dim]);
const actual_lengths = torch.tensor([50, 32, 48, 50, 25, 50, 40, 50]);

// Create padding mask: true for padded positions
const padding_mask = torch.arange(max_seq_len)
  .unsqueeze(0)
  .expand([batch_size, -1])
  .ge(actual_lengths.unsqueeze(1));  // [batch_size, max_seq_len]

const [output] = attn.multihead_attn(sequences, sequences, sequences, {
  key_padding_mask: padding_mask,  // Prevent attention to padding
  need_weights: false
});
// Padded positions won't contribute to attention output

See Also

  • PyTorch torch.nn.MultiheadAttention
Previous
MSELossOptions
Next
MultiheadAttention.forward