torch.distributions.Dirichlet
class Dirichlet extends Distributionnew Dirichlet(concentration: number[] | Tensor, options?: DistributionOptions)
- readonly
concentration(Tensor) - – Concentration parameter.
- readonly
arg_constraints(unknown) - readonly
support(unknown) - readonly
has_rsample(unknown) - readonly
mean(Tensor) - readonly
mode(Tensor) - readonly
variance(Tensor)
Dirichlet distribution: distribution over probability simplexes.
Multivariate generalization of Beta distribution. Generates probability vectors (sum to 1, all non-negative). Essential for:
- Topic modeling (Latent Dirichlet Allocation)
- Bayesian multinomial inference (conjugate prior)
- Compositional data modeling (proportions, percentages)
- Mixture model parameters (mixing proportions)
- Word frequency distributions (document modeling)
- Allele frequencies (population genetics)
- Clustering and soft assignments
- Variational inference in probabilistic models
Parameterized by concentration vector α ∈ ℝ^K with α_i > 0 Support: K-simplex {x ∈ ℝ^K : x_i ≥ 0, Σ x_i = 1}
- Simplicial support: Samples always sum to 1 and are non-negative
- Concentration effect: Higher Σ α → more concentrated, lower → more spread
- Symmetric case: When α uniform, distribution is symmetric on simplex
- Conjugate prior: To multinomial and categorical data
- Marginals are Beta: Each component is marginally Beta distributed
- Scaling invariant: α and c*α have same distribution shape, different concentration
- Positive concentrations: All α_i must be strictly positive
- Simplex constraint: Samples guaranteed to sum to 1 and be non-negative
- Numerical stability: Very small α can cause numerical issues
- GPU support: Currently CPU-only implementation
Examples
// Uniform distribution on 3-simplex
const uniform = new torch.distributions.Dirichlet(torch.ones([3]));
uniform.sample(); // Random probability vector [p0, p1, p2], sums to 1
// Concentrated distribution: alpha >> 1
const concentrated = new torch.distributions.Dirichlet(torch.tensor([10, 10, 10]));
concentrated.sample(); // Near [1/3, 1/3, 1/3]
// Sparse distribution: alpha << 1
const sparse = new torch.distributions.Dirichlet(torch.tensor([0.1, 0.1, 0.1]));
sparse.sample(); // One component near 1, others near 0
// Asymmetric distribution
const asymmetric = new torch.distributions.Dirichlet(torch.tensor([5, 2, 1]));
asymmetric.sample(); // Biased toward first component
// Latent Dirichlet Allocation (LDA): topic modeling
const num_topics = 10;
const num_vocab = 5000;
// Document-topic distribution (prior)
const doc_topic_alpha = torch.full([num_topics], 0.1);
const theta = new torch.distributions.Dirichlet(doc_topic_alpha);
const doc_topics = theta.sample(); // Distribution over topics for this document
// Topic-word distribution (prior)
const topic_word_alpha = torch.full([num_vocab], 0.01);
const beta = new torch.distributions.Dirichlet(topic_word_alpha);
const topic_words = beta.sample([num_topics]); // Word distributions for each topic
// Bayesian multinomial inference
// Prior: Dirichlet(α)
// Data: observed counts from multinomial
// Posterior: Dirichlet(α + counts)
const alpha_prior = torch.tensor([1, 1, 1, 1]); // Uniform prior
const observed_counts = torch.tensor([10, 15, 8, 5]); // Observed data
const alpha_posterior = alpha_prior.add(observed_counts);
const posterior = new torch.distributions.Dirichlet(alpha_posterior);
// Mixture model: mixing proportions
const num_clusters = 5;
const concentration = 1; // Controls how concentrated vs spread out
const mixing_alpha = torch.full([num_clusters], concentration);
const mixing_dist = new torch.distributions.Dirichlet(mixing_alpha);
const mixing_weights = mixing_dist.sample(); // Cluster mixing proportions
// Batched distributions
const batch_alphas = torch.tensor([
[1, 1, 1],
[5, 5, 5],
[0.5, 0.5, 0.5]
]);
const batch_dist = new torch.distributions.Dirichlet(batch_alphas);
const samples = batch_dist.sample(); // Three samples from different Dirichlets