torch.nn.Module.zero_grad
Module.zero_grad(): voidZero out all gradients in the module.
Sets all parameter gradients to null, clearing accumulated gradients. Must be called before each backward pass in training loops to prevent gradient accumulation across multiple loss computations.
Essential for Training:
- Clears old gradients before computing new ones
- Called before each backward() pass
- Prevents gradient accumulation over multiple batches
- Necessary for correct gradient descent steps
Common Pattern:
for (const batch of dataLoader) {
optimizer.zero_grad(); // Clear old gradients
loss = model(batch).backward(); // Compute new gradients
optimizer.step(); // Update parameters
}- Critical: Almost always needed before backward() in training
- Recursive: Zeros gradients in all child modules automatically
- Idempotent: Calling multiple times is safe and has no effect
- Parameters only: Only affects parameter gradients, not intermediate activations
- Forgetting zero_grad(): Causes gradients to accumulate, leading to incorrect updates
- Different from backward(): backward() computes gradients; zero_grad() clears them
Returns
void
Examples
// Training loop pattern
for (const batch of dataLoader) {
// Step 1: Clear old gradients
model.zero_grad();
// Step 2: Forward pass
const output = model.forward(batch.input);
const loss = criterion(output, batch.target);
// Step 3: Backward pass
loss.backward();
// Step 4: Update parameters
optimizer.step();
}
// Check that gradients are cleared
model.zero_grad();
for (const param of model.parameters()) {
console.log(param.grad); // null
}
// Compare with accumulation without zero_grad
loss1.backward();
loss2.backward(); // Gradients accumulated, not replaced
// Use zero_grad() between to get fresh gradientsSee Also
- PyTorch module.zero_grad()
- backward - Compute gradients after computing loss
- parameters - Get all parameters in module
- named_parameters - Get parameters with names for inspection