torch.nn.MultiMarginLoss
class MultiMarginLoss extends Modulenew MultiMarginLoss(options?: { p?: number; margin?: number; weight?: Tensor; reduction?: Reduction })
- readonly
p(number) - readonly
margin(number) - readonly
weight(Tensor | null) - readonly
reduction(Reduction)
Multi-Margin Loss: multi-class hinge loss for classification.
Computes a margin-based loss for multi-class classification. For each sample, the margin between the correct class and other classes is enforced. Useful for hard-margin classification without probabilities, similar to SVM objective.
When to use MultiMarginLoss:
- Multi-class classification with margin-based objective
- SVM-like training for neural networks
- When you want hard margins instead of soft probabilities
- Learning discrimination between classes
- Rarely used; CrossEntropyLoss more standard
Trade-offs:
- vs CrossEntropyLoss: Hinge margin vs probabilistic; CE more common
- vs SVM: Neural network training with similar objective
- Robustness: Hinge loss can be more robust to outliers
- Smoothness: Less smooth than CrossEntropy
Algorithm: For each sample with target y:
- loss = max(0, margin - (score[y] - score[j]))^p
- Penalty for each non-target class j
- Enforces margin between target and all other classes
- Raw scores: Input is scores, not probabilities
- Hard margin: Enforces hard margin between classes
- SVM-like: Similar objective to support vector machines
- Multi-class: For 2 classes; use SoftMarginLoss for binary
- p parameter: Controls loss smoothness (p=1 linear, p=2 quadratic)
- Rarely used: CrossEntropyLoss more popular in modern deep learning
Examples
// Multi-class classification with multi-margin loss
const multi_margin = new torch.nn.MultiMarginLoss({ margin: 1.0, p: 1 });
// Scores for each class
const scores = torch.randn([32, 10]);
// Class targets
const targets = torch.randint(0, 10, [32]);
const loss = multi_margin.forward(scores, targets);
// Enforces margin between target class and others// Multi-class classifier with margin-based training
class MultiMarginClassifier extends torch.nn.Module {
fc1: torch.nn.Linear;
fc2: torch.nn.Linear;
constructor(num_classes: number) {
super();
this.fc1 = new torch.nn.Linear(100, 64);
this.fc2 = new torch.nn.Linear(64, num_classes);
}
forward(x: torch.Tensor): torch.Tensor {
let h = torch.nn.functional.relu(this.fc1.forward(x));
// Return raw scores, not probabilities
return this.fc2.forward(h);
}
}
const model = new MultiMarginClassifier(10);
const loss_fn = new torch.nn.MultiMarginLoss({ margin: 1.0, p: 2 });
const batch_x = torch.randn([32, 100]);
const batch_y = torch.randint(0, 10, [32]);
const scores = model.forward(batch_x);
const loss = loss_fn.forward(scores, batch_y);