torch.Tensor.Tensor.register_hook
Registers a backward hook on the tensor.
The hook will be called every time a gradient with respect to the Tensor is computed. The hook should have the signature: hook(grad: Tensor) -> Tensor | null | void
The hook can modify the gradient or return a new gradient to use. If the hook returns null or undefined, the original gradient is used.
Parameters
Returns
{ remove: () => void }– A handle that can be used to remove the hook via handle.remove()Examples
const x = torch.randn(3, 3, { requires_grad: true });
const handle = x.register_hook((grad) => {
console.log('Gradient:', grad);
return grad.mul(2); // Double the gradient
});
const y = x.sum();
y.backward();
handle.remove(); // Unregister the hook