torch.autograd.enable_grad
function enable_grad<T>(fn: () => T): TContext manager that enables gradient computation.
Re-enables gradient tracking inside a no_grad() block. This is essential when you need to compute gradients for specific operations within code that otherwise disables them. Allows fine-grained control over which parts of computation build the autograd graph. Complementary to no_grad(): enables gradients wherever they're disabled. Essential for:
- Mixed training/inference: parts of forward pass don't need gradients, parts do
- Conditional gradients: enable gradients based on runtime conditions
- Gradient checkpointing: temporarily enable gradients to save and recompute
- Multi-phase computations: some phases train, others evaluate
- Custom loss computation: use gradients for part of loss, not all
- Advanced training: selective gradient computation for efficiency
Nesting Behavior: enable_grad saves the current gradient state (which might be no_grad), sets gradients to enabled, executes the function, then restores the previous state. This allows arbitrary nesting depth without manual state management.
Interaction with is_grad_enabled(): Inside enable_grad, torch.is_grad_enabled() returns true. Operations build the autograd graph and track gradients normally. Exiting enable_grad restores previous state.
- State restoration: Saves and restores gradient state, even if function throws
- Nesting: Can nest arbitrarily deep (no_grad → enable_grad → no_grad → ...)
- Gradient tracking: Only works if requires_grad=true on tensors involved
- Global toggle: Affects torch.is_grad_enabled() only within this context
- Exception safety: Automatically restores state even if function throws exception
- Lightweight: Minimal overhead - just saves/restores boolean flag
- Counterintuitive nesting: Inside no_grad → enable_grad → no_grad, gradients are disabled
- Not for all cases: Overuse indicates poor architecture; consider refactoring
- Requires careful thinking: Nesting can be confusing; document intent clearly
- Still respects requires_grad: Tensors must have requires_grad=true to build graph
Parameters
fn() => T- Function to execute with gradient tracking enabled. Can be sync or async. Inside this function, torch.is_grad_enabled() returns true and gradients are tracked.
Returns
T– The result of the function, with gradients tracked if requires_grad=trueExamples
// Re-enable gradients inside no_grad block
const x = torch.tensor([1, 2, 3], { requires_grad: true, dtype: 'float32' });
torch.no_grad(() => {
console.log(torch.is_grad_enabled()); // false
const y = x.mul(2);
console.log(y.requires_grad); // false - gradients not tracked
// Temporarily enable gradients
const z = torch.enable_grad(() => {
console.log(torch.is_grad_enabled()); // true
return x.pow(2); // Gradients ARE tracked here
});
console.log(z.requires_grad); // true
// z.backward() would work!
});// Mixed forward pass: some operations need gradients, some don't
function forward_with_caching(model, input) {
return torch.no_grad(() => {
// Heavy preprocessing without gradients
const preprocessed = model.expensive_preprocess(input);
// Enable gradients just for the main computation
const output = torch.enable_grad(() => {
return model.main_forward(preprocessed);
});
// Post-processing without gradients (deterministic)
return model.detach_postprocess(output);
});
}// Gradient checkpointing: save memory by selectively recomputing
function checkpoint(segment_fn, *input_tensors) {
return torch.no_grad(() => {
// Forward without gradients (saves memory)
const output = segment_fn(...input_tensors);
// On backward, gradients are re-enabled and segment recomputed
return torch.enable_grad(() => {
return segment_fn(...input_tensors);
});
});
}// Conditional gradient computation
function smart_forward(model, input, compute_gradients) {
if (compute_gradients) {
// Training: compute gradients
return model.forward(input);
} else {
// Inference: no gradients
return torch.no_grad(() => {
if (some_trigger_condition) {
// But if triggered, enable gradients for this part
return torch.enable_grad(() => model.forward(input));
}
return model.forward(input);
});
}
}// Loss with both differentiable and non-differentiable components
function compute_loss(model, batch) {
let total_loss = null;
torch.no_grad(() => {
// Non-differentiable regularization (expensive computation)
const reg_loss = compute_expensive_regularization();
// Differentiable main loss
const main_loss = torch.enable_grad(() => {
const pred = model.forward(batch.input);
return criterion(pred, batch.target);
});
total_loss = main_loss.add(reg_loss);
});
return total_loss;
}See Also
- PyTorch torch.enable_grad()
- torch.no_grad - Disable gradients context
- torch.set_grad_enabled - Globally set gradient state
- torch.is_grad_enabled - Check if gradients currently enabled
- torch.inference_mode - Inference-only mode (stricter than no_grad)