torch.nn.functional.one_hot
Convert class indices to one-hot encoded vectors (categorical representation).
Takes a tensor of class indices and produces a binary matrix where each row has a 1 in the column corresponding to the class index and 0s elsewhere. Essential preprocessing for:
- Converting discrete class labels to dense categorical representation
- Creating ground truth targets for neural networks (especially older loss functions expecting one-hot)
- Data augmentation and manipulation
- Building categorical feature representations for embeddings
- Compatibility with models expecting one-hot encoded inputs
One-Hot Encoding: A categorical representation where each class has exactly one 1 and all other elements are 0. E.g., for 4 classes, index 2 becomes [0, 0, 1, 0]. More memory-intensive than indices but often required by certain architectures or older PyTorch code. Most modern loss functions (cross_entropy) accept indices directly.
Alternative: Most modern code uses indices directly with cross_entropy() instead of one-hot encoding. One-hot is mainly for compatibility or when explicitly required by specific models.
- Output type is float32: Even if input is integer, output is float (one-hot vectors [0, 1]).
- Values are exact 0s and 1s: Not probabilities; exact categorical representation. Use softmax output (probabilities) if you need soft targets.
- Modern code prefers indices: Most new PyTorch code uses indices directly with cross_entropy(). One-hot is mainly for legacy code or specific requirements.
- Output sums to 1: Each one-hot vector (along last dimension) sums to exactly 1. Useful for checking correctness.
- Memory overhead: One-hot requires num_classes times more memory than indices. With many classes (50K+), can be significant. Keep indices when possible.
- Broadcasting works naturally: Any input shape works; last dimension always becomes num_classes. Useful for batched or multi-dimensional data.
- Sparse representation: For large num_classes with mostly zeros, consider sparse tensors instead.
- Index out of bounds: Indices must be in [0, num_classes-1]. Out-of-bounds indices cause errors. Always validate that indices are valid class labels.
- Negative indices unsupported: Negative indices not supported (unlike NumPy). Must be non-negative.
- num_classes must be positive: num_classes ≤ 0 causes errors. Must be positive integer.
- Large num_classes memory: With num_classes in millions, one-hot encoding becomes impractical. Each one-hot vector has num_classes elements, so [batch=1000, num_classes=1M] = 1B elements.
- Not differentiable: one_hot output has no gradients (discrete operation). Use embedding or learnable parameters if you need differentiable discrete representations.
- Dtype conversion: Input indices converted to int32 internally; ensure indices are integer-like.
Parameters
inputTensor- Tensor of non-negative integer indices (class labels). Each element in [0, num_classes-1]. Shape: any shape is allowed (indices are broadcasted). Typical: [batch_size] or [batch_size, seq_length].
num_classesnumber- Total number of classes. One-hot vectors will have length num_classes. Must be positive integer. Each index must be num_classes.
Returns
AnyTensor– hot encoded tensor with shape (*input.shape, num_classes). Output dtype: float32 (probabilities). Example: input shape [3] → output shape [3, num_classes].Examples
// Basic one-hot encoding: 3 samples, 4 classes
const labels = torch.tensor([0, 2, 1]); // Class indices for 3 samples
const one_hot_vecs = torch.nn.functional.one_hot(labels, 4);
// Output: [[1, 0, 0, 0], [0, 0, 1, 0], [0, 1, 0, 0]] shape [3, 4]
// Batch one-hot encoding for sequences
const seq_labels = torch.tensor([[1, 2, 0], [3, 1, 2]]); // [batch=2, seq=3] class indices
const one_hot_seqs = torch.nn.functional.one_hot(seq_labels, 4);
// Output shape: [2, 3, 4] - each position one-hot encoded
// Image segmentation: per-pixel class labels to one-hot
const seg_mask = torch.floor(torch.rand(1, 32, 32).mul(3)); // [1, 32, 32] with values in {0,1,2}
const one_hot_seg = torch.nn.functional.one_hot(seg_mask.to('int32'), 3);
// Output: [1, 32, 32, 3] - each pixel now 3-dimensional one-hot (one-hot for each class)
// Using with older models that expect one-hot targets
const logits = model(images); // [batch, num_classes] raw predictions
const class_labels = torch.tensor([0, 2, 1, 3]); // Ground truth class indices
const one_hot_targets = torch.nn.functional.one_hot(class_labels, num_classes);
// Some older loss functions require one_hot_targets; newer ones work with indices directly
// Manual one-hot verification/visualization
const idx = torch.tensor([1, 1, 2, 0]);
const encoded = torch.nn.functional.one_hot(idx, 3); // [4, 3]
// encoded[0] = [0, 1, 0] (index was 1)
// encoded[1] = [0, 1, 0] (index was 1)
// encoded[2] = [0, 0, 1] (index was 2)
// encoded[3] = [1, 0, 0] (index was 0)See Also
- PyTorch torch.nn.functional.one_hot
- embedding - Learnable lookup table (alternative for categorical data)
- cross_entropy - Modern loss that accepts indices directly (doesn't need one-hot)
- softmax - For soft (probabilistic) targets instead of hard one-hot
- argmax - Inverse: convert one-hot back to indices
- Categorical - Distribution version (probabilistic one-hot)