torch.nn.Module.state_dict
Module.state_dict(options?: StateDictOptions): Record<string, Tensor>Module.state_dict(destination: Record<string, Tensor>, prefix: string, recurse: boolean, options?: StateDictOptions): Record<string, Tensor>Returns a dictionary containing all trainable parameters and buffers of the module.
Creates a snapshot of all parameters and buffers in this module and all submodules, keyed by their hierarchical names. The state_dict is used for saving/loading models, checkpoint management, and model transfer. Includes all parameters (weights, biases) and buffers (running statistics, masks) but excludes gradients. Essential for:
- Model persistence: save to file for later loading
- Checkpoint management: save intermediate models during training
- Model transfer: load pretrained weights into new model
- Inference serving: bundle model state with application
- Ensemble methods: combine multiple trained models
What's Included:
- All parameters in this module (weights, biases)
- All buffers in this module (running statistics, momentum)
- All parameters/buffers in child modules recursively (if recurse=true)
What's NOT Included:
- Gradients (use .grad separately if needed)
- Optimizer state (save optimizer separately with optimizer.state_dict())
- Training statistics (loss history, epoch count, etc.)
- Hyperparameters (learning rate, dropout rate, etc.)
Naming Convention: Parameters are named with hierarchical keys: "layer1.weight", "layer1.bias", "layer2.conv.weight". Names match the attribute structure in the module: self.layer1 = Linear(...) → keys start with "layer1.". This makes state_dicts human-readable and structure-aware.
Common Usage:
// Save checkpoint
const checkpoint = {
epoch: epoch,
model_state: model.state_dict(),
optimizer_state: optimizer.state_dict()
};
// ... save checkpoint to file ...
// Load checkpoint
const checkpoint = load_checkpoint(path);
model.load_state_dict(checkpoint.model_state);
optimizer.load_state_dict(checkpoint.optimizer_state);- Includes buffers: Unlike parameters alone, includes running stats
- Excludes gradients: Use .grad on parameters separately if needed
- Excludes optimizer state: Save optimizer.state_dict() separately
- Recursive by default: Automatically includes all submodule state
- Human-readable keys: Hierarchical names match module structure
- Modifiable: Returned dict is mutable; changes don't affect model
- Must load into same architecture: state_dict is architecture-dependent
- Does not include hyperparameters: Dropout rate, learning rate, etc.
- Shape must match: load_state_dict will error if shapes differ
- Version compatibility: state_dict from different versions may not load
Parameters
optionsStateDictOptionsoptional- Optional settings for state dict generation
Returns
Record<string, Tensor>– The populated state dict, same as destinationExamples
// Basic save/load
const model = new MyModel();
// Training...
const state = model.state_dict();
// Later, load the saved state
const loaded_model = new MyModel();
loaded_model.load_state_dict(state);// Inspect state dict structure
const model = new torch.nn.Sequential(
new torch.nn.Linear(10, 5),
new torch.nn.ReLU(),
new torch.nn.Linear(5, 2)
);
const state = model.state_dict();
console.log(Object.keys(state));
// Output: ["0.weight", "0.bias", "2.weight", "2.bias"]// Transfer learning: use pretrained weights
const pretrained_model = load_pretrained_model();
const pretrained_state = pretrained_model.state_dict();
// Create new model with same architecture
const new_model = new MyModel();
// Load pretrained weights
new_model.load_state_dict(pretrained_state);
// Now new_model has pretrained weights// Selective state dict with prefix
const model = new MyModel();
const full_state = model.state_dict();
// Or get only a submodule's state
const encoder_state = {};
model.encoder.state_dict({ destination: encoder_state, prefix: 'encoder.' });
// Keys will be "encoder.layer1.weight", etc.See Also
- PyTorch module.state_dict()
- torch.nn.Module.load_state_dict - Load saved state
- torch.nn.Module.parameters - Get all parameters
- torch.nn.Module.named_parameters - Get parameters with names
- torch.nn.Module.buffers - Get all buffers