torch.nn.NLLLoss
new NLLLoss(options?: { weight?: Tensor; ignore_index?: number; reduction?: Reduction })
- readonly
weight(Tensor | null) - readonly
ignore_index(number) - readonly
reduction(Reduction)
Negative Log Likelihood (NLL) Loss: loss for pre-computed log-probabilities.
Computes negative log likelihood for classification, typically used with log_softmax. Similar to CrossEntropyLoss but takes log-probabilities as input instead of raw logits. Used when you want explicit control over the softmax/log_softmax computation or when working with already-computed log-probabilities.
When to use NLLLoss:
- You're using log_softmax explicitly (custom output processing)
- Working with pre-computed log-probabilities
- You want to separate softmax from loss computation
- Need explicit control over log-probability computation
- Theoretical/research code requiring explicit log probabilities
Trade-offs vs CrossEntropyLoss:
- CrossEntropyLoss recommended: Takes logits directly, simpler and more standard
- NLLLoss is lower-level: Requires you to apply log_softmax first
- Manual control: NLLLoss allows custom probability transformation
- Numerically equivalent: CE and NLL with log_softmax produce same result
- Common pattern: CE is cleaner; use NLL for custom probability transforms
Algorithm: Assumes input is log-probabilities (typically from log_softmax):
- loss_i = -log_prob[i, target_i]
- Final loss = mean(loss_i) or sum(loss_i) based on reduction
Works with weighted classes and ignore_index for imbalanced data.
- Lower-level than CrossEntropyLoss: Requires explicit log_softmax
- Mathematically equivalent: NLL + log_softmax = CrossEntropy
- Manual probability control: Use when you need custom probability transform
- Weight support: Can weight classes for imbalanced data
- Ignore index: Can ignore certain target values (e.g., padding tokens)
- Standard pattern: Most code uses CrossEntropyLoss directly; NLL for advanced use
- Positional flexibility: Works with reshaped batches for sequence prediction
Examples
// Basic NLL loss with log_softmax
const nll_loss = new torch.nn.NLLLoss();
const logits = torch.randn([32, 10]); // Batch of 32, 10 classes
const targets = torch.randint(0, 10, [32]); // Class labels
// Convert logits to log-probabilities
const log_probs = torch.log_softmax(logits, 1);
// Compute loss
const loss = nll_loss.forward(log_probs, targets);
// Equivalent to CrossEntropyLoss// Classification network with explicit log_softmax
class Classifier extends torch.nn.Module {
fc1: torch.nn.Linear;
fc2: torch.nn.Linear;
constructor() {
super();
this.fc1 = new torch.nn.Linear(784, 256);
this.fc2 = new torch.nn.Linear(256, 10);
}
forward(x: torch.Tensor): torch.Tensor {
let h = torch.nn.functional.relu(this.fc1.forward(x));
let logits = this.fc2.forward(h);
// Model outputs logits, not log-probs
return logits;
}
}
const model = new Classifier();
const nll_loss = new torch.nn.NLLLoss();
const batch_x = torch.randn([32, 784]);
const batch_y = torch.randint(0, 10, [32]);
const logits = model.forward(batch_x);
const log_probs = torch.log_softmax(logits, 1); // Explicit log_softmax
const loss = nll_loss.forward(log_probs, batch_y);// Handling imbalanced classes with class weights
// If class 0 is rare, give it higher weight
const class_weights = torch.tensor([5.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]);
const nll_weighted = new torch.nn.NLLLoss({ weight: class_weights });
const log_probs = torch.log_softmax(torch.randn([32, 10]), 1);
const targets = torch.randint(0, 10, [32]);
const loss = nll_weighted.forward(log_probs, targets);
// Rare class 0 is penalized more heavily// NLP with padding ignore index
const batch_size = 32;
const seq_len = 100;
const vocab_size = 5000;
const log_probs = torch.log_softmax(torch.randn([batch_size, seq_len, vocab_size]), 2);
const targets = torch.randint(0, vocab_size, [batch_size, seq_len]);
// Set padding token (index 0) to be ignored
const nll_nlp = new torch.nn.NLLLoss({ ignore_index: 0 });
const loss = nll_nlp.forward(log_probs.view([-1, vocab_size]), targets.view([-1]));
// Padding tokens don't contribute to loss// Comparing NLLLoss vs CrossEntropyLoss
const logits = torch.randn([32, 10]);
const targets = torch.randint(0, 10, [32]);
// Using CrossEntropyLoss (direct from logits)
const ce_loss = new torch.nn.CrossEntropyLoss();
const ce = ce_loss.forward(logits, targets);
// Using NLLLoss (from log-probabilities)
const log_probs = torch.log_softmax(logits, 1);
const nll_loss = new torch.nn.NLLLoss();
const nll = nll_loss.forward(log_probs, targets);
// ce and nll should be approximately equal