torch.autograd.set_grad_enabled
function set_grad_enabled(enabled: boolean): voidSet whether gradient computation is enabled globally.
Permanently changes the global gradient state until changed again. Unlike no_grad() and enable_grad() which automatically restore previous state, this directly sets the state. Useful for high-level control of training vs inference mode, though no_grad()/enable_grad() are preferred for local control since they auto-restore.
Difference from no_grad/enable_grad:
- set_grad_enabled: direct state change, persists until changed again
- no_grad/enable_grad: context managers, auto-restore previous state
Prefer context managers (no_grad/enable_grad) for local control to avoid accidentally leaving gradients disabled. Use set_grad_enabled for global training/inference switches.
Effect: After calling this, torch.is_grad_enabled() returns the value you set, and all subsequent operations respect this setting until it's changed again or a context manager changes it.
Use Cases:
- Main training/inference mode switch
- Initialization of gradient state at application start
- Global configuration, not typical in production (prefer context managers)
- Persistent change: Affects all subsequent operations until changed again
- No auto-restore: Unlike context managers, you must restore manually
- Global effect: Changes behavior for entire application
- Queryable: Use torch.is_grad_enabled() to check current state
- Affects contexts: no_grad/enable_grad save and restore this state
- Avoid in production: Prefer context managers (no_grad/enable_grad) for safety
- Can be forgotten: Easy to accidentally leave gradients disabled
- Thread-unsafe: Changing global state can affect concurrent code
- Affects whole app: Not local to function like context managers
Parameters
enabledboolean- True to enable gradient tracking, false to disable
Examples
// Disable gradients globally (inference mode)
torch.set_grad_enabled(false);
const result = model.forward(input); // No gradients tracked
console.log(torch.is_grad_enabled()); // false
// Re-enable gradients globally (training mode)
torch.set_grad_enabled(true);
const result2 = model.forward(input); // Gradients tracked
console.log(torch.is_grad_enabled()); // true// Initialize based on mode at application start
const isTraining = process.env.MODE === 'train';
torch.set_grad_enabled(isTraining);
// Now all code follows this setting
for (let epoch = 0; epoch < num_epochs; epoch++) {
const loss = model.forward(batch);
if (isTraining) {
loss.backward();
optimizer.step();
}
}// Typical pattern: use context managers instead
torch.set_grad_enabled(true); // Enable for training
// Use no_grad() for local inference regions
torch.no_grad(() => {
const val_loss = model.forward(val_input); // No gradients
});
// Back to training
const train_loss = model.forward(train_input); // Gradients tracked// Save and restore pattern (for special cases)
const prev_state = torch.is_grad_enabled();
torch.set_grad_enabled(false);
try {
// Do something with gradients disabled
expensive_computation();
} finally {
// Always restore (but prefer context managers instead!)
torch.set_grad_enabled(prev_state);
}See Also
- PyTorch torch.set_grad_enabled()
- torch.is_grad_enabled - Check current gradient state
- torch.no_grad - Disable gradients with auto-restore (preferred)
- torch.enable_grad - Enable gradients with auto-restore (preferred)
- torch.grad_mode - Set mode with auto-restore