torch.nn.HingeEmbeddingLoss
class HingeEmbeddingLoss extends Modulenew HingeEmbeddingLoss(options?: HingeEmbeddingLossOptions)
- readonly
margin(number) - readonly
reduction(Reduction)
Hinge Embedding Loss: metric learning loss for pairs with binary labels (similar/dissimilar).
Computes a margin-based hinge loss for embeddings with binary labels (1 = similar, -1 = dissimilar). When target=1, penalizes if embedding is far from origin (pull closer). When target=-1, penalizes if embedding is close to origin (push farther). Used in Siamese networks and metric learning for similarity judgments.
When to use HingeEmbeddingLoss:
- Siamese networks (learning embeddings of pairs)
- Binary similarity/dissimilarity judgment
- Learning distance metrics
- When you want a simple hinge-based metric loss
- Comparing single embeddings against a reference
Trade-offs:
- vs TripletMarginLoss: Hinge works on single embeddings; triplet needs 3 samples
- vs CosineEmbeddingLoss: Different geometry (hinge vs cosine similarity)
- vs ContrastiveLoss: Similar concept but different formula
- Simplicity: Simpler than triplet but less expressive
Algorithm: For each embedding with target y ∈ {1, -1}:
- If y == 1: loss = max(0, margin - embedding_norm) (pull toward origin)
- If y == -1: loss = max(0, embedding_norm - margin) (push away from origin)
Uses norm (distance from origin) of the embedding, not relative distances.
- Norm-based: Operates on embedding norm (distance from origin)
- Binary labels: Requires target to be 1 (similar) or -1 (dissimilar)
- Margin role: Boundary between similar and dissimilar embeddings
- Simple metric: Easier to tune than triplet but less expressive
- Origin-centric: Compares to origin, not relative to other samples
Examples
// Siamese network: learn embeddings for similarity
const hinge_loss = new torch.nn.HingeEmbeddingLoss({ margin: 1.0 });
// Embedding for similar pair (same class)
const similar_embedding = torch.randn([32, 128]);
// Embedding for dissimilar pair (different class)
const dissimilar_embedding = torch.randn([32, 128]);
const similar_target = torch.ones([32]); // 1 = similar
const dissimilar_target = torch.ones([32]).mul(-1); // -1 = dissimilar
const loss_sim = hinge_loss.forward(similar_embedding, similar_target);
const loss_dissim = hinge_loss.forward(dissimilar_embedding, dissimilar_target);
const total_loss = loss_sim + loss_dissim;// Siamese network with pairwise comparison
class SiameseNetwork extends torch.nn.Module {
encoder: torch.nn.Linear;
constructor() {
super();
this.encoder = new torch.nn.Linear(784, 128);
}
forward(x: torch.Tensor): torch.Tensor {
return this.encoder.forward(x); // Output embeddings
}
}
const model = new SiameseNetwork();
const hinge_loss = new torch.nn.HingeEmbeddingLoss({ margin: 1.0 });
// Train with pairs
const img1 = torch.randn([32, 784]);
const img2 = torch.randn([32, 784]);
const labels = torch.cat([torch.ones([16]), torch.ones([16]).mul(-1)]);
const emb1 = model.forward(img1);
// Compare embeddings (could compute distance/similarity here)
// Simplified: using embedding norm
const loss = hinge_loss.forward(emb1, labels);