torch.distributions.constraints.multinomial
function multinomial(total_count: number): _MultinomialCreates a constraint for multinomial distributions.
The multinomial constraint ensures that sampled values are valid counts for a multinomial distribution. Specifically, sampled values must be:
- Non-negative integers - Each count must be ≥ 0 and have no fractional part
- Sum to total_count - All counts across categories must sum exactly to the specified total
This constraint is essential for categorical sampling where you need to draw a fixed number of samples distributed across multiple categories. Common use cases include:
- Text generation - Drawing multiple tokens from vocabulary at each step
- Reinforcement learning - Sampling multiple actions from a discrete action space
- Discrete mixture models - Ensuring valid mixture component counts
- Multi-label classification - Constraining predicted class counts
How multinomial sampling works: Given probabilities for K categories and total_count draws, the multinomial distribution generates K counts that sum exactly to total_count, where each count represents how many times that category was sampled.
Relationship to categorical distribution:
- Categorical: Sample one category per draw
- Multinomial: Sample total_count categories (with replacement), count outcomes
- Discrete constraint: The multinomial constraint only accepts integer values. Floating-point counts are rejected by the check.
- Event dimension: The constraint operates over the last dimension (event_dim=1), so it works naturally with batched samples.
- Sum constraint: The most important check is that counts sum exactly to total_count. No approximation or tolerance is used.
- Non-negative: All counts must be ≥ 0. Negative counts are never valid.
- PyTorch compatibility: Matches torch.distributions.constraints.multinomial() behavior.
- Common use case: Used internally by Multinomial distribution to validate samples.
- Strict sum requirement: Even off-by-one errors fail the check (sum must be exactly total_count).
- Integer only: Fractional counts like [2.5, 3.5, ...] are invalid.
- Large total_count: With very large total_count, ensure sufficient precision in computations.
- Zero counts allowed: Having zero counts for some categories is valid and common.
Parameters
total_countnumber- The number of draws for the multinomial distribution. Must be a positive integer. All sampled counts must sum to exactly this value.
Returns
_Multinomial– A constraint object that validates multinomial distribution samplesExamples
// Constraint for rolling a 6-sided die 20 times
const constraint = torch.distributions.constraints.multinomial(20);
// Valid sample: [2, 3, 4, 5, 2, 4] - sums to 20
const valid = torch.tensor([2, 3, 4, 5, 2, 4]);
constraint.check(valid); // All elements true
// Invalid sample: [2, 3, 4, 5, 2, 3] - sums to 19
const invalid = torch.tensor([2, 3, 4, 5, 2, 3]);
constraint.check(invalid); // Some elements false// Sampling from multinomial distribution
const probs = torch.tensor([0.1, 0.3, 0.4, 0.2]); // 4 categories
const constraint = torch.distributions.constraints.multinomial(100);
// Create multinomial distribution
const dist = torch.distributions.Multinomial(probs, total_count=100);
const sample = dist.sample(); // Shape: [4], sums to 100
constraint.check(sample); // Always true// Batch sampling with multinomial constraint
const batch_probs = torch.randn(32, 10).softmax(-1); // 32 samples, 10 categories
const constraint = torch.distributions.constraints.multinomial(50);
for (let i = 0; i < batch_probs.shape[0]; i++) {
const dist = torch.distributions.Multinomial(batch_probs[i], total_count=50);
const sample = dist.sample(); // [10], sums to 50
const is_valid = constraint.check(sample); // Always true
}// Text generation with vocabulary sampling
const vocab_size = 50000;
const probs = torch.ones(vocab_size).div(vocab_size); // Uniform distribution
const constraint = torch.distributions.constraints.multinomial(vocab_size);
// Sample 1000 tokens, counts how many of each vocab item
const token_counts = torch.distributions.Multinomial(probs, total_count=1000).sample();
constraint.check(token_counts); // Validates the sample// Reinforcement learning: sampling actions from action space
const num_actions = 6;
const action_logits = torch.randn(num_actions);
const action_probs = action_logits.softmax(-1);
const constraint = torch.distributions.constraints.multinomial(30); // 30 episodes
const dist = torch.distributions.Multinomial(action_probs, total_count=30);
// Sample which actions were taken across 30 episodes
const action_histogram = dist.sample();
constraint.check(action_histogram); // Ensures valid action counts// Validating constraint properties
const constraint = torch.distributions.constraints.multinomial(100);
// Check constraint properties
console.log(constraint.is_discrete); // true - counts are discrete integers
console.log(constraint.event_dim); // 1 - operates along one dimension
// Test edge cases
const all_zeros = torch.zeros(5);
constraint.check(all_zeros); // False (sum is 0, not 100)
const valid = torch.tensor([20, 20, 20, 20, 20]);
constraint.check(valid); // True (sum is 100)See Also
- PyTorch torch.distributions.constraints.multinomial()
- torch.distributions.Multinomial - The multinomial distribution using this constraint
- torch.distributions.constraints.independent - Wrapping constraints with extra dimensions
- torch.distributions.constraints.nonnegative_integer - Base integer constraint
- cat - Constraint for concatenated distributions
- stack - Constraint for stacked distributions