torch.distributions.Distribution
class Distributionnew Distribution(batch_shape: readonly number[], event_shape: readonly number[], options: DistributionOptions = {})
- readonly
batch_shape(readonly number[]) - – The batch shape over which parameters are batched. For example, Normal([0, 1], [1, 2]) has batch_shape=[2].
- readonly
event_shape(readonly number[]) - – The shape of a single sample (without batching). For univariate distributions, this is []. For MultivariateNormal of dim d, this is [d].
- readonly
validate_args(boolean) - – Whether to validate arguments.
- readonly
arg_constraints(ArgConstraints) - – Dictionary of argument name to constraint. Subclasses should override this.
- readonly
support(Constraint) - – Constraint on the support of this distribution. Subclasses should override this.
- readonly
has_rsample(boolean) - – Whether this distribution supports reparameterized sampling. If true, rsample() can be used for differentiable sampling.
- readonly
has_enumerate_support(boolean) - – Whether this distribution can enumerate its support. If true, enumerate_support() can be used.
- readonly
mean(Tensor) - – Mean of the distribution.
- readonly
mode(Tensor) - – Mode of the distribution.
- readonly
variance(Tensor) - – Variance of the distribution.
- readonly
stddev(Tensor) - – Standard deviation of the distribution.
Abstract base class for probability distributions.
All distributions support:
sample()for generating random sampleslog_prob()for computing log probabilitiesmean,variance,modepropertiesentropy()for computing distribution entropycdf(),icdf()for cumulative distribution functions
Reparameterized distributions also support:
rsample()for differentiable sampling (gradients flow through)
Examples
// Create a normal distribution
const normal = new torch.distributions.Normal(0, 1);
// Sample from it
const samples = normal.sample({ sample_shape: [1000] });
// Compute log probability
const x = torch.tensor([0.5, 1.0, -0.5]);
const log_probs = normal.log_prob(x);
// Access properties
normal.mean; // 0
normal.variance; // 1