torch.nn.functional.poisson_nll_loss
function poisson_nll_loss(input: Tensor, target: Tensor, options?: {
log_input?: boolean;
full?: boolean;
eps?: number;
reduction?: 'none' | 'mean' | 'sum';
}): TensorPoisson negative log likelihood loss for count data and event prediction.
Measures negative log likelihood under Poisson distribution assumption. Assumes target values follow Poisson distribution with predicted mean/rate parameters. Computes -log P(target | predicted_rate) to measure prediction quality. Essential for:
- Count data modeling (clicks, events, occurrences per unit time)
- Rare event prediction (earthquakes, network failures, uncommon failures)
- Event rate estimation (visitors per day, calls per hour)
- Object detection (bounding box regression for object counts)
- Traffic/demand forecasting (vehicle counts, passenger loads)
- Biological/medical data (cell counts, incident counts in epidemiology)
- Time-series prediction with count outcomes (daily sales, page views)
Poisson distribution intuition: Poisson models count data: probability of k events when expected rate is λ. P(target=k | λ) = (λ^k * e^(-λ)) / k! ≈ λ for large k (Stirling approximation) Loss = -log P(target | λ) = λ - targetlog(λ) + klog(k) - k (full form)
Two parameter modes:
- log_input=True: input is log(λ), loss = exp(input) - target*input Better numerical stability, preferred when λ can be very small
- log_input=False: input is λ directly, loss = input - target*log(input+eps) When already have rate parameter (non-negative outputs)
Full Stirling approximation: full=False: simplified loss = λ - klog(λ) (faster, sufficient for most uses) full=True: adds Stirling term = klog(k) - k + 0.5log(2πk) Exact negative log likelihood but slower (k! ≈ √(2πk)(k/e)^k)
Key parameters and trade-offs:
- log_input: stability vs input format (Poisson λ > 0, log(λ) ∈ ℝ)
- full: accuracy vs speed (usually False fine; True only if need exact NLL)
- eps: numerical safety (prevents log(0), typical 1e-8)
- Poisson assumption: Assumes target follows Poisson distribution
- Non-negative output: λ (rate) must be ≥ 0; use exp() or softplus if needed
- Log-space stability: log_input=true more stable than direct rate input
- Count data: Naturally suited for non-negative integer counts
- Mean-variance relationship: Poisson has variance = mean (unique property)
- Overdispersion: If variance mean, use negative binomial instead
- Zero-inflation: If excess zeros, use zero-inflated Poisson
- Gradient stability: log_input=true prevents vanishing gradients for small λ
- Target non-negative: Must have target ≥ 0; negative values undefined
- log_input mismatch: Ensure log_input matches output layer design
- Very small rates: Small λ → high variance; use larger batches or regularization
- Very large counts: Large k → numerical issues in Stirling term; avoid if possible
- Full approximation: Stirling term only added where target 1; watch edge cases
Parameters
inputTensor- Predicted Poisson rate parameter (mean/λ) for each sample If log_input=True: log(λ) values ∈ ℝ (unrestricted), any shape OK If log_input=False: λ values ≥ 0 (counts/rates), any shape OK Example: logits from final layer; if output layer uses exp() then log_input=True
targetTensor- Target count values (non-negative integers or floats) Shape must match input; values ≥ 0 (represent observed counts) Example: actual number of events, objects, occurrences
options{ log_input?: boolean; full?: boolean; eps?: number; reduction?: 'none' | 'mean' | 'sum'; }optional- Optional configuration: -
log_input: Whether input is log(λ) (default: true) - true: input is log-rate, more numerically stable - false: input is rate directly, needs exp() at output -full: Include full Stirling approximation (default: false) - true: exact NLL = λ - k*log(λ) + k*log(k) - k + 0.5*log(2πk) - false: simplified = λ - k*log(λ) (usually sufficient) -eps: Small constant for numerical stability (default: 1e-8) - Prevents log(0) when computing log(λ) for small values -reduction: How to aggregate losses (default: 'mean') - 'none': per-sample losses [batch, ...] - 'mean': average loss - 'sum': sum losses
Returns
Tensor– Loss tensor (same shape as input/target if reduction='none', scalar otherwise)Examples
// Count prediction: predict number of events
const batch_size = 32;
// Network predicts log-rate (log-Poisson parameter)
const log_rates = torch.randn([batch_size]); // Could be -2, -1, 0, 1, etc.
// True event counts (observed data)
const event_counts = torch.tensor([0, 1, 0, 2, 1, 3, 0, 1, 2, 0, ...]); // [batch]
const loss = torch.nn.functional.poisson_nll_loss(
log_rates, event_counts,
{ log_input: true } // Input is log(λ)
);
// Loss encourages log_rates to predict observed event counts// Traffic prediction: forecast vehicle counts
const batch_size = 48; // 48 time intervals
// Model outputs log-Poisson rates (from RNN or transformer)
const predicted_log_rates = model(historical_data); // [48]
// Actual observed vehicle counts
const actual_counts = torch.tensor([12, 15, 8, 20, 18, 25, ...]); // [48]
const traffic_loss = torch.nn.functional.poisson_nll_loss(
predicted_log_rates,
actual_counts.to('float32'),
{ log_input: true, reduction: 'mean' }
);// Object detection: predict count of objects in regions
const batch_size = 64;
const num_regions = 100; // Divide image into grid
// Network predicts log-rates for each region
const predicted_counts = model(image); // [64, 100] log-Poisson parameters
// Ground truth object counts per region
const true_counts = torch.randint(0, 5, [64, 100]).to('float32');
const count_loss = torch.nn.functional.poisson_nll_loss(
predicted_counts,
true_counts,
{ log_input: true, full: false } // Simplified loss, usually sufficient
);// Comparison: log_input vs direct input
const target = torch.tensor([0, 1, 2, 5, 10]); // Counts
// Approach 1: Network outputs log-rate directly
const log_rate = model(features); // Returns values in ℝ
const loss1 = torch.nn.functional.poisson_nll_loss(
log_rate, target.to('float32'),
{ log_input: true } // log_input=true for log-space output
);
// Approach 2: Network outputs rate via exp()
const rate = model_with_exp(features); // Returns exp(x), values > 0
const loss2 = torch.nn.functional.poisson_nll_loss(
rate, target.to('float32'),
{ log_input: false } // log_input=false for rate output
);
// Both approaches equivalent; choose based on output layer design// Rare event prediction: earthquake aftershock counts
const time_periods = 1000; // 1000 days
// Seismic model predicts log-rate of aftershocks
const predicted_log_rates = seismic_model(features); // [1000]
// Observed aftershock counts per day (mostly 0s, few large values)
const observed_counts = torch.tensor([0,0,0,1,0,0,2,0,0,0,5,0,0,1,...]);
const seismic_loss = torch.nn.functional.poisson_nll_loss(
predicted_log_rates,
observed_counts.to('float32'),
{ log_input: true, full: false, eps: 1e-8 }
);
// Poisson suitable for sparse, rare count data// With full Stirling approximation (exact negative log likelihood)
const predictions = torch.randn([16]);
const counts = torch.tensor([2, 3, 1, 5, 4, 2, 3, 1, ...]);
// Exact NLL with Stirling term (slower but accurate)
const exact_loss = torch.nn.functional.poisson_nll_loss(
predictions,
counts.to('float32'),
{ log_input: true, full: true }
);
// Includes: λ - k*log(λ) + k*log(k) - k + 0.5*log(2πk)
// Simplified loss (usually sufficient, faster)
const simple_loss = torch.nn.functional.poisson_nll_loss(
predictions,
counts.to('float32'),
{ log_input: true, full: false }
);
// Only: λ - k*log(λ), sufficient for trainingSee Also
- PyTorch torch.nn.functional.poisson_nll_loss
- gaussian_nll_loss - For continuous targets (normal distribution)
- torch.nn.PoissonNLLLoss - Module wrapper
- negative_binomial - Alternative for overdispersed count data (if available)