torch.flatten
function flatten<D extends DType = DType, Dev extends DeviceType = DeviceType>(input: Tensor<Shape, D, Dev>, options?: FlattenOptions): Tensor<Shape, D, Dev>function flatten<D extends DType = DType, Dev extends DeviceType = DeviceType>(input: Tensor<Shape, D, Dev>, start_dim: number, end_dim: number, options?: FlattenOptions): Tensor<Shape, D, Dev>Flattens a contiguous range of dimensions into a single dimension.
Combines multiple consecutive dimensions into one by computing their product. When called with default parameters, flattens all dimensions into a 1D tensor. When startDim and endDim are specified, only that range is flattened. Essential for:
- Neural network pipelines: Converting images to vectors for dense layers
- Batch processing: Reshaping feature maps from convolutional layers
- Data preparation: Collapsing multi-dimensional data for processing
- Dimension manipulation: Selective flattening of specific dimension ranges
- Model architecture: Bridging between convolutional and fully-connected layers
- Tensor reformatting: Preparing data for operations requiring specific dimensions
- Dimension preservation: Dimensions before startDim and after endDim are unchanged
- Negative indices: startDim=-2 and endDim=-1 flatten the last two dimensions
- Default behavior: Called with no parameters, flatten(x) is equivalent to flatten(x, 0, -1) which flattens all dimensions into a single 1D tensor
- Batch dimension pattern: Using flatten(x, 1) is the standard way to preserve batch dimension while flattening all other dimensions
- Memory efficient: Typically creates a view of the data, not a copy. Uses existing memory with new strides
- Shape must be compatible: The product of flattened dimensions must not exceed the maximum tensor size. Very large flattening operations may fail.
- Dimension order matters: Flattening different dimension ranges produces different results. flatten(x, 0, 1) != flatten(x, 1, 2) for most multi-dimensional tensors
Parameters
optionsFlattenOptionsoptional- Optional settings for flatten
Returns
Tensor<Shape, D, Dev>– Tensor where dimensions [startDim:endDim+1] are merged into one. Dimensions before startDim and after endDim remain unchangedExamples
// Full flatten: collapse all dimensions to 1D
const x = torch.randn(2, 3, 4, 5);
const flat = torch.flatten(x); // Shape [120]
// Preserve batch dimension: flatten spatial dimensions
const images = torch.randn(32, 3, 28, 28); // Batch of images
const flat = torch.flatten(images, 1); // Shape [32, 2352]
// Now ready for fully-connected layer
// Partial flatten: flatten middle dimensions
const x = torch.randn(2, 3, 4, 5);
const partial = torch.flatten(x, { start_dim: 1, end_dim: 2 }); // Shape [2, 12, 5]
// Only flattens dims 1 and 2 (3*4=12)
// CNN feature extraction: flatten conv output for FC input
const conv_out = torch.randn(16, 64, 7, 7); // [batch, filters, height, width]
const fc_input = torch.flatten(conv_out, 1); // [16, 3136]
// Can now pass to fully-connected layer
// Batched operation: separate batch from features
const batch_data = torch.randn(32, 10, 20, 30);
const features = torch.flatten(batch_data, 1); // [32, 6000]
// Batch dimension (32) preserved, features (10*20*30) in one dimensionSee Also
- PyTorch torch.flatten(input, start_dim=0, end_dim=-1)
- reshape - More general shape transformation for arbitrary reshaping
- unflatten - Inverse operation: expands one dimension into multiple dimensions
- ravel - Alias for flattening all dimensions (ravel(x) == flatten(x))