Skip to main content
torch.jstorch.jstorch.js
Getting StartedPlaygroundContact
Login
torch.jstorch.jstorch.js
Documentation
IntroductionType SafetyTensor IndexingEinsumEinopsAutogradTraining a ModelProfiling & MemoryPyTorch MigrationBest PracticesRuntimesPerformance
ActivationOptionsAdaptiveAvgPool1dAdaptiveAvgPool2dAdaptiveAvgPool3dAdaptiveLogSoftmaxWithLossAdaptiveMaxPool1dAdaptiveMaxPool2dAdaptiveMaxPool3dadd_moduleAlphaDropoutappendappendapplyAvgPool1dAvgPool1dOptionsAvgPool2dAvgPool2dOptionsAvgPool3dAvgPool3dOptionsBackwardHookBackwardPreHookBatchNorm1dBatchNorm2dBatchNorm3dBatchNormOptionsBCELossBCEWithLogitsLossBilinearBufferBufferOptionsBufferRegistrationHookbufferscallCELUCELUOptionsChannelShufflechildrenCircularPad1dCircularPad2dCircularPad3dclearConstantPad1dConstantPad2dConstantPad3dConv1dConv2dConv3dConvOptionsConvTranspose1dConvTranspose2dConvTranspose3dConvTransposeOptionsCosineEmbeddingLossCosineEmbeddingLossOptionsCosineSimilarityCosineSimilarityOptionscreatecreateCrossEntropyLossCTCLossdecodedecodedeleteDropoutDropout1dDropout2dDropout3dDropoutOptionsELUELUOptionsEmbeddingEmbeddingBagencodeencodeentriesentriesevalextendFeatureAlphaDropoutFlattenFlattenOptionsFoldFoldOptionsforwardforwardforwardforwardforwardforwardforwardforwardforward_with_targetForwardHookForwardPreHookFractionalMaxPool2dFractionalMaxPool3dfrom_pretrainedfrom_pretrainedGaussianNLLLossGELUGELUOptionsgenerate_square_subsequent_maskgetgetgetgetgetget_bufferget_parameterget_submoduleGLUGLUOptionsGroupNormGroupNormOptionsGRUGRUCellHardshrinkHardshrinkOptionsHardsigmoidHardswishHardtanhHardtanhOptionshashasHingeEmbeddingLossHingeEmbeddingLossOptionsHuberLossHuberLossOptionsIdentityInstanceNorm1dInstanceNorm2dInstanceNorm3dInstanceNormOptionsis_uninitialized_bufferis_uninitialized_parameteriterator]iterator]iterator]keyskeysKLDivLossL1LossL1LossOptionsLayerNormLayerNormOptionsLazyBatchNorm1dLazyBatchNorm2dLazyBatchNorm3dLazyConv1dLazyConv2dLazyConv3dLazyConvOptionsLazyConvTranspose1dLazyConvTranspose2dLazyConvTranspose3dLazyConvTransposeOptionsLazyInstanceNorm1dLazyInstanceNorm2dLazyInstanceNorm3dLazyLinearLeakyReLULeakyReLUOptionsLinearLinearOptionsload_state_dictLocalResponseNormLocalResponseNormOptionslog_probLogSigmoidLogSoftmaxLogSoftmaxOptionsLPPool1dLPPool1dOptionsLPPool2dLPPool2dOptionsLPPool3dLPPool3dOptionsLSTMLSTMCellLSTMCellOptionsMarginRankingLossMarginRankingLossOptionsmaterializematerializematerialize_uninitializedmaterialize_uninitializedMaxPool1dMaxPool1dOptionsMaxPool2dMaxPool2dOptionsMaxPool3dMaxPool3dOptionsMaxUnpool1dMaxUnpool1dOptionsMaxUnpool2dMaxUnpool2dOptionsMaxUnpool3dMaxUnpool3dOptionsMishModuleModuleBuffersModuleChildrenModuleDictModuleListModuleParametersModuleRegistrationHookmodulesMSELossMSELossOptionsmultihead_attnMultiheadAttentionMultiheadAttentionOptionsMultiLabelMarginLossMultiLabelMarginLossOptionsMultiLabelSoftMarginLossMultiMarginLossnamed_buffersnamed_childrennamed_modulesnamed_parametersNLLLossnum_parametersPairwiseDistancePairwiseDistanceOptionsParameterParameterDictParameterListParameterOptionsParameterRegistrationHookparametersPixelShufflePixelUnshufflePoissonNLLLosspopPReLUPReLUOptionsReflectionPad1dReflectionPad2dReflectionPad3dregister_backward_hookregister_bufferregister_forward_hookregister_forward_pre_hookregister_full_backward_hookregister_full_backward_pre_hookregister_module_backward_hookregister_module_buffer_registration_hookregister_module_forward_hookregister_module_forward_pre_hookregister_module_full_backward_hookregister_module_full_backward_pre_hookregister_module_module_registration_hookregister_module_parameter_registration_hookregister_parameterReLUReLU6RemovableHandleremoveReplicationPad1dReplicationPad2dReplicationPad3dRMSNormRNNRNNBaseRNNBaseOptionsRNNCellRNNCellOptionsRReLURReLUOptionsrunSELUSequentialsetsetsetSigmoidSiLUSmoothL1LossSmoothL1LossOptionsSoftMarginLossSoftMarginLossOptionsSoftmaxSoftmax2dSoftmaxOptionsSoftminSoftminOptionsSoftplusSoftplusOptionsSoftshrinkSoftshrinkOptionsSoftsignstate_dictSyncBatchNormTanhTanhshrinkThresholdThresholdOptionstotrainTransformerTransformerDecoderTransformerDecoderLayerTransformerDecoderLayerOptionsTransformerDecoderOptionsTransformerEncoderTransformerEncoderLayerTransformerEncoderLayerOptionsTransformerEncoderOptionsTransformerOptionsTripletMarginLossTripletMarginWithDistanceLossUnflattenUnfoldUnfoldOptionsUninitializedBufferUninitializedOptionsUninitializedParameterupdateUpsampleUpsamplingBilinear2dUpsamplingNearest2dvaluesvalueszero_gradZeroPad1dZeroPad2dZeroPad3d
absacosacoshaddaddbmmAddbmmOptionsaddcdivAddcdivOptionsaddcmulAddcmulOptionsaddmmAddmmOptionsaddmvAddmvOptionsaddrAddrOptionsadjointallallcloseAllcloseOptionsamaxaminaminmaxangleanyapplyOutarangeare_deterministic_algorithms_enabledargmaxargminargsortargwhereas_stridedas_tensorasinasinhAssertNoShapeErrorAtat_error_index_out_of_boundsatanatan2atanhatleast_1datleast_2datleast_3dAtShapeautocast_decrement_nestingautocast_increment_nestingAxesRecordbaddbmmBaddbmmOptionsbatch_dimensions_do_not_match_errorbernoulliBinaryOptionsbincountbitwise_andbitwise_left_shiftbitwise_notbitwise_orbitwise_right_shiftbitwise_xorblock_diagbmmbroadcast_error_incompatible_dimensionsbroadcast_shapesbroadcast_tensorsbroadcast_toBroadcastShapebroadcastShapesbucketizecanBroadcastTocartesian_prodcatCatShapecdistceilchain_matmulCholeskyShapechunkchunk_error_dim_out_of_rangeclampclear_autocast_cacheclonecolumn_stackcombinationscompiled_with_cxx11_abicomplexconjconj_physicalcontiguouscopysigncorrcoefcoscoshcount_nonzerocovCPUTensorDatacreateTorchCumExtremeResultcummaxcummincumprodCumShapecumsumcumulative_trapezoidCumulativeOptionsdeg2raddetachDetShapeDeviceDeviceInputDeviceTypediagdiag_embeddiagflatdiagonal_scatterDiagShapediffdigammadimension_error_out_of_rangedistdivdotdsplitdstackDTypeDynamicShapeEigShapeeinops_error_ambiguous_decompositioneinops_error_anonymous_in_outputeinops_error_dimension_mismatcheinops_error_invalid_patterneinops_error_reduce_undefined_outputeinops_error_repeat_missing_sizeeinops_error_undefined_axiseinsumeinsum_error_dimension_mismatcheinsum_error_index_out_of_rangeeinsum_error_invalid_equationeinsum_error_invalid_sublist_elementeinsum_error_operand_count_mismatcheinsum_error_subscript_rank_mismatcheinsum_error_unknown_output_indexEinsumOutputShapeEllipsiseluembedding_bag_error_requires_2d_inputemptyempty_cacheempty_likeeqequalerferfcerfinvexpexp2expandexpand_asexpand_error_incompatibleExpandShapeexpm1eyeEyeOptionsflattenFlattenShapeflipflip_error_dim_out_of_rangefliplrFlipShapeflipudfloat_powerfloorfloor_dividefmaxfminfmodfracfrexpfrombufferfullfull_likegathergather_error_dim_out_of_rangeGatherShapegcdgegeluget_autocast_cpu_dtypeget_autocast_gpu_dtypeget_autocast_ipu_dtypeget_autocast_xla_dtypeget_default_deviceget_default_dtypeget_deterministic_debug_modeget_device_moduleget_file_pathget_float32_matmul_precisionget_num_interop_threadsget_num_threadsget_printoptionsget_rng_stateGradFngthardsigmoidhardswishHasShapeErrorheavisidehistchistogramHistogramResulthsplithstackhypoti0imagindex_addindex_copyindex_fillindex_putindex_reduceindex_selectindex_select_error_dim_out_of_rangeIndexSelectShapeIndexSpecIndicesSpecinverseInverseShapeis_anomaly_check_nan_enabledis_anomaly_enabledis_autocast_cache_enabledis_autocast_cpu_enabledis_autocast_ipu_enabledis_autocast_xla_enabledis_complexis_complex_dtypeis_cpu_only_modeis_deterministic_algorithms_warn_only_enabledis_floating_pointis_floating_point_dtypeis_inference_mode_enabledis_nonzerois_tensoris_warn_always_enabledis_webgpu_availableIs2DIsAtLeast1DiscloseIscloseOptionsisfiniteisinisinfisnanisneginfisposinfisrealIsShapeErroritem_error_not_scalarItemResultkronkthvalueKthvalueOptionslcmldexpleleaky_relulerplgammalinalg_error_not_square_matrixlinalg_error_requires_2dlinalg_error_requires_at_least_2dlinspaceloglog10log1plog2logaddexplogaddexp2logcumsumexplogical_andlogical_notlogical_orlogical_xorlogitlogspacelogsumexpltLUShapemasked_selectmasked_select_asyncMaskSpecmatmulmatmul_error_inner_dimensions_do_not_matchMatmul2DShapeMatmulShapemaxmaximummeanmedianmemory_statsmemory_summarymeshgridminminimummmmodemovedimmsortmulmultinomialmultinomial_asyncmvnan_to_numnanmeannanmediannanquantilenansumnarrownarrow_copynarrow_error_length_exceeds_boundsnarrow_error_start_out_of_boundsNarrowShapeneneedsBroadcastnegNegativeDimnextafternonzeronormnormalNormOptionsnumelonesones_likeouterpackPackShapepermutepermute_error_dimension_count_mismatchPermuteShapepoissonpolarpositivepowPrintOptionsprodprofiler_allow_cudagraph_cupti_lazy_reinit_cuda12promote_typesquantileQuantileOptionsrad2degrandrand_likerandintrandint_likerandnrandn_likerandpermRangeSpecRankravelrealRearrangeShapereciprocalreduceReduceOperationReduceShapeReductionOptionsreluremainderrepeatrepeat_interleaveRepeatInterleaveOptionsRepeatShaperequireWebGPUreset_peak_memory_statsreshapeReshapeShaperesult_typerollrot90roundrsqrtscatterscatter_addscatter_add_scatter_error_dim_out_of_rangescatter_reducescatter_reduce_ScatterShapesearchsortedselectselect_error_index_out_of_boundsselect_scatterSelectShapeseluset_default_deviceset_default_tensor_typeset_deterministic_debug_modeset_float32_matmul_precisionset_printoptionsset_warn_alwaysShapeShapedTensorsigmoidsignsignbitsilusinsincsinhslice_error_out_of_boundsslice_scatterSliceShapeSliceSpecsoftmax_error_dim_out_of_rangeSoftmaxShapesoftplussoftsignsortSortOptionssplitsplit_error_dim_out_of_rangesqrtsquaresqueezeSqueezeShapestackstdstd_meanStdVarOptionssubSublistSublistElementSubscriptIndexsumSVDShapeswapaxessym_floatsym_intsym_notttaketake_along_dimtantanhtensortensor_splitTensorCreatorTensorDatatensordotTensorOptionsTensorStoragetileTileShapetopkTopkOptionsTorchtraceTraceShapetransposetranspose_dims_error_out_of_rangetranspose_error_requires_2d_tensorTransposeDimsShapeTransposeDimsShapeCheckedTransposeShapetrapezoidtriltril_indicestriutriu_indicestruncTypedArrayTypedStorageUnaryOptionsunbindunbind_error_dim_out_of_rangeunflattenuniqueunique_consecutiveunpackUnpackShapeunravel_indexunsqueezeUnsqueezeShapeuse_deterministic_algorithmsValidateBatchedSquareMatrixValidateChunkDimValidatedEinsumShapevalidateDeviceValidatedRearrangeShapeValidatedReduceShapeValidatedRepeatShapevalidateDTypeValidateEinsumValidateOperandCountValidateRanksValidateScalarValidateSplitDimValidateSquareMatrixValidateUnbindDimvar_var_meanvdotviewview_as_complexview_as_realvmapvsplitvstackWebGPUTensorDatawherexlogyzeroszeros_like
torch.js· 2026
LegalTerms of UsePrivacy Policy
/
/
  1. docs
  2. torch.js
  3. torch
  4. nn
  5. LSTMCellOptions

torch.nn.LSTMCellOptions

A long short-term memory (LSTM) cell: single timestep recurrent unit with memory.

Maintains both a hidden state h_t and cell state c_t. The cell state acts as long-term memory, while gates control what information flows in/out. Computes in four gated steps: input gate, forget gate, cell gate, output gate. Essential for:

  • Learning long-range dependencies (up to 100+ timesteps)
  • Avoiding vanishing gradient problem of vanilla RNNs
  • Processing sequences where early tokens matter much later
  • Stable gradient flow through time
  • Fine-grained control over what's remembered vs forgotten

Unlike RNNCell which has unbounded hidden state, LSTMCell uses cell state c_t as protected memory. Forget gate decides what to discard from memory, input gate decides what to add, output gate decides what to expose. This architecture dramatically improves gradient flow in deep networks.

When to use LSTMCell:

  • Building custom architectures with LSTM-level stability
  • Implementing attention on LSTM hidden/cell states
  • Variable-length sequences with masking
  • Teacher forcing with LSTM memory
  • Fine-grained control over cell state for analysis/visualization
  • Bidirectional processing (separate forward/backward LSTM cells)

Trade-offs:

  • vs RNNCell: LSTM has cell state for longer-range deps; RNN is simpler
  • vs GRUCell: LSTM more expressive with separate forget/input gates; GRU more compact
  • Parameters: LSTM has 4x hidden_size gates vs RNN's 1; memory usage 4x higher
  • Stability: LSTM much better for long sequences (100+ steps)
  • Speed: LSTM slower per-step but better convergence often worth it
  • Gradient flow: LSTM preserves gradients via cell state; RNN/GRU may vanish

LSTM Gate Equations: Four gates process input and previous hidden state:

  1. Input gate: i_t = σ(W_ii @ x_t + b_ii + W_hi @ h_{t-1} + b_hi)
  2. Forget gate: f_t = σ(W_if @ x_t + b_if + W_hf @ h_{t-1} + b_hf)
  3. Cell gate: g_t = tanh(W_ig @ x_t + b_ig + W_hg @ h_{t-1} + b_hg)
  4. Output gate: o_t = σ(W_io @ x_t + b_io + W_ho @ h_{t-1} + b_ho)

Then cell and hidden state update:

  • c_t = f_t ⊙ c_{t-1} + i_t ⊙ g_t (forget old + add new)
  • h_t = o_t ⊙ tanh(c_t) (expose filtered cell state)

Definition

export interface LSTMCellOptions {
  /** Whether to include bias terms (default: true) */
  bias?: boolean;
}
bias(boolean)optional
– Whether to include bias terms (default: true)

Examples

// Process sequence maintaining both hidden and cell state
const lstm_cell = new torch.nn.LSTMCell(10, 20);  // input_size=10, hidden_size=20

const x = torch.randn([32, 100, 10]);  // [batch=32, seq_len=100, input_size=10]
let h = torch.zeros([32, 20]);  // Initial hidden state
let c = torch.zeros([32, 20]);  // Initial cell state

// Process sequence with manual loop
const outputs: torch.Tensor[] = [];
for (let t = 0; t < 100; t++) {
  const x_t = x.select(1, t);
  [h, c] = lstm_cell.step(x_t, [h, c]);  // Step returns [h_t, c_t]
  outputs.push(h);
}

const output = torch.stack(outputs, 1);  // [batch, seq_len, hidden_size]
// Language model with LSTM: predicting next token
class LSTMLanguageModel extends torch.nn.Module {
  embedding: torch.nn.Embedding;
  lstm_cell: torch.nn.LSTMCell;
  output_proj: torch.nn.Linear;

  constructor(vocab_size: number, embed_dim: number, hidden_dim: number) {
    super();
    this.embedding = new torch.nn.Embedding(vocab_size, embed_dim);
    this.lstm_cell = new torch.nn.LSTMCell(embed_dim, hidden_dim);
    this.output_proj = new torch.nn.Linear(hidden_dim, vocab_size);
  }

  forward(token_ids: torch.Tensor): torch.Tensor {
    const embedded = this.embedding.forward(token_ids);  // [batch, seq_len, embed_dim]

    let h = torch.zeros([embedded.shape[0], 512]);
    let c = torch.zeros([embedded.shape[0], 512]);
    const logits: torch.Tensor[] = [];

    for (let t = 0; t < embedded.shape[1]; t++) {
      const x_t = embedded.select(1, t);
      [h, c] = this.lstm_cell.step(x_t, [h, c]);
      logits.push(this.output_proj.forward(h));
    }

    return torch.stack(logits, 1);  // [batch, seq_len, vocab_size]
  }
}
// Bidirectional LSTM: forward and backward
const lstm_fwd = new torch.nn.LSTMCell(10, 20);
const lstm_bwd = new torch.nn.LSTMCell(10, 20);

const x = torch.randn([32, 100, 10]);

// Forward pass
let h_fwd = torch.zeros([32, 20]);
let c_fwd = torch.zeros([32, 20]);
const fwd_outputs: torch.Tensor[] = [];

for (let t = 0; t < 100; t++) {
  [h_fwd, c_fwd] = lstm_fwd.step(x.select(1, t), [h_fwd, c_fwd]);
  fwd_outputs.push(h_fwd);
}

// Backward pass
let h_bwd = torch.zeros([32, 20]);
let c_bwd = torch.zeros([32, 20]);
const bwd_outputs: torch.Tensor[] = [];

for (let t = 99; t >= 0; t--) {
  [h_bwd, c_bwd] = lstm_bwd.step(x.select(1, t), [h_bwd, c_bwd]);
  bwd_outputs.unshift(h_bwd);
}

// Concatenate forward and backward hidden states
const bidir_outputs: torch.Tensor[] = [];
for (let t = 0; t < 100; t++) {
  bidir_outputs.push(torch.cat([fwd_outputs[t], bwd_outputs[t]], -1));
}
// Analyzing LSTM internals: gate activations
const lstm = new torch.nn.LSTMCell(100, 256);
const x = torch.randn([1, 100]);
let h = torch.zeros([1, 256]);
let c = torch.zeros([1, 256]);

// Manually compute gates to visualize
const gates_combined = x.matmul(lstm.weight_ih.t()).add(
  h.matmul(lstm.weight_hh.t())
);

if (lstm.bias_ih) {
  gates_combined.add_(lstm.bias_ih);
}

const chunked = gates_combined.chunk(4, 1);
const i_gate = chunked[0].sigmoid();  // Input gate (should learn what's important)
const f_gate = chunked[1].sigmoid();  // Forget gate (should learn what to keep)
const g_gate = chunked[2].tanh();     // Cell candidate
const o_gate = chunked[3].sigmoid();  // Output gate

console.log('Input gate mean:', i_gate.mean().item());
console.log('Forget gate mean:', f_gate.mean().item());
Previous
LSTMCell
Next
MarginRankingLoss