Skip to main content
torch.js has not been released yet.
torch.js logotorch.js logotorch.js
PlaygroundContact
Login
Documentation
IntroductionType SafetyTensor ExpressionsTensor IndexingEinsumEinopsAutogradTraining a ModelProfiling & MemoryPyTorch MigrationBest PracticesRuntimesPerformancePyTorch CompatibilityBenchmarksDType Coverage
allow_mutation_on_saved_tensorsapplybackwardBackwardOptionsdetect_anomalydisable_saved_tensors_hooksemit_nvtxEmitNvtxContextEmitNvtxOptionsenable_gradexecuteBackwardexport_chrome_traceFunctionFunctionCtxgradgrad_modegradcheckGradcheckOptionsGradcheckResultgradgradcheckGradOptionshessianHessianOptionshvpHVPOptionsinference_modeis_anomaly_detection_enabledis_grad_enabledis_inference_mode_enabledjacobianJacobianOptionsjvpJVPOptionskey_averagesmark_dirtymark_non_differentiableno_gradNodeonce_differentiablePackHookprofileProfilerContextProfilerKeyAveragesProfilerOptionsProfilerTableOptionssave_for_backwardsave_on_cpusaved_tensors_hooksset_detect_anomalyset_grad_enabledset_materialize_gradsset_multithreading_enabledstartstoptableUnpackHookvhpVHPOptionsvjpVJPOptions
absacosacoshAdaptivePool1dShapeAdaptivePool2dShapeaddaddbmmAddbmmOptionsaddcdivAddcdivOptionsaddcmulAddcmulOptionsaddmmAddmmOptionsaddmvAddmvOptionsaddrAddrOptionsadjointallallcloseAllcloseOptionsAlphaBetaOptionsamaxaminaminmaxAminmaxOptionsangleanyapplyOutarangeare_deterministic_algorithms_enabledargmaxargminargsortargwhereas_stridedas_tensorasinasinhAssertNoShapeErrorAssertNotErrorAsStridedOptionsAtat_error_index_out_of_boundsatanatan2atanhatleast_1datleast_2datleast_3dAtShapeautocast_decrement_nestingautocast_increment_nestingautograd_gradient_mismatch_errorautograd_not_registered_errorAutogradConfigAutogradDeviceAutogradDTypeAutogradEntryAutogradHandleAutogradHandleImplAxesRecordBackwardFnbaddbmmBaddbmmOptionsbartlett_windowBaseKernelConfigbatch_dimensions_do_not_match_errorbernoulliBernoulliOptionsBinaryBackwardFnBinaryBroadcastResultBinaryDTypeBinaryKernelConfigCPUBinaryKernelCPUBinaryOpConfigBinaryOpNamesBinaryOpSchemaBinaryOptionsbincountBincountOptionsbitwise_andbitwise_left_shiftbitwise_notbitwise_orbitwise_right_shiftbitwise_xorblackman_windowblock_diagbmmBooleanDTypeRulebroadcast_error_incompatible_dimensionsbroadcast_shapesbroadcast_tensorsbroadcast_toBroadcastShapeBroadcastShapeRulebroadcastShapesbucketizeBucketizeOptionsBufferUsagebuildEinopsErrorbuildErrorMessagecanBroadcastTocartesian_prodcatCatOptionsCatShapeCauchyOptionscdistCdistOptionsceilceluCeluFunctionalOptionschain_matmulCheckShapeErrorCholeskyShapechunkchunk_error_dim_out_of_rangeChunkOptionsclampClampOptionsclear_autocast_cacheclearEinopsCacheclearEinsumCacheclonecolumn_stackcombinationsCombinationsOptionscompiled_with_cxx11_abicomplexconjconj_physicalcontiguousConv1dShapeConv2dShapeConv3dShapeConvTranspose2dShapecopysigncorrcoefcoscoshcount_nonzeroCountNonzeroOptionscovcoverage_reportcoverageReportCoverageReportCovOptionsCPUForwardFnCPUKernelConfigCPUKernelEntryCPUOnlyResultCPUTensorDatacreateCumExtremeResultcreateTorchCreationOpSchemaCumExtremeResultcummaxcummincumprodCumShapecumsumcumulative_trapezoidCumulativeOptionsCumulativeOptionsWithDimdeg2raddetachDeterministicOptionsDetShapeDevicedevice_error_requiresDeviceBufferDeviceCapabilitiesDeviceCheckedResultDeviceConfigDeviceContextDeviceEntryDeviceHandleDeviceInputDeviceOptionsDeviceRegistryDeviceTypediagdiag_embedDiagEmbedOptionsdiagflatDiagflatOptionsDiagFlatOptionsdiagonal_scatterDiagonalOptionsDiagonalScatterOptionsDiagOptionsDiagShapediffDiffOptionsdigammadimension_error_out_of_rangeDispatchConfigdistDistOptionsdivdotDotShapeRuleDoubleDoubleDimdropoutDropoutFunctionalOptionsdsplitdstackDTypedtype_already_registered_errordtype_components_mismatch_errordtype_not_found_errorDTypeComponentsDTypeConfigDTypeCoverageReportDTypeDisplayConfigDTypeEntryDTypeHandleDTypeHandleImplDTypeInfoDTypeRegistryDTypeRuleDTypeSerializationConfigDynamicShapeEigShapeeinops_error_ambiguous_decompositioneinops_error_anonymous_in_outputeinops_error_dimension_mismatcheinops_error_invalid_patterneinops_error_reduce_undefined_outputeinops_error_repeat_missing_sizeeinops_error_undefined_axiseinsumeinsum_error_dimension_mismatcheinsum_error_index_out_of_rangeeinsum_error_invalid_equationeinsum_error_invalid_sublist_elementeinsum_error_operand_count_mismatcheinsum_error_subscript_rank_mismatcheinsum_error_unknown_output_indexEinsumOptionsEinsumOutputShapeEllipsiseluelu_EluFunctionalOptionsembedding_bag_error_requires_2d_inputemptyempty_cacheempty_likeeqequalerferfcerfinvexpexp2expandexpand_asexpand_error_incompatibleExpandShapeexpm1ExponentialOptionseyeEyeOptionsfftFFTOptionsfindKernelWithPredicatefindSimilarPatternsflattenFlattenOptionsFlattenShapeflipflip_error_dim_out_of_rangefliplrFlipShapeflipudfloat_powerFloatDTypeRulefloorfloor_dividefmaxfminfmodformatEquationErrorformatShapefracfrexpfrombufferfullfull_likefunction_already_registered_errorFunctionConfigFunctionEntryFunctionHandlegathergather_error_dim_out_of_rangeGatherShapegcdgegeluGeometricOptionsget_autocast_cpu_dtypeget_autocast_gpu_dtypeget_autocast_ipu_dtypeget_autocast_xla_dtypeget_default_deviceget_default_dtypeget_deterministic_debug_modeget_device_configget_device_contextget_device_moduleget_dtype_infoget_file_pathget_float32_matmul_precisionget_num_interop_threadsget_num_threadsget_op_infoget_printoptionsget_real_dtypeget_rng_stategetAutogradgetDTypegetEinopsCacheSizegetEinsumCacheSizegetFunctiongetKernelgetMethodgetOpInfoGetOpKindGetOpSchemagetScalarKernelgluGluFunctionalOptionsGradContextGradFnGradientsForgtHalfHalfDimhamming_windowhann_windowhardshrinkhardsigmoidhardswishhardtanhhardtanh_HardtanhFunctionalOptionshas_autogradhas_devicehas_dtypehas_kernelhasAutogradhasDTypehasFunctionhasKernelhasMethodhasScalarKernelHasShapeErrorheavisidehistcHistcOptionshistogramHistogramOptionsHistogramResulthsplithstackhypoti0IdentityShapeifftimagindex_addindex_copyindex_fillindex_putindex_reduceindex_selectindex_select_error_dim_out_of_rangeIndexPutOptionsIndexSelectShapeIndexSpecIndicesOptionsIndicesSpecinitialize_deviceInputsForInsertDiminvalid_config_errorinverseInverseShapeirfftis_anomaly_check_nan_enabledis_anomaly_enabledis_autocast_cache_enabledis_autocast_cpu_enabledis_autocast_ipu_enabledis_autocast_xla_enabledis_complexis_complex_dtypeis_cpu_only_modeis_deterministic_algorithms_warn_only_enabledis_floating_pointis_floating_point_dtypeis_inference_mode_enabledis_nonzerois_tensoris_warn_always_enabledis_webgpu_availableIs2DIsAtLeast1DIsBinaryOpIsBinaryOpNameiscloseIscloseOptionsisfiniteisinisinfisnanisneginfisposinfisrealIsReductionOpIsReductionOpNameIsRegistryErrorIsShapeErroristftISTFTOptionsIsUnaryOpIsUnaryOpNameitem_error_not_scalarItemResultkaiser_windowKaiserWindowOptionskernel_not_registered_errorkernel_signature_mismatch_errorKernelConfigKernelConfigWebGPUKernelEntryKernelHandleKernelInfoKernelPredicateKernelRegistryKernelWebGPUkronkthvalueKthvalueOptionslcmldexpleleaky_reluleaky_relu_LeakyReluFunctionalOptionslerplevenshteinDistancelgammalinalg_error_not_square_matrixlinalg_error_requires_2dlinalg_error_requires_at_least_2dlinearlinspacelist_custom_deviceslist_custom_dtypeslist_deviceslist_dtypeslist_functionslist_kernelslist_methodslist_opslistCustomDTypeslistDTypeslistFunctionslistKernelsListKernelsOptionslistMethodslistOpsListOpsOptionsloglog_softmaxlog10log1plog2logaddexplogaddexp2logcumsumexplogical_andlogical_notlogical_orlogical_xorLogitOptionsLogNormalOptionsLogOptionslogsigmoidlogspacelogsumexpLogsumexpOptionsltLUShapeLuSolveOptionsmasked_fillmasked_selectmasked_select_asyncMaskSpecmatmulmatmul_error_inner_dimensions_do_not_matchMatmul2DShapeMatmulShapeMatmulShapeRuleMatrixTransposeShapemaxmaximummeanmedianmemory_statsmemory_summarymeshgridmethod_already_registered_errormethod_dtype_not_supported_errorMethodConfigMethodEntryMethodHandleminminimummishmmMMShapeRulemodemovedimmsortmulmultinomialmultinomial_asyncMultinomialAsyncOptionsMultinomialOptionsMultiplyBymvMVShapeRulenan_to_numnanmeannanmediannanquantileNanReductionOptionsnansumNanToNumOptionsnarrownarrow_copynarrow_error_length_exceeds_boundsnarrow_error_start_out_of_boundsNarrowShapeneneedsBroadcastnegNegativeDimnextafternonzeroNonzeroOptionsnormnormalNormalOptionsNormOptionsnumelonesones_likeop_kind_mismatch_errorop_not_found_errorOpCoverageEntryOpInfoOpKindOpNameOpSchemaOpSchemasouterOuterShapepackPackShapepermutepermute_error_dimension_count_mismatchPermuteShapepoissonpolarPool1dShapePool2dShapePool3dShapepositivepowpreluPrintOptionsprodprofiler_allow_cudagraph_cupti_lazy_reinit_cuda12promote_typesPromoteDTypeRulePutOptionsquantileQuantileOptionsrad2degrandrand_likerandintrandint_likeRandintLikeOptionsRandintOptionsrandnrandn_likeRandomLikeOptionsRandomOptionsrandpermRangeSpecRankravelrealrearrangeRearrangeOptionsRearrangeShapereciprocalreduceReduceOperationReduceOptionsReduceShapeReductionKernelConfigCPUReductionKernelCPUReductionOpNamesReductionOpSchemaReductionOptionsReductionShapeRuleregister_backwardregister_deviceregister_dtyperegister_forwardregister_functionregister_methodregister_scalar_forwardregisterAutogradRegisterBackwardOptionsregisterBinaryOpregisterDTypeRegisterDTypeOptionsRegisteredDTyperegisterFunctionRegisterFunctionOptionsregisterKernelRegisterKernelOptionsregisterMethodRegisterMethodOptionsregisterScalarKernelregisterUnaryOpregistration_failed_errorrelurelu_relu6ReluFunctionalOptionsremainderRemoveDimrepeatrepeat_interleaveRepeatInterleaveOptionsRepeatOptionsRepeatShapeReplaceDimrequireWebGPUreset_peak_memory_statsreshapeReshapeShaperesult_typerfftrollRollOptionsrot90Rot90Optionsroundrrelurrelu_RreluFunctionalOptionsrsqrtSafeExpandShapeSameDTypeRuleSameShapeRuleSaveForBackwardScalarCPUForwardFnScalarCPUKernelConfigScalarKernelEntryScalarKernelHandleScalarWebGPUKernelConfigScaleDimscatterscatter_addscatter_add_scatter_error_dim_out_of_rangescatter_reducescatter_reduce_ScatterReduceOptionsScatterShapesearchsortedSearchSortedOptionsselectselect_error_index_out_of_boundsselect_scatterSelectShapeseluset_default_deviceset_default_tensor_typeset_deterministic_debug_modeset_float32_matmul_precisionset_printoptionsset_warn_alwaysSetupContextFnShapeShapeCheckedResultShapedTensorShapeErrorMessageShapeOpSchemaShapeRulesigmoidsignsignbitsilusinsincsinhSizeOptionsslice_error_out_of_boundsslice_scatterSliceOptionsSliceScatterOptionsSliceShapeSliceSpecsoftmaxsoftmax_error_dim_out_of_rangeSoftmaxShapesoftminSoftminFunctionalOptionssoftplusSoftplusFunctionalOptionssoftshrinksoftsignsortSortOptionssplitsplit_error_dim_out_of_rangeSplitOptionssqrtsquaresqueezeSqueezeOptionsSqueezeShapestackStackOptionsStackShapestdstd_meanStdVarMeanOptionsStdVarOptionsstftSTFTOptionsStrideOptionssubSublistSublistElementSubscriptIndexsumSVDShapeswapaxessym_floatsym_intsym_notttaketake_along_dimTakeAlongDimOptionstantanhtanhshrinktensortensor_splitTensorCreatorTensorDatatensordotTensordotOptionsTensorLikeTensorMetaTensorOptionsTensorStoragethresholdthreshold_tileTileShapeToOptionstopkTopkOptionsTorchtraceTraceShapetransposetranspose_dims_error_out_of_rangetranspose_error_requires_2d_tensorTransposeDimsShapeTransposeDimsShapeCheckedTransposeShapetrapezoidTrapezoidOptionsTriangularOptionstriltril_indicesTriOptionsTripletriutriu_indicestrue_dividetruncTupleOfLengthTypedArrayTypedArrayForTypedStorageTypeOptionsUnaryBackwardFnUnaryDTypeUnaryKernelConfigCPUUnaryKernelCPUUnaryOpConfigUnaryOpFnUnaryOpNamesUnaryOpParamsUnaryOpSchemaUnaryOptionsunbindunbind_error_dim_out_of_rangeUnbindOptionsunflattenUniformOptionsuniqueunique_consecutiveUniqueConsecutiveOptionsUniqueOptionsunpackUnpackShapeunravel_indexunregister_deviceunsqueezeUnsqueezeOptionsUnsqueezeShapeuse_deterministic_algorithmsValidateBatchedSquareMatrixValidateChunkDimValidatedEinsumShapevalidateDeviceValidateDeviceValidatedRearrangeShapeValidatedReduceShapeValidatedRepeatShapevalidateDTypeValidateEinsumValidateOperandCountValidateRanksValidateScalarValidateSplitDimValidateSquareMatrixValidateUnbindDimValueOptionsvar_var_meanvdotviewview_as_complexview_as_realvmapvsplitvstackWebGPUKernelConfigWebGPUOnlyResultWebGPUTensorDatawhereWindowOptionsxlogyzeroszeros_like
torch.js· 2026
LegalTerms of UsePrivacy Policy
/
/
  1. docs
  2. torch.js
  3. torch
  4. autograd
  5. saved_tensors_hooks

torch.autograd.saved_tensors_hooks

function saved_tensors_hooks<T>(pack_hook: PackHook, unpack_hook: UnpackHook, fn: () => T): T

Context manager for registering custom hooks on tensors saved for backward.

Allows fine-grained control over how intermediate tensors are stored during forward pass. When operations save tensors for backward computation, these hooks intercept the save/restore. Enables memory optimization strategies like CPU offloading, compression, or checkpointing without modifying operation code. Essential for:

  • Memory optimization: save tensors on CPU instead of GPU (forward efficient, backward recomputes)
  • Compression: compress saved tensors to reduce memory footprint
  • Checkpoint strategies: implement gradient checkpointing to trade compute for memory
  • Custom storage: implement exotic storage backends (disk, network, etc.)
  • Profiling: monitor what tensors are saved and when

Hook Contract:

  • pack_hook: Called during forward, transforms tensor → arbitrary data
  • unpack_hook: Called during backward, transforms data → tensor back

The hooks must be inverses: unpack_hook(pack_hook(x)) ≈ x (within numerical error). If they don't match, backward will fail with incorrect gradients.

Memory Trade-offs: Hooks enable memory/compute trade-offs. For example, saving to CPU instead of GPU:

  • Forward: saves memory (slower tensor copies to CPU)
  • Backward: uses more compute (tensor copies back to GPU, possibly recomputation)

Nesting and Scope: Hooks are registered globally for the duration of the context. Any operation that saves tensors inside this context uses the hooks. Exiting the context restores previous hooks. Hooks can be nested: inner hooks temporarily override outer hooks.

  • Hook pairing: pack and unpack must be inverses (data is round-tripped)
  • Backward uses hooks: pack_hook called during forward, unpack_hook during backward
  • Global scope: Hooks apply to all operations saving tensors inside the function
  • Nesting: Inner hooks override outer hooks for their scope
  • Exception safety: Hooks are removed even if function throws exception
  • No state sharing: Each saved tensor is packed/unpacked independently
  • Performance: Hook overhead is per-saved-tensor; optimization must outweigh cost
  • Hook must invert: unpack_hook(pack_hook(x)) must equal x (approximately)
  • Hook errors fail silently: Errors in hooks cause backward failures, not forward
  • Memory vs compute: Offloading trades GPU memory for compute/communication
  • Disabled in evaluate(): Can be disabled with disable_saved_tensors_hooks()
  • Experimental: Hook behavior may differ from PyTorch edge cases

Parameters

pack_hookPackHook
Function called when tensor is saved during forward. Signature: (tensor: Tensor) = any Takes a tensor, returns arbitrary data (string, number, compressed data, etc.) Should transform the tensor for efficient storage.
unpack_hookUnpackHook
Function called when saved tensor is needed during backward. Signature: (data: any) = Tensor Takes the data from pack_hook, returns a tensor. Must reconstruct the original tensor from the packed data.
fn() => T
Function to execute with hooks active. All tensor saves during this function use the hooks. Can be sync or async.

Returns

T– The result of the function, unmodified

Examples

// CPU offloading: save to CPU to reduce GPU memory
const result = torch.autograd.graph.saved_tensors_hooks(
  (tensor) => {
    // pack: move to CPU when saving
    console.log(`Saving tensor of size ${tensor.numel()} to CPU`);
    return tensor.cpu();
  },
  (tensor) => {
    // unpack: move back to GPU for backward
    console.log(`Restoring tensor to GPU for backward`);
    return tensor.cuda();
  },
  () => {
    // Forward pass with hooks active
    const x = torch.randn(1000, 1000);
    const y = model.forward(x);  // Large tensors saved to CPU
    y.backward();  // Tensors moved back to GPU during backward
    return y;
  }
);
// Compression: reduce memory by quantizing saved tensors
torch.autograd.graph.saved_tensors_hooks(
  (tensor) => {
    // Pack: quantize to int8 to save memory
    const scale = tensor.abs().max();
    const quantized = (tensor.div(scale.mul(1.0 / 127))).round().to('int8');
    return { data: quantized, scale: scale };
  },
  (packed) => {
    // Unpack: dequantize back to float
    const dequant = packed.data.float().mul(packed.scale.mul(1.0 / 127));
    return dequant;
  },
  () => {
    // Forward pass uses compressed tensors
    const y = model.forward(x);
    y.backward();
  }
);
// Checkpointing: discard activations, recompute during backward
function checkpoint_segment(segment_fn, input) {
  // Forward: compute but don't save
  const output = torch.no_grad(() => {
    return segment_fn(input);
  });

  // During backward, recompute to get gradients
  return torch.autograd.graph.saved_tensors_hooks(
    (tensor) => {
      // Pack: save minimal info needed for recomputation
      return null;  // Don't save the tensor
    },
    (data) => {
      // Unpack: recompute on demand during backward
      return segment_fn(input);
    },
    () => {
      // The actual operation (would normally save all activations)
      return segment_fn(input);
    }
  );
}
// Profiling: log what tensors are saved
let total_saved_elements = 0;
torch.autograd.graph.saved_tensors_hooks(
  (tensor) => {
    total_saved_elements += tensor.numel();
    console.log(`Saved tensor: shape=${tensor.shape}, dtype=${tensor.dtype}`);
    return tensor;
  },
  (tensor) => tensor,  // Identity unpack
  () => {
    const y = model.forward(x);
    y.backward();
    console.log(`Total elements saved: ${total_saved_elements}`);
  }
);
// Nesting hooks: inner overrides outer
const outer_result = torch.autograd.graph.saved_tensors_hooks(
  (t) => {
    console.log('Outer pack');
    return t.cpu();
  },
  (t) => {
    console.log('Outer unpack');
    return t.cuda();
  },
  () => {
    const y1 = model.layer1(x);  // Uses outer hooks (CPU)

    // Inner hooks override for this scope
    return torch.autograd.graph.saved_tensors_hooks(
      (t) => {
        console.log('Inner pack');
        return t;  // Keep on GPU
      },
      (t) => {
        console.log('Inner unpack');
        return t;
      },
      () => {
        const y2 = model.layer2(y1);  // Uses inner hooks (GPU)
        return y2;
      }
    );
  }
);

See Also

  • PyTorch torch.autograd.graph.saved_tensors_hooks()
  • torch.autograd.graph.disable_saved_tensors_hooks - Temporarily disable hooks
  • torch.autograd.graph.save_on_cpu - Convenience hook for CPU offloading
  • torch.autograd.graph.PackHook - Type definition for pack_hook
  • torch.autograd.graph.UnpackHook - Type definition for unpack_hook
Previous
save_on_cpu
Next
set_detect_anomaly