torch.nn.PoissonNLLLoss
class PoissonNLLLoss extends Modulenew PoissonNLLLoss(options?: {
log_input?: boolean;
full?: boolean;
eps?: number;
reduction?: Reduction;
})
- readonly
log_input(boolean) - readonly
full(boolean) - readonly
eps(number) - readonly
reduction(Reduction)
Poisson Negative Log Likelihood Loss: for count data regression.
Computes NLL loss assuming target follows a Poisson distribution. Used for predicting count data (integers ≥ 0) where the count is modeled as following a Poisson distribution with parameter λ estimated by the model. Useful for:
- Count prediction (number of events, clicks, occurrences)
- Document length prediction
- Arrival time prediction
- Any task where output is a non-negative integer count
When to use PoissonNLLLoss:
- Predicting count data (integers ≥ 0)
- Target has Poisson distribution (events occur independently at constant rate)
- Count regression problems
- When variance scales with mean (Poisson property)
- Rarely used; typically for specialized count prediction tasks
Trade-offs:
- vs MSELoss: Poisson for count data; MSE for continuous
- vs Negative Binomial: Poisson when variance = mean; Negative Binomial when variance > mean
- Assumption: Targets must be non-negative integers
- Distribution: Assumes Poisson distribution of targets
Algorithm: The Poisson NLL loss is:
- loss = λ - target * log(λ)
- Where λ is the predicted Poisson parameter (mean)
- If log_input=True, input is log(λ), otherwise input is λ
- Full loss includes Stirling approximation: target * log(target) - target
- Count data: Targets must be non-negative integers
- Poisson assumption: Assumes variance equals mean (var = λ)
- Log-input: Typically true for numerical stability
- Small targets: Works well for small/moderate counts
- Specialized: Less common than MSE; for specific count prediction tasks
- Variance property: Poisson naturally handles variance scaling with mean
- Targets must be non-negative (count data)
- Assumes Poisson distribution (variance = mean)
Examples
// Count prediction: predicting number of events
const poisson_loss = new torch.nn.PoissonNLLLoss({ log_input: true });
// Model predicts log-Poisson parameters
const log_lambda = torch.randn([32, 1]); // Log-parameters
// True counts (non-negative integers)
const counts = torch.tensor([[0], [1], [3], [2], [5], ...]);
const loss = poisson_loss.forward(log_lambda, counts);
// Minimizes Poisson NLL for count prediction// Document length prediction
class DocumentLengthPredictor extends torch.nn.Module {
fc1: torch.nn.Linear;
fc2: torch.nn.Linear;
constructor() {
super();
this.fc1 = new torch.nn.Linear(768, 256); // From text encoder
this.fc2 = new torch.nn.Linear(256, 1); // Predict length
}
forward(x: torch.Tensor): torch.Tensor {
let h = torch.nn.functional.relu(this.fc1.forward(x));
// Return log of expected length (Poisson parameter)
return this.fc2.forward(h);
}
}
const model = new DocumentLengthPredictor();
const loss_fn = new torch.nn.PoissonNLLLoss({ log_input: true });
const text_encodings = torch.randn([32, 768]);
const doc_lengths = torch.tensor([100, 150, 75, 200, ...]); // Actual lengths
const log_lambda = model.forward(text_encodings);
const loss = loss_fn.forward(log_lambda, doc_lengths.unsqueeze(1).float());// Click prediction for ads/search results
const poisson = new torch.nn.PoissonNLLLoss({ log_input: true, full: false });
// Model predicts expected number of clicks
const log_click_rate = torch.randn([1000, 1]);
// Observed click counts
const observed_clicks = torch.tensor([
[0], [1], [2], [0], [3], [1], [0], [0], [1], // ...
]);
const loss = poisson.forward(log_click_rate, observed_clicks.float());