torch.nn.register_module_forward_pre_hook
function register_module_forward_pre_hook(hook: ForwardPreHook): RemovableHandleRegisters a global forward pre-hook that executes before every module's forward pass.
Installs a hook that will be called for ALL modules before their forward() methods execute. This is a global hook (not per-module), affecting the entire neural network. Useful for:
- Input validation: Checking or transforming inputs before forward computation
- Input augmentation: Modifying inputs dynamically (e.g., augmentation, normalization)
- Debugging/profiling: Logging input shapes, ranges, or computing input statistics
- Activation manipulation: Injecting noise, applying masks, or other preprocessing
- Training dynamics monitoring: Tracking how inputs change during training
The hook receives the module and its input tensors, and can optionally return modified inputs that replace the original inputs for that forward pass.
- Global scope: Affects ALL modules in the model, not just one layer
- Called before: Executes BEFORE the module's forward() method
- Can modify inputs: Returning tensor arrays replaces the inputs
- Order matters: Multiple hooks execute in registration order
- Performance impact: Heavy computations in hooks slow down every forward pass
- Affects all modules: No way to selectively disable for specific layers
- Return type matters: Must return Tensor[] or undefined, not other types
Parameters
hookForwardPreHook- ForwardPreHook function called with (module, input) before every forward()
Returns
RemovableHandle– RemovableHandle to unregister this hook using .remove()Examples
// Log all inputs before forward
const hook = (module, inputs) => {
console.log(`${module.constructor.name} input shape:`, inputs[0].shape);
// Return undefined to use original inputs
};
const handle = torch.nn.register_module_forward_pre_hook(hook);// Normalize inputs before forward
const hook = (module, inputs) => {
// Return modified inputs
return [inputs[0].sub(inputs[0].mean()).div(inputs[0].std())];
};
torch.nn.register_module_forward_pre_hook(hook);// Detect gradient checkpointing opportunities
const hook = (module, inputs) => {
const input_size = inputs[0].numel() * 4; // bytes for float32
if (input_size > 1e6) {
console.log(`Large input: ${input_size / 1e6}MB - consider gradient checkpointing`);
}
};
torch.nn.register_module_forward_pre_hook(hook);See Also
- PyTorch torch.nn.modules.module.register_module_forward_pre_hook
- register_module_forward_hook - Post-forward hook (after forward completes)
- RemovableHandle - How to unregister the hook
- ForwardPreHook - Type definition for the hook function