torch.autograd.saved_tensors_hooks
function saved_tensors_hooks<T>(pack_hook: PackHook, unpack_hook: UnpackHook, fn: () => T): TContext manager for registering custom hooks on tensors saved for backward.
Allows fine-grained control over how intermediate tensors are stored during forward pass. When operations save tensors for backward computation, these hooks intercept the save/restore. Enables memory optimization strategies like CPU offloading, compression, or checkpointing without modifying operation code. Essential for:
- Memory optimization: save tensors on CPU instead of GPU (forward efficient, backward recomputes)
- Compression: compress saved tensors to reduce memory footprint
- Checkpoint strategies: implement gradient checkpointing to trade compute for memory
- Custom storage: implement exotic storage backends (disk, network, etc.)
- Profiling: monitor what tensors are saved and when
Hook Contract:
- pack_hook: Called during forward, transforms tensor → arbitrary data
- unpack_hook: Called during backward, transforms data → tensor back
The hooks must be inverses: unpack_hook(pack_hook(x)) ≈ x (within numerical error). If they don't match, backward will fail with incorrect gradients.
Memory Trade-offs: Hooks enable memory/compute trade-offs. For example, saving to CPU instead of GPU:
- Forward: saves memory (slower tensor copies to CPU)
- Backward: uses more compute (tensor copies back to GPU, possibly recomputation)
Nesting and Scope: Hooks are registered globally for the duration of the context. Any operation that saves tensors inside this context uses the hooks. Exiting the context restores previous hooks. Hooks can be nested: inner hooks temporarily override outer hooks.
- Hook pairing: pack and unpack must be inverses (data is round-tripped)
- Backward uses hooks: pack_hook called during forward, unpack_hook during backward
- Global scope: Hooks apply to all operations saving tensors inside the function
- Nesting: Inner hooks override outer hooks for their scope
- Exception safety: Hooks are removed even if function throws exception
- No state sharing: Each saved tensor is packed/unpacked independently
- Performance: Hook overhead is per-saved-tensor; optimization must outweigh cost
- Hook must invert: unpack_hook(pack_hook(x)) must equal x (approximately)
- Hook errors fail silently: Errors in hooks cause backward failures, not forward
- Memory vs compute: Offloading trades GPU memory for compute/communication
- Disabled in evaluate(): Can be disabled with disable_saved_tensors_hooks()
- Experimental: Hook behavior may differ from PyTorch edge cases
Parameters
pack_hookPackHook- Function called when tensor is saved during forward. Signature: (tensor: Tensor) = any Takes a tensor, returns arbitrary data (string, number, compressed data, etc.) Should transform the tensor for efficient storage.
unpack_hookUnpackHook- Function called when saved tensor is needed during backward. Signature: (data: any) = Tensor Takes the data from pack_hook, returns a tensor. Must reconstruct the original tensor from the packed data.
fn() => T- Function to execute with hooks active. All tensor saves during this function use the hooks. Can be sync or async.
Returns
T– The result of the function, unmodifiedExamples
// CPU offloading: save to CPU to reduce GPU memory
const result = torch.autograd.graph.saved_tensors_hooks(
(tensor) => {
// pack: move to CPU when saving
console.log(`Saving tensor of size ${tensor.numel()} to CPU`);
return tensor.cpu();
},
(tensor) => {
// unpack: move back to GPU for backward
console.log(`Restoring tensor to GPU for backward`);
return tensor.cuda();
},
() => {
// Forward pass with hooks active
const x = torch.randn(1000, 1000);
const y = model.forward(x); // Large tensors saved to CPU
y.backward(); // Tensors moved back to GPU during backward
return y;
}
);// Compression: reduce memory by quantizing saved tensors
torch.autograd.graph.saved_tensors_hooks(
(tensor) => {
// Pack: quantize to int8 to save memory
const scale = tensor.abs().max();
const quantized = (tensor.div(scale.mul(1.0 / 127))).round().to('int8');
return { data: quantized, scale: scale };
},
(packed) => {
// Unpack: dequantize back to float
const dequant = packed.data.float().mul(packed.scale.mul(1.0 / 127));
return dequant;
},
() => {
// Forward pass uses compressed tensors
const y = model.forward(x);
y.backward();
}
);// Checkpointing: discard activations, recompute during backward
function checkpoint_segment(segment_fn, input) {
// Forward: compute but don't save
const output = torch.no_grad(() => {
return segment_fn(input);
});
// During backward, recompute to get gradients
return torch.autograd.graph.saved_tensors_hooks(
(tensor) => {
// Pack: save minimal info needed for recomputation
return null; // Don't save the tensor
},
(data) => {
// Unpack: recompute on demand during backward
return segment_fn(input);
},
() => {
// The actual operation (would normally save all activations)
return segment_fn(input);
}
);
}// Profiling: log what tensors are saved
let total_saved_elements = 0;
torch.autograd.graph.saved_tensors_hooks(
(tensor) => {
total_saved_elements += tensor.numel();
console.log(`Saved tensor: shape=${tensor.shape}, dtype=${tensor.dtype}`);
return tensor;
},
(tensor) => tensor, // Identity unpack
() => {
const y = model.forward(x);
y.backward();
console.log(`Total elements saved: ${total_saved_elements}`);
}
);// Nesting hooks: inner overrides outer
const outer_result = torch.autograd.graph.saved_tensors_hooks(
(t) => {
console.log('Outer pack');
return t.cpu();
},
(t) => {
console.log('Outer unpack');
return t.cuda();
},
() => {
const y1 = model.layer1(x); // Uses outer hooks (CPU)
// Inner hooks override for this scope
return torch.autograd.graph.saved_tensors_hooks(
(t) => {
console.log('Inner pack');
return t; // Keep on GPU
},
(t) => {
console.log('Inner unpack');
return t;
},
() => {
const y2 = model.layer2(y1); // Uses inner hooks (GPU)
return y2;
}
);
}
);See Also
- PyTorch torch.autograd.graph.saved_tensors_hooks()
- torch.autograd.graph.disable_saved_tensors_hooks - Temporarily disable hooks
- torch.autograd.graph.save_on_cpu - Convenience hook for CPU offloading
- torch.autograd.graph.PackHook - Type definition for pack_hook
- torch.autograd.graph.UnpackHook - Type definition for unpack_hook