torch.searchsorted
function searchsorted(sorted_sequence: Tensor, values: Tensor): Tensorfunction searchsorted(sorted_sequence: Tensor, values: Tensor, out_int32: boolean, right: boolean, side: 'left' | 'right', sorter: Tensor, options: SearchSortedOptions): TensorFind the indices into a sorted sequence such that values would be inserted to maintain order.
Performs binary search to find insertion points that would maintain sorted order. Returns the indices where each value should be inserted. Useful for bucketing, sorting, and range queries.
Commonly used for:
- Bucketing/binning operations on sorted boundaries
- Building histograms with specific bin edges
- Finding indices for insertion in sorted arrays
- Percentile and quantile computations
- Data discretization and categorization
- Binary search complexity is O(n log m) where n = values.size, m = sorted_sequence.size
- sorted_sequence must be sorted for correct results (no validation performed)
- left=0, right=len(sorted_sequence) are valid result indices
- Handles duplicate values: left returns first position, right returns last position + 1
Parameters
sorted_sequenceTensor- 1D sorted tensor containing the reference sequence (must be sorted)
valuesTensor- Tensor of values to search for (can be any shape)
Returns
Tensor– Tensor with same shape as values, containing insertion indicesExamples
// Basic binary search - find insertion points
const seq = torch.tensor([1, 3, 5, 7, 9]);
const vals = torch.tensor([2, 4, 6]);
torch.searchsorted(seq, vals); // [1, 2, 3]
// Left vs right insertion
const seq = torch.tensor([1, 3, 3, 3, 5]);
const vals = torch.tensor([3, 3]);
torch.searchsorted(seq, vals, true, false); // [1, 1] - leftmost positions
torch.searchsorted(seq, vals, true, true); // [4, 4] - rightmost positions
// Bucketing operation (quantizing to nearest bin)
const bins = torch.tensor([0, 10, 20, 30, 40, 50]);
const data = torch.tensor([5, 12, 18, 35, 42]);
const buckets = torch.searchsorted(bins, data); // [1, 2, 2, 4, 5]
// 2D case (searches in same sorted sequence for each row)
const seq = torch.tensor([1, 3, 5, 7, 9]);
const vals = torch.tensor([[2, 4], [6, 8]]); // Shape [2, 2]
const result = torch.searchsorted(seq, vals); // Shape [2, 2], values [[1, 2], [3, 4]]See Also
- PyTorch torch.searchsorted()
- bincount - Histogram using searchsorted internally
- bucketize - Simpler bucketing operation
- sort - Sort a tensor
- argsort - Get indices that sort a tensor