torch.nn.functional.multilabel_soft_margin_loss
function multilabel_soft_margin_loss(input: Tensor, target: Tensor, options?: {
weight?: Tensor;
reduction?: 'none' | 'mean' | 'sum';
}): TensorMulti-label soft margin loss using logistic loss for each class independently.
Applies binary cross-entropy loss to each class independently for multi-label classification. Each class has its own binary (present/absent) prediction, creating a smooth, differentiable loss suitable for probabilistic multi-label learning. Essential for:
- Multi-label classification (treat each label as independent binary prediction)
- Multi-task learning (multiple binary classification tasks)
- Tag prediction with confidence scores (images with multiple tags)
- Scene understanding (multiple objects/actions can co-occur)
- Medical diagnosis (predict presence/absence of multiple conditions)
- Recommendation systems (predict user preference for multiple items)
- One-vs-all classification variants with multiple positive classes
- Probabilistic multi-label learning with soft targets
How multi-label soft margin loss works: Treats each class independently using binary logistic loss (sigmoid cross-entropy). For each class: loss_c = -[y_c * log(sigmoid(x_c)) + (1 - y_c) * log(1 - sigmoid(x_c))] Averages loss across all classes and batch samples. Soft margin uses smooth sigmoid function, enabling gradient flow even when predictions are clearly wrong.
Logistic loss interpretation:
- y_c = 1 (present): -log(sigmoid(x_c)) → penalizes low scores
- y_c = 0 (absent): -log(1 - sigmoid(x_c)) → penalizes high scores
- Smooth function: sigmoid(x) gradually transitions 0→1 over range
- Creates smooth probability interpretation: sigmoid(x) ≈ P(y=1|x)
Key differences from multilabel_margin_loss:
- multilabel_margin_loss: hard margin, ranking-focused, target indices
- multilabel_soft_margin_loss: soft logistic, probability-focused, binary targets
- margin loss: enforces thresholds; soft margin: smooth probability curves
- margin loss: quadratic O(pos*neg) complexity; soft margin: linear O(classes)
- margin loss: CPU-only; soft margin: full GPU support
Target format differences:
- multilabel_margin_loss: indices of positive classes (e.g., [1, 3, -1, ...])
- multilabel_soft_margin_loss: binary labels per class (e.g., [0, 1, 0, 1, 0, ...])
- Independent binary loss: Each class treated independently (sum of 20 binary losses)
- Logistic interpretation: Output can be interpreted as class probability via sigmoid
- Soft targets: Supports soft targets (0.0-1.0), not just binary (0/1)
- Smooth gradients: Sigmoid ensures non-zero gradients even for extreme predictions
- Full GPU support: Unlike multilabel_margin_loss, works on GPU tensors
- Class weighting: Optional weight tensor allows per-class importance adjustment
- Linear complexity: O(batch * num_classes) vs O(batch * pos * neg) for margin loss
- Numerically stable: Implements log-sum-exp tricks internally for stability
- Target format: Must be binary (0/1) or soft (0.0-1.0), not class indices
- Different from margin loss: Not same as multilabel_margin_loss despite similar names
- Weight tensor shape: If provided, weight must match [num_classes] dimension
- Soft labels: While soft targets are supported, training may be less stable with noisy labels
- Class imbalance: Unweighted loss treats all classes equally; use weight for imbalance
- Sigmoid saturation: Very large/small inputs → near-zero gradients (clip or normalize)
Parameters
inputTensor- Score tensor of shape [batch, num_classes] or [..., num_classes] Raw logits/unnormalized scores for each class (usually from final layer) Example: [batch, 20] for 20-class multi-label problem
targetTensor- Binary labels of shape [...] same as input, values ∈ 0, 1 1 indicates class is present, 0 indicates class is absent Example: [[0, 1, 0, 1, ...], [1, 0, 1, 0, ...]] - batch of binary labels
options{ weight?: Tensor; reduction?: 'none' | 'mean' | 'sum'; }optional- Optional configuration: -
weight: Per-class weights of shape [num_classes] (default: None) Allows upweighting important classes (e.g., rare diseases in medical imaging) -reduction: How to aggregate losses (default: 'mean') - 'none': per-sample losses [batch] - 'mean': average loss across batch and classes - 'sum': sum losses across batch and classes
Returns
Tensor– Loss tensor (scalar if reduction='mean'|'sum', or [batch] if reduction='none')Examples
// Multi-label image classification: each image can have multiple objects
const batch_size = 32;
const num_classes = 20;
// Network outputs (logits) for each class
const logits = torch.randn([batch_size, num_classes]);
// Binary labels: 1 if class present, 0 if absent
// Example: [1, 0, 1, 0, ...] means classes 0 and 2 are present
const labels = torch.randint(0, 2, [batch_size, num_classes]).to('float32');
const loss = torch.nn.functional.multilabel_soft_margin_loss(logits, labels);
// Loss = average of 20 independent binary cross-entropy losses per sample// Multi-label document classification with class weights
const doc_scores = torch.randn([64, 50]); // 64 docs, 50 topics/tags
const doc_labels = torch.randint(0, 2, [64, 50]).to('float32'); // which topics present
// Upweight rare/important topics
const topic_weights = torch.ones(50);
topic_weights[5] = 2.0; // topic 5 is important: weight 2x
topic_weights[15] = 1.5; // topic 15 is moderately important
const weighted_loss = torch.nn.functional.multilabel_soft_margin_loss(
doc_scores, doc_labels,
{ weight: topic_weights }
);
// Rare topics contribute more to total loss// Medical imaging: predict presence/absence of multiple findings
const batch_size = 32;
const num_conditions = 10; // pneumonia, TB, COVID-19, etc.
const diagnosis_scores = model(images); // [32, 10] logits
const diagnosis_labels = torch.tensor([
[1, 0, 0, 1, 0, 0, 0, 0, 0, 0], // Patient 0: pneumonia + COVID-19
[0, 1, 0, 0, 0, 0, 1, 0, 0, 0], // Patient 1: TB + bronchitis
// ... more patients
]).to('float32');
const diagnosis_loss = torch.nn.functional.multilabel_soft_margin_loss(
diagnosis_scores, diagnosis_labels
);
// Each condition independently trained to predict presence/absence// Comparison: margin vs soft margin loss
const scores = torch.randn([8, 5]);
const labels = torch.tensor([
[1, 0, 1, 0, 0],
[0, 1, 0, 1, 1],
// ... 6 more samples
]).to('float32');
// Hard margin: enforces ranking constraints (samples 0,1,3 > samples 2,4)
const margin_targets = torch.tensor([
[0, 2, -1, -1, -1], // classes 0 and 2 are positive
[1, 3, 4, -1, -1], // classes 1, 3, and 4 are positive
]);
// const margin_loss = torch.nn.functional.multilabel_margin_loss(scores, margin_targets);
// Soft margin: independent binary predictions for each class
const soft_loss = torch.nn.functional.multilabel_soft_margin_loss(scores, labels);
// Each class: independent sigmoid loss, treats classes independently// Multi-hot encoding: dense vs sparse representations
const batch_size = 16;
const num_tags = 100;
// Approach 1: Multi-hot encoding (batch, num_tags)
const multihot_labels = torch.zeros([batch_size, num_tags]);
// Set to 1 for present tags
multihot_labels[0, 5] = 1;
multihot_labels[0, 12] = 1;
multihot_labels[1, 3] = 1;
const scores = torch.randn([batch_size, num_tags]);
const loss = torch.nn.functional.multilabel_soft_margin_loss(scores, multihot_labels);
// Works directly with multi-hot format (1 for present, 0 for absent)// Per-class confidence: soft targets (values between 0 and 1)
const model_scores = torch.randn([32, 10]);
// Soft labels: confidence in class presence (not just 0/1)
const soft_labels = torch.tensor([
[0.9, 0.0, 0.8, 0.0, 0.1], // High confidence: classes 0,2; low: class 4
[0.0, 0.95, 0.0, 0.7, 0.0], // High confidence: classes 1,3
// ... more samples
]);
const loss = torch.nn.functional.multilabel_soft_margin_loss(
model_scores, soft_labels
);
// Supports soft targets (not just 0/1), enabling distillation and label smoothingSee Also
- PyTorch torch.nn.functional.multilabel_soft_margin_loss
- multilabel_margin_loss - Hard margin alternative with ranking constraints
- binary_cross_entropy - Single binary classification (sigmoid + cross-entropy)
- binary_cross_entropy_with_logits - More numerically stable binary loss
- torch.nn.MultiLabelSoftMarginLoss - Module wrapper
- cross_entropy - Multi-class single-label alternative