torch.Tensor.Tensor.flatten
Tensor.flatten(): Tensor<readonly [number], D, Dev>Tensor.flatten(startDimOrOptions?: number | FlattenOptions, endDimOrOptions?: number | FlattenOptions, options?: FlattenOptions): Tensor<DynamicShape, D, Dev>Flatten dimensions of a tensor into a single dimension.
Collapses a range of dimensions into a single dimension. If no parameters are given, flattens all dimensions into a 1D tensor. Commonly used before fully connected layers to convert multi-dimensional features into vectors.
Common use cases:
- Prepare CNN output for fully connected layers: (batch, channels, height, width) -> (batch, -1)
- Flatten spatial dimensions: (batch, height, width) -> (batch, height*width)
- Reshape image data for processing
- Vectorizing multi-dimensional features
- Preserves batch dimension: Common pattern is flatten(1) to keep batch dimension.
- Negative indices: Both startDim and endDim support negative indexing.
- Fast operation: No data copy, just shape change.
Returns
Tensor<readonly [number], D, Dev>– Tensor with dimensions startDim to endDim collapsed into oneExamples
// Flatten all dimensions to 1D
const x = torch.randn(2, 3, 4);
const flat = x.flatten(); // Shape [24]
// Flatten specific range
const x = torch.randn(32, 3, 224, 224); // Batch of images
const flat = x.flatten(1); // Shape [32, 150528] - flatten channels & spatial
// Flatten middle dimensions
const x = torch.randn(2, 3, 4, 5);
const flat = x.flatten(1, 2); // Shape [2, 12, 5] - flatten dims 1,2
// CNN to FC layer pattern
const conv_output = torch.randn(batch_size, 512, 7, 7);
const flattened = conv_output.flatten(1); // [batch_size, 512*7*7]
const logits = fc_layer(flattened);
// Negative indexing
const x = torch.randn(2, 3, 4, 5);
const flat = x.flatten(1, -1); // Flatten from dim 1 to lastSee Also
- PyTorch tensor.flatten()
- reshape - More general dimension reshaping
- unflatten - Opposite: split a dimension back into multiple
- ravel - PyTorch alias for flatten all