torch.nn.BackwardPreHook
export type BackwardPreHook = (
module: Module,
gradOutput: (Tensor | null)[]
) => (Tensor | null)[] | void;Hook called before the backward pass of a module.
BackwardPreHook is executed at the start of a module's backward computation. It receives output gradients (from downstream) and can optionally modify them before they're processed by the module's backward implementation.
Signature:
(module: Module, gradOutput: (Tensor | null)[]) => (Tensor | null)[] | voidParameters:
module: The module whose backward is about to rungradOutput: Gradients w.r.t. module outputs (from downstream)
Return Value:
- Return modified gradOutput to change what the backward sees
- Return void to use original gradients unchanged
- Useful for gradient preprocessing or validation
Use Cases:
- Inspect incoming gradients before processing
- Validate gradient properties before backward
- Implement gradient scaling or normalization
- Monitor gradient flow at different depths
- Implement gradient checkpointing strategies
Examples
const hook: BackwardPreHook = (module, gradOutput) => {
// Log gradient norm before processing
console.log('Incoming grad norm:',
gradOutput[0]?.norm().item() ?? 'null'
);
// Don't modify
return;
};
module.register_full_backward_pre_hook(hook);