torch.special.polygamma
function polygamma<S extends Shape>(n: number, input: Tensor<S, 'float32'>, _options?: SpecialUnaryOptions<S>): Tensor<S, 'float32'>Computes the n-th derivative of the digamma function.
The polygamma function ψ^(n)(x) = d^n/dx^n ψ(x) (n-th derivative of digamma) appears throughout Bayesian statistics and variational inference. Essential for:
- Variational inference: computing KL divergences for conjugate priors (gamma, Dirichlet, Wishart)
- Bayesian deep learning: gradient-based posterior approximation, ELBO optimization, reparameterization
- Statistical learning: maximum likelihood estimation with gamma/Dirichlet, moment matching
- Approximate inference: black-box variational inference, gradient estimators for discrete latent variables
- Optimization: parameterized exponential families, natural gradient descent in information geometry
- Risk modeling: loss function derivatives, portfolio optimization with risk measures
- Probabilistic modeling: graphical models with continuous latent variables, deep generative models
Hierarchy of Derivatives: n=0 is digamma ψ(x), n=1 is trigamma ψ'(x), n≥2 are higher polygammas. Each order adds new structure: trigamma controls variance in Bayesian updates, higher orders affect higher cumulants.
Central to Exponential Families: Log-partition function of gamma/Dirichlet/Wishart involves log Γ(x); digamma ψ and polygamma ψ^(n) encode mean/variance/cumulants of natural parameter space.
- Order n=0: Digamma ψ(x) = d/dx log Γ(x); fundamental special function related to derivatives of gamma
- Order n=1: Trigamma ψ'(x) = variance term in exponential family distributions; appears in Hessian of ELBO
- Higher orders (n≥2): Higher cumulants; rarely needed but appear in refined approximations
- Digamma properties: ψ(1) = -γ (Euler constant), ψ(n+1) = ψ(n) + 1/n (recurrence), ψ(1/2) = -2 ln(2) - γ
- Digamma connection to means: ψ(α) = E[log X] for X ~ Gamma(α, β); fundamental for variational inference
- Trigamma positivity: ψ'(x) 0 for all x 0; related to log-concavity of Γ
- Asymptotic series: ψ(x) ~ log(x) - 1/(2x) - 1/(12x²) + ... for large x (used for numerics)
- Domain x 0 required: Undefined for x ≤ 0 (poles and singularities); handle carefully
- Poles at negative integers: If input contains negative integers, results are NaN/Inf
- Numerical stability for small x: For x 0.1, use asymptotic series or transformation; forward recurrence
- Large n polygamma: Higher orders decay as 1/x^n+1; may underflow for large n and moderate x
Parameters
nnumber- Order of derivative (non-negative integer). n=0 gives digamma ψ(x), n=1 gives trigamma ψ'(x), n≥2 higher derivatives
inputTensor<S, 'float32'>- Input tensor x. Must be x 0 for mathematical convergence. Can be scalar or Tensor
_optionsSpecialUnaryOptions<S>optional
Returns
Tensor<S, 'float32'>– Tensor with ψ^(n)(x) valuesExamples
// Order 0: digamma function ψ(x)
const x = torch.linspace(0.1, 3, 5);
const digamma = torch.special.polygamma(0, x); // ψ(x) = d/dx log Γ(x)
// digamma(1) ≈ -0.5772 (Euler-Mascheroni constant)
// digamma(2) ≈ 0.4228
// digamma(3) ≈ 0.9227
// Order 1: trigamma function ψ'(x)
const trigamma = torch.special.polygamma(1, x); // ψ'(x) = d²/dx² log Γ(x)
// trigamma(1) = π²/6 ≈ 1.6449 (Basel problem!)
// Appears as variance term in Dirichlet/Gamma variational inference
// Variational inference: KL divergence for gamma distributions
// KL(q_gamma || p_gamma) involves digamma differences and trigamma variance
const alpha_q = torch.tensor([2.0, 3.0, 5.0]); // Variational parameters
const beta_q = torch.tensor([1.0, 1.0, 1.0]); // Rate parameters
const digamma_term = torch.special.polygamma(0, alpha_q); // E[log X] in variational
const trigamma_term = torch.special.polygamma(1, alpha_q); // Variance correction
// KL divergence uses these to compute E_q[log q(x)] - E_q[log p(x)]
// Dirichlet variational autoencoder: softmax approximation
const alpha = torch.tensor([1.0, 2.0, 3.0, 4.0]); // Dirichlet concentration parameters
const psi_alpha = torch.special.polygamma(0, alpha); // Digamma ψ(α_k)
const psi_sum = torch.special.polygamma(0, alpha.sum()); // Digamma ψ(Σ α_k)
const mean_logp = psi_alpha.sub(psi_sum); // E_Dirichlet[log p_k] (variational bound)
// Entropy and KL divergence computations use these expectations
// Higher order: tetragamma (n=2)
const x_point = torch.tensor([2.0]);
const polygamma_order2 = torch.special.polygamma(2, x_point); // ψ''(x)
// Fourth cumulant-related term; less common but appears in high-order approximations
// Optimization dynamics: natural gradient in exponential family
const param = torch.tensor([0.5, 1.0, 2.0, 5.0]);
const fisher_info = torch.special.polygamma(1, param); // Trigamma = Fisher information for gamma
// Natural gradient descent step uses fisher_info^{-1} * gradient
// Batch computation: different orders
const x_batch = torch.tensor([1.0, 2.0, 3.0]);
const digamma_batch = torch.special.polygamma(0, x_batch); // [ψ(1), ψ(2), ψ(3)]
const trigamma_batch = torch.special.polygamma(1, x_batch); // [ψ'(1), ψ'(2), ψ'(3)]
// Forms basis for gradient-based inference algorithms
// Domain and asymptotics
const x_small = torch.tensor([0.1]); // Near 0: large negative (digamma)
const x_large = torch.tensor([100.0]); // Large x: rapidly decays
const psi_small = torch.special.polygamma(0, x_small); // Large negative ~ -1/x + γ
const psi_large = torch.special.polygamma(0, x_large); // ~ log(x)
// Asymptotic behavior: ψ(x) ~ log(x) - 1/(2x) as x → ∞See Also
- PyTorch torch.special.polygamma()
- torch.special.digamma - Digamma function (special case n=0, convenience alias)
- torch.lgamma - Log-gamma function log Γ(x) (zeroth antiderivative of digamma)
- torch.special.multigammaln - Multivariate log-gamma (uses polygamma for derivatives)