Skip to main content
torch.jstorch.jstorch.js
Getting StartedPlaygroundContact
Login
torch.jstorch.jstorch.js
Documentation
IntroductionType SafetyTensor IndexingEinsumEinopsAutogradTraining a ModelProfiling & MemoryPyTorch MigrationBest PracticesRuntimesPerformance
ActivationOptionsAdaptiveAvgPool1dAdaptiveAvgPool2dAdaptiveAvgPool3dAdaptiveLogSoftmaxWithLossAdaptiveMaxPool1dAdaptiveMaxPool2dAdaptiveMaxPool3dadd_moduleAlphaDropoutappendappendapplyAvgPool1dAvgPool1dOptionsAvgPool2dAvgPool2dOptionsAvgPool3dAvgPool3dOptionsBackwardHookBackwardPreHookBatchNorm1dBatchNorm2dBatchNorm3dBatchNormOptionsBCELossBCEWithLogitsLossBilinearBufferBufferOptionsBufferRegistrationHookbufferscallCELUCELUOptionsChannelShufflechildrenCircularPad1dCircularPad2dCircularPad3dclearConstantPad1dConstantPad2dConstantPad3dConv1dConv2dConv3dConvOptionsConvTranspose1dConvTranspose2dConvTranspose3dConvTransposeOptionsCosineEmbeddingLossCosineEmbeddingLossOptionsCosineSimilarityCosineSimilarityOptionscreatecreateCrossEntropyLossCTCLossdecodedecodedeleteDropoutDropout1dDropout2dDropout3dDropoutOptionsELUELUOptionsEmbeddingEmbeddingBagencodeencodeentriesentriesevalextendFeatureAlphaDropoutFlattenFlattenOptionsFoldFoldOptionsforwardforwardforwardforwardforwardforwardforwardforwardforward_with_targetForwardHookForwardPreHookFractionalMaxPool2dFractionalMaxPool3dfrom_pretrainedfrom_pretrainedGaussianNLLLossGELUGELUOptionsgenerate_square_subsequent_maskgetgetgetgetgetget_bufferget_parameterget_submoduleGLUGLUOptionsGroupNormGroupNormOptionsGRUGRUCellHardshrinkHardshrinkOptionsHardsigmoidHardswishHardtanhHardtanhOptionshashasHingeEmbeddingLossHingeEmbeddingLossOptionsHuberLossHuberLossOptionsIdentityInstanceNorm1dInstanceNorm2dInstanceNorm3dInstanceNormOptionsis_uninitialized_bufferis_uninitialized_parameteriterator]iterator]iterator]keyskeysKLDivLossL1LossL1LossOptionsLayerNormLayerNormOptionsLazyBatchNorm1dLazyBatchNorm2dLazyBatchNorm3dLazyConv1dLazyConv2dLazyConv3dLazyConvOptionsLazyConvTranspose1dLazyConvTranspose2dLazyConvTranspose3dLazyConvTransposeOptionsLazyInstanceNorm1dLazyInstanceNorm2dLazyInstanceNorm3dLazyLinearLeakyReLULeakyReLUOptionsLinearLinearOptionsload_state_dictLocalResponseNormLocalResponseNormOptionslog_probLogSigmoidLogSoftmaxLogSoftmaxOptionsLPPool1dLPPool1dOptionsLPPool2dLPPool2dOptionsLPPool3dLPPool3dOptionsLSTMLSTMCellLSTMCellOptionsMarginRankingLossMarginRankingLossOptionsmaterializematerializematerialize_uninitializedmaterialize_uninitializedMaxPool1dMaxPool1dOptionsMaxPool2dMaxPool2dOptionsMaxPool3dMaxPool3dOptionsMaxUnpool1dMaxUnpool1dOptionsMaxUnpool2dMaxUnpool2dOptionsMaxUnpool3dMaxUnpool3dOptionsMishModuleModuleBuffersModuleChildrenModuleDictModuleListModuleParametersModuleRegistrationHookmodulesMSELossMSELossOptionsmultihead_attnMultiheadAttentionMultiheadAttentionOptionsMultiLabelMarginLossMultiLabelMarginLossOptionsMultiLabelSoftMarginLossMultiMarginLossnamed_buffersnamed_childrennamed_modulesnamed_parametersNLLLossnum_parametersPairwiseDistancePairwiseDistanceOptionsParameterParameterDictParameterListParameterOptionsParameterRegistrationHookparametersPixelShufflePixelUnshufflePoissonNLLLosspopPReLUPReLUOptionsReflectionPad1dReflectionPad2dReflectionPad3dregister_backward_hookregister_bufferregister_forward_hookregister_forward_pre_hookregister_full_backward_hookregister_full_backward_pre_hookregister_module_backward_hookregister_module_buffer_registration_hookregister_module_forward_hookregister_module_forward_pre_hookregister_module_full_backward_hookregister_module_full_backward_pre_hookregister_module_module_registration_hookregister_module_parameter_registration_hookregister_parameterReLUReLU6RemovableHandleremoveReplicationPad1dReplicationPad2dReplicationPad3dRMSNormRNNRNNBaseRNNBaseOptionsRNNCellRNNCellOptionsRReLURReLUOptionsrunSELUSequentialsetsetsetSigmoidSiLUSmoothL1LossSmoothL1LossOptionsSoftMarginLossSoftMarginLossOptionsSoftmaxSoftmax2dSoftmaxOptionsSoftminSoftminOptionsSoftplusSoftplusOptionsSoftshrinkSoftshrinkOptionsSoftsignstate_dictSyncBatchNormTanhTanhshrinkThresholdThresholdOptionstotrainTransformerTransformerDecoderTransformerDecoderLayerTransformerDecoderLayerOptionsTransformerDecoderOptionsTransformerEncoderTransformerEncoderLayerTransformerEncoderLayerOptionsTransformerEncoderOptionsTransformerOptionsTripletMarginLossTripletMarginWithDistanceLossUnflattenUnfoldUnfoldOptionsUninitializedBufferUninitializedOptionsUninitializedParameterupdateUpsampleUpsamplingBilinear2dUpsamplingNearest2dvaluesvalueszero_gradZeroPad1dZeroPad2dZeroPad3d
absacosacoshaddaddbmmAddbmmOptionsaddcdivAddcdivOptionsaddcmulAddcmulOptionsaddmmAddmmOptionsaddmvAddmvOptionsaddrAddrOptionsadjointallallcloseAllcloseOptionsamaxaminaminmaxangleanyapplyOutarangeare_deterministic_algorithms_enabledargmaxargminargsortargwhereas_stridedas_tensorasinasinhAssertNoShapeErrorAtat_error_index_out_of_boundsatanatan2atanhatleast_1datleast_2datleast_3dAtShapeautocast_decrement_nestingautocast_increment_nestingAxesRecordbaddbmmBaddbmmOptionsbatch_dimensions_do_not_match_errorbernoulliBinaryOptionsbincountbitwise_andbitwise_left_shiftbitwise_notbitwise_orbitwise_right_shiftbitwise_xorblock_diagbmmbroadcast_error_incompatible_dimensionsbroadcast_shapesbroadcast_tensorsbroadcast_toBroadcastShapebroadcastShapesbucketizecanBroadcastTocartesian_prodcatCatShapecdistceilchain_matmulCholeskyShapechunkchunk_error_dim_out_of_rangeclampclear_autocast_cacheclonecolumn_stackcombinationscompiled_with_cxx11_abicomplexconjconj_physicalcontiguouscopysigncorrcoefcoscoshcount_nonzerocovCPUTensorDatacreateTorchCumExtremeResultcummaxcummincumprodCumShapecumsumcumulative_trapezoidCumulativeOptionsdeg2raddetachDetShapeDeviceDeviceInputDeviceTypediagdiag_embeddiagflatdiagonal_scatterDiagShapediffdigammadimension_error_out_of_rangedistdivdotdsplitdstackDTypeDynamicShapeEigShapeeinops_error_ambiguous_decompositioneinops_error_anonymous_in_outputeinops_error_dimension_mismatcheinops_error_invalid_patterneinops_error_reduce_undefined_outputeinops_error_repeat_missing_sizeeinops_error_undefined_axiseinsumeinsum_error_dimension_mismatcheinsum_error_index_out_of_rangeeinsum_error_invalid_equationeinsum_error_invalid_sublist_elementeinsum_error_operand_count_mismatcheinsum_error_subscript_rank_mismatcheinsum_error_unknown_output_indexEinsumOutputShapeEllipsiseluembedding_bag_error_requires_2d_inputemptyempty_cacheempty_likeeqequalerferfcerfinvexpexp2expandexpand_asexpand_error_incompatibleExpandShapeexpm1eyeEyeOptionsflattenFlattenShapeflipflip_error_dim_out_of_rangefliplrFlipShapeflipudfloat_powerfloorfloor_dividefmaxfminfmodfracfrexpfrombufferfullfull_likegathergather_error_dim_out_of_rangeGatherShapegcdgegeluget_autocast_cpu_dtypeget_autocast_gpu_dtypeget_autocast_ipu_dtypeget_autocast_xla_dtypeget_default_deviceget_default_dtypeget_deterministic_debug_modeget_device_moduleget_file_pathget_float32_matmul_precisionget_num_interop_threadsget_num_threadsget_printoptionsget_rng_stateGradFngthardsigmoidhardswishHasShapeErrorheavisidehistchistogramHistogramResulthsplithstackhypoti0imagindex_addindex_copyindex_fillindex_putindex_reduceindex_selectindex_select_error_dim_out_of_rangeIndexSelectShapeIndexSpecIndicesSpecinverseInverseShapeis_anomaly_check_nan_enabledis_anomaly_enabledis_autocast_cache_enabledis_autocast_cpu_enabledis_autocast_ipu_enabledis_autocast_xla_enabledis_complexis_complex_dtypeis_cpu_only_modeis_deterministic_algorithms_warn_only_enabledis_floating_pointis_floating_point_dtypeis_inference_mode_enabledis_nonzerois_tensoris_warn_always_enabledis_webgpu_availableIs2DIsAtLeast1DiscloseIscloseOptionsisfiniteisinisinfisnanisneginfisposinfisrealIsShapeErroritem_error_not_scalarItemResultkronkthvalueKthvalueOptionslcmldexpleleaky_relulerplgammalinalg_error_not_square_matrixlinalg_error_requires_2dlinalg_error_requires_at_least_2dlinspaceloglog10log1plog2logaddexplogaddexp2logcumsumexplogical_andlogical_notlogical_orlogical_xorlogitlogspacelogsumexpltLUShapemasked_selectmasked_select_asyncMaskSpecmatmulmatmul_error_inner_dimensions_do_not_matchMatmul2DShapeMatmulShapemaxmaximummeanmedianmemory_statsmemory_summarymeshgridminminimummmmodemovedimmsortmulmultinomialmultinomial_asyncmvnan_to_numnanmeannanmediannanquantilenansumnarrownarrow_copynarrow_error_length_exceeds_boundsnarrow_error_start_out_of_boundsNarrowShapeneneedsBroadcastnegNegativeDimnextafternonzeronormnormalNormOptionsnumelonesones_likeouterpackPackShapepermutepermute_error_dimension_count_mismatchPermuteShapepoissonpolarpositivepowPrintOptionsprodprofiler_allow_cudagraph_cupti_lazy_reinit_cuda12promote_typesquantileQuantileOptionsrad2degrandrand_likerandintrandint_likerandnrandn_likerandpermRangeSpecRankravelrealRearrangeShapereciprocalreduceReduceOperationReduceShapeReductionOptionsreluremainderrepeatrepeat_interleaveRepeatInterleaveOptionsRepeatShaperequireWebGPUreset_peak_memory_statsreshapeReshapeShaperesult_typerollrot90roundrsqrtscatterscatter_addscatter_add_scatter_error_dim_out_of_rangescatter_reducescatter_reduce_ScatterShapesearchsortedselectselect_error_index_out_of_boundsselect_scatterSelectShapeseluset_default_deviceset_default_tensor_typeset_deterministic_debug_modeset_float32_matmul_precisionset_printoptionsset_warn_alwaysShapeShapedTensorsigmoidsignsignbitsilusinsincsinhslice_error_out_of_boundsslice_scatterSliceShapeSliceSpecsoftmax_error_dim_out_of_rangeSoftmaxShapesoftplussoftsignsortSortOptionssplitsplit_error_dim_out_of_rangesqrtsquaresqueezeSqueezeShapestackstdstd_meanStdVarOptionssubSublistSublistElementSubscriptIndexsumSVDShapeswapaxessym_floatsym_intsym_notttaketake_along_dimtantanhtensortensor_splitTensorCreatorTensorDatatensordotTensorOptionsTensorStoragetileTileShapetopkTopkOptionsTorchtraceTraceShapetransposetranspose_dims_error_out_of_rangetranspose_error_requires_2d_tensorTransposeDimsShapeTransposeDimsShapeCheckedTransposeShapetrapezoidtriltril_indicestriutriu_indicestruncTypedArrayTypedStorageUnaryOptionsunbindunbind_error_dim_out_of_rangeunflattenuniqueunique_consecutiveunpackUnpackShapeunravel_indexunsqueezeUnsqueezeShapeuse_deterministic_algorithmsValidateBatchedSquareMatrixValidateChunkDimValidatedEinsumShapevalidateDeviceValidatedRearrangeShapeValidatedReduceShapeValidatedRepeatShapevalidateDTypeValidateEinsumValidateOperandCountValidateRanksValidateScalarValidateSplitDimValidateSquareMatrixValidateUnbindDimvar_var_meanvdotviewview_as_complexview_as_realvmapvsplitvstackWebGPUTensorDatawherexlogyzeroszeros_like
torch.js· 2026
LegalTerms of UsePrivacy Policy
/
/
  1. docs
  2. torch.js
  3. torch
  4. nn
  5. CosineEmbeddingLossOptions

torch.nn.CosineEmbeddingLossOptions

Cosine Embedding Loss: metric learning loss for comparing embeddings via cosine similarity.

Measures the cosine distance between two vectors with a margin. Designed for similarity learning tasks where you want similar pairs to have high cosine similarity and dissimilar pairs to have low cosine similarity. Common in:

  • Siamese networks (comparing two inputs)
  • Face recognition (matching faces)
  • Person re-identification
  • Metric learning where relative similarity matters

When to use CosineEmbeddingLoss:

  • Siamese/Triplet networks (pair-wise comparison of embeddings)
  • Learning embeddings where cosine distance is the metric
  • Face recognition, person re-ID, similarity learning
  • When you want to push similar pairs together and dissimilar pairs apart
  • Learning representations where direction matters, not magnitude

Trade-offs:

  • vs TripletMarginLoss: Both metric learning; triplet is harder mining (3 samples), cosine is pair-wise (2 samples)
  • vs MSELoss: Cosine similarity is invariant to magnitude, better for embeddings
  • Pair-wise: Only takes 2 embeddings vs triplet's 3 (anchor, positive, negative)
  • Margin interpretation: In cosine space (-1 to 1), margin controls separation

Algorithm: For each pair (x1, x2) with label y ∈ {1, -1}:

  • similarity = cosine_similarity(x1, x2) = dot(x1, x2) / (||x1|| * ||x2||)
  • If y == 1: loss = max(0, margin - similarity) (push together)
  • If y == -1: loss = max(0, similarity - margin) (push apart)

Negative target means dissimilar, positive means similar.

Definition

export interface CosineEmbeddingLossOptions {
  /** Margin for dissimilar pairs (default: 0) */
  margin?: number;
  /** How to reduce the loss ('none' | 'mean' | 'sum', default: 'mean') */
  reduction?: Reduction;
}
margin(number)optional
– Margin for dissimilar pairs (default: 0)
reduction(Reduction)optional
– How to reduce the loss ('none' | 'mean' | 'sum', default: 'mean')

Examples

// Siamese network: learn embeddings of paired inputs
const cosine_loss = new torch.nn.CosineEmbeddingLoss(0.0);

// Two embeddings from same class (similar)
const embedding1 = torch.randn([32, 128]);  // 32 pairs, 128-dim embeddings
const embedding2 = torch.randn([32, 128]);
const target = torch.ones([32]);  // 1 means similar

const loss = cosine_loss.forward(embedding1, embedding2, target);
// Encourages cosine_similarity(embedding1, embedding2) to be high
// With negative pairs (dissimilar)
const pos_emb1 = torch.randn([16, 128]);
const pos_emb2 = torch.randn([16, 128]);
const pos_target = torch.ones([16]);  // Similar (y=1)

const neg_emb1 = torch.randn([16, 128]);
const neg_emb2 = torch.randn([16, 128]);
const neg_target = torch.ones([16]).mul(-1);  // Dissimilar (y=-1)

// Concatenate into batch
const all_emb1 = torch.cat([pos_emb1, neg_emb1], 0);
const all_emb2 = torch.cat([pos_emb2, neg_emb2], 0);
const all_target = torch.cat([pos_target, neg_target], 0);

const cosine_loss = new torch.nn.CosineEmbeddingLoss(0.0);
const loss = cosine_loss.forward(all_emb1, all_emb2, all_target);
// Simultaneously pulls similar pairs together and pushes dissimilar apart
// Face recognition with margin
const face_extractor = new FaceFeatureExtractor();  // Outputs embeddings
const cosine_loss = new torch.nn.CosineEmbeddingLoss(0.25);  // Margin = 0.25

const face1 = face_extractor.forward(image1);  // [1, 512] embedding
const face2 = face_extractor.forward(image2);  // [1, 512] embedding
const is_same = torch.tensor([1]);  // 1 if same person, -1 if different

const loss = cosine_loss.forward(face1, face2, is_same);
// With margin, requires high similarity for same person, low for different
// Person re-identification: matching gallery to query
const query_embedding = torch.randn([256]);
const gallery_embeddings = torch.randn([1000, 256]);

// Expand query to match batch
const query_batch = query_embedding.unsqueeze(0).expand([1000, 256]);

// Assume first 100 gallery samples are same person (target=1), others different (target=-1)
const target = torch.cat([torch.ones([100]), torch.ones([900]).mul(-1)]);

const cosine_loss = new torch.nn.CosineEmbeddingLoss(0.1);
const loss = cosine_loss.forward(query_batch, gallery_embeddings, target);
Previous
CosineEmbeddingLoss
Next
CosineSimilarity