torch.nn.TripletMarginWithDistanceLoss
class TripletMarginWithDistanceLoss extends Modulenew TripletMarginWithDistanceLoss(options?: {
distance_function?: (x1: Tensor, x2: Tensor) => Tensor;
margin?: number;
swap?: boolean;
reduction?: Reduction;
})
- readonly
margin(number) - readonly
swap(boolean) - readonly
reduction(Reduction)
Triplet Margin With Distance Loss: triplet loss with custom distance function.
Flexible variant of TripletMarginLoss that allows custom distance metrics instead of fixed Lp norms. Enables using any differentiable distance function: cosine, custom learned metrics, or domain-specific distance measures. More flexible than standard triplet loss.
When to use TripletMarginWithDistanceLoss:
- Need custom distance metric (not just Lp norms)
- Using learned distance functions
- Domain-specific distance measures
- Cosine distance instead of Euclidean
- Advanced metric learning scenarios
Trade-offs:
- vs TripletMarginLoss: More flexible but requires custom distance function
- Flexibility: Can use any differentiable distance metric
- Complexity: More setup required (custom function)
- Performance: May be slower with complex distance functions
Algorithm: For each triplet (anchor, positive, negative):
- d_pos = distance_function(anchor, positive)
- d_neg = distance_function(anchor, negative)
- loss = max(0, d_pos - d_neg + margin)
Similar to TripletMarginLoss but uses custom distance function.
- Custom distance: Allows any differentiable distance metric
- Flexibility: More powerful than fixed Lp norm triplet loss
- Default: Uses Euclidean L2 if no distance function provided
- Backprop: Distance function must be differentiable
- Advanced: Typically for research/specialized applications
Examples
// Cosine distance instead of Euclidean
const cosine_distance = (x1: torch.Tensor, x2: torch.Tensor): torch.Tensor => {
// Cosine distance = 1 - cosine_similarity
return torch.ones([]).sub(
torch.nn.functional.cosine_similarity(x1, x2)
);
};
const triplet = new torch.nn.TripletMarginWithDistanceLoss({
distance_function: cosine_distance,
margin: 0.5
});
const anchor = torch.randn([32, 128]);
const positive = torch.randn([32, 128]);
const negative = torch.randn([32, 128]);
const loss = triplet.forward(anchor, positive, negative);
// Uses cosine distance instead of Euclidean// Learned distance metric (using another network)
class LearnedMetric extends torch.nn.Module {
fc1: torch.nn.Linear;
fc2: torch.nn.Linear;
constructor() {
super();
this.fc1 = new torch.nn.Linear(256, 64);
this.fc2 = new torch.nn.Linear(64, 1);
}
forward(x1: torch.Tensor, x2: torch.Tensor): torch.Tensor {
// Concatenate embeddings and learn distance
const combined = torch.cat([x1, x2], 1);
let h = torch.nn.functional.relu(this.fc1.forward(combined));
return torch.sigmoid(this.fc2.forward(h)); // Distance in [0, 1]
}
}
const metric_net = new LearnedMetric();
const triplet = new torch.nn.TripletMarginWithDistanceLoss({
distance_function: (x1, x2) => metric_net.forward(x1, x2),
margin: 0.3
});
const anchor = torch.randn([32, 128]);
const positive = torch.randn([32, 128]);
const negative = torch.randn([32, 128]);
const loss = triplet.forward(anchor, positive, negative);
// Learns both embeddings AND distance metric jointly