torch.nn.MultiheadAttention.state_dict
MultiheadAttention.state_dict(options?: StateDictOptions): Record<string, Tensor>MultiheadAttention.state_dict(destination: Record<string, Tensor>, prefix: string, recurse: boolean, options?: StateDictOptions): Record<string, Tensor>Returns the state dict with PyTorch-compatible key names.
PyTorch's MHA uses a combined in_proj_weight (shape [3*embed_dim, embed_dim])
instead of separate q_proj_weight, k_proj_weight, v_proj_weight.
This method converts our internal separate format to PyTorch's combined format.
Parameters
optionsStateDictOptionsoptional
Returns
Record<string, Tensor>