Skip to main content
torch.js has not been released yet.
torch.js logotorch.js logotorch.js
PlaygroundContact
Login
Documentation
IntroductionType SafetyTensor ExpressionsTensor IndexingEinsumEinopsAutogradTraining a ModelProfiling & MemoryPyTorch MigrationBest PracticesRuntimesPerformancePyTorch CompatibilityBenchmarksDType Coverage
ArgConstraintsBernoulliBetaBinomialbroadcast_allCategoricalCauchycdfcheckChi2ClampOptionsclampTensorConstraintDirichletDistributionDistributionOptionsentropyentropyenumerate_supportEnumerateSupportOptionsexpandExpandOptionsExponentialExponentialFamilyextendedShapeFisherSnedecorGammaGeometricgetDeviceFromTensorsgetDTypeFromTensorsgreater_thangreater_than_eqGumbelhalf_open_intervalHalfCauchyHalfNormalicdfindependentIndependentinteger_intervalintervalInverseGammakl_divergenceKumaraswamyLaplacelazy_propertyless_thanlog_probLogisticNormallogits_to_probsLogitsToProbsOptionsLogNormalLowRankMultivariateNormalMixtureSameFamilyMultinomialMultivariateNormalNegativeBinomialNormalOneHotCategoricalParetoperplexityPoissonprobs_to_logitsProbsToLogitsOptionsregister_klRelaxedBernoulliRelaxedOneHotCategoricalrsamplesamplesample_nSampleOptionsset_default_validate_argsstackStackOptionsStudentTStudentTOptionssumRightmosttoStringtoStringtoTensorTransformedDistributionUniformVonMisesWeibullWishart
absacosacoshAdaptivePool1dShapeAdaptivePool2dShapeaddaddbmmAddbmmOptionsaddcdivAddcdivOptionsaddcmulAddcmulOptionsaddmmAddmmOptionsaddmvAddmvOptionsaddrAddrOptionsadjointallallcloseAllcloseOptionsAlphaBetaOptionsamaxaminaminmaxAminmaxOptionsangleanyapplyOutarangeare_deterministic_algorithms_enabledargmaxargminargsortargwhereas_stridedas_tensorasinasinhAssertNoShapeErrorAssertNotErrorAsStridedOptionsAtat_error_index_out_of_boundsatanatan2atanhatleast_1datleast_2datleast_3dAtShapeautocast_decrement_nestingautocast_increment_nestingautograd_gradient_mismatch_errorautograd_not_registered_errorAutogradConfigAutogradDeviceAutogradDTypeAutogradEntryAutogradHandleAutogradHandleImplAxesRecordBackwardFnbaddbmmBaddbmmOptionsbartlett_windowBaseKernelConfigbatch_dimensions_do_not_match_errorbernoulliBernoulliOptionsBinaryBackwardFnBinaryBroadcastResultBinaryDTypeBinaryKernelConfigCPUBinaryKernelCPUBinaryOpConfigBinaryOpNamesBinaryOpSchemaBinaryOptionsbincountBincountOptionsbitwise_andbitwise_left_shiftbitwise_notbitwise_orbitwise_right_shiftbitwise_xorblackman_windowblock_diagbmmBooleanDTypeRulebroadcast_error_incompatible_dimensionsbroadcast_shapesbroadcast_tensorsbroadcast_toBroadcastShapeBroadcastShapeRulebroadcastShapesbucketizeBucketizeOptionsBufferUsagebuildEinopsErrorbuildErrorMessagecanBroadcastTocartesian_prodcatCatOptionsCatShapeCauchyOptionscdistCdistOptionsceilceluCeluFunctionalOptionschain_matmulCheckShapeErrorCholeskyShapechunkchunk_error_dim_out_of_rangeChunkOptionsclampClampOptionsclear_autocast_cacheclearEinopsCacheclearEinsumCacheclonecolumn_stackcombinationsCombinationsOptionscompiled_with_cxx11_abicomplexconjconj_physicalcontiguousConv1dShapeConv2dShapeConv3dShapeConvTranspose2dShapecopysigncorrcoefcoscoshcount_nonzeroCountNonzeroOptionscovcoverage_reportcoverageReportCoverageReportCovOptionsCPUForwardFnCPUKernelConfigCPUKernelEntryCPUOnlyResultCPUTensorDatacreateCumExtremeResultcreateTorchCreationOpSchemaCumExtremeResultcummaxcummincumprodCumShapecumsumcumulative_trapezoidCumulativeOptionsCumulativeOptionsWithDimdeg2raddetachDeterministicOptionsDetShapeDevicedevice_error_requiresDeviceBufferDeviceCapabilitiesDeviceCheckedResultDeviceConfigDeviceContextDeviceEntryDeviceHandleDeviceInputDeviceOptionsDeviceRegistryDeviceTypediagdiag_embedDiagEmbedOptionsdiagflatDiagflatOptionsDiagFlatOptionsdiagonal_scatterDiagonalOptionsDiagonalScatterOptionsDiagOptionsDiagShapediffDiffOptionsdigammadimension_error_out_of_rangeDispatchConfigdistDistOptionsdivdotDotShapeRuleDoubleDoubleDimdropoutDropoutFunctionalOptionsdsplitdstackDTypedtype_already_registered_errordtype_components_mismatch_errordtype_not_found_errorDTypeComponentsDTypeConfigDTypeCoverageReportDTypeDisplayConfigDTypeEntryDTypeHandleDTypeHandleImplDTypeInfoDTypeRegistryDTypeRuleDTypeSerializationConfigDynamicShapeEigShapeeinops_error_ambiguous_decompositioneinops_error_anonymous_in_outputeinops_error_dimension_mismatcheinops_error_invalid_patterneinops_error_reduce_undefined_outputeinops_error_repeat_missing_sizeeinops_error_undefined_axiseinsumeinsum_error_dimension_mismatcheinsum_error_index_out_of_rangeeinsum_error_invalid_equationeinsum_error_invalid_sublist_elementeinsum_error_operand_count_mismatcheinsum_error_subscript_rank_mismatcheinsum_error_unknown_output_indexEinsumOptionsEinsumOutputShapeEllipsiseluelu_EluFunctionalOptionsembedding_bag_error_requires_2d_inputemptyempty_cacheempty_likeeqequalerferfcerfinvexpexp2expandexpand_asexpand_error_incompatibleExpandShapeexpm1ExponentialOptionseyeEyeOptionsfftFFTOptionsfindKernelWithPredicatefindSimilarPatternsflattenFlattenOptionsFlattenShapeflipflip_error_dim_out_of_rangefliplrFlipShapeflipudfloat_powerFloatDTypeRulefloorfloor_dividefmaxfminfmodformatEquationErrorformatShapefracfrexpfrombufferfullfull_likefunction_already_registered_errorFunctionConfigFunctionEntryFunctionHandlegathergather_error_dim_out_of_rangeGatherShapegcdgegeluGeometricOptionsget_autocast_cpu_dtypeget_autocast_gpu_dtypeget_autocast_ipu_dtypeget_autocast_xla_dtypeget_default_deviceget_default_dtypeget_deterministic_debug_modeget_device_configget_device_contextget_device_moduleget_dtype_infoget_file_pathget_float32_matmul_precisionget_num_interop_threadsget_num_threadsget_op_infoget_printoptionsget_real_dtypeget_rng_stategetAutogradgetDTypegetEinopsCacheSizegetEinsumCacheSizegetFunctiongetKernelgetMethodgetOpInfoGetOpKindGetOpSchemagetScalarKernelgluGluFunctionalOptionsGradContextGradFnGradientsForgtHalfHalfDimhamming_windowhann_windowhardshrinkhardsigmoidhardswishhardtanhhardtanh_HardtanhFunctionalOptionshas_autogradhas_devicehas_dtypehas_kernelhasAutogradhasDTypehasFunctionhasKernelhasMethodhasScalarKernelHasShapeErrorheavisidehistcHistcOptionshistogramHistogramOptionsHistogramResulthsplithstackhypoti0IdentityShapeifftimagindex_addindex_copyindex_fillindex_putindex_reduceindex_selectindex_select_error_dim_out_of_rangeIndexPutOptionsIndexSelectShapeIndexSpecIndicesOptionsIndicesSpecinitialize_deviceInputsForInsertDiminvalid_config_errorinverseInverseShapeirfftis_anomaly_check_nan_enabledis_anomaly_enabledis_autocast_cache_enabledis_autocast_cpu_enabledis_autocast_ipu_enabledis_autocast_xla_enabledis_complexis_complex_dtypeis_cpu_only_modeis_deterministic_algorithms_warn_only_enabledis_floating_pointis_floating_point_dtypeis_inference_mode_enabledis_nonzerois_tensoris_warn_always_enabledis_webgpu_availableIs2DIsAtLeast1DIsBinaryOpIsBinaryOpNameiscloseIscloseOptionsisfiniteisinisinfisnanisneginfisposinfisrealIsReductionOpIsReductionOpNameIsRegistryErrorIsShapeErroristftISTFTOptionsIsUnaryOpIsUnaryOpNameitem_error_not_scalarItemResultkaiser_windowKaiserWindowOptionskernel_not_registered_errorkernel_signature_mismatch_errorKernelConfigKernelConfigWebGPUKernelEntryKernelHandleKernelInfoKernelPredicateKernelRegistryKernelWebGPUkronkthvalueKthvalueOptionslcmldexpleleaky_reluleaky_relu_LeakyReluFunctionalOptionslerplevenshteinDistancelgammalinalg_error_not_square_matrixlinalg_error_requires_2dlinalg_error_requires_at_least_2dlinearlinspacelist_custom_deviceslist_custom_dtypeslist_deviceslist_dtypeslist_functionslist_kernelslist_methodslist_opslistCustomDTypeslistDTypeslistFunctionslistKernelsListKernelsOptionslistMethodslistOpsListOpsOptionsloglog_softmaxlog10log1plog2logaddexplogaddexp2logcumsumexplogical_andlogical_notlogical_orlogical_xorLogitOptionsLogNormalOptionsLogOptionslogsigmoidlogspacelogsumexpLogsumexpOptionsltLUShapeLuSolveOptionsmasked_fillmasked_selectmasked_select_asyncMaskSpecmatmulmatmul_error_inner_dimensions_do_not_matchMatmul2DShapeMatmulShapeMatmulShapeRuleMatrixTransposeShapemaxmaximummeanmedianmemory_statsmemory_summarymeshgridmethod_already_registered_errormethod_dtype_not_supported_errorMethodConfigMethodEntryMethodHandleminminimummishmmMMShapeRulemodemovedimmsortmulmultinomialmultinomial_asyncMultinomialAsyncOptionsMultinomialOptionsMultiplyBymvMVShapeRulenan_to_numnanmeannanmediannanquantileNanReductionOptionsnansumNanToNumOptionsnarrownarrow_copynarrow_error_length_exceeds_boundsnarrow_error_start_out_of_boundsNarrowShapeneneedsBroadcastnegNegativeDimnextafternonzeroNonzeroOptionsnormnormalNormalOptionsNormOptionsnumelonesones_likeop_kind_mismatch_errorop_not_found_errorOpCoverageEntryOpInfoOpKindOpNameOpSchemaOpSchemasouterOuterShapepackPackShapepermutepermute_error_dimension_count_mismatchPermuteShapepoissonpolarPool1dShapePool2dShapePool3dShapepositivepowpreluPrintOptionsprodprofiler_allow_cudagraph_cupti_lazy_reinit_cuda12promote_typesPromoteDTypeRulePutOptionsquantileQuantileOptionsrad2degrandrand_likerandintrandint_likeRandintLikeOptionsRandintOptionsrandnrandn_likeRandomLikeOptionsRandomOptionsrandpermRangeSpecRankravelrealrearrangeRearrangeOptionsRearrangeShapereciprocalreduceReduceOperationReduceOptionsReduceShapeReductionKernelConfigCPUReductionKernelCPUReductionOpNamesReductionOpSchemaReductionOptionsReductionShapeRuleregister_backwardregister_deviceregister_dtyperegister_forwardregister_functionregister_methodregister_scalar_forwardregisterAutogradRegisterBackwardOptionsregisterBinaryOpregisterDTypeRegisterDTypeOptionsRegisteredDTyperegisterFunctionRegisterFunctionOptionsregisterKernelRegisterKernelOptionsregisterMethodRegisterMethodOptionsregisterScalarKernelregisterUnaryOpregistration_failed_errorrelurelu_relu6ReluFunctionalOptionsremainderRemoveDimrepeatrepeat_interleaveRepeatInterleaveOptionsRepeatOptionsRepeatShapeReplaceDimrequireWebGPUreset_peak_memory_statsreshapeReshapeShaperesult_typerfftrollRollOptionsrot90Rot90Optionsroundrrelurrelu_RreluFunctionalOptionsrsqrtSafeExpandShapeSameDTypeRuleSameShapeRuleSaveForBackwardScalarCPUForwardFnScalarCPUKernelConfigScalarKernelEntryScalarKernelHandleScalarWebGPUKernelConfigScaleDimscatterscatter_addscatter_add_scatter_error_dim_out_of_rangescatter_reducescatter_reduce_ScatterReduceOptionsScatterShapesearchsortedSearchSortedOptionsselectselect_error_index_out_of_boundsselect_scatterSelectShapeseluset_default_deviceset_default_tensor_typeset_deterministic_debug_modeset_float32_matmul_precisionset_printoptionsset_warn_alwaysSetupContextFnShapeShapeCheckedResultShapedTensorShapeErrorMessageShapeOpSchemaShapeRulesigmoidsignsignbitsilusinsincsinhSizeOptionsslice_error_out_of_boundsslice_scatterSliceOptionsSliceScatterOptionsSliceShapeSliceSpecsoftmaxsoftmax_error_dim_out_of_rangeSoftmaxShapesoftminSoftminFunctionalOptionssoftplusSoftplusFunctionalOptionssoftshrinksoftsignsortSortOptionssplitsplit_error_dim_out_of_rangeSplitOptionssqrtsquaresqueezeSqueezeOptionsSqueezeShapestackStackOptionsStackShapestdstd_meanStdVarMeanOptionsStdVarOptionsstftSTFTOptionsStrideOptionssubSublistSublistElementSubscriptIndexsumSVDShapeswapaxessym_floatsym_intsym_notttaketake_along_dimTakeAlongDimOptionstantanhtanhshrinktensortensor_splitTensorCreatorTensorDatatensordotTensordotOptionsTensorLikeTensorMetaTensorOptionsTensorStoragethresholdthreshold_tileTileShapeToOptionstopkTopkOptionsTorchtraceTraceShapetransposetranspose_dims_error_out_of_rangetranspose_error_requires_2d_tensorTransposeDimsShapeTransposeDimsShapeCheckedTransposeShapetrapezoidTrapezoidOptionsTriangularOptionstriltril_indicesTriOptionsTripletriutriu_indicestrue_dividetruncTupleOfLengthTypedArrayTypedArrayForTypedStorageTypeOptionsUnaryBackwardFnUnaryDTypeUnaryKernelConfigCPUUnaryKernelCPUUnaryOpConfigUnaryOpFnUnaryOpNamesUnaryOpParamsUnaryOpSchemaUnaryOptionsunbindunbind_error_dim_out_of_rangeUnbindOptionsunflattenUniformOptionsuniqueunique_consecutiveUniqueConsecutiveOptionsUniqueOptionsunpackUnpackShapeunravel_indexunregister_deviceunsqueezeUnsqueezeOptionsUnsqueezeShapeuse_deterministic_algorithmsValidateBatchedSquareMatrixValidateChunkDimValidatedEinsumShapevalidateDeviceValidateDeviceValidatedRearrangeShapeValidatedReduceShapeValidatedRepeatShapevalidateDTypeValidateEinsumValidateOperandCountValidateRanksValidateScalarValidateSplitDimValidateSquareMatrixValidateUnbindDimValueOptionsvar_var_meanvdotviewview_as_complexview_as_realvmapvsplitvstackWebGPUKernelConfigWebGPUOnlyResultWebGPUTensorDatawhereWindowOptionsxlogyzeroszeros_like
torch.js· 2026
LegalTerms of UsePrivacy Policy
/
/
  1. docs
  2. torch.js
  3. torch
  4. distributions
  5. RelaxedOneHotCategorical

torch.distributions.RelaxedOneHotCategorical

class RelaxedOneHotCategorical extends Distribution
new RelaxedOneHotCategorical(temperature: number | Tensor, options: { probs?: number[] | Tensor; logits?: number[] | Tensor } & DistributionOptions)
readonlytemperature(Tensor)
– Temperature parameter controlling the relaxation. Lower temperature = more one-hot-like.
readonlyarg_constraints(unknown)
readonlysupport(unknown)
readonlyhas_rsample(unknown)
readonlyprobs(Tensor)
– Get probability of each category.
readonlylogits(Tensor)
– Get log-probabilities.
readonlymean(Tensor)
readonlymode(Tensor)
readonlyvariance(Tensor)

RelaxedOneHotCategorical: continuous relaxation of discrete categorical via Gumbel-Softmax trick.

Parameterized by temperature T and probabilities/logits over K categories. A clever "trick" for making discrete sampling differentiable. Instead of sampling a hard one-hot vector, sample a soft probability vector (on the simplex) using the Gumbel-max trick with temperature control. As T → 0, samples approach one-hot vectors; as T → ∞, samples approach uniform. Crucial for differentiable discrete optimization. Essential for:

  • Reparameterized gradient estimation through discrete categorical choices
  • Variational inference with discrete latent variables (VIMCO, REBAR, etc.)
  • Differentiable discrete sequence learning (learned discrete selections)
  • Temperature-controlled annealing from continuous to discrete
  • Gumbel-Softmax trick for discrete generative models
  • Structured output learning (discrete decisions in neural networks)
  • Reinforcement learning policy sampling (continuous approximation)
  • Discrete representation learning with soft selection

The Gumbel-Softmax Trick: Instead of sampling discrete one-hot and losing gradients, sample from Gumbel distribution, add log-probabilities, divide by temperature T, apply softmax. Result: continuous vector on simplex that approximates one-hot, with full gradient flow. Temperature controls softness: low T = closer to one-hot, high T = closer to uniform.

Gradient Trick: Hard discrete sampling is non-differentiable. Gumbel-Softmax makes it differentiable by using softmax approximation that becomes sharper as T → 0 during annealing.

Sampling: Xk=exp⁡(log⁡(pk)+GkT)∑jexp⁡(log⁡(pj)+GjT)where Gk∼Gumbel(0,1),T is temperature, K = number of categoriesSupport: K-simplex (probability vectors)={x:xk≥0,∑kxk=1}Limit as T→0+:RelaxedOneHotCategorical(T,p)→OneHotCategorical(p)Limit as T→∞:RelaxedOneHotCategorical(T,p)→Uniform on simplex\begin{aligned} \text{Sampling: } X_k = \frac{\exp\left(\frac{\log(p_k) + G_k}{T}\right)}{\sum_j \exp\left(\frac{\log(p_j) + G_j}{T}\right)} \\ \text{where } G_k \sim \text{Gumbel}(0, 1), T \text{ is temperature, } K \text{ = number of categories} \\ \text{Support: } K\text{-simplex (probability vectors)} = \{x : x_k \geq 0, \sum_k x_k = 1\} \\ \text{Limit as } T \to 0^+: \quad \text{RelaxedOneHotCategorical}(T, p) \to \text{OneHotCategorical}(p) \\ \text{Limit as } T \to \infty: \quad \text{RelaxedOneHotCategorical}(T, p) \to \text{Uniform on simplex} \end{aligned}Sampling: Xk​=∑j​exp(Tlog(pj​)+Gj​​)exp(Tlog(pk​)+Gk​​)​where Gk​∼Gumbel(0,1),T is temperature, K = number of categoriesSupport: K-simplex (probability vectors)={x:xk​≥0,k∑​xk​=1}Limit as T→0+:RelaxedOneHotCategorical(T,p)→OneHotCategorical(p)Limit as T→∞:RelaxedOneHotCategorical(T,p)→Uniform on simplex​
  • Reparameterization trick: rsample() has gradients (uses Gumbel), sample() may not
  • Temperature controls discreteness: Lower T = more one-hot-like, higher T = smoother
  • Gumbel-Softmax trick: Adds Gumbel(0,1) noise before softmax to create one-hot approximation
  • Simplex support: All samples are valid probability distributions (sum to 1, all ≥ 0)
  • Continuous approximation: Provides gradient-friendly approximation to discrete sampling
  • Annealing strategy: Often cool temperature during training for better discreteness
  • Gradient estimator: One of several methods for discrete variational inference (REBAR, RELAX alternatives)
  • Temperature must be positive: T ≤ 0 causes errors or numerical issues
  • Very small T: T 0.01 can cause numerical instability (log_prob underflow)
  • Not fully discrete: Samples never truly one-hot (T 0 always gives soft vectors)
  • Approximation error: Low T approximates OneHotCategorical, but never identical

Examples

// Simple Gumbel-Softmax: temperature=0.5, uniform categories
const roc = new torch.distributions.RelaxedOneHotCategorical(0.5, {
  probs: torch.tensor([0.25, 0.25, 0.25, 0.25])
});
const sample = roc.sample();  // [~0.35, ~0.20, ~0.25, ~0.20] (soft one-hot)

// Temperature annealing: gradually cool temperature during training
// Start hot (soft) for good gradient flow, gradually cool toward discrete
for (let epoch = 0; epoch < 100; epoch++) {
  const T = Math.max(0.1, 1.0 * Math.exp(-0.01 * epoch));  // exponential annealing
  const dist = new torch.distributions.RelaxedOneHotCategorical(T, { logits: logits });
  const z = dist.rsample();  // reparameterized sample (has gradients!)
  const log_prob = dist.log_prob(z);  // for loss computation
  // Optimize with gradients flowing through z
}

// Variational autoencoder with discrete latent: 5 discrete choices
const latent_size = 5;
const temperature = 0.67;  // moderate relaxation
const logits = encoder(x);  // [batch, 5]
const latent_dist = new torch.distributions.RelaxedOneHotCategorical(temperature, { logits });
const z = latent_dist.rsample();  // [batch, 5] soft one-hot (differentiable!)
const recon = decoder(z);
const log_prob = latent_dist.log_prob(z);  // for KL divergence term
n *
// Low temperature: closer to discrete
const low_temp = new torch.distributions.RelaxedOneHotCategorical(0.01, {
  probs: torch.tensor([0.7, 0.2, 0.1])
});
const hard_sample = low_temp.sample();  // [~0.98, ~0.02, ~0.00]
n *
// High temperature: softer (more uniform-like)
const high_temp = new torch.distributions.RelaxedOneHotCategorical(10, {
  probs: torch.tensor([0.7, 0.2, 0.1])
});
const soft_sample = high_temp.sample();  // [~0.40, ~0.35, ~0.25]
n *
// Batched sampling with different temperatures
const temps = torch.tensor([0.1, 0.5, 1.0, 5.0]);
const logits = torch.randn([4, 3]);  // [4 batch, 3 categories]
const dist = new torch.distributions.RelaxedOneHotCategorical(temps, { logits });
const samples = dist.rsample();  // [4, 3] shaped soft one-hot vectors
// First row hardest (T=0.1), last row softest (T=5.0)
Previous
RelaxedBernoulli
Next
SampleOptions