Skip to main content
torch.js has not been released yet.
torch.js logotorch.js logotorch.js
PlaygroundContact
Login
Documentation
IntroductionType SafetyTensor ExpressionsTensor IndexingEinsumEinopsAutogradTraining a ModelProfiling & MemoryPyTorch MigrationBest PracticesRuntimesPerformancePyTorch CompatibilityBenchmarksDType Coverage
AdadeltaAdadeltaOptionsAdafactorAdafactorOptionsAdagradAdagradOptionsAdamAdamaxAdamaxOptionsAdamOptionsAdamWAdamWOptionsadd_param_groupASGDASGDOptionsget_averaged_paramLBFGSLBFGSOptionsLBFGSStepOptionsload_state_dictLoadStateDictPostHookLoadStateDictPreHookMuonMuonOptionsNAdamNAdamOptionsOptimizerOptimizerStateOptimizerStateDictParamGroupRAdamRAdamOptionsregister_load_state_dict_post_hookregister_load_state_dict_pre_hookregister_state_dict_post_hookregister_state_dict_pre_hookregister_step_post_hookregister_step_pre_hookRemovableHandleRMSpropRMSpropOptionsRpropRpropOptionsSGDSGDOptionsSparseAdamSparseAdamOptionsstate_dictstate_dict_asyncStateDictPostHookStateDictPreHookstepstepStepPostHookStepPreHookstepWithClosurezero_grad
absacosacoshAdaptivePool1dShapeAdaptivePool2dShapeaddaddbmmAddbmmOptionsaddcdivAddcdivOptionsaddcmulAddcmulOptionsaddmmAddmmOptionsaddmvAddmvOptionsaddrAddrOptionsadjointallallcloseAllcloseOptionsAlphaBetaOptionsamaxaminaminmaxAminmaxOptionsangleanyapplyOutarangeare_deterministic_algorithms_enabledargmaxargminargsortargwhereas_stridedas_tensorasinasinhAssertNoShapeErrorAssertNotErrorAsStridedOptionsAtat_error_index_out_of_boundsatanatan2atanhatleast_1datleast_2datleast_3dAtShapeautocast_decrement_nestingautocast_increment_nestingautograd_gradient_mismatch_errorautograd_not_registered_errorAutogradConfigAutogradDeviceAutogradDTypeAutogradEntryAutogradHandleAutogradHandleImplAxesRecordBackwardFnbaddbmmBaddbmmOptionsbartlett_windowBaseKernelConfigbatch_dimensions_do_not_match_errorbernoulliBernoulliOptionsBinaryBackwardFnBinaryBroadcastResultBinaryDTypeBinaryKernelConfigCPUBinaryKernelCPUBinaryOpConfigBinaryOpNamesBinaryOpSchemaBinaryOptionsbincountBincountOptionsbitwise_andbitwise_left_shiftbitwise_notbitwise_orbitwise_right_shiftbitwise_xorblackman_windowblock_diagbmmBooleanDTypeRulebroadcast_error_incompatible_dimensionsbroadcast_shapesbroadcast_tensorsbroadcast_toBroadcastShapeBroadcastShapeRulebroadcastShapesbucketizeBucketizeOptionsBufferUsagebuildEinopsErrorbuildErrorMessagecanBroadcastTocartesian_prodcatCatOptionsCatShapeCauchyOptionscdistCdistOptionsceilceluCeluFunctionalOptionschain_matmulCheckShapeErrorCholeskyShapechunkchunk_error_dim_out_of_rangeChunkOptionsclampClampOptionsclear_autocast_cacheclearEinopsCacheclearEinsumCacheclonecolumn_stackcombinationsCombinationsOptionscompiled_with_cxx11_abicomplexconjconj_physicalcontiguousConv1dShapeConv2dShapeConv3dShapeConvTranspose2dShapecopysigncorrcoefcoscoshcount_nonzeroCountNonzeroOptionscovcoverage_reportcoverageReportCoverageReportCovOptionsCPUForwardFnCPUKernelConfigCPUKernelEntryCPUOnlyResultCPUTensorDatacreateCumExtremeResultcreateTorchCreationOpSchemaCumExtremeResultcummaxcummincumprodCumShapecumsumcumulative_trapezoidCumulativeOptionsCumulativeOptionsWithDimdeg2raddetachDeterministicOptionsDetShapeDevicedevice_error_requiresDeviceBufferDeviceCapabilitiesDeviceCheckedResultDeviceConfigDeviceContextDeviceEntryDeviceHandleDeviceInputDeviceOptionsDeviceRegistryDeviceTypediagdiag_embedDiagEmbedOptionsdiagflatDiagflatOptionsDiagFlatOptionsdiagonal_scatterDiagonalOptionsDiagonalScatterOptionsDiagOptionsDiagShapediffDiffOptionsdigammadimension_error_out_of_rangeDispatchConfigdistDistOptionsdivdotDotShapeRuleDoubleDoubleDimdropoutDropoutFunctionalOptionsdsplitdstackDTypedtype_already_registered_errordtype_components_mismatch_errordtype_not_found_errorDTypeComponentsDTypeConfigDTypeCoverageReportDTypeDisplayConfigDTypeEntryDTypeHandleDTypeHandleImplDTypeInfoDTypeRegistryDTypeRuleDTypeSerializationConfigDynamicShapeEigShapeeinops_error_ambiguous_decompositioneinops_error_anonymous_in_outputeinops_error_dimension_mismatcheinops_error_invalid_patterneinops_error_reduce_undefined_outputeinops_error_repeat_missing_sizeeinops_error_undefined_axiseinsumeinsum_error_dimension_mismatcheinsum_error_index_out_of_rangeeinsum_error_invalid_equationeinsum_error_invalid_sublist_elementeinsum_error_operand_count_mismatcheinsum_error_subscript_rank_mismatcheinsum_error_unknown_output_indexEinsumOptionsEinsumOutputShapeEllipsiseluelu_EluFunctionalOptionsembedding_bag_error_requires_2d_inputemptyempty_cacheempty_likeeqequalerferfcerfinvexpexp2expandexpand_asexpand_error_incompatibleExpandShapeexpm1ExponentialOptionseyeEyeOptionsfftFFTOptionsfindKernelWithPredicatefindSimilarPatternsflattenFlattenOptionsFlattenShapeflipflip_error_dim_out_of_rangefliplrFlipShapeflipudfloat_powerFloatDTypeRulefloorfloor_dividefmaxfminfmodformatEquationErrorformatShapefracfrexpfrombufferfullfull_likefunction_already_registered_errorFunctionConfigFunctionEntryFunctionHandlegathergather_error_dim_out_of_rangeGatherShapegcdgegeluGeometricOptionsget_autocast_cpu_dtypeget_autocast_gpu_dtypeget_autocast_ipu_dtypeget_autocast_xla_dtypeget_default_deviceget_default_dtypeget_deterministic_debug_modeget_device_configget_device_contextget_device_moduleget_dtype_infoget_file_pathget_float32_matmul_precisionget_num_interop_threadsget_num_threadsget_op_infoget_printoptionsget_real_dtypeget_rng_stategetAutogradgetDTypegetEinopsCacheSizegetEinsumCacheSizegetFunctiongetKernelgetMethodgetOpInfoGetOpKindGetOpSchemagetScalarKernelgluGluFunctionalOptionsGradContextGradFnGradientsForgtHalfHalfDimhamming_windowhann_windowhardshrinkhardsigmoidhardswishhardtanhhardtanh_HardtanhFunctionalOptionshas_autogradhas_devicehas_dtypehas_kernelhasAutogradhasDTypehasFunctionhasKernelhasMethodhasScalarKernelHasShapeErrorheavisidehistcHistcOptionshistogramHistogramOptionsHistogramResulthsplithstackhypoti0IdentityShapeifftimagindex_addindex_copyindex_fillindex_putindex_reduceindex_selectindex_select_error_dim_out_of_rangeIndexPutOptionsIndexSelectShapeIndexSpecIndicesOptionsIndicesSpecinitialize_deviceInputsForInsertDiminvalid_config_errorinverseInverseShapeirfftis_anomaly_check_nan_enabledis_anomaly_enabledis_autocast_cache_enabledis_autocast_cpu_enabledis_autocast_ipu_enabledis_autocast_xla_enabledis_complexis_complex_dtypeis_cpu_only_modeis_deterministic_algorithms_warn_only_enabledis_floating_pointis_floating_point_dtypeis_inference_mode_enabledis_nonzerois_tensoris_warn_always_enabledis_webgpu_availableIs2DIsAtLeast1DIsBinaryOpIsBinaryOpNameiscloseIscloseOptionsisfiniteisinisinfisnanisneginfisposinfisrealIsReductionOpIsReductionOpNameIsRegistryErrorIsShapeErroristftISTFTOptionsIsUnaryOpIsUnaryOpNameitem_error_not_scalarItemResultkaiser_windowKaiserWindowOptionskernel_not_registered_errorkernel_signature_mismatch_errorKernelConfigKernelConfigWebGPUKernelEntryKernelHandleKernelInfoKernelPredicateKernelRegistryKernelWebGPUkronkthvalueKthvalueOptionslcmldexpleleaky_reluleaky_relu_LeakyReluFunctionalOptionslerplevenshteinDistancelgammalinalg_error_not_square_matrixlinalg_error_requires_2dlinalg_error_requires_at_least_2dlinearlinspacelist_custom_deviceslist_custom_dtypeslist_deviceslist_dtypeslist_functionslist_kernelslist_methodslist_opslistCustomDTypeslistDTypeslistFunctionslistKernelsListKernelsOptionslistMethodslistOpsListOpsOptionsloglog_softmaxlog10log1plog2logaddexplogaddexp2logcumsumexplogical_andlogical_notlogical_orlogical_xorLogitOptionsLogNormalOptionsLogOptionslogsigmoidlogspacelogsumexpLogsumexpOptionsltLUShapeLuSolveOptionsmasked_fillmasked_selectmasked_select_asyncMaskSpecmatmulmatmul_error_inner_dimensions_do_not_matchMatmul2DShapeMatmulShapeMatmulShapeRuleMatrixTransposeShapemaxmaximummeanmedianmemory_statsmemory_summarymeshgridmethod_already_registered_errormethod_dtype_not_supported_errorMethodConfigMethodEntryMethodHandleminminimummishmmMMShapeRulemodemovedimmsortmulmultinomialmultinomial_asyncMultinomialAsyncOptionsMultinomialOptionsMultiplyBymvMVShapeRulenan_to_numnanmeannanmediannanquantileNanReductionOptionsnansumNanToNumOptionsnarrownarrow_copynarrow_error_length_exceeds_boundsnarrow_error_start_out_of_boundsNarrowShapeneneedsBroadcastnegNegativeDimnextafternonzeroNonzeroOptionsnormnormalNormalOptionsNormOptionsnumelonesones_likeop_kind_mismatch_errorop_not_found_errorOpCoverageEntryOpInfoOpKindOpNameOpSchemaOpSchemasouterOuterShapepackPackShapepermutepermute_error_dimension_count_mismatchPermuteShapepoissonpolarPool1dShapePool2dShapePool3dShapepositivepowpreluPrintOptionsprodprofiler_allow_cudagraph_cupti_lazy_reinit_cuda12promote_typesPromoteDTypeRulePutOptionsquantileQuantileOptionsrad2degrandrand_likerandintrandint_likeRandintLikeOptionsRandintOptionsrandnrandn_likeRandomLikeOptionsRandomOptionsrandpermRangeSpecRankravelrealrearrangeRearrangeOptionsRearrangeShapereciprocalreduceReduceOperationReduceOptionsReduceShapeReductionKernelConfigCPUReductionKernelCPUReductionOpNamesReductionOpSchemaReductionOptionsReductionShapeRuleregister_backwardregister_deviceregister_dtyperegister_forwardregister_functionregister_methodregister_scalar_forwardregisterAutogradRegisterBackwardOptionsregisterBinaryOpregisterDTypeRegisterDTypeOptionsRegisteredDTyperegisterFunctionRegisterFunctionOptionsregisterKernelRegisterKernelOptionsregisterMethodRegisterMethodOptionsregisterScalarKernelregisterUnaryOpregistration_failed_errorrelurelu_relu6ReluFunctionalOptionsremainderRemoveDimrepeatrepeat_interleaveRepeatInterleaveOptionsRepeatOptionsRepeatShapeReplaceDimrequireWebGPUreset_peak_memory_statsreshapeReshapeShaperesult_typerfftrollRollOptionsrot90Rot90Optionsroundrrelurrelu_RreluFunctionalOptionsrsqrtSafeExpandShapeSameDTypeRuleSameShapeRuleSaveForBackwardScalarCPUForwardFnScalarCPUKernelConfigScalarKernelEntryScalarKernelHandleScalarWebGPUKernelConfigScaleDimscatterscatter_addscatter_add_scatter_error_dim_out_of_rangescatter_reducescatter_reduce_ScatterReduceOptionsScatterShapesearchsortedSearchSortedOptionsselectselect_error_index_out_of_boundsselect_scatterSelectShapeseluset_default_deviceset_default_tensor_typeset_deterministic_debug_modeset_float32_matmul_precisionset_printoptionsset_warn_alwaysSetupContextFnShapeShapeCheckedResultShapedTensorShapeErrorMessageShapeOpSchemaShapeRulesigmoidsignsignbitsilusinsincsinhSizeOptionsslice_error_out_of_boundsslice_scatterSliceOptionsSliceScatterOptionsSliceShapeSliceSpecsoftmaxsoftmax_error_dim_out_of_rangeSoftmaxShapesoftminSoftminFunctionalOptionssoftplusSoftplusFunctionalOptionssoftshrinksoftsignsortSortOptionssplitsplit_error_dim_out_of_rangeSplitOptionssqrtsquaresqueezeSqueezeOptionsSqueezeShapestackStackOptionsStackShapestdstd_meanStdVarMeanOptionsStdVarOptionsstftSTFTOptionsStrideOptionssubSublistSublistElementSubscriptIndexsumSVDShapeswapaxessym_floatsym_intsym_notttaketake_along_dimTakeAlongDimOptionstantanhtanhshrinktensortensor_splitTensorCreatorTensorDatatensordotTensordotOptionsTensorLikeTensorMetaTensorOptionsTensorStoragethresholdthreshold_tileTileShapeToOptionstopkTopkOptionsTorchtraceTraceShapetransposetranspose_dims_error_out_of_rangetranspose_error_requires_2d_tensorTransposeDimsShapeTransposeDimsShapeCheckedTransposeShapetrapezoidTrapezoidOptionsTriangularOptionstriltril_indicesTriOptionsTripletriutriu_indicestrue_dividetruncTupleOfLengthTypedArrayTypedArrayForTypedStorageTypeOptionsUnaryBackwardFnUnaryDTypeUnaryKernelConfigCPUUnaryKernelCPUUnaryOpConfigUnaryOpFnUnaryOpNamesUnaryOpParamsUnaryOpSchemaUnaryOptionsunbindunbind_error_dim_out_of_rangeUnbindOptionsunflattenUniformOptionsuniqueunique_consecutiveUniqueConsecutiveOptionsUniqueOptionsunpackUnpackShapeunravel_indexunregister_deviceunsqueezeUnsqueezeOptionsUnsqueezeShapeuse_deterministic_algorithmsValidateBatchedSquareMatrixValidateChunkDimValidatedEinsumShapevalidateDeviceValidateDeviceValidatedRearrangeShapeValidatedReduceShapeValidatedRepeatShapevalidateDTypeValidateEinsumValidateOperandCountValidateRanksValidateScalarValidateSplitDimValidateSquareMatrixValidateUnbindDimValueOptionsvar_var_meanvdotviewview_as_complexview_as_realvmapvsplitvstackWebGPUKernelConfigWebGPUOnlyResultWebGPUTensorDatawhereWindowOptionsxlogyzeroszeros_like
torch.js· 2026
LegalTerms of UsePrivacy Policy
/
/
  1. docs
  2. torch.js
  3. torch
  4. optim
  5. SparseAdam

torch.optim.SparseAdam

class SparseAdam extends Optimizer
new SparseAdam(params: Tensor[] | Iterable<Tensor>, options: SparseAdamOptions = {})

SparseAdam optimizer: Adam variant optimized for parameters with sparse gradients.

SparseAdam is a modification of Adam designed for parameters with sparse gradients, where only a small fraction of positions have nonzero gradients at each step. Classic examples: embedding layers in NLP where only a few words appear per batch.

Key Optimization:

  • Only updates state (first and second moments) for positions with nonzero gradients
  • Positions with zero gradient don't affect state (m_t, v_t unchanged)
  • This is crucial for embedding layers where each batch uses different embeddings
  • Dramatically reduces computation and memory access for sparse scenarios

Motivation: With standard Adam, computing m_t and v_t for all million positions in an embedding layer is wasteful when only ~1000 are actually used. SparseAdam only touches used positions, providing massive speedup for sparse scenarios.

When to use SparseAdam:

  • Embedding layers (NLP, recommendation systems, graphs)
  • Any layer where gradient is sparse (most positions are zero)
  • Large embedding tables where sparse access is critical for performance
  • Recommendation systems with massive item embeddings
  • Graph neural networks with node embeddings
  • Any sparse feature scenario

Trade-offs:

  • Only works with sparse gradient patterns (don't use with dense layers)
  • Requires manually specifying which params are sparse
  • Dense layers should use standard Adam (SparseAdam adds masking overhead)
  • Mixing sparse and dense parameters needs careful setup

Implementation Note: Current implementation works with dense tensors but only updates state where gradients are nonzero. For ultra-high-performance, true sparse tensor support would be better, but this handles 90% of use cases efficiently.

maskt=(∇f(p)≠0)mt=β1mt−1+(1−β1)∇f(p)⋅maskt,vt=β2vt−1+(1−β2)(∇f(p))2⋅masktm^t=mt1−β1t,v^t=vt1−β2tp←p−αm^tv^t+ϵ⋅maskt\begin{aligned} \text{mask}_t = (\nabla f(\mathbf{p}) \neq 0) \\ m_t = \beta_1 m_{t-1} + (1 - \beta_1) \nabla f(\mathbf{p}) \cdot \text{mask}_t, \quad v_t = \beta_2 v_{t-1} + (1 - \beta_2) (\nabla f(\mathbf{p}))^2 \cdot \text{mask}_t \\ \hat{m}_t = \frac{m_t}{1 - \beta_1^t}, \quad \hat{v}_t = \frac{v_t}{1 - \beta_2^t} \\ \mathbf{p} \leftarrow \mathbf{p} - \alpha \frac{\hat{\mathbf{m}}_t}{\sqrt{\hat{v}_t} + \epsilon} \cdot \text{mask}_t \end{aligned}maskt​=(∇f(p)=0)mt​=β1​mt−1​+(1−β1​)∇f(p)⋅maskt​,vt​=β2​vt−1​+(1−β2​)(∇f(p))2⋅maskt​m^t​=1−β1t​mt​​,v^t​=1−β2t​vt​​p←p−αv^t​​+ϵm^t​​⋅maskt​​
  • Sparse gradients only: Only beneficial if gradients are actually sparse (mostly zero).
  • Embedding layers: Primary use case - word embeddings, item embeddings, node embeddings.
  • Performance critical: Can provide 10-100x speedup for large embedding tables.
  • Don't mix: Use SparseAdam for sparse layers, Adam for dense layers.
  • Default lr: 1e-3 same as Adam, usually doesn't need adjustment.
  • Dense implementation: Current implementation uses dense tensors with masking.
  • True sparse tensors: Would be even faster, but dense with masking sufficient.
  • Batch size effects: More efficient with smaller batches (more sparsity).
  • Embedding specialization: SparseAdam designed specifically for embedding patterns.
  • Not for general use: Only use if you have actual sparse gradient patterns.
  • Hyperparameter tuning: Same as Adam - lr, betas rarely need adjustment.

Examples

// SparseAdam for embedding layers (common in NLP)
const embedding = new torch.nn.Embedding(vocab_size, embedding_dim);
const sparse_adam = new torch.optim.SparseAdam(embedding.parameters(), { lr: 1e-3 });

for (const batch of train_loader) {
  // batch.token_ids: [batch_size, seq_len] - only ~batch_size*seq_len words used
  const embeddings = embedding.forward(batch.token_ids);
  const loss = model.loss(embeddings, batch.y);

  sparse_adam.zero_grad();
  // loss.backward();
  sparse_adam.step();  // Only updates embeddings that appeared
}
// Recommendation system with sparse item embeddings
const item_embeddings = new torch.nn.Embedding(num_items, embedding_dim);
const user_embeddings = new torch.nn.Embedding(num_users, embedding_dim);
const sparse_adam = new torch.optim.SparseAdam([
  item_embeddings.parameters(),
  user_embeddings.parameters()
], { lr: 1e-3 });

// With millions of items, only ~100 appear per batch
// SparseAdam only updates those 100, not all millions
// Graph neural network with node embeddings
const node_embeddings = new torch.nn.Embedding(num_nodes, embedding_dim);
const sparse_adam = new torch.optim.SparseAdam(
  node_embeddings.parameters(),
  { lr: 1e-3 }
);

// In subgraph batching, only ~1000 nodes per batch even with millions total
// SparseAdam handles this efficiently
// Multi-task learning with sparse task embeddings
const task_embeddings = new torch.nn.Embedding(num_tasks, task_dim);
const sparse_adam = new torch.optim.SparseAdam(
  task_embeddings.parameters(),
  { lr: 2e-3, betas: [0.9, 0.999] }
);

// Different tasks in each batch mean sparse gradient patterns
// Comparison: SparseAdam vs standard Adam for embeddings
const sparse_adam = new torch.optim.SparseAdam(
  embedding.parameters(), { lr: 1e-3 }
);
const adam = new torch.optim.Adam(
  embedding.parameters(), { lr: 1e-3 }
);

// SparseAdam: only updates used embeddings, ~100x faster
// Adam: updates all million embeddings every step, very slow
// Use SparseAdam for embedding layers, Adam for everything else

See Also

  • PyTorch torch.optim.SparseAdam
Previous
SGDOptions
Next
SparseAdamOptions