torch.nn.register_module_buffer_registration_hook
function register_module_buffer_registration_hook(hook: BufferRegistrationHook): RemovableHandleRegisters a global hook that executes whenever a buffer is registered in any module.
Installs a hook that will be called for ALL modules when register_buffer() is invoked. This hook can inspect, validate, or modify buffers before they're stored in the module. Can also prevent registration by returning null. Useful for:
- Buffer validation: Ensuring buffers have correct dtype, device, or shape
- Buffer transformation: Auto-converting buffers to specific types (e.g., float32)
- Buffer tracking: Monitoring what buffers are added to models
- Constraint enforcement: Ensuring all buffers meet certain requirements
- Device management: Automatically placing buffers on specific devices
The hook receives the module, buffer name (e.g., 'running_mean'), and the buffer tensor. Returning null prevents the buffer from being registered. Returning void registers unchanged. Returning a tensor registers the modified tensor instead.
- Module creation only: Called during model construction (register_buffer)
- Can modify/block: Can return modified tensor or null to prevent registration
- Receives null for unregistration: Also called with buffer=null to unregister buffers
- Global scope: Affects all modules in the entire model
- Blocking buffers: Returning null might break model functionality
- Type safety: Ensure modified buffers are compatible with model expectations
- Early execution: Runs during model construction, before training
Parameters
hookBufferRegistrationHook- BufferRegistrationHook called with (module, name, buffer) when register_buffer() is called
Returns
RemovableHandle– RemovableHandle to unregister this hook using .remove()Examples
// Ensure all buffers are float32
const hook = (module, name, buffer) => {
if (buffer && buffer.dtype !== 'float32') {
console.log(`Converting ${name} from ${buffer.dtype} to float32`);
return buffer.to({ dtype: 'float32' });
}
};
torch.nn.register_module_buffer_registration_hook(hook);// Track buffer registration
const buffers = [];
const hook = (module, name, buffer) => {
if (buffer) {
buffers.push({
module: module.constructor.name,
name,
shape: buffer.shape,
dtype: buffer.dtype
});
}
};
torch.nn.register_module_buffer_registration_hook(hook);// Prevent buffers with certain names
const hook = (module, name, buffer) => {
if (name === 'temp_buffer') {
console.log('Blocking temporary buffer');
return null; // Prevent registration
}
};
torch.nn.register_module_buffer_registration_hook(hook);See Also
- PyTorch torch.nn.modules.module.register_module_buffer_registration_hook
- register_module_parameter_registration_hook - Hook for parameters
- register_module_module_registration_hook - Hook for submodules
- Module.register_buffer - Method that triggers this hook