torch.autograd.UnpackHook
export type UnpackHook = (data: any) => Tensor;Type for an unpack hook function - reconstructs tensors during backward from packed data.
Called during the backward pass when a previously packed tensor is needed to compute gradients. The unpack hook takes the data produced by pack_hook and reconstructs the original tensor. This enables memory optimizations: pack_hook compresses, unpack_hook decompresses.
Contract: Takes the data returned by pack_hook, returns a reconstructed tensor. Must be the inverse of pack_hook: unpack_hook(pack_hook(tensor)) must equal the original tensor (within numerical error). Critical that pack/unpack pairs are compatible.
Typical Uses:
- CPU offloading:
(t) => t.cuda()- move from CPU back to GPU - Decompression:
(d) => decompress(d)- decompress quantized tensor - Recomputation:
(metadata) => recompute_fn(metadata)- recompute original tensor - Transform inverse:
(d) => transform_inverse(d)- invert pack_hook transformation
Key Properties:
- Inverse operation: should undo what pack_hook did
- Deterministic: same input should produce same output
- Must reconstruct accurately: gradients depend on this
- Paired: works correctly with the corresponding pack_hook
Error Handling: Errors in unpack_hook manifest as incorrect gradients, not exceptions. Careful testing with gradcheck() is essential to ensure pack/unpack pairs work correctly.
Examples
// Simple CPU offloading hook (inverse of pack_hook)
const unpackHook: torch.autograd.graph.UnpackHook = (tensor) => {
return tensor.cuda(); // Move tensor back to GPU
};// Dequantization hook: decompress uint8 back to float32
const unpackHook: torch.autograd.graph.UnpackHook = (data) => {
const { quantized, min, max, scale } = data;
return quantized.float().mul(scale).add(min); // Restore original values
};// Identity hook: tensor unchanged
const unpackHook: torch.autograd.graph.UnpackHook = (tensor) => {
return tensor; // Tensor unchanged from pack
};// Recomputation hook: recompute tensor instead of storing
const unpackHook: torch.autograd.graph.UnpackHook = (metadata) => {
// Metadata contains instructions for recomputation
return recompute_layer(metadata);
};