torch.distributions.OneHotCategorical
class OneHotCategorical extends Distributionnew OneHotCategorical(options: { probs?: number[] | Tensor; logits?: number[] | Tensor } & DistributionOptions)
- readonly
arg_constraints(unknown) - readonly
support(unknown) - readonly
has_enumerate_support(unknown) - readonly
probs(Tensor) - readonly
logits(Tensor) - readonly
param_shape(readonly number[]) - readonly
mean(Tensor) - readonly
mode(Tensor) - readonly
variance(Tensor)
OneHotCategorical distribution: categorical with one-hot encoded output vectors instead of indices.
Parameterized by probabilities or logits over K categories. Essentially identical to Categorical in behavior, but samples are returned as one-hot vectors (all zeros except for single 1) instead of integer indices. More convenient for neural networks and models expecting dense representations. Essential for:
- Classification neural networks with dense output layers (before softmax)
- Discrete choice modeling with one-hot representation
- Mixture model components (for linear combinations)
- Attention mechanisms and weighted aggregation
- Sparse tensor operations (one-hot are naturally sparse)
- Policy networks in reinforcement learning (action selection)
- Generative models with discrete latent variables
Relationship to Categorical: OneHotCategorical samples as [0,0,1,0] while Categorical samples as 2 (the index). Functionally equivalent; just different output representation. Use OneHotCategorical when you want dense one-hot vectors, Categorical when you want indices.
Output Format: Sample shape is [...batch_dims, K] where K = number of categories. Each sample has exactly one 1 and K-1 zeros. Argmax of sample gives category index.
- One-hot representation: Exactly one element is 1, rest are 0 (by construction)
- Sparse output: One-hot vectors are naturally sparse (only 1 non-zero element)
- Equivalent to Categorical: Same distribution, just different output format
- Entropy same as Categorical: Entropy only depends on probs, not representation
- Integer conversion: Argmax gives integer index; useful for indexing or comparisons
- Linear combinations: One-hot useful for weighted sums (e.g., mixture models)
- Support enumeration: All K one-hot vectors can be enumerated (finite support)
- Probs vs Logits: Exactly one of probs or logits must be specified, not both
- K dimension last: Probabilities/logits shape is [..., K]; one-hot output is [..., K]
- Output always one-hot: Every sample has sum=1 and exactly one element=1
- Not equivalent to Bernoulli: Despite being binary-ish, completely different from Bernoulli
Examples
// 4-category one-hot categorical (uniform)
const ohc = new torch.distributions.OneHotCategorical({
probs: torch.tensor([0.25, 0.25, 0.25, 0.25])
});
const sample = ohc.sample(); // [0, 0, 1, 0] (exactly one 1)
const samples = ohc.sample([1000]); // [1000, 4] shaped samples
n *
// Biased distribution: first category more likely
const biased = new torch.distributions.OneHotCategorical({
probs: torch.tensor([0.7, 0.2, 0.05, 0.05])
});
const sample = biased.sample(); // mostly [1, 0, 0, 0], sometimes others
// Neural network classification: using logits from model
const batch_logits = model(x); // [batch_size, num_classes]
const dist = new torch.distributions.OneHotCategorical({ logits: batch_logits });
const one_hot_samples = dist.sample(); // [batch_size, num_classes]
const class_indices = one_hot_samples.argmax(-1); // [batch_size] class indices
const log_probs = dist.log_prob(one_hot_samples); // [batch_size] log-likelihoods
n *
// Mixture model: one-hot selection of mixture component
const num_components = 3;
const mixture_probs = torch.tensor([0.4, 0.35, 0.25]);
const selector = new torch.distributions.OneHotCategorical({
probs: mixture_probs
});
const component_selector = selector.sample(); // one-hot [3] which component
// Use for: y = Σ component_selector[i] * component[i].sample()
n *
// Batched one-hot categorical with different probabilities
const batch_probs = torch.tensor([
[0.5, 0.3, 0.2], // batch 0: favor category 0
[0.2, 0.3, 0.5], // batch 1: favor category 2
]); // [2, 3] shape
const batch_dist = new torch.distributions.OneHotCategorical({ probs: batch_probs });
const batch_samples = batch_dist.sample(); // [2, 3] shaped one-hot samples
const batch_log_probs = batch_dist.log_prob(batch_samples); // [2] log-likelihoods
n *
// Entropy: measure uncertainty over categories
const certain = new torch.distributions.OneHotCategorical({
probs: torch.tensor([0.99, 0.005, 0.005])
});
const entropy_low = certain.entropy(); // near 0 (almost deterministic)\n * const fair = new torch.distributions.OneHotCategorical({
probs: torch.tensor([0.33, 0.33, 0.34])\n * });
const entropy_high = fair.entropy(); // log(3) ≈ 1.1 (maximum uncertainty)