torch.poisson
function poisson<S extends Shape, D extends DType = DType, Dev extends DeviceType = DeviceType>(input: Tensor<S, D, Dev>): Tensor<S, D, Dev>Draws samples from a Poisson distribution with specified rate parameters.
For each rate λ in the input tensor, samples the number of events occurring in a fixed interval. Returns non-negative integers. The Poisson distribution models count data and rare events. Essential for:
- Count data modeling: Events, arrivals, occurrences
- Stochastic simulation: Generating count-based random variables
- Data augmentation: Adding Poisson noise to counts
- Probabilistic models: Modeling count-dependent phenomena
- Queue simulation: Modeling arrivals in queueing systems
- Spike trains: Modeling neural spike counts and timing
Implementation: Uses Knuth's algorithm (inverse transform sampling) for CPU. GPU uses dedicated kernel for efficient batch sampling.
- Non-negative output: Results are always integers ≥ 0
- Rate interpretation: Higher λ → higher variance and mean
- Integer output: Returns exact counts, not floating-point approximations
- Knuth algorithm: Efficient for small λ ( ~30)
- Shape preservation: Output has same shape as input
- Variance equals mean: Poisson property: Var = E = λ
- GPU efficient: GPU kernel handles batch operations efficiently
- Non-negative rates: Input must be ≥ 0; negative rates are invalid
- No gradient flow: Sampling operation is non-differentiable
- Knuth efficiency: Method is O(λ) per sample; slow for very large λ
- Random seed: Results depend on global RNG seed; use manual_seed for reproducibility
- Large λ approximation: For very large λ, may use normal approximation (check implementation)
Parameters
inputTensor<S, D, Dev>- Tensor of rate parameters (λ values, must be ≥ 0) with shape (...)
Returns
Tensor<S, D, Dev>– Tensor with same shape as input, containing non-negative integer countsExamples
// Simple Poisson sampling
const rates = torch.tensor([1.0, 4.0, 10.0]);
const samples = torch.poisson(rates); // e.g., [0, 5, 9]
// Model event arrivals (e.g., customer arrivals)
const arrival_rate = 5.0; // Average 5 arrivals per hour
const n_hours = 10;
const rate_tensor = torch.full([n_hours], arrival_rate);
const arrivals = torch.poisson(rate_tensor); // Counts per hour
// arrivals might be [4, 6, 5, 3, 7, 5, 4, 6, 5, 4]
// Add Poisson noise to count data
const true_counts = torch.tensor([10, 20, 15, 25, 30]);
const noise = torch.poisson(true_counts.mul(0.1)); // 10% Poisson noise
const noisy_counts = true_counts.add(noise); // Add noise
// Model spike counts in neural recordings
const firing_rates = torch.randn(100, 50).abs(); // 100 neurons, 50 time bins
const spike_counts = torch.poisson(firing_rates); // Integer spike counts per neuron per bin
// Batch sampling with varying rates
const batch_rates = torch.tensor([
[1.0, 2.0, 3.0],
[0.5, 1.5, 2.5],
[2.0, 3.0, 4.0]
]); // 3x3 matrix of rates
const counts = torch.poisson(batch_rates); // 3x3 matrix of counts
// Generate synthetic event sequences
const event_rate = torch.tensor([3.0]); // Average 3 events per interval
const n_intervals = 1000;
const events = torch.poisson(torch.ones(n_intervals).mul(3.0)); // Events across intervalsSee Also
- PyTorch torch.poisson()
- bernoulli - Binary random sampling
- randint - Discrete uniform sampling
- multinomial - Categorical distribution sampling
- normal - Normal/Gaussian distribution sampling
- rand - Continuous uniform sampling