torch.nn.MultiheadAttention.forward
MultiheadAttention.forward<S extends Shape, D extends DType = DType, Dev extends DeviceType = DeviceType>(...inputs: Tensor<S, D, Dev>[]): Tensor<S, D, Dev>Forward pass for multi-head attention with self-attention by default.
This is a simplified interface that defaults to self-attention mode where query, key, and value come from the same input tensor. For more control over masking, dropout, and output options, use the multihead_attn method directly.
Input shape: Accepts variable arguments to support both self-attention and cross-attention:
- Self-attention:
forward(x)where x is [seq_len, batch_size, embed_dim] or [batch_size, seq_len, embed_dim] - Cross-attention:
forward(query, key, value)where key/value can have different dimensions
Output shape: Same as query input shape (or transposed if batch_first)
- Default behavior: If key or value are not provided, defaults to self-attention where all three come from the query input.
- Residual connections: In practice, always use with skip connections:
out = x + attn(x)or with layer norm:out = norm(x + attn(x)) - Batch first: If batch_first=true, input/output shapes have batch dimension first. This method handles transposition automatically.
Parameters
inputsTensor<S, D, Dev>[]- Variable number of tensors: - inputs[0]: Query tensor of shape (L, N, E) or (N, L, E) if batch_first - inputs[1]: Optional key tensor of shape (S, N, kdim) or (N, S, kdim) if batch_first - inputs[2]: Optional value tensor of shape (S, N, vdim) or (N, S, vdim) if batch_first
Returns
Tensor<S, D, Dev>– Attention output tensor with same shape as query input. Contains weighted sum of values based on attention weights over key positions. No attention weights are returned from this method (use multihead_attn() with need_weights=true for weights).See Also
- multihead_attn - Full-featured method with masking and weight output options