torch.autograd.PackHook
export type PackHook = (tensor: Tensor) => any;Type for a pack hook function - transforms tensors during forward for storage.
Called during the forward pass when an operation saves a tensor for later backward computation. The pack hook intercepts this save and can transform the tensor into a more memory-efficient representation. Common transformations include moving to CPU, compressing, or extracting just the metadata needed for recomputation.
Contract: Takes a tensor, returns arbitrary data (or Tensor). The returned data is stored and later passed to the corresponding unpack_hook. The hook must be paired with a compatible unpack_hook: unpack_hook(pack_hook(tensor)) must reconstruct the original tensor (within numerical error).
Typical Uses:
- CPU offloading:
(t) => t.cpu()- move to CPU to save GPU memory - Compression:
(t) => compress(t)- quantize or compress tensor - Metadata extraction:
(t) => t.shape- save only metadata for recomputation - Checkpointing:
(t) => null- don't save anything, recompute later
Key Properties:
- Deterministic: same input should produce same output
- Should be efficient: transformation cost shouldn't outweigh memory savings
- May lose information: can discard unnecessary data as long as unpack_hook recovers it
- Paired: must work correctly with the corresponding unpack_hook
Examples
// Simple CPU offloading hook
const packHook: torch.autograd.graph.PackHook = (tensor) => {
return tensor.cpu(); // Move tensor to CPU
};// Quantization hook: compress float32 to uint8
const packHook: torch.autograd.graph.PackHook = (tensor) => {
const min = tensor.min();
const max = tensor.max();
const scale = (max - min) / 255;
const quantized = ((tensor - min) / scale).round().to('uint8');
return { quantized, min, max, scale }; // Store quantization info
};// Identity hook: don't change the tensor
const packHook: torch.autograd.graph.PackHook = (tensor) => {
return tensor; // Save tensor unchanged
};// Profiling hook: log saved tensors
const packHook: torch.autograd.graph.PackHook = (tensor) => {
console.log(`Saving tensor: shape=${tensor.shape}, size=${tensor.storage.byteLength} bytes`);
return tensor;
};