torch.nn.register_module_parameter_registration_hook
function register_module_parameter_registration_hook(hook: ParameterRegistrationHook): RemovableHandleRegisters a global hook that executes whenever a parameter is registered in any module.
Installs a hook that will be called for ALL modules when register_parameter() is invoked. This hook can inspect, validate, or modify trainable parameters before they're stored. Can also prevent registration by returning null. Useful for:
- Parameter validation: Ensuring parameters have correct shape, dtype, or initialization
- Parameter transformation: Auto-converting parameters (e.g., to float32, moving to device)
- Parameter tracking: Monitoring what trainable parameters are added to models
- Initialization control: Custom initialization or constraint application
- Trainability control: Making certain parameters non-trainable conditionally
The hook receives the module, parameter name (e.g., 'weight', 'bias'), and the Parameter object. Returning null prevents the parameter from being registered. Returning void registers unchanged. Returning a modified Parameter registers the modified version instead.
- Model construction time: Called during register_parameter() calls
- Can modify/block: Can return modified parameter or null to prevent registration
- Trainability control: Can set requires_grad before registration
- Parameter objects: Receives Parameter instances, not just Tensors
- Blocking parameters: Returning null removes the parameter from training
- Type safety: Modified parameters must be compatible with layer expectations
- Early execution: Runs during model construction, before training starts
Parameters
hookParameterRegistrationHook- ParameterRegistrationHook called with (module, name, param) when register_parameter() is called
Returns
RemovableHandle– RemovableHandle to unregister this hook using .remove()Examples
// Ensure all parameters are float32
const hook = (module, name, param) => {
if (param && param.dtype !== 'float32') {
console.log(`Converting parameter '${name}' to float32`);
return param.to({ dtype: 'float32' });
}
};
torch.nn.register_module_parameter_registration_hook(hook);// Log all parameter registration
const hook = (module, name, param) => {
if (param) {
console.log(`Parameter '${name}' in ${module.constructor.name}: shape=${param.shape}`);
}
};
torch.nn.register_module_parameter_registration_hook(hook);// Make certain parameters non-trainable
const hook = (module, name, param) => {
if (param && name.startsWith('_')) {
param.requires_grad = false;
}
};
torch.nn.register_module_parameter_registration_hook(hook);// Custom parameter initialization
const hook = (module, name, param) => {
if (param && name === 'weight' && module instanceof torch.nn.Linear) {
// Custom weight initialization
param.data.uniform_(-0.05, 0.05);
}
};
torch.nn.register_module_parameter_registration_hook(hook);See Also
- PyTorch torch.nn.modules.module.register_module_parameter_registration_hook
- register_module_buffer_registration_hook - Hook for buffers
- register_module_module_registration_hook - Hook for submodules
- Module.register_parameter - Method that triggers this hook
- Parameter - Parameter type used by this hook