torch.distributions.kl_divergence
function kl_divergence(p: Distribution, q: Distribution): TensorCompute the Kullback-Leibler (KL) divergence between two distributions.
Measures how one probability distribution P differs from another distribution Q using information theory. KL(P||Q) quantifies the expected number of extra bits needed to encode samples from P using a code optimized for Q. It is a fundamental measure of distributional divergence, though it is asymmetric (KL(P||Q) ≠ KL(Q||P)). Useful for:
- Variational inference: Measuring divergence between approximate and true posteriors
- Generative models: Training objectives for VAEs, diffusion models, and other latent variable models
- Model comparison: Determining how well one distribution approximates another
- Optimization: Fitting distributions to data by minimizing KL divergence
- Information theory: Quantifying the information loss from distribution approximation
The computation uses pre-registered implementations optimized for each distribution pair, ensuring numerical stability and efficiency. Common pairs (Normal, Exponential, Bernoulli, Categorical, Laplace, Gamma, Beta, Dirichlet) are built-in.
- Asymmetric measure: KL(P||Q) ≠ KL(Q||P). Ensure correct parameter order.
- Always non-negative: KL divergence ≥ 0, with equality iff P = Q almost everywhere.
- Support constraint: P's support must be contained in Q's support, else KL = ∞.
- Numerical stability: Implementations use log-space computation to prevent underflow/overflow.
- Batch support: Handles batched distributions with multiple parameter sets automatically.
- Missing registrations: Raises error if the distribution pair is not registered. Use register_kl() to add support.
- Divergence to self: While KL(P||P) = 0 in theory, numerical precision may yield small non-zero values.
Parameters
pDistribution- The first/reference distribution (where samples are drawn from)
qDistribution- The second distribution (the one used to approximate p)
Returns
Tensor– A tensor containing KL(p || q) values (scalar or batch)Examples
// Compare two normal distributions
const p = torch.distributions.Normal(torch.tensor(0), torch.tensor(1));
const q = torch.distributions.Normal(torch.tensor(1), torch.tensor(2));
const kl = torch.distributions.kl_divergence(p, q);// VAE loss - compare learned posterior to standard normal prior
const posterior = torch.distributions.Normal(mean, log_var.mul(0.5).exp());
const prior = torch.distributions.Normal(
torch.zeros(latent_dim),
torch.ones(latent_dim)
);
const kl_loss = torch.distributions.kl_divergence(posterior, prior).mean();// Batch KL divergence - multiple distribution pairs
const p = torch.distributions.Categorical(logits_p); // Shape: [batch_size, num_classes]
const q = torch.distributions.Categorical(logits_q); // Shape: [batch_size, num_classes]
const kl = torch.distributions.kl_divergence(p, q); // Shape: [batch_size]See Also
- PyTorch torch.distributions.kl_divergence()
- register_kl - Register custom KL divergence implementations for distribution pairs
- Normal - Gaussian distribution (Normal)
- Categorical - Categorical distribution for discrete outcomes
- Bernoulli - Bernoulli distribution for binary outcomes