torch.optim.Adafactor
new Adafactor(params: Tensor[] | Iterable<Tensor>, options: AdafactorOptions = {})
Adafactor optimizer: Memory-efficient adaptive learning rates with factored second moments.
Adafactor addresses the memory bottleneck of Adam/AdamW for training large models by using factored second moment estimation. Instead of storing a full n×m matrix of second moments (n parameters), Adafactor stores only row and column factors (n+m scalars), reducing memory from O(n*m) to O(n+m).
Key Innovation: Standard adaptive methods (Adam) maintain v_t = second moment for each parameter. For a matrix of shape (n, m), this requires n*m memory. Adafactor approximates v_t as outer product of row and column factors: v_t ≈ r_t ⊗ c_t This reduces memory by orders of magnitude while maintaining similar convergence.
Features:
- Factored second moments: O(n+m) instead of O(n*m) memory
- Automatic learning rate schedule: can auto-tune if relative_step=true
- Optional momentum: can add beta1 for first moment if desired
- Gradient clipping: prevents unstable large updates
- Works with or without explicit learning rate
Recommended for:
- Training transformer models with billions of parameters
- When memory is critical constraint (fitting in GPU memory)
- Large embedding tables and dense layers
- Fine-tuning pretrained large models
- Scenarios where Adam won't fit but training is possible
Trade-offs:
- Slightly less accurate than Adam (factored approximation)
- Relative step schedule can be unstable in early training
- More hyperparameters to tune (decay_rate, clip_threshold)
- Not as widely adopted as Adam/AdamW
- Memory efficient: Reduces Adam memory by orders of magnitude (O(n*m)→O(n+m)).
- Transformers ideal: Perfect for large transformer models with billions of parameters.
- Auto learning rate: relative_step=true auto-computes learning rate based on statistics.
- Gradient clipping: Built-in clipping prevents unstable updates from noisy gradients.
- Factored approximation: Not as accurate as Adam but memory savings huge.
- Warmup recommended: Use warmup_init=true for more stable early training.
- Less established: Newer than Adam, some edge cases less well-understood.
- Momentum optional: Can add momentum for faster convergence if desired.
- No bias correction needed: Adaptive learning rate schedule replaces it.
- Used in practice: Proven effective in large language model training (T5, BART, etc.).
- Hyperparameter tuning: decay_rate and clip_threshold most important to tune.
Examples
// Basic Adafactor with auto learning rate (typical usage)
const model = new TransformerModel();
const adafactor = new torch.optim.Adafactor(model.parameters());
// Adafactor auto-tunes learning rate based on parameter statistics
for (const batch of train_loader) {
const loss = model.loss(batch.x, batch.y);
adafactor.zero_grad();
// loss.backward();
adafactor.step();
}// Adafactor with explicit learning rate
const adafactor = new torch.optim.Adafactor(model.parameters(), {
lr: 1e-3,
relative_step: false // Use explicit lr instead of auto
});
// Disabling relative_step uses provided lr like Adam// Adafactor with momentum for faster convergence
const adafactor = new torch.optim.Adafactor(model.parameters(), {
beta1: 0.9, // Enable momentum
weight_decay: 1e-5
});
// Adding momentum often speeds up convergence// Adafactor for memory-constrained training
const adafactor = new torch.optim.Adafactor(model.parameters(), {
clip_threshold: 1.0,
decay_rate: -0.8,
warmup_init: true // Better early training stability
});
// Memory usage: O(n+m) instead of O(n*m) for adam
// Fits models that adam can't fit// Adafactor configuration for fine-tuning
const adafactor = new torch.optim.Adafactor(model.parameters(), {
relative_step: true, // Auto learning rate
clip_threshold: 1.0,
weight_decay: 1e-5,
beta1: 0.9 // Momentum helps fine-tuning
});
// Good balance of memory efficiency and performance