torch.nn.functional.multi_head_attention_forward
function multi_head_attention_forward(query: Tensor, key: Tensor, value: Tensor, embed_dim_to_check: number, num_heads: number, in_proj_weight: Tensor | null, in_proj_bias: Tensor | null, bias_k: Tensor | null, bias_v: Tensor | null, add_zero_attn: boolean, dropout_p: number, out_proj_weight: Tensor, out_proj_bias: Tensor | null, options?: MultiHeadAttentionFunctionalOptions): [Tensor, Tensor | null]Multi-head attention forward pass.
This implements scaled dot-product attention with multiple heads: Attention(Q, K, V) = softmax(Q * K^T / sqrt(d_k)) * V
Parameters
queryTensor- Query tensor of shape (L, N, E) where L is target sequence length, N is batch size, E is embedding dimension
keyTensor- Key tensor of shape (S, N, E) where S is source sequence length
valueTensor- Value tensor of shape (S, N, E)
embed_dim_to_checknumber- Expected embedding dimension
num_headsnumber- Number of attention heads
in_proj_weightTensor | null- Projection weight for query, key, value (3*E, E)
in_proj_biasTensor | null- Projection bias for query, key, value (3*E)
bias_kTensor | null- Optional bias for key
bias_vTensor | null- Optional bias for value
add_zero_attnboolean- If true, add a zero attention weight
dropout_pnumber- Dropout probability
out_proj_weightTensor- Output projection weight (E, E)
out_proj_biasTensor | null- Output projection bias (E)
optionsMultiHeadAttentionFunctionalOptionsoptional