torch.nn.MultiheadAttention.load_state_dict
MultiheadAttention.load_state_dict(state_dict: Record<string, Tensor>): voidLoads state dict from PyTorch-compatible format.
PyTorch's MHA uses a combined in_proj_weight (shape [3*embed_dim, embed_dim])
instead of separate q_proj_weight, k_proj_weight, v_proj_weight.
This method converts PyTorch's combined format to our internal separate format.
Parameters
state_dictRecord<string, Tensor>