torch.nn.functional.triplet_margin_with_distance_loss
function triplet_margin_with_distance_loss(anchor: Tensor, positive: Tensor, negative: Tensor, options?: {
distance_function?: (x1: Tensor, x2: Tensor) => Tensor;
margin?: number;
swap?: boolean;
reduction?: 'none' | 'mean' | 'sum';
}): TensorTriplet margin loss with custom distance function for flexible metric learning.
Extends triplet margin loss by allowing custom distance metrics instead of fixed Lp norms. Measures relative distances in triplets: (anchor, positive, negative) using any differentiable distance function. Pulls positive close to anchor while pushing negative far, with explicit margin separation. Essential for:
- Custom distance metrics (cosine, angular, Mahalanobis, learned metrics)
- Siamese/Triplet networks with specialized distance measures
- Contrastive learning with task-specific similarities
- Deep metric learning (metric can be learned end-to-end)
- Person re-identification with application-specific metrics
- Image retrieval with domain-specific similarity measures
- One-shot/few-shot learning with custom comparison functions
How custom distance triplet loss works: Instead of fixed Lp distance (L2, L1), use custom distance function d(x, y) for any metric: Loss = max(0, d(anchor, positive) - d(anchor, negative) + margin)
Key advantages over standard triplet_margin_loss:
- Flexibility: Use any differentiable distance function (cosine, learned metric, etc.)
- Swap parameter: Optional harder negative mining by using d(positive, negative)
- Custom metrics: Mahalanobis, angular distance, or learned similarity
- Metric learning: Distance function can have learnable parameters (e.g., embedding transform)
Swap parameter explanation: When swap=true, compares d(anchor, negative) with d(positive, negative) and takes minimum. This provides harder negative mining: if negative is closer to positive, use that constraint. Helps prevent trivial negatives that are far from both anchor and positive.
- Flexibility vs complexity: Custom metrics provide flexibility but require differentiability
- Distance properties: Ensure distance_function returns non-negative values for optimization stability
- Gradient flow: Custom distance function must be differentiable for backpropagation
- Swap for hard negatives: swap=true helps when negatives might be closer to positive than anchor
- Metric optimization: If distance function has learnable parameters, they're optimized end-to-end
- Default L2 distance: When distance_function not provided, defaults to Euclidean distance
- Computational cost: Custom distances may be slower than built-in Lp norms (consider batching)
- Numerical stability: Clamp norm computations to avoid division by zero for custom distances
- Distance must be non-negative: Return negative values from distance_function can break loss
- Differentiability required: distance_function must be differentiable for gradient computation
- Margin tuning: Different distance metrics have different scales (normalize or tune margin)
- Swap parameter: swap=true adds computational cost (computes extra distance d(p,n))
- Learned metrics: If distance function has parameters, ensure they're being optimized
Parameters
anchorTensor- Anchor embedding tensor of shape [batch, embedding_dim] Example: query image embeddings [batch, 128] or reference sample [batch, feature_dim]
positiveTensor- Positive embedding tensor of shape [batch, embedding_dim] (similar to anchor) Example: second image of same object [batch, 128] or matching sample
negativeTensor- Negative embedding tensor of shape [batch, embedding_dim] (dissimilar to anchor) Example: image of different object [batch, 128] or non-matching sample
options{ distance_function?: (x1: Tensor, x2: Tensor) => Tensor; margin?: number; swap?: boolean; reduction?: 'none' | 'mean' | 'sum'; }optional- Optional configuration: -
distance_function: Custom distance function (x1, x2) = distance_tensor (default: L2) Example: (a, p) = a.sub(p).square().sum(-1).sqrt() for Euclidean distance -margin: Margin between positive and negative distances (default: 1.0) - margin=1.0: negative must be at least 1.0 farther than positive - Higher margin: aggressive separation; lower margin: relaxed constraint -swap: Enable harder negative mining using d(positive, negative) (default: false) - true: d_an = min(d(a,n), d(p,n)) - uses harder constraint if negative is close to positive - false: d_an = d(a,n) - standard formulation -reduction: How to aggregate batch losses (default: 'mean') - 'none': per-sample losses [batch] - 'mean': average loss across batch - 'sum': sum losses across batch
Returns
Tensor– Loss tensor (scalar if reduction='mean', or [batch] if reduction='none')Examples
// Custom cosine distance metric
const cosine_distance = (x1: Tensor, x2: Tensor) => {
const dot = x1.mul(x2).sum(-1);
const norm1 = x1.square().sum(-1).sqrt();
const norm2 = x2.square().sum(-1).sqrt();
const cos_sim = dot.div(norm1.mul(norm2).clamp(1e-8, Infinity));
return cos_sim.mul(-1).add(1); // convert similarity to distance
};
const anchor_embed = model(anchor_img); // [batch, 128]
const positive_embed = model(positive_img);
const negative_embed = model(negative_img);
const loss = torch.nn.functional.triplet_margin_with_distance_loss(
anchor_embed, positive_embed, negative_embed,
{ distance_function: cosine_distance, margin: 0.5 }
);// L1 (Manhattan) distance for robustness to outliers
const l1_distance = (x1: Tensor, x2: Tensor) => {
return x1.sub(x2).abs().sum(-1); // sum of absolute differences
};
const loss = torch.nn.functional.triplet_margin_with_distance_loss(
anchor, positive, negative,
{ distance_function: l1_distance, margin: 0.8 }
);// With swap enabled for harder negative mining
const anchor = torch.randn([32, 256]);
const positive = torch.randn([32, 256]);
const negative = torch.randn([32, 256]);
// Default L2 distance with swap
const loss = torch.nn.functional.triplet_margin_with_distance_loss(
anchor, positive, negative,
{ margin: 1.0, swap: true }
);
// If d(p, n) < d(a, n), uses d(p, n) as negative distance (harder constraint)// Learned metric: distance through neural network layer
class LearnedMetric extends torch.nn.Module {
distance_net: torch.nn.Sequential;
constructor() {
super();
// Network that outputs distance between pairs
this.distance_net = new torch.nn.Sequential(
new torch.nn.Linear(256, 128),
new torch.nn.ReLU(),
new torch.nn.Linear(128, 1)
);
}
forward(x1: Tensor, x2: Tensor): Tensor {
const diff = x1.sub(x2);
const dist = this.distance_net(diff);
return dist.abs(); // ensure non-negative distance
}
}
const metric = new LearnedMetric();
const loss = torch.nn.functional.triplet_margin_with_distance_loss(
anchor, positive, negative,
{ distance_function: (a, p) => metric.forward(a, p), margin: 0.5 }
);// Mahalanobis distance with learned covariance
const mahal_distance = (x1: Tensor, x2: Tensor, precision: Tensor) => {
const diff = x1.sub(x2); // [batch, dim]
const mahal = diff.matmul(precision).mul(diff).sum(-1).sqrt();
return mahal;
};
// Precision matrix (learned inverse covariance)
const precision = torch.eye(128); // [128, 128]
const loss = torch.nn.functional.triplet_margin_with_distance_loss(
anchor, positive, negative,
{
distance_function: (a, p) => mahal_distance(a, p, precision),
margin: 0.5
}
);See Also
- PyTorch torch.nn.functional.triplet_margin_with_distance_loss
- triplet_margin_loss - Standard triplet loss with fixed Lp norms
- cosine_embedding_loss - Similar concept but uses cosine similarity for pairs
- torch.nn.TripletMarginLoss - Module wrapper for standard triplet loss