torch.nn.functional.multi_margin_loss
function multi_margin_loss(input: Tensor, target: Tensor, options?: {
p?: number;
margin?: number;
weight?: Tensor;
reduction?: 'none' | 'mean' | 'sum';
}): TensorMulti-class margin loss function for classification with custom margin.
Computes margin-based loss that encourages the correct class score to be at least
margin higher than all other class scores. Useful for ranking-style learning where
enforcing gaps between classes is important. Essential for:
- Metric learning and ranking tasks (learn embeddings with margin constraints)
- Multi-class classification with margin requirements
- Robust classification (prevents overconfidence on positive examples)
- Computer vision (face recognition with margin losses like ArcFace, CosFace)
- Information retrieval (ranking items with margin-based objectives)
- Siamese/Triplet networks (enforcing relative distances between classes)
Core idea: For each sample, penalize all non-target classes if their score is higher than target score minus margin. Loss = 0 if target class is far enough ahead.
Loss formula: For each sample i with target class y_i:
- Correct score: x_i^{y_i}
- For each class j ≠ y_i: if x_i^j > x_i^{y_i} - margin, add penalty
- Total sample loss: (1/C) Σ_j [max(0, margin + x_i^j - x_i^{y_i})]^p
Why margin? Encourages decision boundary robustness:
- Large margin → larger buffer zone → more robust to perturbations
- Small margin → allows tight decision boundaries (faster convergence)
- margin=0 → hinge loss style
- margin>0 → forces explicit separation (recommended)
p parameter: Controls loss growth rate for violations
- p=1: Linear (standard hinge loss behavior)
- p=2: Quadratic (penalizes violations more heavily)
- p=other: Custom growth rate
- Zero loss: loss=0 when target_score ≥ all_other_scores + margin
- Margin interpretation: margin=1.0 means target should be ≥ 1 point above others
- No probability: Operates on raw logits, not normalized probabilities
- Batch dimension: First dimension must be batch; all samples share same margin
- Linear vs quadratic: p=1 (linear) is computationally simpler; p1 more stable
- Default p=1: Linear margin loss is most common (standard hinge loss behavior)
- Target requirement: All targets must be valid class indices [0, num_classes)
- Broadcasting: Weight tensor must have shape [num_classes]
- Target validity: Invalid target indices ( 0 or ≥ num_classes) cause errors
- Margin tuning: margin=0 disables margin (may cause degenerate solutions)
- Large p: High p values can cause numerical instability; p≤2 recommended
- Class imbalance: Equal margin for all classes; use weight for class importance
- Normalization: Use different learning rates or scaling if margin too high/low
Parameters
inputTensor- Raw model outputs (logits) of shape [batch_size, num_classes]. Each row contains raw scores for each class.
targetTensor- Target class indices of shape [batch_size]. Each value in [0, num_classes). Indicates which class should have highest score.
options{ p?: number; margin?: number; weight?: Tensor; reduction?: 'none' | 'mean' | 'sum'; }optional- Optional configuration object: -
p: Power for loss computation (default: 1). Higher values → stronger penalties -margin: Minimum desired margin (default: 1.0). Target score should be ≥ other + margin -weight: Per-class weight tensor of shape [num_classes] (default: None, uniform weights) -reduction: How to aggregate batch losses (default: 'mean') - 'none': Return unreduced loss [batch_size] - 'mean': Return scalar mean of losses - 'sum': Return scalar sum of losses
Returns
Tensor– Loss tensor of shape [] (scalar) if reduction='mean'|'sum', else [batch_size]Examples
// Multi-class classification with margin enforcement
const batch_size = 32;
const num_classes = 10;
const logits = torch.randn([batch_size, num_classes]); // Model predictions
const targets = torch.randint(0, num_classes, [batch_size]); // True labels
// Standard multi-margin loss with margin=1.0
const loss = torch.nn.functional.multi_margin_loss(logits, targets);
// loss: scalar tensor// Importance weighting for imbalanced classes
const num_classes = 5;
const class_weights = torch.tensor([1.0, 2.0, 1.0, 1.5, 3.0]); // Weight rare classes higher
const logits = torch.randn([32, num_classes]);
const targets = torch.randint(0, num_classes, [32]);
const loss = torch.nn.functional.multi_margin_loss(logits, targets, {
weight: class_weights,
margin: 1.0,
p: 1
});// Quadratic margin loss for stronger penalization of violations
const logits = torch.randn([32, 10]);
const targets = torch.randint(0, 10, [32]);
// p=2 means violations are squared (stronger penalties)
const loss = torch.nn.functional.multi_margin_loss(logits, targets, {
margin: 1.0,
p: 2 // Quadratic loss
});// Custom margin for different training stages
const logits = torch.randn([32, 10]);
const targets = torch.randint(0, 10, [32]);
// Large margin: early training, enforce strong separation
const loss_strong = torch.nn.functional.multi_margin_loss(logits, targets, {
margin: 2.0 // Larger margin → more robust
});
// Small margin: fine-tuning, allow tighter boundaries
const loss_weak = torch.nn.functional.multi_margin_loss(logits, targets, {
margin: 0.5 // Smaller margin → tighter boundaries
});// Per-sample losses for custom aggregation
const logits = torch.randn([32, 10]);
const targets = torch.randint(0, 10, [32]);
const per_sample_loss = torch.nn.functional.multi_margin_loss(logits, targets, {
margin: 1.0,
reduction: 'none' // Returns [32] tensor
});
// Custom weighting by sample difficulty
const sample_weights = torch.where(per_sample_loss.gt(0.5), 2.0, 1.0);
const weighted_loss = per_sample_loss.mul(sample_weights).mean();See Also
- PyTorch torch.nn.functional.multi_margin_loss
- torch.nn.functional.hinge_embedding_loss - Similar margin loss for pairs
- torch.nn.functional.margin_ranking_loss - Margin loss for ranking pairs
- torch.nn.functional.triplet_margin_loss - Margin loss for triplets (metric learning)
- torch.nn.functional.cross_entropy - Softmax loss (probabilistic alternative)
- torch.nn.functional.nll_loss - Negative log-likelihood (probabilistic)