torch.autograd.grad_mode
function grad_mode<T>(mode: boolean, fn: () => T): TContext manager for explicitly setting gradient computation mode.
Temporarily enables or disables automatic gradient computation for a code block, then restores
the previous state. This is useful for conditionally controlling gradients based on runtime
conditions (training vs evaluation, specific operations, etc.). Unlike no_grad() which always
disables, or enable_grad() which always enables, grad_mode() takes a boolean to set the
desired state.
Use cases:
- Training/evaluation switching: Use a single flag to control gradient computation
- Conditional gradients: Enable gradients only for specific operations or layers
- Hybrid training: Mix gradient-enabled and gradient-disabled operations
- Nested mode control: Safely nesting gradient contexts with proper restoration
- Dynamic control: Let runtime conditions determine if gradients are needed
Key differences from alternatives:
grad_mode(bool, fn): Explicit boolean, useful for dynamic controlno_grad(): Always disables, optimized for evaluationenable_grad(): Always enables, for re-enabling in no_grad contextset_grad_enabled(bool): Globally sets state (affects all subsequent code)
The function automatically saves and restores the previous gradient state, making it safe to nest or combine with other context managers.
- State restoration: Automatically restores previous mode, even if fn throws
- Not global: Only affects code within fn(), doesn't change global state outside
- Mode stacking: Safe to nest multiple grad_mode calls; innermost takes precedence
- Overrides parent mode: Explicitly sets mode, overriding any parent context
- try/finally safety: Exception-safe - always restores previous mode
- Lightweight: Minimal overhead compared to no_grad or enable_grad
- Read-only within function: The gradient mode set at the start of fn() persists throughout the function, even if set_grad_enabled() is called inside fn()
- Different from set_grad_enabled: grad_mode is temporary and local; set_grad_enabled is global and affects all subsequent code
- Performance trade-off: Enabling gradients has memory and computation overhead; use no_grad for evaluation to avoid unnecessary computation
Parameters
modeboolean- True to enable gradients, false to disable them: -
true: Gradients are computed for all operations in fn() -false: Gradients are not computed (like no_grad) fn() => T- Function to execute within the specified gradient mode
Returns
T– The result of calling fn()Examples
// Train/eval switching with single flag
const is_training = true;
const output = torch.grad_mode(is_training, () => {
return model.forward(input);
});
// Gradients computed during training, not during evaluation// Conditional gradient computation
const x = torch.randn(5, { requires_grad: true });
let result;
const compute_gradients = model.training;
result = torch.grad_mode(compute_gradients, () => {
const y = x.square();
return y.sum();
});
if (compute_gradients) {
result.backward(); // Only works if gradients were computed
}// Nested gradient contexts
const x = torch.randn(3, { requires_grad: true });
torch.grad_mode(false, () => {
// Gradients disabled
const y = x.square();
torch.grad_mode(true, () => {
// Gradients re-enabled
const z = y.sum();
z.backward(); // Works because re-enabled here
});
// Gradients disabled again
const w = x.sin();
});// Selective gradient computation for different parts
const input = torch.randn(32, 10, { requires_grad: true });
// Compute embeddings without gradients
const embeddings = torch.grad_mode(false, () => {
return embedding_layer(input);
});
// Fine-tune with gradients
const output = torch.grad_mode(true, () => {
return classifier(embeddings);
});// Using model.training property
class MyModel extends torch.nn.Module {
forward(x: torch.Tensor): torch.Tensor {
return torch.grad_mode(this.training, () => {
let y = this.encoder(x);
y = torch.relu(y);
return this.decoder(y);
});
}
}See Also
- PyTorch torch.autograd.set_grad_enabled()
- no_grad - Context manager that always disables gradients (recommended for eval)
- enable_grad - Context manager that always enables gradients
- set_grad_enabled - Globally set gradient mode (affects all subsequent code)
- is_grad_enabled - Check current gradient mode
- torch.autograd - Autograd module