torch.nn.MultiheadAttention
class MultiheadAttention extends Modulenew MultiheadAttention(options: MultiheadAttentionOptions)
- readonly
embed_dim(number) - readonly
num_heads(number) - readonly
head_dim(number) - readonly
dropout(number) - readonly
batch_first(boolean) - readonly
kdim(number) - readonly
vdim(number) - readonly
add_zero_attn(boolean) - readonly
q_proj_weight(Parameter) - readonly
k_proj_weight(Parameter) - readonly
v_proj_weight(Parameter) - readonly
out_proj(Linear) - readonly
in_proj_bias(Parameter | null) - readonly
bias_k(Parameter | null) - readonly
bias_v(Parameter | null)
MultiheadAttention module for neural sequence processing.
Multi-head attention is the core mechanism that powers modern Transformers and enables models to selectively focus on different parts of the input sequence when producing each output element. Rather than using a single attention function, multi-head attention applies multiple "attention heads" in parallel, each learning to attend to different aspects of the input (e.g., one head might focus on syntactic structure while another focuses on semantic relationships). This allows the model to jointly attend to information from multiple representation subspaces simultaneously.
Multi-head attention is the fundamental building block of the Transformer architecture and has become the standard for sequence modeling tasks ranging from NLP (language models, machine translation) to vision (image classification with ViT) to multimodal applications. The parallel heads enable computational efficiency (each head processes a smaller embedding dimension) while maintaining expressive power through diversity of learned attention patterns.
Core idea: Instead of a single attention computation, split the embedding into multiple "heads", compute scaled dot-product attention for each head independently, then concatenate and project: MultiHead(Q, K, V) = Concat(head_1, ..., head_h) W^O where head_i = Attention(Q W^Q_i, K W^K_i, V W^V_i)
Each head learns to attend to different semantic or syntactic aspects of the input, enabling richer information flow and more expressive transformations.
Key mechanisms:
- Scaled dot-product attention: Core similarity function: Attention(Q,K,V) = softmax(QK^T/√d_k)V
- Parallel heads: num_heads independent attention computations, each with embed_dim/num_heads features
- Query/Key/Value projections: Learned linear transformations before attention computation
- Output projection: Learned linear transformation after concatenating all heads
- Masking: Support for causal masks (autoregressive), padding masks, and custom attention masks
When to use MultiheadAttention:
- Sequence modeling: RNNs/LSTMs replacements for any sequential data (text, time series, audio)
- Transformer layers: Core component of encoder/decoder Transformer architectures
- Self-attention: When query, key, and value are the same (attends to positions in same sequence)
- Cross-attention: When query comes from one sequence and key/value from another (e.g., encoder-decoder)
- Vision Transformers: Image patches processed as sequences with spatial attention
- Multimodal: Attending across modalities (text attending to image regions, etc.)
Attention types:
- Self-attention: query = key = value (looks at same sequence), most common
- Cross-attention: query from one sequence, key/value from different sequence (encoder-decoder)
- Causal attention: Only attends to current and past positions (autoregressive generation)
- Masked attention: Key padding masks prevent attention to padding tokens
Algorithm:
- Project Q, K, V with separate learned weight matrices
- Split embeddings into num_heads parts (each head gets embed_dim / num_heads features)
- For each head: compute scaled dot-product attention with scaling factor 1/√(embed_dim/num_heads)
- Apply optional masking (causal, padding, or custom attention mask)
- Apply softmax to get attention probabilities
- Apply dropout for regularization (during training only)
- Multiply attention weights by values to get weighted context
- Concatenate outputs from all heads
- Apply final output projection
- Standard Transformer building block: MultiheadAttention + FFN are the core components.
- Embed_dim must be divisible by num_heads: embed_dim % num_heads must equal 0.
- Parallel computation: All heads compute independently, enabling GPU parallelization.
- Self-attention default: If key/value not specified, defaults to self-attention (query=key=value).
- Cross-attention: Set different kdim/vdim to attend to different input spaces.
- Causal masking: Essential for autoregressive models (language generation, next-token prediction).
- Padding masks: Important for efficient batch processing with variable-length sequences.
- Residual connections: Always use with skip connections: out = norm(x + attn(x)) in practice.
- Layer normalization: Pre-norm (layer_norm - attn) is more stable than post-norm (attn - layer_norm).
Examples
// Transformer encoder block with self-attention
class TransformerEncoderBlock extends torch.nn.Module {
private attn: torch.nn.MultiheadAttention;
private norm1: torch.nn.LayerNorm;
private norm2: torch.nn.LayerNorm;
private fc: torch.nn.Sequential; // Feed-forward network
constructor(embed_dim: number, num_heads: number) {
super();
this.attn = new torch.nn.MultiheadAttention({
embed_dim: embed_dim,
num_heads: num_heads,
dropout: 0.1
});
this.norm1 = new torch.nn.LayerNorm([embed_dim]);
this.norm2 = new torch.nn.LayerNorm([embed_dim]);
// Feed-forward network: linear -> ReLU -> linear
this.fc = new torch.nn.Sequential(
new torch.nn.Linear(embed_dim, 4 * embed_dim),
new torch.nn.ReLU(),
new torch.nn.Linear(4 * embed_dim, embed_dim)
);
}
forward(x: torch.Tensor): torch.Tensor {
// x: [seq_len, batch_size, embed_dim]
// Self-attention with residual connection
const attn_out = this.attn.forward(x);
x = x.add(attn_out); // Residual connection
x = this.norm1.forward(x); // Layer normalization
// Feed-forward with residual connection
const ff_out = this.fc.forward(x);
x = x.add(ff_out); // Residual connection
x = this.norm2.forward(x); // Layer normalization
return x;
}
}// Cross-attention: decoder attends to encoder outputs
const batch_size = 32, seq_len = 50, tgt_len = 40, embed_dim = 512;
const attn = new torch.nn.MultiheadAttention({
embed_dim: embed_dim,
num_heads: 8,
dropout: 0.1,
kdim: embed_dim, // Key dimension (from encoder)
vdim: embed_dim // Value dimension (from encoder)
});
// Encoder outputs: key and value come from encoder
const encoder_output = torch.randn([seq_len, batch_size, embed_dim]);
// Decoder hidden state: query comes from decoder
const decoder_hidden = torch.randn([tgt_len, batch_size, embed_dim]);
// Cross-attention: decoder queries attend to encoder outputs
const [attn_output] = attn.multihead_attn(
decoder_hidden, // query (from decoder)
encoder_output, // key (from encoder)
encoder_output, // value (from encoder)
{ need_weights: true }
);
// attn_output: [tgt_len, batch_size, embed_dim]// Causal attention for autoregressive generation
const batch_size = 16, seq_len = 100, embed_dim = 768, num_heads = 12;
const attn = new torch.nn.MultiheadAttention({
embed_dim: embed_dim,
num_heads: num_heads,
dropout: 0.1,
batch_first: true // Input shape: [batch, seq_len, embed_dim]
});
const x = torch.randn([batch_size, seq_len, embed_dim]);
// Causal mask: each position only attends to itself and previous positions
const [output, attn_weights] = attn.multihead_attn(x, x, x, {
is_causal: true, // Only look at past and current positions
need_weights: true
});
// output: [batch_size, seq_len, embed_dim]
// Ensures autoregressive property: position i depends only on <i// Padding mask for variable-length sequences
const batch_size = 8, max_seq_len = 50, embed_dim = 256;
const attn = new torch.nn.MultiheadAttention({
embed_dim: embed_dim,
num_heads: 8,
batch_first: true
});
// Some sequences are shorter than max_seq_len, padded with zeros
const sequences = torch.randn([batch_size, max_seq_len, embed_dim]);
const actual_lengths = torch.tensor([50, 32, 48, 50, 25, 50, 40, 50]);
// Create padding mask: true for padded positions
const padding_mask = torch.arange(max_seq_len)
.unsqueeze(0)
.expand([batch_size, -1])
.ge(actual_lengths.unsqueeze(1)); // [batch_size, max_seq_len]
const [output] = attn.multihead_attn(sequences, sequences, sequences, {
key_padding_mask: padding_mask, // Prevent attention to padding
need_weights: false
});
// Padded positions won't contribute to attention output