torch.tensor_split
function tensor_split<S extends Shape>(input: Tensor<S>, indices_or_sections: number | number[], options?: SplitOptions): Tensor<DynamicShape>[]function tensor_split<S extends Shape, D extends number>(input: Tensor<S>, indices_or_sections: number | number[], dim: D, options?: SplitOptions): ValidateSplitDim<S, D> extends true ? Tensor<DynamicShape>[] : ValidateSplitDim<S, D>Splits a tensor into multiple sub-tensors along a dimension.
Splits the input tensor into chunks either by specifying the number of equal chunks or by providing exact split indices. Unlike split(), this handles unequal chunks better and returns exactly N tensors when splitting into N sections. Useful for:
- Batch processing: dividing data into mini-batches
- Cross-validation: splitting data into train/val folds
- Gradient accumulation: splitting batches for memory efficiency
- Data parallelism: distributing data across devices
- Feature extraction: splitting feature groups
- Output count: Splitting into N sections returns N tensors
- Views with shared storage: Returned tensors share storage with input
- Dimension must be valid: dim must be in range [0, rank)
- Indices must be sorted: Split indices should be in ascending order
- Unequal chunks: With int N, last chunk may be smaller than others
- Index bounds: Indices must be in range (0, dim_size)
- Dimension range: dim must be valid for input tensor
Parameters
inputTensor<S>- The input tensor
indices_or_sectionsnumber | number[]- Either number of equal chunks (int) or array of split indices
optionsSplitOptionsoptional
Returns
Tensor<DynamicShape>[]– Array of tensors split along the dimensionExamples
// Split into equal chunks
const x = torch.arange(10);
const chunks = torch.tensor_split(x, 3); // 3 chunks: [0-3], [3-6], [6-9]
chunks.map(c => c.toArray()); // [[0,1,2], [3,4,5], [6,7,8,9]]
// Split at specific indices
const y = torch.arange(12).reshape(3, 4);
const parts = torch.tensor_split(y, [1, 2], 0); // Split rows at indices 1, 2
// Returns 3 tensors: [0:1], [1:2], [2:3]
// Batch splitting for memory efficiency
const batch = torch.randn(1000, 256); // Large batch
const mini_batches = torch.tensor_split(batch, 10, 0); // 10 chunks of 100
// Cross-validation split
const dataset = torch.randn(100, 10); // 100 samples, 10 features
const folds = torch.tensor_split(dataset, 5, 0); // 5-fold split
// Feature group splitting
const features = torch.randn(32, 256); // 256 features
const [early, middle, late] = torch.tensor_split(features, [100, 200], 1);
// Split features: [0:100], [100:200], [200:256]