torch.distributions.Categorical
class Categorical extends Distributionnew Categorical(options: { probs?: number[] | Tensor; logits?: number[] | Tensor } & DistributionOptions)
- readonly
arg_constraints(unknown) - readonly
support(Constraint) - readonly
has_enumerate_support(unknown) - readonly
probs(Tensor) - – Get probs, computing from logits if needed.
- readonly
logits(Tensor) - – Get logits, computing from probs if needed.
- readonly
param_shape(readonly number[]) - – Shape of the parameter tensor.
- readonly
mean(Tensor) - readonly
mode(Tensor) - readonly
variance(Tensor)
Categorical distribution: discrete probability over K categories.
Parameterized by probabilities or logits over K discrete categories. The fundamental distribution for modeling discrete choices. Essential for:
- Classification in neural networks (output layer, cross-entropy loss)
- Natural language processing (token prediction, language models)
- Reinforcement learning (action selection, policy)
- Mixture models and topic modeling
- Multinomial sampling and generation tasks
- Gumbel-max trick for differentiable sampling
The probability mass function: P(X=k) = p_k for k in {0, ..., K-1}
- Probs vs Logits: Use logits for numerical stability (no normalization needed)
- Automatic normalization: Probs don't need to sum exactly to 1 (will be normalized)
- Discrete sampling: Samples are always integers 0 to K-1
- No rsample: Regular sample() is reparameterizable via Gumbel-max trick
- One-hot encoding: Can use one-hot_categorical for one-hot samples instead
- Batch support: Can handle batches of distributions (one per row of logits/probs)
- Entropy decreasing: More concentrated distribution → lower entropy
- Mutual exclusive: Specify exactly one of probs or logits, not both
- K must match: All samples and log_prob values must be valid indices 0 to K-1
- Shape broadcasting: Batch dimensions of probs/logits must be broadcastable
- Numerical stability: Very small/large logits can cause numerical issues
Examples
// Uniform categorical over 4 classes (equal probability)
const uniform = new torch.distributions.Categorical({
probs: torch.tensor([0.25, 0.25, 0.25, 0.25])
});
const sample = uniform.sample(); // 0, 1, 2, or 3 with equal probability
// Biased categorical distribution
const biased = new torch.distributions.Categorical({
probs: torch.tensor([0.7, 0.2, 0.05, 0.05])
});
const samples = biased.sample([1000]); // mostly 0s, some 1s, few 2s/3s
// Using logits instead of probs (more numerically stable)
const logits = torch.tensor([2.0, 1.0, -1.0, -2.0]); // unnormalized
const dist = new torch.distributions.Categorical({ logits });
const prediction = dist.sample(); // class with highest logit most likely
// Multi-class classification: batched distributions
const batch_logits = torch.randn([batch_size, num_classes]); // from model
const dist = new torch.distributions.Categorical({ logits: batch_logits });
const pred_classes = dist.sample(); // [batch_size] predictions
const log_probs = dist.log_prob(targets); // compute NLL loss
// Policy network in RL: actor samples actions
const policy_logits = policy_network(state); // [num_actions]
const action_dist = new torch.distributions.Categorical({ logits: policy_logits });
const action = action_dist.sample(); // sample action from policy
const log_prob = action_dist.log_prob(action); // for policy gradient
// Temperature scaling: control exploration vs exploitation
const logits = torch.tensor([1.0, 2.0, 3.0]);
const temperature = 0.1; // low temp = more peaked (exploit)
const scaled_logits = logits.div(temperature);
const dist = new torch.distributions.Categorical({ logits: scaled_logits });
const sample = dist.sample(); // highly peaked distribution
// Entropy: measure of uncertainty
const uniform_cat = new torch.distributions.Categorical({
probs: torch.ones([10]).div(10) // uniform
});
const entropy_max = uniform_cat.entropy(); // log(10) ≈ 2.3
const certain_cat = new torch.distributions.Categorical({
probs: torch.tensor([0.99, 0.01])
});
const entropy_low = certain_cat.entropy(); // near 0