torch.nn.TripletMarginLoss
class TripletMarginLoss extends Modulenew TripletMarginLoss(options?: {
margin?: number;
p?: number;
eps?: number;
swap?: boolean;
reduction?: Reduction;
})
- readonly
margin(number) - readonly
p(number) - readonly
eps(number) - readonly
swap(boolean) - readonly
reduction(Reduction)
Triplet Margin Loss: powerful metric learning loss for learning embeddings via relative distances.
Measures the distance relationships between three samples: anchor, positive (same class), and negative (different class). Pushes the anchor-positive distance below the anchor-negative distance by a margin. Standard loss for:
- Learning visual embeddings for face recognition
- Person re-identification
- Deep metric learning
- Ranking and retrieval problems
- Any task requiring relative similarity learning
When to use TripletMarginLoss:
- Face recognition, person re-ID, similarity learning
- Learning embeddings where relative distance is important
- Siamese/Triplet networks architecture
- Hard negative mining for metric learning
- Learning discriminative embeddings
- When you have triplets of (anchor, positive, negative)
Trade-offs:
- vs CosineEmbeddingLoss: Triplet is stronger (3 samples), cosine is pair-wise (2 samples)
- vs MSELoss: Uses relative distances, better for metric learning
- Hard mining: Can focus on hard triplets to improve convergence
- Margin tuning: Critical hyperparameter; typical 0.5-1.0
- Distance metric: Supports different p-norms (p=2 is Euclidean)
Algorithm: For each triplet (anchor, positive, negative):
- d_pos = distance(anchor, positive)
- d_neg = distance(anchor, negative)
- loss = max(0, d_pos - d_neg + margin)
The loss is zero when negative is further than positive by at least margin. Uses Lp norm by default (p=2 for Euclidean distance).
- Metric learning standard: The standard loss for learning embeddings
- Three samples required: Needs anchor, positive, and negative (not pair-wise)
- Relative distances: Optimizes relative distances, not absolute values
- Margin tuning critical: Performance depends heavily on margin choice
- Hard negative mining: Can improve by selecting hard negatives during training
- Distance metric: Use p=2 (Euclidean) for most tasks
- Embedding space: Directly optimizes the learned embedding space
- Convergence: Can converge faster with hard triplet mining
Examples
// Basic triplet loss for metric learning
const triplet_loss = new torch.nn.TripletMarginLoss({ margin: 1.0 });
// Three embeddings per sample
const anchor = torch.randn([32, 128]); // Batch of 32 embeddings
const positive = torch.randn([32, 128]); // From same class
const negative = torch.randn([32, 128]); // From different class
const loss = triplet_loss.forward(anchor, positive, negative);
// Loss encourages: dist(anchor, positive) < dist(anchor, negative) - margin// Face recognition with triplet loss
class FaceEmbedder extends torch.nn.Module {
fc1: torch.nn.Linear;
fc2: torch.nn.Linear;
// Outputs 128-dim embeddings
forward(x: torch.Tensor): torch.Tensor {
let h = torch.nn.functional.relu(this.fc1.forward(x));
return this.fc2.forward(h);
}
}
const model = new FaceEmbedder();
const triplet_loss = new torch.nn.TripletMarginLoss({ margin: 0.5, p: 2 });
// Create triplets: anchor is person A, positive is another photo of A, negative is person B
const anchor_img = torch.randn([32, 3, 224, 224]);
const positive_img = torch.randn([32, 3, 224, 224]);
const negative_img = torch.randn([32, 3, 224, 224]);
const anchor_emb = model.forward(anchor_img);
const positive_emb = model.forward(positive_img);
const negative_emb = model.forward(negative_img);
const loss = triplet_loss.forward(anchor_emb, positive_emb, negative_emb);
// Network learns to embed faces so same person is close, different person is far// Metric learning with different distance metrics
const embeddings_anchor = torch.randn([64, 256]);
const embeddings_positive = torch.randn([64, 256]);
const embeddings_negative = torch.randn([64, 256]);
// Euclidean distance (default)
const euclidean = new torch.nn.TripletMarginLoss({ margin: 1.0, p: 2 });
const loss_euclidean = euclidean.forward(embeddings_anchor, embeddings_positive, embeddings_negative);
// Manhattan distance
const manhattan = new torch.nn.TripletMarginLoss({ margin: 1.0, p: 1 });
const loss_manhattan = manhattan.forward(embeddings_anchor, embeddings_positive, embeddings_negative);// Hard negative mining (advanced)
// Collect multiple negatives and select hardest ones
const anchor = torch.randn([32, 128]);
const positive = torch.randn([32, 128]);
const all_negatives = torch.randn([32, 100, 128]); // 100 candidates
// Find hardest negative for each anchor (closest to anchor)
const triplet_loss = new torch.nn.TripletMarginLoss({ margin: 0.5 });
// Batch processing with hard negatives
let total_loss = 0;
for (let i = 0; i < 32; i++) {
// Could add logic to select hardest negative here
const negative = all_negatives[i][0]; // Simplified
const batch_anchor = anchor[i].unsqueeze(0);
const batch_positive = positive[i].unsqueeze(0);
const batch_negative = negative.unsqueeze(0);
const loss = triplet_loss.forward(batch_anchor, batch_positive, batch_negative);
total_loss = total_loss + loss;
}// Tuning margin for different problems
const anchor = torch.randn([32, 128]);
const positive = torch.randn([32, 128]);
const negative = torch.randn([32, 128]);
// Small margin: easier training, less separation
const easy = new torch.nn.TripletMarginLoss({ margin: 0.3 });
const loss_easy = easy.forward(anchor, positive, negative);
// Medium margin: balanced (default)
const balanced = new torch.nn.TripletMarginLoss({ margin: 1.0 });
const loss_balanced = balanced.forward(anchor, positive, negative);
// Large margin: harder training, strong separation
const hard = new torch.nn.TripletMarginLoss({ margin: 2.0 });
const loss_hard = hard.forward(anchor, positive, negative);