torch.nn.MultiLabelMarginLossOptions
Multi-Label Margin Loss: hinge loss for multi-label classification with target ranking.
Computes a ranking-based margin loss for multi-label classification where each sample can have multiple target classes. The loss encourages target classes to have high scores and non-target classes to have low scores, with a margin between them. Used when:
- Each sample has multiple correct labels (e.g., image tagging: "cat", "dog", "outdoor")
- You want to rank correct labels higher than incorrect ones
- Margin-based training with structured predictions
- Learning to rank scores for multi-label scenarios
Unlike CrossEntropyLoss (single label per sample) or MultiLabelSoftMarginLoss (uses sigmoid), this loss uses a ranking/margin approach. It computes loss based on which scores are highest, encouraging target labels to rank above non-target labels by at least a margin.
When to use MultiLabelMarginLoss:
- Multi-label classification (multiple labels per sample)
- You have ranking/ordering information about labels
- Want margin-based learning (enforce separation between target/non-target scores)
- Image tagging, document classification with multiple topics
- When targets can be represented as a sequence of label indices
- Explicit ranking of correct labels (first target is most important)
Trade-offs:
- vs MultiLabelSoftMarginLoss: Margin-based (ranking) vs sigmoid-based (probabilities)
- vs CrossEntropyLoss: Handles multiple labels vs single label per sample
- Target format: Requires label indices (not one-hot or probabilities)
- Ranking information: Encodes order of targets (first target ranked highest)
- Computational complexity: O(num_targets * num_classes) for each sample
Algorithm: For each sample with targets [t1, t2, ...] and target count C:
- For each target label t in targets: loss_t = max(0, 1 - score[t] + score[j]) for all j ∉ targets
- Final loss = sum of all margin violations / C
This encourages: score[target] ≥ score[non-target] + 1 (margin of 1) The loss is non-zero when non-target labels score higher than target labels.
Definition
export interface MultiLabelMarginLossOptions {
/** How to reduce loss across batch (default: 'mean') */
reduction?: Reduction;
}reduction(Reduction)optional- – How to reduce loss across batch (default: 'mean')
Examples
// Multi-label classification: image tagging
const mlml = new torch.nn.MultiLabelMarginLoss();
// Predicted scores: [batch=2, num_classes=5]
const scores = torch.tensor([
[0.1, 2.5, -1.0, 1.2, 0.3], // Scores for 5 classes
[-0.2, 1.0, 3.1, 0.5, 2.2]
]);
// Target labels: classes 1,3 are targets for sample 1; classes 2,4 for sample 2
const targets = torch.tensor([
[1, 3, -1, -1, -1], // Sample 1: targets are classes 1 and 3 (padded with -1)
[2, 4, -1, -1, -1] // Sample 2: targets are classes 2 and 4 (padded with -1)
], { dtype: torch.int64 });
const loss = mlml.forward(scores, targets);
// Encourages score[1] > score[0,2,4] and score[3] > score[0,2,4] for sample 1
// And score[2] > score[0,1,3] and score[4] > score[0,1,3] for sample 2// Music genre classification: one song can have multiple genres
class GenreClassifier extends torch.nn.Module {
fc1: torch.nn.Linear;
fc2: torch.nn.Linear;
constructor() {
super();
this.fc1 = new torch.nn.Linear(256, 128);
this.fc2 = new torch.nn.Linear(128, 10); // 10 genres
}
forward(audio: torch.Tensor): torch.Tensor {
const h = torch.relu(this.fc1.forward(audio));
return this.fc2.forward(h); // Scores, not probabilities
}
}
const model = new GenreClassifier();
const loss_fn = new torch.nn.MultiLabelMarginLoss();
// Batch of audio features
const audio_batch = torch.randn([32, 256]);
const predictions = model.forward(audio_batch);
// Ground truth: some songs have multiple genre labels
// E.g., sample 0 has genres [2, 5] (rock, electronic), sample 1 has genres [1, 3, 7] (pop, jazz, folk)
const true_genres = torch.tensor([
[2, 5, -1, -1, -1],
[1, 3, 7, -1, -1],
// ... more samples
], { dtype: torch.int64 });
const loss = loss_fn.forward(predictions, true_genres);
// Model learns: rock songs should score high on genre 2, electronic on 5, etc.// Document classification with topic ordering
// Topics ranked by importance: [main_topic, secondary_topic, tertiary_topic]
const doc_scores = torch.randn([16, 20]); // 20 topics
const topic_targets = torch.tensor([
[3, 7, 12, -1, -1], // Document 0: topics 3 (main), 7 (secondary), 12 (tertiary)
[1, 5, -1, -1, -1], // Document 1: topics 1 (main), 5 (secondary)
[11, -1, -1, -1, -1], // Document 2: topic 11 only (single topic)
// ... more documents
], { dtype: torch.int64 });
const loss_fn = new torch.nn.MultiLabelMarginLoss({ reduction: 'mean' });
const loss = loss_fn.forward(doc_scores, topic_targets);
// Model learns topic ranking: primary topic should outscore secondary, etc.