Skip to main content
torch.js has not been released yet.
torch.js logotorch.js logotorch.js
PlaygroundContact
Login
Documentation
IntroductionType SafetyTensor ExpressionsTensor IndexingEinsumEinopsAutogradTraining a ModelProfiling & MemoryPyTorch MigrationBest PracticesRuntimesPerformancePyTorch CompatibilityBenchmarksDType Coverage
ActivationOptionsAdaptiveAvgPool1dAdaptiveAvgPool2dAdaptiveAvgPool3dAdaptiveLogSoftmaxOptionsAdaptiveLogSoftmaxWithLossAdaptiveMaxPool1dAdaptiveMaxPool1dOptionsAdaptiveMaxPool2dAdaptiveMaxPool2dOptionsAdaptiveMaxPool3dAdaptiveMaxPool3dOptionsadd_moduleAlphaDropoutappendappendapplyAvgPool1dAvgPool1dOptionsAvgPool2dAvgPool2dOptionsAvgPool3dAvgPool3dOptionsBackwardHookBackwardPreHookBatchNorm1dBatchNorm2dBatchNorm3dBatchNormOptionsBCELossBCEWithLogitsLossBilinearBilinearOptionsBufferBufferOptionsBufferRegistrationHookbufferscallCELUCELUOptionsChannelShufflechildrenCircularPad1dCircularPad2dCircularPad3dclearConstantPad1dConstantPad2dConstantPad3dConv1dConv2dConv3dConvOptionsConvTranspose1dConvTranspose2dConvTranspose3dConvTransposeOptionsCosineEmbeddingLossCosineEmbeddingLossOptionsCosineSimilarityCosineSimilarityOptionscreatecreateCrossEntropyLossCTCLossdecodedecodedeleteDropoutDropout1dDropout2dDropout3dDropoutOptionsELUELUOptionsEmbeddingEmbeddingBagEmbeddingBagForwardOptionsEmbeddingBagFromPretrainedOptionsEmbeddingBagOptionsEmbeddingFromPretrainedOptionsEmbeddingOptionsencodeencodeentriesentriesevalextendFeatureAlphaDropoutFlattenFlattenOptionsFoldFoldOptionsforwardforwardforwardforwardforwardforwardforwardforwardforwardforwardforwardforwardforward_with_targetForwardHookForwardPreHookFractionalMaxPool2dFractionalMaxPool3dFractionalMaxPoolOptionsfrom_pretrainedfrom_pretrainedGaussianNLLLossGELUGELUOptionsgenerate_square_subsequent_maskgetgetgetgetgetget_bufferget_parameterget_submoduleGLUGLUOptionsGroupNormGroupNormOptionsGRUGRUCellHardshrinkHardshrinkOptionsHardsigmoidHardswishHardtanhHardtanhOptionshashasHingeEmbeddingLossHingeEmbeddingLossOptionsHuberLossHuberLossOptionsIdentityInstanceNorm1dInstanceNorm2dInstanceNorm3dInstanceNormOptionsis_uninitialized_bufferis_uninitialized_parameteriterator]iterator]iterator]iterator]keyskeysKLDivLossL1LossL1LossOptionsLayerNormLayerNormOptionsLazyBatchNorm1dLazyBatchNorm2dLazyBatchNorm3dLazyConv1dLazyConv2dLazyConv3dLazyConvOptionsLazyConvTranspose1dLazyConvTranspose2dLazyConvTranspose3dLazyConvTransposeOptionsLazyInstanceNorm1dLazyInstanceNorm2dLazyInstanceNorm3dLazyLinearLeakyReLULeakyReLUOptionsLinearLinearOptionsload_state_dictload_state_dictLocalResponseNormLocalResponseNormOptionslog_probLogSigmoidLogSoftmaxLogSoftmaxOptionsLPPool1dLPPool1dOptionsLPPool2dLPPool2dOptionsLPPool3dLPPool3dOptionsLSTMLSTMCellLSTMCellOptionsMarginRankingLossMarginRankingLossOptionsmaterializematerializematerialize_uninitializedmaterialize_uninitializedMaxPool1dMaxPool1dOptionsMaxPool2dMaxPool2dOptionsMaxPool3dMaxPool3dOptionsMaxUnpool1dMaxUnpool1dOptionsMaxUnpool2dMaxUnpool2dOptionsMaxUnpool3dMaxUnpool3dOptionsMishModuleModuleBuffersModuleChildrenModuleDictModuleDictOptionsModuleListModuleListOptionsModuleParametersModuleRegistrationHookmodulesMSELossMSELossOptionsmultihead_attnMultiheadAttentionMultiheadAttentionOptionsMultiheadAttnOptionsMultiLabelMarginLossMultiLabelMarginLossOptionsMultiLabelSoftMarginLossMultiMarginLossnamed_buffersnamed_childrennamed_modulesnamed_parametersNamedModulesOptionsNamedRecurseOptionsNLLLossnum_parametersNumParametersOptionsPairwiseDistancePairwiseDistanceOptionsParameterParameterDictParameterDictOptionsParameterListParameterListOptionsParameterOptionsParameterRegistrationHookparametersPixelShufflePixelUnshufflePoissonNLLLosspoppopPReLUPReLUOptionsRecurseOptionsReflectionPad1dReflectionPad2dReflectionPad3dregister_backward_hookregister_bufferregister_forward_hookregister_forward_pre_hookregister_full_backward_hookregister_full_backward_pre_hookregister_module_backward_hookregister_module_buffer_registration_hookregister_module_forward_hookregister_module_forward_pre_hookregister_module_full_backward_hookregister_module_full_backward_pre_hookregister_module_module_registration_hookregister_module_parameter_registration_hookregister_parameterReLUReLU6RemovableHandleremoveReplicationPad1dReplicationPad2dReplicationPad3dRMSNormRMSNormOptionsRNNRNNBaseRNNBaseOptionsRNNCellRNNCellOptionsRReLURReLUOptionsrunrunSELUSequentialsetsetsetSigmoidSiLUSmoothL1LossSmoothL1LossOptionsSoftMarginLossSoftMarginLossOptionsSoftmaxSoftmax2dSoftmaxOptionsSoftminSoftminOptionsSoftplusSoftplusOptionsSoftshrinkSoftshrinkOptionsSoftsignstate_dictstate_dictStateDictOptionsstepSyncBatchNormTanhTanhshrinkThresholdThresholdOptionstotototrainTrainOptionsTransformerTransformerDecoderTransformerDecoderDecodeOptionsTransformerDecoderLayerTransformerDecoderLayerDecodeOptionsTransformerDecoderLayerOptionsTransformerDecoderOptionsTransformerEncoderTransformerEncoderEncodeOptionsTransformerEncoderLayerTransformerEncoderLayerEncodeOptionsTransformerEncoderLayerOptionsTransformerEncoderOptionsTransformerOptionsTransformerRunOptionsTripletMarginLossTripletMarginWithDistanceLossUnflattenUnfoldUnfoldOptionsUninitializedBufferUninitializedOptionsUninitializedParameterupdateUpsampleUpsamplingBilinear2dUpsamplingNearest2dvaluesvalueszero_gradZeroPad1dZeroPad2dZeroPad3d
absacosacoshAdaptivePool1dShapeAdaptivePool2dShapeaddaddbmmAddbmmOptionsaddcdivAddcdivOptionsaddcmulAddcmulOptionsaddmmAddmmOptionsaddmvAddmvOptionsaddrAddrOptionsadjointallallcloseAllcloseOptionsAlphaBetaOptionsamaxaminaminmaxAminmaxOptionsangleanyapplyOutarangeare_deterministic_algorithms_enabledargmaxargminargsortargwhereas_stridedas_tensorasinasinhAssertNoShapeErrorAssertNotErrorAsStridedOptionsAtat_error_index_out_of_boundsatanatan2atanhatleast_1datleast_2datleast_3dAtShapeautocast_decrement_nestingautocast_increment_nestingautograd_gradient_mismatch_errorautograd_not_registered_errorAutogradConfigAutogradDeviceAutogradDTypeAutogradEntryAutogradHandleAutogradHandleImplAxesRecordBackwardFnbaddbmmBaddbmmOptionsbartlett_windowBaseKernelConfigbatch_dimensions_do_not_match_errorbernoulliBernoulliOptionsBinaryBackwardFnBinaryBroadcastResultBinaryDTypeBinaryKernelConfigCPUBinaryKernelCPUBinaryOpConfigBinaryOpNamesBinaryOpSchemaBinaryOptionsbincountBincountOptionsbitwise_andbitwise_left_shiftbitwise_notbitwise_orbitwise_right_shiftbitwise_xorblackman_windowblock_diagbmmBooleanDTypeRulebroadcast_error_incompatible_dimensionsbroadcast_shapesbroadcast_tensorsbroadcast_toBroadcastShapeBroadcastShapeRulebroadcastShapesbucketizeBucketizeOptionsBufferUsagebuildEinopsErrorbuildErrorMessagecanBroadcastTocartesian_prodcatCatOptionsCatShapeCauchyOptionscdistCdistOptionsceilceluCeluFunctionalOptionschain_matmulCheckShapeErrorCholeskyShapechunkchunk_error_dim_out_of_rangeChunkOptionsclampClampOptionsclear_autocast_cacheclearEinopsCacheclearEinsumCacheclonecolumn_stackcombinationsCombinationsOptionscompiled_with_cxx11_abicomplexconjconj_physicalcontiguousConv1dShapeConv2dShapeConv3dShapeConvTranspose2dShapecopysigncorrcoefcoscoshcount_nonzeroCountNonzeroOptionscovcoverage_reportcoverageReportCoverageReportCovOptionsCPUForwardFnCPUKernelConfigCPUKernelEntryCPUOnlyResultCPUTensorDatacreateCumExtremeResultcreateTorchCreationOpSchemaCumExtremeResultcummaxcummincumprodCumShapecumsumcumulative_trapezoidCumulativeOptionsCumulativeOptionsWithDimdeg2raddetachDeterministicOptionsDetShapeDevicedevice_error_requiresDeviceBufferDeviceCapabilitiesDeviceCheckedResultDeviceConfigDeviceContextDeviceEntryDeviceHandleDeviceInputDeviceOptionsDeviceRegistryDeviceTypediagdiag_embedDiagEmbedOptionsdiagflatDiagflatOptionsDiagFlatOptionsdiagonal_scatterDiagonalOptionsDiagonalScatterOptionsDiagOptionsDiagShapediffDiffOptionsdigammadimension_error_out_of_rangeDispatchConfigdistDistOptionsdivdotDotShapeRuleDoubleDoubleDimdropoutDropoutFunctionalOptionsdsplitdstackDTypedtype_already_registered_errordtype_components_mismatch_errordtype_not_found_errorDTypeComponentsDTypeConfigDTypeCoverageReportDTypeDisplayConfigDTypeEntryDTypeHandleDTypeHandleImplDTypeInfoDTypeRegistryDTypeRuleDTypeSerializationConfigDynamicShapeEigShapeeinops_error_ambiguous_decompositioneinops_error_anonymous_in_outputeinops_error_dimension_mismatcheinops_error_invalid_patterneinops_error_reduce_undefined_outputeinops_error_repeat_missing_sizeeinops_error_undefined_axiseinsumeinsum_error_dimension_mismatcheinsum_error_index_out_of_rangeeinsum_error_invalid_equationeinsum_error_invalid_sublist_elementeinsum_error_operand_count_mismatcheinsum_error_subscript_rank_mismatcheinsum_error_unknown_output_indexEinsumOptionsEinsumOutputShapeEllipsiseluelu_EluFunctionalOptionsembedding_bag_error_requires_2d_inputemptyempty_cacheempty_likeeqequalerferfcerfinvexpexp2expandexpand_asexpand_error_incompatibleExpandShapeexpm1ExponentialOptionseyeEyeOptionsfftFFTOptionsfindKernelWithPredicatefindSimilarPatternsflattenFlattenOptionsFlattenShapeflipflip_error_dim_out_of_rangefliplrFlipShapeflipudfloat_powerFloatDTypeRulefloorfloor_dividefmaxfminfmodformatEquationErrorformatShapefracfrexpfrombufferfullfull_likefunction_already_registered_errorFunctionConfigFunctionEntryFunctionHandlegathergather_error_dim_out_of_rangeGatherShapegcdgegeluGeometricOptionsget_autocast_cpu_dtypeget_autocast_gpu_dtypeget_autocast_ipu_dtypeget_autocast_xla_dtypeget_default_deviceget_default_dtypeget_deterministic_debug_modeget_device_configget_device_contextget_device_moduleget_dtype_infoget_file_pathget_float32_matmul_precisionget_num_interop_threadsget_num_threadsget_op_infoget_printoptionsget_real_dtypeget_rng_stategetAutogradgetDTypegetEinopsCacheSizegetEinsumCacheSizegetFunctiongetKernelgetMethodgetOpInfoGetOpKindGetOpSchemagetScalarKernelgluGluFunctionalOptionsGradContextGradFnGradientsForgtHalfHalfDimhamming_windowhann_windowhardshrinkhardsigmoidhardswishhardtanhhardtanh_HardtanhFunctionalOptionshas_autogradhas_devicehas_dtypehas_kernelhasAutogradhasDTypehasFunctionhasKernelhasMethodhasScalarKernelHasShapeErrorheavisidehistcHistcOptionshistogramHistogramOptionsHistogramResulthsplithstackhypoti0IdentityShapeifftimagindex_addindex_copyindex_fillindex_putindex_reduceindex_selectindex_select_error_dim_out_of_rangeIndexPutOptionsIndexSelectShapeIndexSpecIndicesOptionsIndicesSpecinitialize_deviceInputsForInsertDiminvalid_config_errorinverseInverseShapeirfftis_anomaly_check_nan_enabledis_anomaly_enabledis_autocast_cache_enabledis_autocast_cpu_enabledis_autocast_ipu_enabledis_autocast_xla_enabledis_complexis_complex_dtypeis_cpu_only_modeis_deterministic_algorithms_warn_only_enabledis_floating_pointis_floating_point_dtypeis_inference_mode_enabledis_nonzerois_tensoris_warn_always_enabledis_webgpu_availableIs2DIsAtLeast1DIsBinaryOpIsBinaryOpNameiscloseIscloseOptionsisfiniteisinisinfisnanisneginfisposinfisrealIsReductionOpIsReductionOpNameIsRegistryErrorIsShapeErroristftISTFTOptionsIsUnaryOpIsUnaryOpNameitem_error_not_scalarItemResultkaiser_windowKaiserWindowOptionskernel_not_registered_errorkernel_signature_mismatch_errorKernelConfigKernelConfigWebGPUKernelEntryKernelHandleKernelInfoKernelPredicateKernelRegistryKernelWebGPUkronkthvalueKthvalueOptionslcmldexpleleaky_reluleaky_relu_LeakyReluFunctionalOptionslerplevenshteinDistancelgammalinalg_error_not_square_matrixlinalg_error_requires_2dlinalg_error_requires_at_least_2dlinearlinspacelist_custom_deviceslist_custom_dtypeslist_deviceslist_dtypeslist_functionslist_kernelslist_methodslist_opslistCustomDTypeslistDTypeslistFunctionslistKernelsListKernelsOptionslistMethodslistOpsListOpsOptionsloglog_softmaxlog10log1plog2logaddexplogaddexp2logcumsumexplogical_andlogical_notlogical_orlogical_xorLogitOptionsLogNormalOptionsLogOptionslogsigmoidlogspacelogsumexpLogsumexpOptionsltLUShapeLuSolveOptionsmasked_fillmasked_selectmasked_select_asyncMaskSpecmatmulmatmul_error_inner_dimensions_do_not_matchMatmul2DShapeMatmulShapeMatmulShapeRuleMatrixTransposeShapemaxmaximummeanmedianmemory_statsmemory_summarymeshgridmethod_already_registered_errormethod_dtype_not_supported_errorMethodConfigMethodEntryMethodHandleminminimummishmmMMShapeRulemodemovedimmsortmulmultinomialmultinomial_asyncMultinomialAsyncOptionsMultinomialOptionsMultiplyBymvMVShapeRulenan_to_numnanmeannanmediannanquantileNanReductionOptionsnansumNanToNumOptionsnarrownarrow_copynarrow_error_length_exceeds_boundsnarrow_error_start_out_of_boundsNarrowShapeneneedsBroadcastnegNegativeDimnextafternonzeroNonzeroOptionsnormnormalNormalOptionsNormOptionsnumelonesones_likeop_kind_mismatch_errorop_not_found_errorOpCoverageEntryOpInfoOpKindOpNameOpSchemaOpSchemasouterOuterShapepackPackShapepermutepermute_error_dimension_count_mismatchPermuteShapepoissonpolarPool1dShapePool2dShapePool3dShapepositivepowpreluPrintOptionsprodprofiler_allow_cudagraph_cupti_lazy_reinit_cuda12promote_typesPromoteDTypeRulePutOptionsquantileQuantileOptionsrad2degrandrand_likerandintrandint_likeRandintLikeOptionsRandintOptionsrandnrandn_likeRandomLikeOptionsRandomOptionsrandpermRangeSpecRankravelrealrearrangeRearrangeOptionsRearrangeShapereciprocalreduceReduceOperationReduceOptionsReduceShapeReductionKernelConfigCPUReductionKernelCPUReductionOpNamesReductionOpSchemaReductionOptionsReductionShapeRuleregister_backwardregister_deviceregister_dtyperegister_forwardregister_functionregister_methodregister_scalar_forwardregisterAutogradRegisterBackwardOptionsregisterBinaryOpregisterDTypeRegisterDTypeOptionsRegisteredDTyperegisterFunctionRegisterFunctionOptionsregisterKernelRegisterKernelOptionsregisterMethodRegisterMethodOptionsregisterScalarKernelregisterUnaryOpregistration_failed_errorrelurelu_relu6ReluFunctionalOptionsremainderRemoveDimrepeatrepeat_interleaveRepeatInterleaveOptionsRepeatOptionsRepeatShapeReplaceDimrequireWebGPUreset_peak_memory_statsreshapeReshapeShaperesult_typerfftrollRollOptionsrot90Rot90Optionsroundrrelurrelu_RreluFunctionalOptionsrsqrtSafeExpandShapeSameDTypeRuleSameShapeRuleSaveForBackwardScalarCPUForwardFnScalarCPUKernelConfigScalarKernelEntryScalarKernelHandleScalarWebGPUKernelConfigScaleDimscatterscatter_addscatter_add_scatter_error_dim_out_of_rangescatter_reducescatter_reduce_ScatterReduceOptionsScatterShapesearchsortedSearchSortedOptionsselectselect_error_index_out_of_boundsselect_scatterSelectShapeseluset_default_deviceset_default_tensor_typeset_deterministic_debug_modeset_float32_matmul_precisionset_printoptionsset_warn_alwaysSetupContextFnShapeShapeCheckedResultShapedTensorShapeErrorMessageShapeOpSchemaShapeRulesigmoidsignsignbitsilusinsincsinhSizeOptionsslice_error_out_of_boundsslice_scatterSliceOptionsSliceScatterOptionsSliceShapeSliceSpecsoftmaxsoftmax_error_dim_out_of_rangeSoftmaxShapesoftminSoftminFunctionalOptionssoftplusSoftplusFunctionalOptionssoftshrinksoftsignsortSortOptionssplitsplit_error_dim_out_of_rangeSplitOptionssqrtsquaresqueezeSqueezeOptionsSqueezeShapestackStackOptionsStackShapestdstd_meanStdVarMeanOptionsStdVarOptionsstftSTFTOptionsStrideOptionssubSublistSublistElementSubscriptIndexsumSVDShapeswapaxessym_floatsym_intsym_notttaketake_along_dimTakeAlongDimOptionstantanhtanhshrinktensortensor_splitTensorCreatorTensorDatatensordotTensordotOptionsTensorLikeTensorMetaTensorOptionsTensorStoragethresholdthreshold_tileTileShapeToOptionstopkTopkOptionsTorchtraceTraceShapetransposetranspose_dims_error_out_of_rangetranspose_error_requires_2d_tensorTransposeDimsShapeTransposeDimsShapeCheckedTransposeShapetrapezoidTrapezoidOptionsTriangularOptionstriltril_indicesTriOptionsTripletriutriu_indicestrue_dividetruncTupleOfLengthTypedArrayTypedArrayForTypedStorageTypeOptionsUnaryBackwardFnUnaryDTypeUnaryKernelConfigCPUUnaryKernelCPUUnaryOpConfigUnaryOpFnUnaryOpNamesUnaryOpParamsUnaryOpSchemaUnaryOptionsunbindunbind_error_dim_out_of_rangeUnbindOptionsunflattenUniformOptionsuniqueunique_consecutiveUniqueConsecutiveOptionsUniqueOptionsunpackUnpackShapeunravel_indexunregister_deviceunsqueezeUnsqueezeOptionsUnsqueezeShapeuse_deterministic_algorithmsValidateBatchedSquareMatrixValidateChunkDimValidatedEinsumShapevalidateDeviceValidateDeviceValidatedRearrangeShapeValidatedReduceShapeValidatedRepeatShapevalidateDTypeValidateEinsumValidateOperandCountValidateRanksValidateScalarValidateSplitDimValidateSquareMatrixValidateUnbindDimValueOptionsvar_var_meanvdotviewview_as_complexview_as_realvmapvsplitvstackWebGPUKernelConfigWebGPUOnlyResultWebGPUTensorDatawhereWindowOptionsxlogyzeroszeros_like
torch.js· 2026
LegalTerms of UsePrivacy Policy
/
/
  1. docs
  2. torch.js
  3. torch
  4. nn
  5. BCELoss

torch.nn.BCELoss

class BCELoss extends Module
new BCELoss(options?: { weight?: Tensor; reduction?: Reduction })
readonlyweight(Tensor | null)
readonlyreduction(Reduction)

Binary Cross Entropy (BCE) Loss: standard loss for binary classification and multi-label classification.

Measures the divergence between predicted probabilities and binary targets (0 or 1). Expects input to already be sigmoid-transformed (values in [0, 1]). Essential for:

  • Binary classification (predicting one output, two classes)
  • Multi-label classification (multiple independent binary decisions per sample)
  • Any task with sigmoid output layer
  • Pixel-wise classification (semantic segmentation masks)
  • Per-element binary decisions

Important: This expects sigmoid-transformed probabilities as input. If you have raw logits, use BCEWithLogitsLoss instead (more numerically stable).

When to use BCELoss:

  • Your model outputs probabilities (after sigmoid)
  • Binary or multi-label classification tasks
  • Custom sigmoid application before loss
  • When you explicitly want to separate sigmoid from loss computation

Trade-offs:

  • vs BCEWithLogitsLoss: Raw logits are more numerically stable; use that if possible
  • vs CrossEntropyLoss: BCE for multi-label (multiple ones per sample); CE for single-label
  • Numerical stability: BCEWithLogitsLoss is preferred (avoids log(0) issues)
  • Explicit sigmoid: BCE requires you to sigmoid first; BCEWithLogitsLoss does it internally

Algorithm: For each element in batch:

  • loss_i = -(target_i * log(pred_i) + (1 - target_i) * log(1 - pred_i))

Numerically unstable with extreme probability values (0 or 1). Use BCEWithLogitsLoss for raw logits to avoid numerical issues.

Li=−(targetilog⁡(inputi)+(1−targeti)log⁡(1−inputi))With weights: Li=wi⋅Li\begin{aligned} L_i = -(\text{target}_i \log(\text{input}_i) + (1 - \text{target}_i) \log(1 - \text{input}_i)) \\ \text{With weights: } L_i = w_i \cdot L_i \end{aligned}Li​=−(targeti​log(inputi​)+(1−targeti​)log(1−inputi​))With weights: Li​=wi​⋅Li​​
  • Input requirements: Input must be probabilities in (0, 1), NOT raw logits
  • Use BCEWithLogitsLoss: If you have raw logits (more numerically stable)
  • Multi-label vs multi-class: BCE for multiple binary decisions, CrossEntropy for single class
  • Numerical stability: Avoid extreme values (0 or 1), use small epsilon if needed
  • Gradient behavior: Larger loss when very far from target, good for learning
  • Common pattern: FC layer → sigmoid → BCE loss
  • Weight usage: Can weight individual samples for importance sampling
  • Computational: O(batch_size × num_elements) - efficient

Examples

// Binary classification with sigmoid output
const bce = new torch.nn.BCELoss();

// Predicted probabilities (must be in [0, 1])
const predictions = torch.sigmoid(torch.randn([32, 1]));

// Binary targets (0 or 1)
const targets = torch.randint(0, 2, [32, 1]);

// Compute loss
const loss = bce.forward(predictions, targets);
// Multi-label classification: predicting multiple independent labels
class MultiLabelClassifier extends torch.nn.Module {
  fc1: torch.nn.Linear;
  fc2: torch.nn.Linear;
  sigmoid: torch.nn.Sigmoid;

  constructor(input_dim: number, num_labels: number) {
    super();
    this.fc1 = new torch.nn.Linear(input_dim, 128);
    this.fc2 = new torch.nn.Linear(128, num_labels);
    this.sigmoid = new torch.nn.Sigmoid();
  }

  forward(x: torch.Tensor): torch.Tensor {
    x = this.fc1.forward(x);
    x = torch.nn.functional.relu(x);
    x = this.fc2.forward(x);       // Raw logits
    x = this.sigmoid.forward(x);   // Convert to probabilities
    return x;                       // [batch, num_labels] in (0, 1)
  }
}

const model = new MultiLabelClassifier(100, 5);  // 5 independent labels
const bce = new torch.nn.BCELoss();

const batch_x = torch.randn([32, 100]);
const batch_y = torch.randint(0, 2, [32, 5]);  // Multi-hot encoded
const probs = model.forward(batch_x);
const loss = bce.forward(probs, batch_y);  // Each label is independent

See Also

  • PyTorch torch.nn.BCELoss
Previous
BatchNormOptions
Next
BCEWithLogitsLoss