torch.nn.ModuleDict
class ModuleDict extends Modulenew ModuleDict(options?: ModuleDictOptions)
- readonly
length(number) - – Get the number of modules in the dictionary.
Holds submodules in a dictionary (key → module mapping) with automatic registration.
ModuleDict stores named modules in a dictionary-like container. Unlike ModuleList which uses integer indices, ModuleDict uses string keys for named access. Essential for:
- Conditional layer selection by name (multi-task models, mixture of experts)
- Named module access (clearer than numeric indices)
- Complex architectures with different module types accessed by role
- Branches in model (main branch, attention branch, etc.)
- Easy module lookup and modification by semantic name
Key difference from ModuleList: Uses string keys instead of integer indices. Modules are registered automatically for proper parameter tracking, gradient computation, and device/mode propagation.
When to use ModuleDict:
- Named module access is clearer than numeric indices
- Conditional selection of modules (if condition, use module_a else module_b)
- Multi-branch architectures (different heads, tasks, or modalities)
- Mixture of Experts models
- Easy modification/inspection of named submodules
When NOT to use:
- Sequential layer stacks - use ModuleList instead
- Simple linear structure - use attributes (this.fc1, this.fc2, etc.)
- Named access: String keys enable semantic module naming and retrieval
- Parameter tracking: All modules are automatically tracked for parameters/buffers
- Device movement: .to(device) applies to all submodules
- Training mode: .train()/.eval() propagates to all submodules
- Iteration: Iterate with for...of or keys()/values()/entries()
- No forward method: Must implement forward() and manually call submodules
- Key semantics: Use meaningful names like 'head', 'backbone', 'expert_0', etc.
Examples
// Multi-task model with named heads
class MultiTaskModel extends torch.nn.Module {
backbone: torch.nn.Sequential;
heads: torch.nn.ModuleDict;
constructor() {
super();
this.backbone = new torch.nn.Sequential(
new torch.nn.Linear(784, 256),
new torch.nn.ReLU(),
new torch.nn.Linear(256, 128)
);
this.heads = new torch.nn.ModuleDict({
'classification': new torch.nn.Linear(128, 10),
'regression': new torch.nn.Linear(128, 1),
'segmentation': new torch.nn.Linear(128, 5)
});
}
forward(x: torch.Tensor, task: string): torch.Tensor {
x = this.backbone.forward(x);
const head = this.heads.get(task);
if (!head) throw new Error(`Unknown task: ${task}`);
return (head).forward(x);
}
}
const model = new MultiTaskModel();
const clf_output = model.forward(x, 'classification'); // [batch, 10]
const reg_output = model.forward(x, 'regression'); // [batch, 1]// Mixture of Experts model
class MixtureOfExperts extends torch.nn.Module {
experts: torch.nn.ModuleDict;
gating: torch.nn.Linear;
constructor(num_experts: number, hidden_dim: number, output_dim: number) {
super();
this.experts = new torch.nn.ModuleDict();
for (let i = 0; i < num_experts; i++) {
const expert = new torch.nn.Sequential(
new torch.nn.Linear(hidden_dim, hidden_dim),
new torch.nn.ReLU(),
new torch.nn.Linear(hidden_dim, output_dim)
);
this.experts.set(`expert_${i}`, expert);
}
this.gating = new torch.nn.Linear(hidden_dim, num_experts);
}
forward(x: torch.Tensor): torch.Tensor {
const gates = torch.nn.functional.softmax(this.gating.forward(x), -1); // [batch, num_experts]
let output = null;
for (const [name, expert] of this.experts) {
const expert_output = (expert).forward(x); // [batch, output_dim]
// TODO: Weighted sum by gating
}
return output;
}
}