torch.nn.functional.cross_entropy
function cross_entropy(input: Tensor, target: Tensor): Tensorfunction cross_entropy(input: Tensor, target: Tensor, weight: Tensor | null, size_average: boolean | null, ignore_index: number, reduce: boolean | null, reduction: 'none' | 'mean' | 'sum', label_smoothing: number, options: CrossEntropyFunctionalOptions): TensorCross Entropy Loss: standard loss function for multi-class classification from raw logits.
Computes the cross-entropy loss between predicted logits and target class indices. This is the go-to loss function for classification tasks in deep learning. It combines softmax normalization and negative log-likelihood into a single, numerically stable operation. Implemented as nll_loss(log_softmax(input, { dim: -1 }), target). Essential for:
- Multi-class classification (image classification, text classification, object detection)
- Neural network training with class targets (standard supervised learning)
- Tasks with mutually exclusive categories
- Pre-training and fine-tuning classification heads
- Any situation where you want to predict one of K mutually exclusive classes
Why Cross-Entropy: Cross-entropy measures the divergence between predicted probability distribution (from softmax) and true one-hot distribution. Minimizing cross-entropy is equivalent to maximizing likelihood of correct class. Empirically, it enables faster learning than MSE loss for classification.
Softmax Intuition: Softmax converts raw logits to probability distribution (all positive, sum to 1). Cross-entropy penalizes assigning low probability to the correct class. As model confidence in correct class increases (probability → 1), loss → 0. As confidence decreases (probability → 0), loss → ∞.
\begin{aligned} \text{Softmax: } p_i = \frac{\exp(\text{logit}_i)}{\sum_j \exp(\text{logit}_j)} \\ \text{Cross-Entropy: } L = -\log(p_{\text{target}}) \\ \text{Final loss} = \begin{cases} \\ \text{Gradient w.r.t. logit: } \frac{\partial L}{\partial \text{logit}_i} = \begin{cases} \end{aligned}- Expects raw logits, not probabilities: Input should be unnormalized scores from model output (e.g., from linear layer). Do NOT apply softmax first; cross_entropy does that internally.
- Numerically stable: Uses log-softmax internally to avoid numerical overflow/underflow that would occur with naive exp(logit) / sum(exp(logits)) calculation.
- One-hot targets: Targets should be class indices, not one-hot encoded vectors. Unlike some frameworks, this expects integer class labels directly.
- Standard choice for classification: When unsure about loss function, use cross-entropy for multi-class classification. It's the standard and usually works best.
- Batch independence: With reduction='mean', loss is normalized by batch size, making it independent of batch size. Critical for consistent learning rates across different batch sizes.
- Gradient saturation avoided: Combining softmax + log + neg_log_likelihood directly (rather than softmax then NLL) helps avoid gradient saturation and numerical issues at extremes.
- Identical to nll_loss(log_softmax(x, -1), y): But more stable and concise. Use cross_entropy directly unless you need pre-computed log-probabilities from another source.
- Invalid target values: Target indices must be in [0, num_classes-1]. Out-of-range indices cause undefined behavior. Always validate that targets match logits dimensionality.
- Shape mismatch: Input must be 2D [batch_size, num_classes] and targets 1D [batch_size]. Mismatched shapes throw errors. Transposed inputs are a common mistake.
- Don't pre-apply softmax: If you apply softmax before cross_entropy, it will softmax twice, giving wrong results. Pass raw logits directly.
- Batch size mismatch: If batch dimension of input and target don't match, error is thrown. Verify batching is consistent across model output and labels.
- Overflow risk with extreme logits: Very large logits ( 100) can cause overflow in intermediate calculations (though log-softmax mitigates this). Keep logits in reasonable range.
Parameters
inputTensor- Raw logits (unnormalized scores) from model output. Shape [batch_size, num_classes]. Can be any real-valued tensor (e.g., output of linear layer without activation).
targetTensor- Target class indices (0 to num_classes-1). Shape [batch_size]. Integer-like values.
Returns
Tensor– Loss tensor. Scalar (shape []) if reduction='mean' or 'sum', shape [batch_size] if reduction='none'Examples
// Standard image classification: 32-sample batch, 10 classes (CIFAR-10)
const logits = torch.randn(32, 10); // Model output: raw scores
const targets = torch.tensor([...]); // Class indices: 0-9
const loss = torch.nn.functional.cross_entropy(logits, targets); // Scalar loss
// Training loop: update weights to minimize loss
for (let epoch = 0; epoch < 100; epoch++) {
const logits = model(images); // Forward pass
const loss = torch.nn.functional.cross_entropy(logits, class_labels);
optimizer.zero_grad();
loss.backward(); // Compute gradients
optimizer.step(); // Update weights
}
// Multi-class with 1000 categories (ImageNet classification)
const batch_logits = torch.randn(256, 1000); // [batch=256, classes=1000]
const batch_targets = torch.floor(torch.rand(256).mul(1000)); // Random targets for demo
const loss = torch.nn.functional.cross_entropy(batch_logits, batch_targets);
// Loss typically in range [0, log(1000)] ≈ [0, 6.9] (lower is better)
// Analyzing per-sample losses: find hard examples
const per_sample_losses = torch.nn.functional.cross_entropy(
logits, targets, 'none' // Shape [batch_size]
);
// Find hardest samples (highest loss)
const { values: hardest_losses, indices: hard_indices } = per_sample_losses.topk(10);
// Use these samples for curriculum learning or hard negative mining
// Text classification: predicting one of C topics
const num_topics = 20;
const text_embeddings = encoder(text); // [batch, embedding_dim]
const classification_logits = classifier(text_embeddings); // [batch, num_topics]
const topic_labels = torch.tensor([5, 3, 12, 7, ...]); // True topic for each text
const loss = torch.nn.functional.cross_entropy(
classification_logits, topic_labels
);See Also
- PyTorch torch.nn.functional.cross_entropy
- nll_loss - Lower-level function for pre-computed log-probabilities
- log_softmax - Compute log-probabilities; used internally by cross_entropy
- softmax - Normalize to probabilities (not used in loss, but useful for inference)
- torch.nn.CrossEntropyLoss - Stateful class-based version with optional weight per class
- binary_cross_entropy_with_logits - Loss for binary classification from logits
- mse_loss - Alternative loss for regression or when cross-entropy doesn't work