torch.unflatten
Expands a dimension of the input tensor over multiple dimensions.
Inverse of ravel/flatten: takes a tensor and expands one dimension into multiple dimensions. The specified dimension must match the product of the new sizes. Useful for:
- Shape transformation: converting flattened features back to multi-D
- Reconstruction: reshaping model outputs to original image/sequence shapes
- Feature unrolling: expanding a single dimension into spatial dimensions
- Deconvolution prep: reshaping flat vectors to spatial for conv layers
- Batch processing: expanding batch dimension into sub-batches
- Tensorboard visualization: reshaping features for viewing
Reverse operation to ravel(). Takes one dimension of size N and expands it to multiple dimensions whose product equals N.
- Product constraint: Product of sizes MUST equal input.shape[dim]
- Negative indexing: dim=-1 refers to last dimension, dim=-2 to second-last, etc.
- Order preserved: Elements maintain row-major ordering from input
- Inverse of ravel: Perfect companion to ravel/flatten operations
- No copy if possible: May use view if memory layout allows
- Size mismatch error: Product of sizes must exactly equal dimension size
- Dimension bounds: dim must be valid (0 to rank-1 or -rank to -1)
- Shape inference: All sizes must be positive integers
Parameters
inputTensor- The input tensor of any shape
dimnumber- The dimension to expand (supports negative indexing: -1 for last dim)
sizesnumber[]- Target shape for the expanded dimension. Product must equal input.shape[dim]
Returns
Tensor– Tensor with dim expanded into multiple dimensions according to sizesExamples
// Inverse of ravel: expand flat tensor back to image shape
const flat = torch.randn(32*3*224*224); // [4,802,432]
const images = torch.unflatten(flat, 0, [32, 3, 224, 224]); // [32, 3, 224, 224]
// Reshape conv features for subsequent layers
const conv_flat = torch.randn(4*64*14*14); // [50,176]
const conv_shaped = torch.unflatten(conv_flat, 0, [4, 64, 14, 14]); // [4, 64, 14, 14]
// Model output processing: expand batch into chunks
const output = torch.randn(1000); // Flat output
const batched = torch.unflatten(output, 0, [100, 10]); // [100, 10] - 100 samples, 10 features
// Negative indexing: expand last dimension
const vectors = torch.randn(32, 128);
const expanded = torch.unflatten(vectors, -1, [8, 16]); // [32, 8, 16]
// Middle dimension expansion
const sequence = torch.randn(16, 256, 50); // batch=16, hidden=256, time=50
const expanded = torch.unflatten(sequence, 1, [64, 4]); // [16, 64, 4, 50]
// Expands hidden=256 into 64*4=256
// Reconstruction from encoder
const encoded = torch.randn(100, 256); // 100 samples, 256-dim encoding
const reconstructed = torch.unflatten(encoded, 1, [16, 16]); // [100, 16, 16]
// Prepare for conv decoder expecting spatial inputSee Also
- PyTorch torch.unflatten()
- ravel - Inverse: flatten to 1D
- flatten - Flatten multiple dimensions
- reshape - Reshape arbitrary number of dimensions
- view - Lower-level shape change