torch.autograd.is_grad_enabled
function is_grad_enabled(): booleanCheck if gradient computation is currently enabled globally.
Returns whether automatic differentiation is active. When true, operations on tensors with requires_grad=true will build the autograd graph. When false (inside no_grad()), all operations skip gradient computation regardless of tensor settings. Used primarily for debugging and conditional logic based on training/inference mode.
Global State: This returns the global gradient state set by set_grad_enabled(), no_grad(), enable_grad(), etc. The state applies to all new operations globally, though individual tensors can override via requires_grad=false.
Use Cases:
- Conditional model behavior: different code paths for training vs inference
- Debugging: verify gradient state when unexpected behavior occurs
- Logging: record whether gradients were enabled during a computation
- Assertions: ensure code is running in expected mode (training or inference)
- Global state: Returns the current setting for all new operations
- Individual override: Tensors can still disable gradients via requires_grad=false
- Affected by contexts: Changed by no_grad(), enable_grad(), set_grad_enabled()
- Inspection only: This function doesn't modify state, just queries it
- Per-operation behavior: Individual operations respect both this and tensor.requires_grad
- Not sufficient alone: requires_grad=false on tensor overrides gradient enabled state
- For debugging only: Don't rely on this for production gradient control
- Can be racy: If code changes global state, results may be unexpected
Returns
boolean– True if gradients are currently being tracked, false if disabled (no_grad active)Examples
// Check gradient status
console.log(torch.is_grad_enabled()); // true (default)
torch.no_grad(() => {
console.log(torch.is_grad_enabled()); // false
});
console.log(torch.is_grad_enabled()); // true (restored)// Conditional forward pass based on gradient state
function forward(model, input) {
if (torch.is_grad_enabled()) {
// Training: use full precision, build graph
return model.forward_full_precision(input);
} else {
// Inference: use quantized, faster version
return model.forward_quantized(input);
}
}// Debug: verify gradient state after complex nesting
torch.no_grad(() => {
torch.enable_grad(() => {
torch.no_grad(() => {
if (torch.is_grad_enabled()) {
console.log('Gradients enabled');
} else {
console.log('Gradients disabled'); // This prints
}
});
});
});// Assertion for training code
function training_step(model, batch) {
// Ensure we're not accidentally running in no_grad()
console.assert(torch.is_grad_enabled(), 'Gradients must be enabled for training!');
const pred = model.forward(batch.input);
const loss = criterion(pred, batch.target);
loss.backward(); // Safe because gradients are enabled
}See Also
- PyTorch torch.is_grad_enabled()
- torch.set_grad_enabled - Change the global gradient state
- torch.no_grad - Temporarily disable gradients
- torch.enable_grad - Re-enable gradients inside no_grad
- torch.Tensor.requires_grad - Per-tensor gradient control