torch.nn.CrossEntropyLoss
class CrossEntropyLoss extends Modulenew CrossEntropyLoss(options?: {
weight?: Tensor;
ignore_index?: number;
reduction?: Reduction;
label_smoothing?: number;
})
- readonly
weight(Tensor | null) - readonly
ignore_index(number) - readonly
reduction(Reduction) - readonly
label_smoothing(number)
Cross Entropy Loss: standard loss for multi-class classification.
Combines softmax and negative log likelihood, measuring divergence between predicted probability distribution and target class distribution. Most widely used loss for:
- Image classification (ImageNet, CIFAR, MNIST: predicting class from image)
- Text classification (sentiment analysis, topic prediction, intent classification)
- Named entity recognition (NER: classifying token types)
- Machine translation (predicting next word from vocabulary)
- Speech recognition (classifying phonemes)
- Multi-class categorization (assigning single label from many classes)
Accepts raw logits (un-normalized scores) and automatically applies softmax. Numerically stable implementation prevents overflow/underflow with large logits. Perfect for one-hot encoded targets (single correct class per sample).
- Input format: Takes raw logits, NOT probabilities (applies softmax internally)
- Target format: Class indices, not one-hot encoding. For class c, use scalar c
- Softmax applied: Numerically stable implementation, handles large logits
- Class balancing: Use weights parameter for imbalanced datasets
- Label smoothing: Helps with overfitting and calibration (typical: 0.1)
- Ignore index: Useful for padding tokens (NLP) or invalid samples
- Gradient properties: Non-zero gradients, smooth learning landscape
- Common mistake: Using softmax output instead of raw logits (double softmax)
- Computational: O(batch_size × num_classes) - efficient
- Standard in NLP: Default choice for language modeling, machine translation
Examples
// Image classification: CIFAR-10 (10 classes)
const ce_loss = new torch.nn.CrossEntropyLoss();
// Model outputs logits [batch, num_classes]
const logits = torch.randn([32, 10]); // 32 images, 10 classes
// Target class IDs [batch]
const targets = torch.tensor([0, 5, 3, 7, 2, 1, 4, 9, 3, 2], { dtype: 'int32' });
// (Extended to 32 in real use)
// Compute loss
const loss = ce_loss.forward(logits, targets); // scalar
// Internally: applies softmax to logits, then -log(p[target_class])// Classification network
class ImageClassifier extends torch.nn.Module {
conv1: torch.nn.Conv2d;
conv2: torch.nn.Conv2d;
fc1: torch.nn.Linear;
fc2: torch.nn.Linear;
constructor(num_classes: number) {
super();
this.conv1 = new torch.nn.Conv2d(3, 32, 3, { padding: 1 });
this.conv2 = new torch.nn.Conv2d(32, 64, 3, { padding: 1 });
this.fc1 = new torch.nn.Linear(64 * 8 * 8, 128);
this.fc2 = new torch.nn.Linear(128, num_classes); // Logits, no softmax
}
forward(x: torch.Tensor): torch.Tensor {
x = this.conv1.forward(x);
x = torch.nn.functional.relu(x);
x = torch.nn.functional.max_pool2d(x, 2);
x = this.conv2.forward(x);
x = torch.nn.functional.relu(x);
x = torch.nn.functional.max_pool2d(x, 2);
x = x.view(x.shape[0], -1); // Flatten
x = this.fc1.forward(x);
x = torch.nn.functional.relu(x);
x = this.fc2.forward(x); // Return logits, NOT softmax
return x;
}
}
const model = new ImageClassifier(10);
const ce = new torch.nn.CrossEntropyLoss();
// Forward pass
const batch_x = torch.randn([32, 3, 32, 32]); // CIFAR-10 images
const batch_y = torch.randint(0, 10, [32], { dtype: 'int32' });
const logits = model.forward(batch_x);
const loss = ce.forward(logits, batch_y);// Handling class imbalance with weights
const class_weights = torch.tensor([1.0, 2.0, 0.5, 1.5, 1.0], { dtype: 'float32' });
const ce_weighted = new torch.nn.CrossEntropyLoss({ weight: class_weights });
// Now underrepresented classes (weight > 1) are more important
// Useful for datasets where some classes appear rarely// Label smoothing for regularization
const ce_smooth = new torch.nn.CrossEntropyLoss({ label_smoothing: 0.1 });
// Instead of one-hot [0, 0, 1, 0, 0], target becomes:
// [0.025, 0.025, 0.9, 0.025, 0.025]
// Prevents model from being overconfident, improves generalization
const logits = torch.randn([32, 5]);
const targets = torch.tensor([2, 0, 4, 1, 3], { dtype: 'int32' });
const loss = ce_smooth.forward(logits, targets);// Ignoring specific classes (e.g., padding tokens in NLP)
const PAD_ID = -1; // Padding class index
const ce_ignore = new torch.nn.CrossEntropyLoss({ ignore_index: PAD_ID });
const logits = torch.randn([32, 1000]); // Vocabulary size 1000
const targets = torch.tensor([5, 23, -1, 102, 54, -1], { dtype: 'int32' });
// Samples with target -1 (padding) don't contribute to loss
const loss = ce_ignore.forward(logits, targets);// Evaluation: converting logits to probabilities
const ce = new torch.nn.CrossEntropyLoss();
const logits = torch.tensor([[2.0, 1.0, 0.1]], { dtype: 'float32' });
// For prediction
const probs = torch.nn.functional.softmax(logits, -1); // [0.66, 0.24, 0.10]
const pred_class = torch.argmax(probs, -1); // 0
// For loss computation (training)
const target = torch.tensor([0], { dtype: 'int32' });
const loss = ce.forward(logits, target);