torch.nn.BufferRegistrationHook
export type BufferRegistrationHook = (
module: Module,
name: string,
buffer: Tensor | null
) => Tensor | null | void;Hook called when a buffer is registered in a module.
BufferRegistrationHook is triggered when register_buffer() is called on a module. It can inspect or modify the buffer being registered, or prevent registration by returning null.
Signature:
(module: Module, name: string, buffer: Tensor | null) => Tensor | null | voidParameters:
module: The module registering the buffername: Name/key for the buffer (e.g., 'running_mean')buffer: The tensor being registered (or null to unregister)
Return Value:
- Return a modified tensor to change what gets registered
- Return null to prevent registration
- Return void to register unchanged
Use Cases:
- Monitor buffers being added to modules
- Implement buffer validation or constraints
- Modify buffers before registration (dtype, shape)
- Track buffer lifecycle
- Prevent certain buffers from being registered
Examples
const hook: BufferRegistrationHook = (module, name, buffer) => {
if (buffer && buffer.dtype !== 'float32') {
console.warn(`Converting ${name} to float32`);
return buffer.to({ dtype: 'float32' });
}
return;
};
torch.nn.register_buffer_registration_hook(hook);