torch.nn.MultiheadAttention.multihead_attn
MultiheadAttention.multihead_attn(query: Tensor, options?: MultiheadAttnOptions): [Tensor, Tensor | null]MultiheadAttention.multihead_attn(query: Tensor, key: Tensor, value: Tensor, key_padding_mask: Tensor | undefined, need_weights: boolean, attn_mask: Tensor | undefined, average_attn_weights: boolean, is_causal: boolean, options?: MultiheadAttnOptions): [Tensor, Tensor | null]Full multi-head attention with comprehensive control over masking and outputs.
This method provides complete control over attention computation including support for causal masking (autoregressive generation), padding masks (variable-length sequences), custom attention masks, and returning attention weights for interpretability. It's the main entry point for all attention patterns: self-attention, cross-attention (encoder-decoder), and masked attention for generation.
Computation flow:
- Project query, key, value tensors using learned weight matrices
- Reshape embeddings into separate attention heads (num_heads parallel subspaces)
- Apply scaled dot-product attention for each head: softmax((QK^T)/√d_k) V
- Apply masks (causal, padding, or custom) before softmax to control which positions can attend
- Apply dropout to attention weights for regularization (training mode only)
- Concatenate outputs from all heads and apply output projection
- Optionally return attention weights (averaged across heads or per-head)
Input shapes: (seq_len, batch, embed_dim) or (batch, seq_len, embed_dim) if batch_first
- L: target sequence length (query length)
- S: source sequence length (key/value length, can differ from L for cross-attention)
- N: batch size
- E: embedding dimension (embed_dim)
- d_k: head dimension = embed_dim / num_heads
Output shapes: Same as query input, or transposed to batch-first if needed
- Shape requirements: query and key must have compatible batch dimensions. Query length (L) determines output length. Key/value length (S) determines which positions can attend.
- Causal vs attn_mask: is_causal=true and attn_mask cannot be used together. Use is_causal for autoregressive models, attn_mask for custom patterns.
- Mask syntax: Boolean masks (key_padding_mask) where True=masked. Additive masks (attn_mask) where negative values suppress attention. Both get combined additively.
- Dropout during eval: Dropout is disabled during evaluation mode (self.training=false). Always call model.eval() before inference to match training behavior.
- Head dimension: embed_dim must be divisible by num_heads. Each head processes embed_dim/num_heads dimensions. This is validated in constructor, but useful to know for debugging.
- Attention weights interpretation: High values indicate strong attention. weights[b, q, k] shows how much position q attends to position k. Averaged across heads gives overall pattern.
- Numerical stability: Scaled dot-product attention divides by √d_k to prevent softmax saturation. Masking is applied before softmax for numerical stability (-inf instead of post-softmax).
Parameters
queryTensor- Query tensor of shape (L, N, E) or (N, L, E) if batch_first. Determines what to attend to. Length L determines output length. For self-attention, use same tensor as key/value.
optionsMultiheadAttnOptionsoptional- Optional dictionary controlling attention behavior: Masking options: -
attn_mask- Custom attention mask of shape (L, S) or (N*num_heads, L, S). True/1.0 positions are masked (set to -inf before softmax). Used for custom patterns beyond causal/padding. -key_padding_mask- Boolean mask of shape (N, S) where True marks padded positions in key. Prevents attention to padding tokens. Essential for batches with variable lengths. -is_causal- If true, applies causal mask restricting each position from attending to future positions. Essential for autoregressive generation (language models, next-token prediction). Cannot be combined with attn_mask. Enforces: position i only attends to positions [0...i]. Output options: -need_weights- If true, returns attention weights as second tuple element. Default: true. Weights are useful for visualization and interpretability but add computation cost. -average_attn_weights- If true (default), averages weights across all heads to produce shape (N, L, S) for single attention weight matrix. If false, returns per-head weights with shape (N, num_heads, L, S) - useful for analysis.
Returns
[Tensor, Tensor | null]– Tuple of (attention_output, attention_weights): - attention_output: Tensor of shape (L, N, E) or (N, L, E) if batch_first. Contextual representation where each position has attended to relevant positions weighted by attention scores. - attention_weights: Tensor of shape (N, L, S) if average_attn_weights=true, or (N, num_heads, L, S) if average_attn_weights=false. Null if need_weights=false. Values in [0, 1] indicating how much each position attended to each key position.Examples
// Causal self-attention for next-token prediction
const attn = new torch.nn.MultiheadAttention({
embed_dim: 512,
num_heads: 8,
dropout: 0.1,
batch_first: true
});
const x = torch.randn([batch_size, seq_len, 512]);
const [output, weights] = attn.multihead_attn(x, x, x, {
is_causal: true, // Only attend to past positions
need_weights: true
});
// weights[i, j, k] = attention from position j to position k
// For causal: weights[i, j, k] = 0 for k > j (future positions masked)// Cross-attention with encoder-decoder
const attn = new torch.nn.MultiheadAttention({
embed_dim: 768,
num_heads: 12,
kdim: 768,
vdim: 768,
batch_first: true
});
const encoder_output = torch.randn([batch_size, src_len, 768]); // Encoder outputs
const decoder_hidden = torch.randn([batch_size, tgt_len, 768]); // Decoder hidden state
// Decoder attends to encoder outputs
const [context, weights] = attn.multihead_attn(
decoder_hidden, // query: what to attend to
encoder_output, // key/value: encoder outputs
encoder_output,
{ need_weights: false }
);
// context: decoder representations augmented with encoder information// Variable-length batch with padding mask
const attn = new torch.nn.MultiheadAttention({
embed_dim: 256,
num_heads: 4,
batch_first: true
});
const max_len = 100;
const batch = torch.randn([32, max_len, 256]);
const actual_lens = torch.tensor([80, 95, 100, 75, ...]); // 32 values
// Create padding mask
const positions = torch.arange(max_len).unsqueeze(0);
const padding_mask = positions >= actual_lens.unsqueeze(1); // [32, max_len]
const [output, weights] = attn.multihead_attn(
batch, batch, batch,
{ key_padding_mask: padding_mask }
);
// Padded positions contribute 0 attention weight to all positionsSee Also
- forward - Simplified self-attention interface without options
- torch.nn.functional.scaled_dot_product_attention - Core attention computation
- torch.nn.TransformerEncoderLayer - Complete layer with attention + feedforward
- torch.nn.TransformerDecoderLayer - Decoder layer with self+cross attention