torch.Tensor.Tensor.unflatten
Tensor.unflatten(dim: number, sizes: readonly number[]): Tensor<DynamicShape, D, Dev>Expand a dimension into multiple dimensions.
Opposite of flatten - splits a single dimension into multiple dimensions. Useful for reshaping flattened data back to multi-dimensional form or reshaping feature vectors into spatial structures.
Common use cases:
- Reshape flattened data back to original form
- Convert feature vectors to spatial: (batch, 256) -> (batch, 16, 16)
- Reshape output of operations that flattened their input
- Hierarchical reshaping of multi-dimensional data
- Must multiply: Sizes must multiply to the original dimension size, or error is thrown.
- Single dimension: Only unflatters one dimension at a time.
- Negative indexing: Dimension supports negative indexing.
Parameters
dimnumber- Dimension to unflatten
sizesreadonly number[]- Sizes of the unflattened dimensions. Must satisfy: product(sizes) == original_shape[dim]
Returns
Tensor<DynamicShape, D, Dev>– Tensor with dimension dim split into multiple dimensionsExamples
// Basic unflatten
const x = torch.zeros(3, 12, 5);
const unflat = x.unflatten(1, [3, 4]); // Shape [3, 3, 4, 5]
// Reshape feature vector to spatial
const features = torch.randn(32, 256); // 32 samples, 256 features
const spatial = features.unflatten(1, [16, 16]); // [32, 16, 16]
// Undo flattening from CNN
const flattened = torch.randn(batch_size, 512 * 7 * 7);
const unflat = flattened.unflatten(1, [512, 7, 7]); // [batch, 512, 7, 7]
// Hierarchical reshape
const x = torch.randn(24);
const reshaped = x.unflatten(0, [2, 3, 4]); // [2, 3, 4]See Also
- PyTorch tensor.unflatten()
- flatten - Opposite: collapse multiple dimensions into one
- reshape - More general dimension reshaping