torch.index_select
function index_select<S extends Shape, D extends number, I extends Shape>(input: Tensor<S>, dim: D, index: Tensor<I>): Tensor<IndexSelectShape<S, D, I extends readonly [infer Len extends number] ? Len : number>>Selects elements from the input tensor along a dimension using indices.
Returns a new tensor with elements selected along dim using a 1-D index tensor. All other dimensions are kept intact. The output has the same rank as input with the selected dimension having size equal to the index tensor length. Useful for:
- Selecting rows/columns from matrices
- Filtering batch elements based on computed indices
- Reordering tensor dimensions (with permutation indices)
- Feature/class selection from high-dimensional data
- Index-based data shuffling or sorting
- Index is 1-D: Always selects along a single dimension with 1-D indices
- Output rank: Same as input rank
- Order preserved: Output order matches index order (can be used to reorder)
- Bounds checking: All values in index must be valid ( input.shape[dim])
- Duplicates allowed: Index can contain repeated values (some elements selected multiple times)
- Out-of-bounds: Invalid index values cause errors
- 1-D indices only: Use gather() for higher-dimensional indexing patterns
- Dimension validity: dim must be in range [0, input.rank)
Parameters
inputTensor<S>- The input tensor to select from
dimD- The dimension along which to select (0 to rank-1)
indexTensor<I>- A 1-D tensor containing the indices to select (all values must be in [0, input.shape[dim]))
Returns
Tensor<IndexSelectShape<S, D, I extends readonly [infer Len extends number] ? Len : number>>– A tensor with shape[dim] = index.shape[0], other dimensions unchangedExamples
// Select specific rows
const x = torch.randn(10, 5); // [10, 5]
const row_indices = torch.tensor([0, 2, 4]);
torch.index_select(x, 0, row_indices); // [3, 5] - select rows 0, 2, 4
// Select specific columns
const col_indices = torch.tensor([1, 3]);
torch.index_select(x, 1, col_indices); // [10, 2] - select columns 1, 3
// Filter batch samples
const batch = torch.randn(32, 64, 28, 28); // [batch_size, channels, height, width]
const keep_indices = torch.tensor([0, 5, 10, 15, 20]); // Keep 5 samples
const filtered = torch.index_select(batch, 0, keep_indices); // [5, 64, 28, 28]
// Select class features
const features = torch.randn(100, 200); // 100 samples, 200 features
const important_features = torch.tensor([5, 15, 25, 35, 45]);
const selected = torch.index_select(features, 1, important_features); // [100, 5]
// Permute based on indices (like argsort)
const values = torch.tensor([5.0, 2.0, 8.0, 1.0]);
const sorted_idx = torch.tensor([3, 1, 0, 2]); // Indices for sorted order
torch.index_select(values, 0, sorted_idx); // [1.0, 2.0, 5.0, 8.0]See Also
- PyTorch torch.index_select()
- gather - More flexible N-D indexing along a dimension
- select - Select single element along dimension (reduces rank)
- take - Select using flat indices
- take_along_dim - Gather with broadcasting support