torch.nn.register_module_forward_hook
function register_module_forward_hook(hook: ForwardHook): RemovableHandleRegisters a global forward post-hook that executes after every module's forward pass.
Installs a hook that will be called for ALL modules after their forward() methods complete. Receives the module, its inputs, and the computed outputs. Can inspect or modify outputs before they propagate to downstream layers. Useful for:
- Output validation: Checking output ranges, detecting NaN/Inf, validating shapes
- Activation monitoring: Recording statistics (mean, std, min, max) of layer outputs
- Output transformation: Clipping, normalizing, or applying constraints to outputs
- Intermediate feature extraction: Collecting activations for analysis
- Debugging: Visualizing what each layer produces during forward pass
The hook receives inputs (possibly modified by pre-hooks) and outputs, and can optionally return modified outputs that replace the original outputs.
- Global scope: Affects ALL modules in the model
- Called after: Executes AFTER the module's forward() method completes
- Access to inputs: Can see both original and possibly modified inputs
- Can modify outputs: Returning tensors replaces the original outputs
- Performance impact: Runs for every module in every forward pass
- Memory usage: Collecting activations increases memory footprint
- Output type handling: Must handle both single tensors and tensor arrays
Parameters
hookForwardHook- ForwardHook function called with (module, input, output) after every forward()
Returns
RemovableHandle– RemovableHandle to unregister this hook using .remove()Examples
// Monitor activation statistics
const hook = (module, input, output) => {
const out = Array.isArray(output) ? output[0] : output;
console.log(`${module.constructor.name} output mean:`, out.mean().item());
};
torch.nn.register_module_forward_hook(hook);// Clip outputs to prevent divergence
const hook = (module, input, output) => {
if (Array.isArray(output)) {
return output.map(t => torch.clamp(t, -1, 1));
}
return torch.clamp(output, -1, 1);
};
torch.nn.register_module_forward_hook(hook);// Collect intermediate representations
const activations = [];
const hook = (module, input, output) => {
activations.push(Array.isArray(output) ? output[0] : output);
};
torch.nn.register_module_forward_hook(hook);
model.forward(x);
// activations now contains outputs from every layerSee Also
- PyTorch torch.nn.modules.module.register_module_forward_hook
- register_module_forward_pre_hook - Pre-forward hook (before forward)
- register_module_backward_hook - Backward pass hook
- RemovableHandle - How to unregister the hook