torch.distributions.RelaxedOneHotCategorical
class RelaxedOneHotCategorical extends Distributionnew RelaxedOneHotCategorical(temperature: number | Tensor, options: { probs?: number[] | Tensor; logits?: number[] | Tensor } & DistributionOptions)
- readonly
temperature(Tensor) - – Temperature parameter controlling the relaxation. Lower temperature = more one-hot-like.
- readonly
arg_constraints(unknown) - readonly
support(unknown) - readonly
has_rsample(unknown) - readonly
probs(Tensor) - – Get probability of each category.
- readonly
logits(Tensor) - – Get log-probabilities.
- readonly
mean(Tensor) - readonly
mode(Tensor) - readonly
variance(Tensor)
RelaxedOneHotCategorical: continuous relaxation of discrete categorical via Gumbel-Softmax trick.
Parameterized by temperature T and probabilities/logits over K categories. A clever "trick" for making discrete sampling differentiable. Instead of sampling a hard one-hot vector, sample a soft probability vector (on the simplex) using the Gumbel-max trick with temperature control. As T → 0, samples approach one-hot vectors; as T → ∞, samples approach uniform. Crucial for differentiable discrete optimization. Essential for:
- Reparameterized gradient estimation through discrete categorical choices
- Variational inference with discrete latent variables (VIMCO, REBAR, etc.)
- Differentiable discrete sequence learning (learned discrete selections)
- Temperature-controlled annealing from continuous to discrete
- Gumbel-Softmax trick for discrete generative models
- Structured output learning (discrete decisions in neural networks)
- Reinforcement learning policy sampling (continuous approximation)
- Discrete representation learning with soft selection
The Gumbel-Softmax Trick: Instead of sampling discrete one-hot and losing gradients, sample from Gumbel distribution, add log-probabilities, divide by temperature T, apply softmax. Result: continuous vector on simplex that approximates one-hot, with full gradient flow. Temperature controls softness: low T = closer to one-hot, high T = closer to uniform.
Gradient Trick: Hard discrete sampling is non-differentiable. Gumbel-Softmax makes it differentiable by using softmax approximation that becomes sharper as T → 0 during annealing.
- Reparameterization trick: rsample() has gradients (uses Gumbel), sample() may not
- Temperature controls discreteness: Lower T = more one-hot-like, higher T = smoother
- Gumbel-Softmax trick: Adds Gumbel(0,1) noise before softmax to create one-hot approximation
- Simplex support: All samples are valid probability distributions (sum to 1, all ≥ 0)
- Continuous approximation: Provides gradient-friendly approximation to discrete sampling
- Annealing strategy: Often cool temperature during training for better discreteness
- Gradient estimator: One of several methods for discrete variational inference (REBAR, RELAX alternatives)
- Temperature must be positive: T ≤ 0 causes errors or numerical issues
- Very small T: T 0.01 can cause numerical instability (log_prob underflow)
- Not fully discrete: Samples never truly one-hot (T 0 always gives soft vectors)
- Approximation error: Low T approximates OneHotCategorical, but never identical
Examples
// Simple Gumbel-Softmax: temperature=0.5, uniform categories
const roc = new torch.distributions.RelaxedOneHotCategorical(0.5, {
probs: torch.tensor([0.25, 0.25, 0.25, 0.25])
});
const sample = roc.sample(); // [~0.35, ~0.20, ~0.25, ~0.20] (soft one-hot)
// Temperature annealing: gradually cool temperature during training
// Start hot (soft) for good gradient flow, gradually cool toward discrete
for (let epoch = 0; epoch < 100; epoch++) {
const T = Math.max(0.1, 1.0 * Math.exp(-0.01 * epoch)); // exponential annealing
const dist = new torch.distributions.RelaxedOneHotCategorical(T, { logits: logits });
const z = dist.rsample(); // reparameterized sample (has gradients!)
const log_prob = dist.log_prob(z); // for loss computation
// Optimize with gradients flowing through z
}
// Variational autoencoder with discrete latent: 5 discrete choices
const latent_size = 5;
const temperature = 0.67; // moderate relaxation
const logits = encoder(x); // [batch, 5]
const latent_dist = new torch.distributions.RelaxedOneHotCategorical(temperature, { logits });
const z = latent_dist.rsample(); // [batch, 5] soft one-hot (differentiable!)
const recon = decoder(z);
const log_prob = latent_dist.log_prob(z); // for KL divergence term
n *
// Low temperature: closer to discrete
const low_temp = new torch.distributions.RelaxedOneHotCategorical(0.01, {
probs: torch.tensor([0.7, 0.2, 0.1])
});
const hard_sample = low_temp.sample(); // [~0.98, ~0.02, ~0.00]
n *
// High temperature: softer (more uniform-like)
const high_temp = new torch.distributions.RelaxedOneHotCategorical(10, {
probs: torch.tensor([0.7, 0.2, 0.1])
});
const soft_sample = high_temp.sample(); // [~0.40, ~0.35, ~0.25]
n *
// Batched sampling with different temperatures
const temps = torch.tensor([0.1, 0.5, 1.0, 5.0]);
const logits = torch.randn([4, 3]); // [4 batch, 3 categories]
const dist = new torch.distributions.RelaxedOneHotCategorical(temps, { logits });
const samples = dist.rsample(); // [4, 3] shaped soft one-hot vectors
// First row hardest (T=0.1), last row softest (T=5.0)