torch.take_along_dim
function take_along_dim<S extends Shape, Dt extends DType, Dev extends DeviceType>(input: Tensor<S, Dt, Dev>, index: Tensor, options?: TakeAlongDimOptions): Tensor<DynamicShape, Dt, Dev>function take_along_dim<S extends Shape, Dt extends DType, Dev extends DeviceType>(input: Tensor<S, Dt, Dev>, index: Tensor, dim: number | null): Tensor<DynamicShape, Dt, Dev>Gathers values using indices, with automatic broadcasting.
Similar to gather() but with more flexible indexing where the index tensor doesn't need to have the same number of dimensions. The indices are broadcastable against input along the specified dimension. Useful for:
- Top-k selection: getting top values from predictions
- Argsort reconstruction: gathering sorted values
- Broadcasting indices: when indices have different dimensionality
- Flattened indexing: treating input as flattened along dim
If dim is None, both input and index are flattened and standard gathering is performed.
- Broadcasting: Index tensor automatically broadcasts against input
- Flexible rank: Index doesn't need same rank as input
- Flattening with null: Using dim=null treats both as 1D
- Output shape: Matches index shape after broadcasting
- Broadcasting rules: Standard numpy-style broadcasting applied
- Dimension validity: dim must be valid or null
Parameters
inputTensor<S, Dt, Dev>- The source tensor
indexTensor- The indices tensor (can be 1D or broadcastable to input)
optionsTakeAlongDimOptionsoptional
Returns
Tensor<DynamicShape, Dt, Dev>– A tensor with gathered valuesExamples
// Gather with 1D indices (automatic broadcasting)
const input = torch.randn(10, 20);
const indices = torch.tensor([2, 5, 1, 8, 15]); // 1D indices
const result = torch.take_along_dim(input, indices, 1); // Auto-broadcast
// Top-k values gathering
const scores = torch.randn(32, 100); // Scores for 100 classes
const top_indices = torch.topk(scores, 5, 1)[1]; // Get top 5 indices
const top_values = torch.take_along_dim(scores, top_indices, 1); // Get top 5 values
// Flattened gathering (dim=None)
const data = torch.randn(4, 5, 6);
const flat_indices = torch.tensor([0, 5, 10, 50, 100]);
const gathered = torch.take_along_dim(data, flat_indices, null); // Flatten and gatherSee Also
- PyTorch torch.take_along_dim()
- gather - More strict rank matching, no broadcasting
- index_select - Select along single dimension
- topk - Get top-k values and indices