torch.nn.register_module_backward_hook
function register_module_backward_hook(hook: BackwardHook): RemovableHandleRegisters a global backward hook that executes during the backward pass of every module.
Installs a hook that will be called for ALL modules during their backward() computations. Receives both input gradients (computed by backward) and output gradients (from downstream), and can optionally modify input gradients before they propagate backward. Useful for:
- Gradient debugging: Inspecting gradient flow and detecting anomalies
- Gradient clipping: Preventing gradient explosion during training
- Gradient statistics: Monitoring gradient norms, sparsity, and distribution
- Gradient manipulation: Applying per-layer gradient scaling or filtering
- Training diagnostics: Detecting vanishing/exploding gradients
The hook runs after the backward computation completes, so both gradInput and gradOutput are available. This is the "full" variant that provides both types of gradients.
- Backward only: Called only during backward pass, not forward
- Post-backward execution: Runs after backward computation completes
- Both gradient types: Has access to both input and output gradients
- Null handling: Some gradients may be null (for layers without gradients)
- Performance impact: Heavy operations slow down backward pass
- Gradient modification: Returning modified gradients affects upstream computation
- Requires backward: Only executed when .backward() is called
Parameters
hookBackwardHook- BackwardHook function called with (module, gradInput, gradOutput)
Returns
RemovableHandle– RemovableHandle to unregister this hook using .remove()Examples
// Monitor gradient statistics
const hook = (module, gradInput, gradOutput) => {
gradInput.forEach((g, i) => {
if (g) console.log(`Layer ${i} grad norm:`, g.norm().item());
});
};
torch.nn.register_module_backward_hook(hook);// Clip gradients to prevent explosion
const hook = (module, gradInput, gradOutput) => {
return gradInput.map(g => g ? torch.clamp(g, -1, 1) : g);
};
torch.nn.register_module_backward_hook(hook);// Detect dying layers (zero gradients)
const hook = (module, gradInput, gradOutput) => {
gradInput.forEach((g, i) => {
if (g && torch.all(g.eq(0)).item()) {
console.warn(`${module.constructor.name} layer ${i} has zero gradients!`);
}
});
};
torch.nn.register_module_backward_hook(hook);See Also
- PyTorch torch.nn.modules.module.register_module_backward_hook
- register_module_full_backward_pre_hook - Pre-backward hook
- register_module_full_backward_hook - Alias for this function
- register_module_forward_hook - Forward pass hook