torch.nn.functional.triplet_margin_loss
function triplet_margin_loss(anchor: Tensor, positive: Tensor, negative: Tensor): Tensorfunction triplet_margin_loss(anchor: Tensor, positive: Tensor, negative: Tensor, margin: number, p: number, eps: number, swap: boolean, size_average: boolean | null, reduce: boolean | null, reduction: 'none' | 'mean' | 'sum', options: TripletMarginLossFunctionalOptions): TensorTriplet Margin Loss: learns embeddings with relative distances between triplets.
Measures distances in a triplet: (anchor, positive, negative). Pulls positive close to anchor while pushing negative far from anchor, with explicit margin separation. Essential for:
- Siamese/Triplet networks (standard architecture for metric learning)
- Face recognition (Facenet, VGGFace) - foundational loss for face embedding
- Person re-identification (ReID) - cross-camera pedestrian matching
- Image retrieval and ranking (pull relevant images close, irrelevant far)
- One-shot/few-shot learning (learn from minimal labeled examples)
- Metric learning and distance-based classification
- Deep metric learning (modern application to large-scale retrieval)
Triplet loss intuition: Loss = max(0, d(anchor, positive) - d(anchor, negative) + margin) Optimizes relative distances: positive should be closer than negative by at least margin. Unlike classification loss (absolute), triplet loss focuses on relative ranking.
Key properties:
- Relative ranking: Only relative distances matter, not absolute values
- Margin parameter: Explicit safety gap between positive and negative
- Hard negatives: Loss strongly drives on difficult negatives (hard mining crucial)
- Metric learning: Directly optimizes Lp distances (L2 standard)
FaceNet landmark paper: Modern face recognition builds on triplet loss + CNN. Training uses hard negative mining. Embeddings of same person cluster, embeddings of different people separate.
- Hard negative mining critical: Random negatives often easy; mining hard ones crucial
- Batch construction: Sampling strategy (batch hard mining) significantly impacts performance
- Margin parameter: Tune based on embedding scale; typically 0.5-2.0 for L2
- Embedding normalization: Often normalize embeddings to unit sphere (cosine distance)
- Relative ranking: Only relative ordering matters; absolute distances less important
- Convergence: Triple loss slower than classification but better for metric learning
- Data efficiency: Triplet/siamese networks learn from limited labeled data
- Sampling strategy matters: Random sampling often yields too easy negatives
- Margin tuning: Too small → underconstrained; too large → may be unsatisfiable
- Batch size: Small batches limit mining diversity; 32+ recommended
- No negative in batch: Ensures positive harder than negative in other samples too
- Non-differentiable max: max(0, ...) has zero gradient when margin satisfied
Parameters
anchorTensor- Anchor embedding tensor of shape [batch, embedding_dim] Example: first face image embeddings [batch, 128]
positiveTensor- Positive embedding tensor of shape [batch, embedding_dim] (similar to anchor) Example: second image of same person [batch, 128]
negativeTensor- Negative embedding tensor of shape [batch, embedding_dim] (dissimilar to anchor) Example: image of different person [batch, 128]
Returns
Tensor– Loss tensor (scalar if reduction='mean', or [batch] if reduction='none')Examples
// FaceNet-style face recognition
const anchor_embed = model(person_a_img); // [batch, 128] - reference image
const positive_embed = model(person_a_img2); // [batch, 128] - another image same person
const negative_embed = model(person_b_img); // [batch, 128] - different person
const loss = torch.nn.functional.triplet_margin_loss(
anchor_embed, positive_embed, negative_embed,
margin=1.0, p=2
);
// positive_dist + margin ≤ negative_dist → loss = 0// Person re-identification with hard negative mining
const anchor = encoder(person_query); // [batch, 2048]
const positive = encoder(same_person_gallery); // [batch, 2048]
// Hard negatives: most difficult examples (closest imposters)
// Simple strategy: random different person (in practice: mine hardest negatives)
const negative = encoder(different_person_hard); // [batch, 2048]
const reid_loss = torch.nn.functional.triplet_margin_loss(
anchor, positive, negative, margin=0.5
);// One-shot learning: siamese network with triplet loss
const query_embedding = siamese_net(query_image); // [1, 256]
const support_positive = siamese_net(support_same_class); // [1, 256]
const support_negative = siamese_net(support_diff_class); // [1, 256]
const loss = torch.nn.functional.triplet_margin_loss(
query_embedding.expand([batch, 256]),
support_positive.expand([batch, 256]),
support_negative.expand([batch, 256]),
margin=0.5
);// Image retrieval: learn embeddings for semantic search
const anchor_img = model(query_image); // [batch, 512]
const pos_img = model(relevant_image); // [batch, 512]
const neg_img = model(irrelevant_image); // [batch, 512]
const retrieval_loss = torch.nn.functional.triplet_margin_loss(
anchor_img, pos_img, neg_img,
margin=1.0,
p=2
);
// Embeddings clustered: relevant images close, irrelevant far// Batch hard triplet mining for training efficiency
const embeddings = model(batch_images); // [batch_size, 128]
// Construct triplets within batch (online hard mining)
let total_loss = 0;
for (let i = 0; i < batch_size; i++) {
// Find hardest positive and negative in batch
const pos_idx = findHardestPositive(embeddings, labels, i);
const neg_idx = findHardestNegative(embeddings, labels, i);
const triplet_loss = torch.nn.functional.triplet_margin_loss(
embeddings.slice([i, i+1], 0),
embeddings.slice([pos_idx, pos_idx+1], 0),
embeddings.slice([neg_idx, neg_idx+1], 0),
margin=0.5
);
total_loss += triplet_loss;
}// Different Lp norms comparison
const anchor = torch.randn([32, 256]);
const pos = torch.randn([32, 256]);
const neg = torch.randn([32, 256]);
// L2 norm (Euclidean): standard for embeddings
const loss_l2 = torch.nn.functional.triplet_margin_loss(anchor, pos, neg, 1.0, 2);
// L1 norm (Manhattan): more robust to outliers
const loss_l1 = torch.nn.functional.triplet_margin_loss(anchor, pos, neg, 1.0, 1);See Also
- PyTorch torch.nn.functional.triplet_margin_loss
- cosine_embedding_loss - Similar but uses cosine similarity (angle-based)
- torch.nn.TripletMarginLoss - Module wrapper with learnable parameters
- contrastive_loss - Pairwise alternative (simpler but often less effective)