torch.nn.CosineEmbeddingLossOptions
Cosine Embedding Loss: metric learning loss for comparing embeddings via cosine similarity.
Measures the cosine distance between two vectors with a margin. Designed for similarity learning tasks where you want similar pairs to have high cosine similarity and dissimilar pairs to have low cosine similarity. Common in:
- Siamese networks (comparing two inputs)
- Face recognition (matching faces)
- Person re-identification
- Metric learning where relative similarity matters
When to use CosineEmbeddingLoss:
- Siamese/Triplet networks (pair-wise comparison of embeddings)
- Learning embeddings where cosine distance is the metric
- Face recognition, person re-ID, similarity learning
- When you want to push similar pairs together and dissimilar pairs apart
- Learning representations where direction matters, not magnitude
Trade-offs:
- vs TripletMarginLoss: Both metric learning; triplet is harder mining (3 samples), cosine is pair-wise (2 samples)
- vs MSELoss: Cosine similarity is invariant to magnitude, better for embeddings
- Pair-wise: Only takes 2 embeddings vs triplet's 3 (anchor, positive, negative)
- Margin interpretation: In cosine space (-1 to 1), margin controls separation
Algorithm: For each pair (x1, x2) with label y ∈ {1, -1}:
- similarity = cosine_similarity(x1, x2) = dot(x1, x2) / (||x1|| * ||x2||)
- If y == 1: loss = max(0, margin - similarity) (push together)
- If y == -1: loss = max(0, similarity - margin) (push apart)
Negative target means dissimilar, positive means similar.
Definition
export interface CosineEmbeddingLossOptions {
/** Margin for dissimilar pairs (default: 0) */
margin?: number;
/** How to reduce the loss ('none' | 'mean' | 'sum', default: 'mean') */
reduction?: Reduction;
}margin(number)optional- – Margin for dissimilar pairs (default: 0)
reduction(Reduction)optional- – How to reduce the loss ('none' | 'mean' | 'sum', default: 'mean')
Examples
// Siamese network: learn embeddings of paired inputs
const cosine_loss = new torch.nn.CosineEmbeddingLoss(0.0);
// Two embeddings from same class (similar)
const embedding1 = torch.randn([32, 128]); // 32 pairs, 128-dim embeddings
const embedding2 = torch.randn([32, 128]);
const target = torch.ones([32]); // 1 means similar
const loss = cosine_loss.forward(embedding1, embedding2, target);
// Encourages cosine_similarity(embedding1, embedding2) to be high// With negative pairs (dissimilar)
const pos_emb1 = torch.randn([16, 128]);
const pos_emb2 = torch.randn([16, 128]);
const pos_target = torch.ones([16]); // Similar (y=1)
const neg_emb1 = torch.randn([16, 128]);
const neg_emb2 = torch.randn([16, 128]);
const neg_target = torch.ones([16]).mul(-1); // Dissimilar (y=-1)
// Concatenate into batch
const all_emb1 = torch.cat([pos_emb1, neg_emb1], 0);
const all_emb2 = torch.cat([pos_emb2, neg_emb2], 0);
const all_target = torch.cat([pos_target, neg_target], 0);
const cosine_loss = new torch.nn.CosineEmbeddingLoss(0.0);
const loss = cosine_loss.forward(all_emb1, all_emb2, all_target);
// Simultaneously pulls similar pairs together and pushes dissimilar apart// Face recognition with margin
const face_extractor = new FaceFeatureExtractor(); // Outputs embeddings
const cosine_loss = new torch.nn.CosineEmbeddingLoss(0.25); // Margin = 0.25
const face1 = face_extractor.forward(image1); // [1, 512] embedding
const face2 = face_extractor.forward(image2); // [1, 512] embedding
const is_same = torch.tensor([1]); // 1 if same person, -1 if different
const loss = cosine_loss.forward(face1, face2, is_same);
// With margin, requires high similarity for same person, low for different// Person re-identification: matching gallery to query
const query_embedding = torch.randn([256]);
const gallery_embeddings = torch.randn([1000, 256]);
// Expand query to match batch
const query_batch = query_embedding.unsqueeze(0).expand([1000, 256]);
// Assume first 100 gallery samples are same person (target=1), others different (target=-1)
const target = torch.cat([torch.ones([100]), torch.ones([900]).mul(-1)]);
const cosine_loss = new torch.nn.CosineEmbeddingLoss(0.1);
const loss = cosine_loss.forward(query_batch, gallery_embeddings, target);