torch.nn.register_module_full_backward_pre_hook
function register_module_full_backward_pre_hook(hook: BackwardPreHook): RemovableHandleRegisters a global backward pre-hook that executes before every module's backward computation.
Installs a hook that will be called for ALL modules at the start of their backward passes, before gradient computation. Receives output gradients (from downstream) and can optionally modify them before they're used by the module's backward implementation. Useful for:
- Gradient preprocessing: Normalizing, scaling, or filtering incoming gradients
- Gradient validation: Checking gradients before propagation (detect NaN/Inf early)
- Conditional backward: Skipping certain layers based on gradient properties
- Gradient checkpointing: Implementing custom memory-efficiency strategies
- Debugging: Monitoring what gradients arrive at each layer
The hook runs BEFORE backward computation, so you see the incoming gradients from the next layer before they're used. Different from register_module_backward_hook which runs AFTER the backward computation.
- Pre-computation: Executes BEFORE backward computation completes
- Only output gradients: Has access to gradOutput (not gradInput yet)
- Early validation: Can catch gradient issues before they affect computation
- Modification possible: Can modify gradOutput before backward processes it
- Backward only: Called only during backward pass
- Early timing: Runs before backward computation, so be careful with assumptions
- Limited information: Can't see gradInput yet (computed during backward)
Parameters
hookBackwardPreHook- BackwardPreHook function called with (module, gradOutput) before backward
Returns
RemovableHandle– RemovableHandle to unregister this hook using .remove()Examples
// Validate incoming gradients
const hook = (module, gradOutput) => {
gradOutput.forEach((g, i) => {
if (g && (torch.isnan(g).any().item() || torch.isinf(g).any().item())) {
console.error(`Invalid gradients at layer ${i}`);
}
});
};
torch.nn.register_module_full_backward_pre_hook(hook);// Scale incoming gradients
const hook = (module, gradOutput) => {
const scale = 0.1; // Scale down gradients
return gradOutput.map(g => g ? g.mul(scale) : g);
};
torch.nn.register_module_full_backward_pre_hook(hook);// Monitor gradient flow at different layers
const hook = (module, gradOutput) => {
const norm = gradOutput[0]?.norm().item() ?? 0;
console.log(`${module.constructor.name} incoming grad norm:`, norm);
};
torch.nn.register_module_full_backward_pre_hook(hook);See Also
- PyTorch torch.nn.modules.module.register_module_full_backward_pre_hook
- register_module_backward_hook - Post-backward hook with both gradient types
- register_module_full_backward_hook - Post-backward full hook
- register_module_forward_pre_hook - Forward pre-hook for comparison