torch.unravel_index
Converts flat linear indices into multi-dimensional coordinate tensors.
Given flat indices (as if tensor was flattened), returns the corresponding multi-dimensional coordinates. Inverse of ravel_multi_index. Useful for:
- Finding element locations: where did argmax/argmin point to in original shape?
- Sparse tensor conversion: converting linear indices to (row, col) coordinates
- Debugging: converting flat indices back to readable positions
- Image processing: finding which pixel (y, x) corresponds to flat index
- Tensor indexing: reconstructing multi-d indices from flattened search results
- Data extraction: getting coordinates of specific flat positions
Given N flat indices and a target shape, returns array of N coordinate tensors, one for each dimension. E.g., unravel_index([5], [3, 4]) → [1, 1] because in a 3x4 array (12 elements total), flat index 5 = row 1, col 1.
- Row-major ordering: Assumes row-major (C-contiguous) memory layout
- Return format: Array with one tensor per dimension (destructure with [d0, d1, d2, ...] )
- Same shape: Each returned tensor has same shape as input indices
- Non-negative indices: All indices must be = 0 and numel(shape)
- Inverse operation: Inverse of ravel_multi_index (flatten indices)
- Out-of-bounds: Indices must be product(shape), else gives wrong coords
- Shape interpretation: Assumes row-major layout; col-major layout would differ
- Integer indices required: Indices must be integer type (no floats)
Parameters
indicesTensor- Tensor of flat indices (1D tensor of non-negative integers)
shapenumber[]- Target shape to interpret indices with respect to
Returns
Tensor[]– Array of tensors with same shape as indices, one tensor per dimension in shapeExamples
// Simple 2D example
const flat = torch.tensor([3, 5, 8]);
const [rows, cols] = torch.unravel_index(flat, [3, 4]);
// rows: [0, 1, 2], cols: [3, 1, 0]
// Flat indices: 3 → (0,3), 5 → (1,1), 8 → (2,0)
// Finding pixel positions from flat index
const max_index = torch.argmax(image_flat); // Single max index
const [y, x] = torch.unravel_index(max_index, [H, W]);
// y, x are now the pixel coordinates of the maximum value
// Sparse matrix format conversion
const sparse_indices = torch.tensor([0, 5, 10, 15]); // Flat indices
const [rows, cols] = torch.unravel_index(sparse_indices, [5, 5]);
// Convert flat sparse format to (row, col) coordinates
// 3D example: converting flat coordinates
const flat = torch.tensor([0, 13, 26]); // Indices in 3D array
const [x, y, z] = torch.unravel_index(flat, [3, 4, 2]);
// Converts to 3D coordinates
// Batch processing multiple indices
const flat_batch = torch.tensor([[5, 10], [2, 7]]); // 2x2 batch of indices
const [rows, cols] = torch.unravel_index(flat_batch, [4, 4]);
// Returns [rows, cols] with same shape as input
// Finding all positions of value >= threshold
const flat_tensor = torch.randn(1000);
const threshold_indices = torch.where(flat_tensor.gt(2))[0]; // Get flat indices
const coords = torch.unravel_index(threshold_indices, [10, 100]); // Map back to 2D shapeSee Also
- PyTorch torch.unravel_index()
- ravel_multi_index - Inverse: convert coordinates to flat indices
- argmax - Get max index (often fed to unravel_index)
- argmin - Get min index (often fed to unravel_index)
- where - Get indices of elements matching condition
- nonzero - Get coordinates of non-zero elements directly