Skip to main content
torch.js has not been released yet.
torch.js logotorch.js logotorch.js
PlaygroundContact
Login
Documentation
IntroductionType SafetyTensor ExpressionsTensor IndexingEinsumEinopsAutogradTraining a ModelProfiling & MemoryPyTorch MigrationBest PracticesRuntimesPerformancePyTorch CompatibilityBenchmarksDType Coverage
ActivationOptionsAdaptiveAvgPool1dAdaptiveAvgPool2dAdaptiveAvgPool3dAdaptiveLogSoftmaxOptionsAdaptiveLogSoftmaxWithLossAdaptiveMaxPool1dAdaptiveMaxPool1dOptionsAdaptiveMaxPool2dAdaptiveMaxPool2dOptionsAdaptiveMaxPool3dAdaptiveMaxPool3dOptionsadd_moduleAlphaDropoutappendappendapplyAvgPool1dAvgPool1dOptionsAvgPool2dAvgPool2dOptionsAvgPool3dAvgPool3dOptionsBackwardHookBackwardPreHookBatchNorm1dBatchNorm2dBatchNorm3dBatchNormOptionsBCELossBCEWithLogitsLossBilinearBilinearOptionsBufferBufferOptionsBufferRegistrationHookbufferscallCELUCELUOptionsChannelShufflechildrenCircularPad1dCircularPad2dCircularPad3dclearConstantPad1dConstantPad2dConstantPad3dConv1dConv2dConv3dConvOptionsConvTranspose1dConvTranspose2dConvTranspose3dConvTransposeOptionsCosineEmbeddingLossCosineEmbeddingLossOptionsCosineSimilarityCosineSimilarityOptionscreatecreateCrossEntropyLossCTCLossdecodedecodedeleteDropoutDropout1dDropout2dDropout3dDropoutOptionsELUELUOptionsEmbeddingEmbeddingBagEmbeddingBagForwardOptionsEmbeddingBagFromPretrainedOptionsEmbeddingBagOptionsEmbeddingFromPretrainedOptionsEmbeddingOptionsencodeencodeentriesentriesevalextendFeatureAlphaDropoutFlattenFlattenOptionsFoldFoldOptionsforwardforwardforwardforwardforwardforwardforwardforwardforwardforwardforwardforwardforward_with_targetForwardHookForwardPreHookFractionalMaxPool2dFractionalMaxPool3dFractionalMaxPoolOptionsfrom_pretrainedfrom_pretrainedGaussianNLLLossGELUGELUOptionsgenerate_square_subsequent_maskgetgetgetgetgetget_bufferget_parameterget_submoduleGLUGLUOptionsGroupNormGroupNormOptionsGRUGRUCellHardshrinkHardshrinkOptionsHardsigmoidHardswishHardtanhHardtanhOptionshashasHingeEmbeddingLossHingeEmbeddingLossOptionsHuberLossHuberLossOptionsIdentityInstanceNorm1dInstanceNorm2dInstanceNorm3dInstanceNormOptionsis_uninitialized_bufferis_uninitialized_parameteriterator]iterator]iterator]iterator]keyskeysKLDivLossL1LossL1LossOptionsLayerNormLayerNormOptionsLazyBatchNorm1dLazyBatchNorm2dLazyBatchNorm3dLazyConv1dLazyConv2dLazyConv3dLazyConvOptionsLazyConvTranspose1dLazyConvTranspose2dLazyConvTranspose3dLazyConvTransposeOptionsLazyInstanceNorm1dLazyInstanceNorm2dLazyInstanceNorm3dLazyLinearLeakyReLULeakyReLUOptionsLinearLinearOptionsload_state_dictload_state_dictLocalResponseNormLocalResponseNormOptionslog_probLogSigmoidLogSoftmaxLogSoftmaxOptionsLPPool1dLPPool1dOptionsLPPool2dLPPool2dOptionsLPPool3dLPPool3dOptionsLSTMLSTMCellLSTMCellOptionsMarginRankingLossMarginRankingLossOptionsmaterializematerializematerialize_uninitializedmaterialize_uninitializedMaxPool1dMaxPool1dOptionsMaxPool2dMaxPool2dOptionsMaxPool3dMaxPool3dOptionsMaxUnpool1dMaxUnpool1dOptionsMaxUnpool2dMaxUnpool2dOptionsMaxUnpool3dMaxUnpool3dOptionsMishModuleModuleBuffersModuleChildrenModuleDictModuleDictOptionsModuleListModuleListOptionsModuleParametersModuleRegistrationHookmodulesMSELossMSELossOptionsmultihead_attnMultiheadAttentionMultiheadAttentionOptionsMultiheadAttnOptionsMultiLabelMarginLossMultiLabelMarginLossOptionsMultiLabelSoftMarginLossMultiMarginLossnamed_buffersnamed_childrennamed_modulesnamed_parametersNamedModulesOptionsNamedRecurseOptionsNLLLossnum_parametersNumParametersOptionsPairwiseDistancePairwiseDistanceOptionsParameterParameterDictParameterDictOptionsParameterListParameterListOptionsParameterOptionsParameterRegistrationHookparametersPixelShufflePixelUnshufflePoissonNLLLosspoppopPReLUPReLUOptionsRecurseOptionsReflectionPad1dReflectionPad2dReflectionPad3dregister_backward_hookregister_bufferregister_forward_hookregister_forward_pre_hookregister_full_backward_hookregister_full_backward_pre_hookregister_module_backward_hookregister_module_buffer_registration_hookregister_module_forward_hookregister_module_forward_pre_hookregister_module_full_backward_hookregister_module_full_backward_pre_hookregister_module_module_registration_hookregister_module_parameter_registration_hookregister_parameterReLUReLU6RemovableHandleremoveReplicationPad1dReplicationPad2dReplicationPad3dRMSNormRMSNormOptionsRNNRNNBaseRNNBaseOptionsRNNCellRNNCellOptionsRReLURReLUOptionsrunrunSELUSequentialsetsetsetSigmoidSiLUSmoothL1LossSmoothL1LossOptionsSoftMarginLossSoftMarginLossOptionsSoftmaxSoftmax2dSoftmaxOptionsSoftminSoftminOptionsSoftplusSoftplusOptionsSoftshrinkSoftshrinkOptionsSoftsignstate_dictstate_dictStateDictOptionsstepSyncBatchNormTanhTanhshrinkThresholdThresholdOptionstotototrainTrainOptionsTransformerTransformerDecoderTransformerDecoderDecodeOptionsTransformerDecoderLayerTransformerDecoderLayerDecodeOptionsTransformerDecoderLayerOptionsTransformerDecoderOptionsTransformerEncoderTransformerEncoderEncodeOptionsTransformerEncoderLayerTransformerEncoderLayerEncodeOptionsTransformerEncoderLayerOptionsTransformerEncoderOptionsTransformerOptionsTransformerRunOptionsTripletMarginLossTripletMarginWithDistanceLossUnflattenUnfoldUnfoldOptionsUninitializedBufferUninitializedOptionsUninitializedParameterupdateUpsampleUpsamplingBilinear2dUpsamplingNearest2dvaluesvalueszero_gradZeroPad1dZeroPad2dZeroPad3d
absacosacoshAdaptivePool1dShapeAdaptivePool2dShapeaddaddbmmAddbmmOptionsaddcdivAddcdivOptionsaddcmulAddcmulOptionsaddmmAddmmOptionsaddmvAddmvOptionsaddrAddrOptionsadjointallallcloseAllcloseOptionsAlphaBetaOptionsamaxaminaminmaxAminmaxOptionsangleanyapplyOutarangeare_deterministic_algorithms_enabledargmaxargminargsortargwhereas_stridedas_tensorasinasinhAssertNoShapeErrorAssertNotErrorAsStridedOptionsAtat_error_index_out_of_boundsatanatan2atanhatleast_1datleast_2datleast_3dAtShapeautocast_decrement_nestingautocast_increment_nestingautograd_gradient_mismatch_errorautograd_not_registered_errorAutogradConfigAutogradDeviceAutogradDTypeAutogradEntryAutogradHandleAutogradHandleImplAxesRecordBackwardFnbaddbmmBaddbmmOptionsbartlett_windowBaseKernelConfigbatch_dimensions_do_not_match_errorbernoulliBernoulliOptionsBinaryBackwardFnBinaryBroadcastResultBinaryDTypeBinaryKernelConfigCPUBinaryKernelCPUBinaryOpConfigBinaryOpNamesBinaryOpSchemaBinaryOptionsbincountBincountOptionsbitwise_andbitwise_left_shiftbitwise_notbitwise_orbitwise_right_shiftbitwise_xorblackman_windowblock_diagbmmBooleanDTypeRulebroadcast_error_incompatible_dimensionsbroadcast_shapesbroadcast_tensorsbroadcast_toBroadcastShapeBroadcastShapeRulebroadcastShapesbucketizeBucketizeOptionsBufferUsagebuildEinopsErrorbuildErrorMessagecanBroadcastTocartesian_prodcatCatOptionsCatShapeCauchyOptionscdistCdistOptionsceilceluCeluFunctionalOptionschain_matmulCheckShapeErrorCholeskyShapechunkchunk_error_dim_out_of_rangeChunkOptionsclampClampOptionsclear_autocast_cacheclearEinopsCacheclearEinsumCacheclonecolumn_stackcombinationsCombinationsOptionscompiled_with_cxx11_abicomplexconjconj_physicalcontiguousConv1dShapeConv2dShapeConv3dShapeConvTranspose2dShapecopysigncorrcoefcoscoshcount_nonzeroCountNonzeroOptionscovcoverage_reportcoverageReportCoverageReportCovOptionsCPUForwardFnCPUKernelConfigCPUKernelEntryCPUOnlyResultCPUTensorDatacreateCumExtremeResultcreateTorchCreationOpSchemaCumExtremeResultcummaxcummincumprodCumShapecumsumcumulative_trapezoidCumulativeOptionsCumulativeOptionsWithDimdeg2raddetachDeterministicOptionsDetShapeDevicedevice_error_requiresDeviceBufferDeviceCapabilitiesDeviceCheckedResultDeviceConfigDeviceContextDeviceEntryDeviceHandleDeviceInputDeviceOptionsDeviceRegistryDeviceTypediagdiag_embedDiagEmbedOptionsdiagflatDiagflatOptionsDiagFlatOptionsdiagonal_scatterDiagonalOptionsDiagonalScatterOptionsDiagOptionsDiagShapediffDiffOptionsdigammadimension_error_out_of_rangeDispatchConfigdistDistOptionsdivdotDotShapeRuleDoubleDoubleDimdropoutDropoutFunctionalOptionsdsplitdstackDTypedtype_already_registered_errordtype_components_mismatch_errordtype_not_found_errorDTypeComponentsDTypeConfigDTypeCoverageReportDTypeDisplayConfigDTypeEntryDTypeHandleDTypeHandleImplDTypeInfoDTypeRegistryDTypeRuleDTypeSerializationConfigDynamicShapeEigShapeeinops_error_ambiguous_decompositioneinops_error_anonymous_in_outputeinops_error_dimension_mismatcheinops_error_invalid_patterneinops_error_reduce_undefined_outputeinops_error_repeat_missing_sizeeinops_error_undefined_axiseinsumeinsum_error_dimension_mismatcheinsum_error_index_out_of_rangeeinsum_error_invalid_equationeinsum_error_invalid_sublist_elementeinsum_error_operand_count_mismatcheinsum_error_subscript_rank_mismatcheinsum_error_unknown_output_indexEinsumOptionsEinsumOutputShapeEllipsiseluelu_EluFunctionalOptionsembedding_bag_error_requires_2d_inputemptyempty_cacheempty_likeeqequalerferfcerfinvexpexp2expandexpand_asexpand_error_incompatibleExpandShapeexpm1ExponentialOptionseyeEyeOptionsfftFFTOptionsfindKernelWithPredicatefindSimilarPatternsflattenFlattenOptionsFlattenShapeflipflip_error_dim_out_of_rangefliplrFlipShapeflipudfloat_powerFloatDTypeRulefloorfloor_dividefmaxfminfmodformatEquationErrorformatShapefracfrexpfrombufferfullfull_likefunction_already_registered_errorFunctionConfigFunctionEntryFunctionHandlegathergather_error_dim_out_of_rangeGatherShapegcdgegeluGeometricOptionsget_autocast_cpu_dtypeget_autocast_gpu_dtypeget_autocast_ipu_dtypeget_autocast_xla_dtypeget_default_deviceget_default_dtypeget_deterministic_debug_modeget_device_configget_device_contextget_device_moduleget_dtype_infoget_file_pathget_float32_matmul_precisionget_num_interop_threadsget_num_threadsget_op_infoget_printoptionsget_real_dtypeget_rng_stategetAutogradgetDTypegetEinopsCacheSizegetEinsumCacheSizegetFunctiongetKernelgetMethodgetOpInfoGetOpKindGetOpSchemagetScalarKernelgluGluFunctionalOptionsGradContextGradFnGradientsForgtHalfHalfDimhamming_windowhann_windowhardshrinkhardsigmoidhardswishhardtanhhardtanh_HardtanhFunctionalOptionshas_autogradhas_devicehas_dtypehas_kernelhasAutogradhasDTypehasFunctionhasKernelhasMethodhasScalarKernelHasShapeErrorheavisidehistcHistcOptionshistogramHistogramOptionsHistogramResulthsplithstackhypoti0IdentityShapeifftimagindex_addindex_copyindex_fillindex_putindex_reduceindex_selectindex_select_error_dim_out_of_rangeIndexPutOptionsIndexSelectShapeIndexSpecIndicesOptionsIndicesSpecinitialize_deviceInputsForInsertDiminvalid_config_errorinverseInverseShapeirfftis_anomaly_check_nan_enabledis_anomaly_enabledis_autocast_cache_enabledis_autocast_cpu_enabledis_autocast_ipu_enabledis_autocast_xla_enabledis_complexis_complex_dtypeis_cpu_only_modeis_deterministic_algorithms_warn_only_enabledis_floating_pointis_floating_point_dtypeis_inference_mode_enabledis_nonzerois_tensoris_warn_always_enabledis_webgpu_availableIs2DIsAtLeast1DIsBinaryOpIsBinaryOpNameiscloseIscloseOptionsisfiniteisinisinfisnanisneginfisposinfisrealIsReductionOpIsReductionOpNameIsRegistryErrorIsShapeErroristftISTFTOptionsIsUnaryOpIsUnaryOpNameitem_error_not_scalarItemResultkaiser_windowKaiserWindowOptionskernel_not_registered_errorkernel_signature_mismatch_errorKernelConfigKernelConfigWebGPUKernelEntryKernelHandleKernelInfoKernelPredicateKernelRegistryKernelWebGPUkronkthvalueKthvalueOptionslcmldexpleleaky_reluleaky_relu_LeakyReluFunctionalOptionslerplevenshteinDistancelgammalinalg_error_not_square_matrixlinalg_error_requires_2dlinalg_error_requires_at_least_2dlinearlinspacelist_custom_deviceslist_custom_dtypeslist_deviceslist_dtypeslist_functionslist_kernelslist_methodslist_opslistCustomDTypeslistDTypeslistFunctionslistKernelsListKernelsOptionslistMethodslistOpsListOpsOptionsloglog_softmaxlog10log1plog2logaddexplogaddexp2logcumsumexplogical_andlogical_notlogical_orlogical_xorLogitOptionsLogNormalOptionsLogOptionslogsigmoidlogspacelogsumexpLogsumexpOptionsltLUShapeLuSolveOptionsmasked_fillmasked_selectmasked_select_asyncMaskSpecmatmulmatmul_error_inner_dimensions_do_not_matchMatmul2DShapeMatmulShapeMatmulShapeRuleMatrixTransposeShapemaxmaximummeanmedianmemory_statsmemory_summarymeshgridmethod_already_registered_errormethod_dtype_not_supported_errorMethodConfigMethodEntryMethodHandleminminimummishmmMMShapeRulemodemovedimmsortmulmultinomialmultinomial_asyncMultinomialAsyncOptionsMultinomialOptionsMultiplyBymvMVShapeRulenan_to_numnanmeannanmediannanquantileNanReductionOptionsnansumNanToNumOptionsnarrownarrow_copynarrow_error_length_exceeds_boundsnarrow_error_start_out_of_boundsNarrowShapeneneedsBroadcastnegNegativeDimnextafternonzeroNonzeroOptionsnormnormalNormalOptionsNormOptionsnumelonesones_likeop_kind_mismatch_errorop_not_found_errorOpCoverageEntryOpInfoOpKindOpNameOpSchemaOpSchemasouterOuterShapepackPackShapepermutepermute_error_dimension_count_mismatchPermuteShapepoissonpolarPool1dShapePool2dShapePool3dShapepositivepowpreluPrintOptionsprodprofiler_allow_cudagraph_cupti_lazy_reinit_cuda12promote_typesPromoteDTypeRulePutOptionsquantileQuantileOptionsrad2degrandrand_likerandintrandint_likeRandintLikeOptionsRandintOptionsrandnrandn_likeRandomLikeOptionsRandomOptionsrandpermRangeSpecRankravelrealrearrangeRearrangeOptionsRearrangeShapereciprocalreduceReduceOperationReduceOptionsReduceShapeReductionKernelConfigCPUReductionKernelCPUReductionOpNamesReductionOpSchemaReductionOptionsReductionShapeRuleregister_backwardregister_deviceregister_dtyperegister_forwardregister_functionregister_methodregister_scalar_forwardregisterAutogradRegisterBackwardOptionsregisterBinaryOpregisterDTypeRegisterDTypeOptionsRegisteredDTyperegisterFunctionRegisterFunctionOptionsregisterKernelRegisterKernelOptionsregisterMethodRegisterMethodOptionsregisterScalarKernelregisterUnaryOpregistration_failed_errorrelurelu_relu6ReluFunctionalOptionsremainderRemoveDimrepeatrepeat_interleaveRepeatInterleaveOptionsRepeatOptionsRepeatShapeReplaceDimrequireWebGPUreset_peak_memory_statsreshapeReshapeShaperesult_typerfftrollRollOptionsrot90Rot90Optionsroundrrelurrelu_RreluFunctionalOptionsrsqrtSafeExpandShapeSameDTypeRuleSameShapeRuleSaveForBackwardScalarCPUForwardFnScalarCPUKernelConfigScalarKernelEntryScalarKernelHandleScalarWebGPUKernelConfigScaleDimscatterscatter_addscatter_add_scatter_error_dim_out_of_rangescatter_reducescatter_reduce_ScatterReduceOptionsScatterShapesearchsortedSearchSortedOptionsselectselect_error_index_out_of_boundsselect_scatterSelectShapeseluset_default_deviceset_default_tensor_typeset_deterministic_debug_modeset_float32_matmul_precisionset_printoptionsset_warn_alwaysSetupContextFnShapeShapeCheckedResultShapedTensorShapeErrorMessageShapeOpSchemaShapeRulesigmoidsignsignbitsilusinsincsinhSizeOptionsslice_error_out_of_boundsslice_scatterSliceOptionsSliceScatterOptionsSliceShapeSliceSpecsoftmaxsoftmax_error_dim_out_of_rangeSoftmaxShapesoftminSoftminFunctionalOptionssoftplusSoftplusFunctionalOptionssoftshrinksoftsignsortSortOptionssplitsplit_error_dim_out_of_rangeSplitOptionssqrtsquaresqueezeSqueezeOptionsSqueezeShapestackStackOptionsStackShapestdstd_meanStdVarMeanOptionsStdVarOptionsstftSTFTOptionsStrideOptionssubSublistSublistElementSubscriptIndexsumSVDShapeswapaxessym_floatsym_intsym_notttaketake_along_dimTakeAlongDimOptionstantanhtanhshrinktensortensor_splitTensorCreatorTensorDatatensordotTensordotOptionsTensorLikeTensorMetaTensorOptionsTensorStoragethresholdthreshold_tileTileShapeToOptionstopkTopkOptionsTorchtraceTraceShapetransposetranspose_dims_error_out_of_rangetranspose_error_requires_2d_tensorTransposeDimsShapeTransposeDimsShapeCheckedTransposeShapetrapezoidTrapezoidOptionsTriangularOptionstriltril_indicesTriOptionsTripletriutriu_indicestrue_dividetruncTupleOfLengthTypedArrayTypedArrayForTypedStorageTypeOptionsUnaryBackwardFnUnaryDTypeUnaryKernelConfigCPUUnaryKernelCPUUnaryOpConfigUnaryOpFnUnaryOpNamesUnaryOpParamsUnaryOpSchemaUnaryOptionsunbindunbind_error_dim_out_of_rangeUnbindOptionsunflattenUniformOptionsuniqueunique_consecutiveUniqueConsecutiveOptionsUniqueOptionsunpackUnpackShapeunravel_indexunregister_deviceunsqueezeUnsqueezeOptionsUnsqueezeShapeuse_deterministic_algorithmsValidateBatchedSquareMatrixValidateChunkDimValidatedEinsumShapevalidateDeviceValidateDeviceValidatedRearrangeShapeValidatedReduceShapeValidatedRepeatShapevalidateDTypeValidateEinsumValidateOperandCountValidateRanksValidateScalarValidateSplitDimValidateSquareMatrixValidateUnbindDimValueOptionsvar_var_meanvdotviewview_as_complexview_as_realvmapvsplitvstackWebGPUKernelConfigWebGPUOnlyResultWebGPUTensorDatawhereWindowOptionsxlogyzeroszeros_like
torch.js· 2026
LegalTerms of UsePrivacy Policy
/
/
  1. docs
  2. torch.js
  3. torch
  4. nn
  5. ModuleList

torch.nn.ModuleList

class ModuleList extends Module
new ModuleList(options?: ModuleListOptions)
readonlylength(number)
– Get the number of modules in the list.

Holds submodules in a list with automatic parameter registration.

ModuleList stores modules in a Python list-like container. Unlike building a model by composing modules as attributes (e.g., this.layer1 = Linear(...)), ModuleList allows dynamic construction and modification of layer sequences. Essential for:

  • Stacking variable numbers of layers (determined at runtime)
  • Sequential processing of layer sequences (encoder stacks, decoder stacks)
  • Conditional layer selection based on input or configuration
  • Dynamic model construction (pruning, adaptation, curriculum learning)
  • Iterating over layers and applying the same operation to all

Key difference from array of modules: ModuleList automatically registers submodules for proper parameter tracking, gradient computation, device movement, and training/eval mode propagation. Plain Python/TypeScript arrays won't track parameters correctly.

When to use ModuleList:

  • Variable-depth networks (number of layers determined at runtime)
  • Sequential stacks where each layer processes previous layer's output
  • Encoder/decoder with many similar layers (transformers, ResNets with many blocks)
  • Dynamic architectures that add/remove layers during execution
  • Simpler than ModuleDict when you don't need named access

When NOT to use:

  • Fixed layer counts - use attributes (this.fc1, this.fc2, etc.)
  • Complex interconnection patterns - might be clearer with attributes
  • Named module access preferred - use ModuleDict instead
  • Parameter tracking: All modules in list are automatically tracked for parameters/buffers
  • Device movement: .to(device) applies to all submodules
  • Training mode: .train()/.eval() propagates to all submodules
  • Iteration: Use for...of to iterate over modules
  • No forward method: Must manually implement forward() - ModuleList doesn't have it
  • Index-based access: get(i) or index operator (if supported)
  • Dynamic append: Use append() to add modules during construction

Examples

// Build a ResNet-style stack with variable depth
class ResNetWithModuleList extends torch.nn.Module {
  layers: torch.nn.ModuleList;

  constructor(num_blocks: number, num_features: number) {
    super();
    this.layers = new torch.nn.ModuleList();
    for (let i = 0; i < num_blocks; i++) {
      const block = new ResNetBlock(num_features);
      this.layers.append(block);
    }
  }

  forward(x: torch.Tensor): torch.Tensor {
    for (const layer of this.layers) {
      x = (layer).forward(x);
    }
    return x;
  }
}

const model = new ResNetWithModuleList(50, 64);  // 50 blocks
// Transformer encoder stack
class TransformerEncoder extends torch.nn.Module {
  layers: torch.nn.ModuleList;
  norm: torch.nn.LayerNorm;

  constructor(num_layers: number, d_model: number, nhead: number) {
    super();
    this.layers = new torch.nn.ModuleList();
    for (let i = 0; i < num_layers; i++) {
      const layer = new torch.nn.TransformerEncoderLayer(
        d_model,
        nhead,
        2048,
        0.1,
        'relu'
      );
      this.layers.append(layer);
    }
    this.norm = new torch.nn.LayerNorm(d_model);
  }

  forward(x: torch.Tensor): torch.Tensor {
    for (const layer of this.layers) {
      x = (layer).forward(x);
    }
    return this.norm.forward(x);
  }
}
// MLP with dynamic width and depth
class DynamicMLP extends torch.nn.Module {
  layers: torch.nn.ModuleList;

  constructor(input_size: number, hidden_sizes: number[]) {
    super();
    this.layers = new torch.nn.ModuleList();

    let in_size = input_size;
    for (const hidden_size of hidden_sizes) {
      this.layers.append(new torch.nn.Linear(in_size, hidden_size));
      in_size = hidden_size;
    }
  }

  forward(x: torch.Tensor): torch.Tensor {
    for (let i = 0; i < this.layers.length; i++) {
      x = (this.layers.get(i)).forward(x);
      if (i < this.layers.length - 1) {
        x = torch.nn.functional.relu(x);
      }
    }
    return x;
  }
}

const mlp = new DynamicMLP(784, [512, 256, 128, 10]);

See Also

  • PyTorch torch.nn.ModuleList
Previous
ModuleDictOptions
Next
ModuleList.[Symbol.iterator]