torch.nn.BCEWithLogitsLoss
class BCEWithLogitsLoss extends Modulenew BCEWithLogitsLoss(options?: { weight?: Tensor; reduction?: Reduction; pos_weight?: Tensor })
Binary Cross Entropy with Logits: numerically stable BCE for binary/multi-label classification.
Combines sigmoid and binary cross entropy in a single operation with numerical stability. Avoids precision loss from computing sigmoid → log separately. Recommended choice for:
- Binary classification (more stable than BCELoss)
- Multi-label classification (independent binary decisions)
- Imbalanced binary problems (with pos_weight parameter)
- Object detection (foreground/background per anchor)
- Anomaly detection (normality scores per sample)
- Any task producing binary decisions
Takes raw logits (unbounded values) instead of probabilities, computing sigmoid internally. Numerically more stable, especially with extreme logit values. Preferred over BCELoss.
- Input format: Takes raw logits, NOT probabilities (critical difference from BCELoss)
- Numerical stability: Implemented with log-sum-exp trick for numerical stability
- Preferred over BCELoss: Always use this unless input is already sigmoid-ed
- pos_weight usage: For imbalanced data where minority class is important
- Multi-label: Each output treated as independent binary classification
- Gradient behavior: Smooth gradients, good for optimization
- Common pattern: FC → (no sigmoid) → BCEWithLogitsLoss
- Sigmoid application: Applied internally during loss computation
- Computational: O(batch_size × num_elements) - efficient
- Prediction threshold: Use 0.5 as threshold for converting probabilities to binary
Examples
// Binary classification with logits (preferred over BCELoss)
const bce_logits = new torch.nn.BCEWithLogitsLoss();
// Raw logits from model (NOT sigmoid-ed)
const logits = torch.randn([32, 1]);
// Binary targets (0 or 1)
const targets = torch.randint(0, 2, [32, 1]);
// Compute loss (applies sigmoid internally)
const loss = bce_logits.forward(logits, targets);// Binary classification network with logits
class BinaryClassifier extends torch.nn.Module {
fc1: torch.nn.Linear;
fc2: torch.nn.Linear;
constructor(input_dim: number) {
super();
this.fc1 = new torch.nn.Linear(input_dim, 64);
this.fc2 = new torch.nn.Linear(64, 1);
}
forward(x: torch.Tensor): torch.Tensor {
x = this.fc1.forward(x);
x = torch.nn.functional.relu(x);
x = this.fc2.forward(x); // Return logits, NOT sigmoid
return x;
}
}
const model = new BinaryClassifier(100);
const bce = new torch.nn.BCEWithLogitsLoss();
const batch_x = torch.randn([32, 100]);
const batch_y = torch.randint(0, 2, [32, 1]);
const logits = model.forward(batch_x); // Raw logits
const loss = bce.forward(logits, batch_y);// Handling imbalanced data with pos_weight
const pos_weight = torch.tensor([5.0]); // Positive class 5x more important
const bce_weighted = new torch.nn.BCEWithLogitsLoss({ pos_weight });
// In datasets with 90% negatives, 10% positives
// Use pos_weight ≈ num_negatives / num_positives ≈ 9
const logits = torch.randn([32, 1]);
const targets = torch.randint(0, 2, [32, 1]);
const loss = bce_weighted.forward(logits, targets);// Multi-label classification with logits
const num_labels = 10;
const bce = new torch.nn.BCEWithLogitsLoss();
const logits = torch.randn([32, num_labels]); // Multiple independent binary decisions
const targets = torch.randint(0, 2, [32, num_labels]); // Multi-hot encoded
const loss = bce.forward(logits, targets); // Each label treated independently// Prediction: converting logits to probabilities
const logits = torch.tensor([[-2.0, 0.0, 2.0], [-1.0, 1.0, 3.0]]);
const probs = torch.sigmoid(logits); // Convert to probabilities
const predictions = torch.where(probs > 0.5, 1, 0); // Threshold at 0.5
// For loss computation (training)
const bce = new torch.nn.BCEWithLogitsLoss();
const targets = torch.tensor([[0, 0, 1], [0, 1, 1]]);
const loss = bce.forward(logits, targets);