torch.nn.functional.multilabel_margin_loss
function multilabel_margin_loss(input: Tensor, target: Tensor, reductionOrOptions?:
| 'none'
| 'mean'
| 'sum'
| {
reduction?: 'none' | 'mean' | 'sum';
}): TensorMulti-label margin loss for ranking multiple positive classes with margin separation.
Measures margin-based loss when multiple target labels are correct simultaneously. Ensures all positive classes score higher than all negative classes by a margin. Extends binary/multi-class classification to multi-label scenarios where samples can belong to multiple classes. Essential for:
- Multi-label classification (samples with multiple correct labels)
- Ranking problems with multiple relevant items (information retrieval)
- Tag prediction (images with multiple tags, documents with multiple topics)
- Scene understanding and activity recognition (multiple co-occurring objects/actions)
- Music genre/artist prediction (samples belong to multiple genres)
- Medical image analysis (multiple diseases/findings per image)
- Multi-task learning where tasks are viewed as multi-label prediction
How multi-label margin loss works: For each sample with multiple positive classes {y₁, y₂, ...} and negative classes {j ∉ positives}: Loss = sum over all (positive, negative) pairs: max(0, 1 - (score[positive] - score[negative])) Ensures each positive class outranks every negative class by at least margin=1.
Target format - crucial for understanding: target contains indices of positive classes, padded with -1 to fixed length. Example: 5 classes, sample has classes 1 and 3 positive → target=[1, 3, -1, -1, ...] The -1 padding is required and tells the loss function where positive labels end.
Key differences from multi_margin_loss:
- multi_margin_loss: only one positive class per sample (single-label)
- multilabel_margin_loss: multiple positive classes per sample (multi-label)
- This generates all positive-negative pairs and enforces ranking on all pairs
Difference from multilabel_soft_margin_loss:
- margin loss (hard): max(0, margin - (positive - negative)) - hinge-like
- soft margin loss: log(1 + exp(-positive)) + log(1 + exp(negative)) - logistic
- margin: ranking-focused, threshold-based; soft: probability-focused, smooth
- Target format critical: Must contain class indices followed by -1 padding
- All positive-negative pairs: Loss considers every combination of (positive, negative) classes
- Margin = 1: Hard-coded margin of 1 (not configurable like multi_margin_loss)
- Relative ranking: Focus on ensuring positive classes rank higher than negatives
- Variable label counts: Samples can have different numbers of positive classes (use -1 padding)
- Computational cost: O(num_positive * num_negative) pairs per sample can be expensive
- CPU-only limitation: Currently requires CPU tensors (GPU implementation pending)
- Symmetric negative: Treats all non-positive classes equally as negatives
- Target tensor shape: Must match [batch, max_num_positive_labels] and use -1 padding
- -1 padding required: Loss assumes -1 marks end of positive labels; incorrect format breaks loss
- All negative classes: Every class not in target is treated as negative (no class weighting)
- CPU tensors only: Will error if input or target are on GPU device
- No hard negatives: Uses all negatives equally; doesn't support hard negative mining
- Quadratic complexity: For dense positive sets, loss computation is O(pos * neg) per sample
Parameters
inputTensor- Score tensor of shape [batch, num_classes] with class scores Example: logits from final layer [batch, 10] for 10 classes
targetTensor- Positive class indices of shape [batch, num_positive_classes] Contains indices of positive classes for each sample, padded with -1 Example: [[1, 3, -1], [0, 2, -1]] for 2 samples where first has classes 1,3 positive
reductionOrOptions| 'none' | 'mean' | 'sum' | { reduction?: 'none' | 'mean' | 'sum'; }optional
Returns
Tensor– Loss tensor (scalar if reduction='mean', or [batch] if reduction='none')Examples
// Image classification: each image can have multiple object classes
const batch_size = 4;
const num_classes = 10; // e.g., dog, cat, bird, car, person, bicycle, etc.
// Network scores for each class
const scores = torch.randn([batch_size, num_classes]);
// Multi-label targets (each sample has multiple positive classes)
// Sample 0: classes 1 (cat) and 5 (bicycle) are present
// Sample 1: classes 0 (dog) and 3 (car) are present
// Sample 2: class 2 (bird) only
// Sample 3: classes 4 (person), 7 (tree), 9 (sky)
const targets = torch.tensor([
[1, 5, -1, -1], // -1 padding indicates end of positive labels
[0, 3, -1, -1],
[2, -1, -1, -1],
[4, 7, 9, -1]
]);
const loss = torch.nn.functional.multilabel_margin_loss(scores, targets);
// For each (positive, negative) pair, enforces positive_score > negative_score + margin// Tag prediction: predict multiple tags for documents
const num_docs = 32;
const num_tags = 50;
const tag_scores = model(documents); // [32, 50]
// Each document has multiple relevant tags
// Different documents can have different numbers of relevant tags
const tag_targets = torch.tensor([
[5, 12, 23, -1, -1, ...], // doc 0 has 3 relevant tags
[1, 2, 8, 15, -1, ...], // doc 1 has 4 relevant tags
[10, 20, -1, -1, -1, ...], // doc 2 has 2 relevant tags
// ... more documents
]);
const loss = torch.nn.functional.multilabel_margin_loss(tag_scores, tag_targets);// Medical imaging: multiple diseases can co-occur
const num_images = 64;
const num_diseases = 20;
const disease_scores = model(images); // [64, 20] logits
// Target: indices of diseases present in each image
const disease_presence = torch.tensor([
[2, 5, 11, -1, -1], // image 0: pneumonia, tuberculosis, asthma
[1, 8, -1, -1, -1], // image 1: diabetes, hypertension
[3, 7, 9, 13, -1], // image 2: four diseases co-occur
// ...
]);
const disease_loss = torch.nn.functional.multilabel_margin_loss(
disease_scores,
disease_presence
);
// Loss ensures all present diseases score higher than all absent diseases// Comparison: same data with different losses
const scores = torch.randn([8, 5]); // 8 samples, 5 classes
const targets = torch.tensor([
[0, 2, -1, -1, -1],
[1, 3, 4, -1, -1],
// ... 6 more samples
]);
// Multi-label margin loss (ranking, threshold-based)
const margin_loss = torch.nn.functional.multilabel_margin_loss(scores, targets);
// Margin loss enforces hard ranking constraints
// Each positive class must outrank every negative class by exactly 1
// Good when you care about relative ordering (ranking)// Handling variable number of labels per sample
const num_classes = 20;
const max_labels = 8; // maximum number of positive classes any sample has
const targets = torch.tensor([
[5, 12, -1, -1, -1, -1, -1, -1], // 2 positive labels
[0, 2, 8, 15, 19, -1, -1, -1], // 5 positive labels
[10, -1, -1, -1, -1, -1, -1, -1], // 1 positive label
[1, 3, 7, 11, 13, 17, 18, -1], // 7 positive labels
]);
const logits = torch.randn([4, num_classes]);
const loss = torch.nn.functional.multilabel_margin_loss(logits, targets);
// Automatically handles variable numbers of positive labels per sampleSee Also
- PyTorch torch.nn.functional.multilabel_margin_loss
- multi_margin_loss - Single-label variant (only one positive class per sample)
- multilabel_soft_margin_loss - Soft margin alternative using logistic loss
- torch.nn.MultiLabelMarginLoss - Module wrapper
- cross_entropy - Probabilistic alternative for single-label classification