torch.nn.MultiLabelSoftMarginLoss
class MultiLabelSoftMarginLoss extends Modulenew MultiLabelSoftMarginLoss(options?: { weight?: Tensor; reduction?: Reduction })
- readonly
weight(Tensor | null) - readonly
reduction(Reduction)
Multi-Label Soft Margin Loss: sigmoid-based loss for multi-label classification.
Computes binary cross-entropy loss for multi-label classification, treating each class independently with sigmoid activation. Essential for:
- Multi-label classification where multiple labels are correct per sample
- Image tagging, content recommendation, document classification
- Any task where binary decisions for multiple classes per sample
- Learning independent probabilities for each class
- Imbalanced multi-label datasets with per-class weighting
This is the most common loss for multi-label classification. It applies sigmoid to each output independently, then computes binary cross-entropy for each class. This treats each class decision as independent, unlike MultiLabelMarginLoss which uses ranking/margins.
When to use MultiLabelSoftMarginLoss:
- Multi-label classification (multiple independent labels per sample)
- Want probability estimates for each class (not ranking-based)
- Imbalanced classes (use per-class weights to adjust)
- Standard choice for most multi-label problems
- When you want sigmoid outputs (probabilities in [0, 1])
- Image tagging, movie genre classification, disease diagnosis
Trade-offs:
- vs MultiLabelMarginLoss: Soft (sigmoid + BCE) vs hard (margin-based ranking)
- vs CrossEntropyLoss: Multiple labels per sample vs single label
- Class independence: Each class predicted independently (no interaction)
- Probability output: Natural probability interpretation via sigmoid
- Imbalanced data: Can use per-class weights to handle imbalance
Algorithm: For each sample and class pair (i, k):
- BCE_loss_ik = -(target_ik * log(σ(score_ik)) + (1 - target_ik) * log(1 - σ(score_ik)))
- Final loss = mean over all (i, k) pairs, optionally weighted
Where σ(x) = 1 / (1 + exp(-x)) is the sigmoid function. The sigmoid is applied internally; provide raw logits, not probabilities.
- Logits vs probabilities: Input should be raw logits, sigmoid is applied internally
- Independent classes: Each class treated independently (no softmax interaction)
- Binary targets: Target values should be 0 or 1 (not probabilities or logits)
- Per-class weighting: Use weight parameter to emphasize rare/important classes
- No softmax: Unlike CrossEntropyLoss, no softmax is used (outputs can sum to 1)
- Probability output: Use sigmoid(logits) during inference to get class probabilities
- Threshold selection: Default threshold for classification is 0.5, but can be tuned
- Gradient flow: Allows independent optimization per class without competition
- Input should be raw logits, NOT probabilities (don't apply sigmoid before passing)
- Target should be binary (0 or 1) or probabilities in [0, 1]
- With weight parameter, it should have shape [num_classes]
- Be careful with extreme logit values: very large/small values can cause numerical issues
Examples
// Image tagging: classify which objects are in an image
class ImageTagger extends torch.nn.Module {
conv1: torch.nn.Conv2d;
pool: torch.nn.MaxPool2d;
fc1: torch.nn.Linear;
fc2: torch.nn.Linear;
constructor() {
super();
this.conv1 = new torch.nn.Conv2d(3, 32, 3, { padding: 1 });
this.pool = new torch.nn.MaxPool2d(2);
this.fc1 = new torch.nn.Linear(32 * 112 * 112, 128);
this.fc2 = new torch.nn.Linear(128, 20); // 20 possible tags
}
forward(images: torch.Tensor): torch.Tensor {
let x = torch.relu(this.conv1.forward(images));
x = this.pool.forward(x);
x = x.reshape([x.shape[0], -1]); // Flatten
x = torch.relu(this.fc1.forward(x));
return this.fc2.forward(x); // Raw logits
}
}
const model = new ImageTagger();
const loss_fn = new torch.nn.MultiLabelSoftMarginLoss();
// Batch of images: [batch=32, channels=3, height=224, width=224]
const images = torch.randn([32, 3, 224, 224]);
const logits = model.forward(images); // [32, 20]
// Ground truth: each image can have multiple tags
// E.g., one image may have tags: cat, dog, outdoor (classes 0, 2, 5)
const tags = torch.zeros([32, 20]);
tags[0][0] = 1; // Image 0: has tag 0 (cat)
tags[0][2] = 1; // Image 0: has tag 2 (dog)
tags[0][5] = 1; // Image 0: has tag 5 (outdoor)
// ... set more tags for other images
const loss = loss_fn.forward(logits, tags);
// Model learns to predict cat, dog, outdoor as 1 and other tags as 0// Recommendation system: predict which items user will interact with
const mlsml = new torch.nn.MultiLabelSoftMarginLoss();
// Model outputs: raw scores for 100 possible items
const user_embeddings = torch.randn([batch_size, 50]);
const item_scores = torch.randn([batch_size, 100]); // Raw logits
// Ground truth: which items user actually interacted with
const interactions = torch.zeros([batch_size, 100]);
// Set interactions[i][j] = 1 for items user i interacted with
const loss = mlsml.forward(item_scores, interactions);
// Model learns to score interacted items high, non-interacted items low// Imbalanced multi-label problem: use per-class weights
// Some classes are much rarer than others
const class_weights = torch.tensor([1.0, 1.0, 2.0, 0.5, 3.0]); // Class 4 is rare (weight=3)
const mlsml = new torch.nn.MultiLabelSoftMarginLoss({
weight: class_weights,
reduction: 'mean'
});
const logits = torch.randn([16, 5]); // Predictions for 5 classes
const targets = torch.zeros([16, 5]); // Binary targets
targets[0][4] = 1; // Sample 0 has rare class 4
const loss = mlsml.forward(logits, targets);
// Rare classes contribute more to loss (weight=3), forcing model to learn them// Movie genre classification: predict multiple genres for a movie
const mlsml = new torch.nn.MultiLabelSoftMarginLoss({ reduction: 'mean' });
// Network predicts logits for 15 movie genres
const movie_features = torch.randn([batch_size, 256]);
const genre_logits = torch.randn([batch_size, 15]);
// Ground truth: each movie has multiple genres
const genre_targets = torch.tensor([
[1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], // Drama, Thriller
[0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], // Action, Comedy
[1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], // Drama, Action
// ... more movies
]);
const loss = mlsml.forward(genre_logits, genre_targets);
// Model learns independent probabilities for each genre