torch.index_reduce
function index_reduce<S extends Shape, Dt extends DType, Dev extends DeviceType>(input: Tensor<S, Dt, Dev>, dim: number, index: Tensor, source: Tensor, reduce: 'prod' | 'mean' | 'amax' | 'amin', options?: ScatterReduceOptions): Tensor<S, Dt, Dev>function index_reduce<S extends Shape, Dt extends DType, Dev extends DeviceType>(input: Tensor<S, Dt, Dev>, dim: number, index: Tensor, source: Tensor, reduce: 'prod' | 'mean' | 'amax' | 'amin', include_self: boolean, options?: ScatterReduceOptions): Tensor<S, Dt, Dev>Applies custom reduction operation at 1D index positions.
Reduces source values into input at positions specified by 1D indices. Similar to scatter_reduce() but for 1D indexing. Supports multiple reduction operations. Useful for:
- Custom pooling: mean/max pooling at indices
- Robust aggregation: using min/max to resist outliers
- Product aggregation: multiplying factors at positions
- Flexible histograms: reducing with operations other than sum
- Flexible message passing: custom node aggregation
- Reduction types: prod, mean, amax, amin
- include_self: Controls whether input values participate in reduction
- 1D indices: Index must be 1D
- Multiple contributions: Multiple indices can map to same position
- Reduction type: Must be one of specified operations
- Numerical properties: Different reductions have different effects
Parameters
inputTensor<S, Dt, Dev>- The input tensor
dimnumber- The dimension along which to index
indexTensor- 1D tensor of indices (length must match source.shape[dim])
sourceTensor- The source tensor with values to reduce
reduce'prod' | 'mean' | 'amax' | 'amin'- Reduction operation: 'prod'|'mean'|'amax'|'amin'
optionsScatterReduceOptionsoptional
Returns
Tensor<S, Dt, Dev>– A new tensor with reduced valuesExamples
// Mean pooling at indices
const output = torch.zeros(5, 64);
const features = torch.randn(20, 64);
const pool_indices = torch.tensor([0, 1, 0, 2, 1, 0, 1, 2, 0, 1, 3, 4, 3, 4, 2, 2, 3, 4, 0, 1]);
torch.index_reduce(output, 0, pool_indices, features, 'mean');
// Max pooling (robust)
const max_result = torch.full([5, 64], -Infinity);
torch.index_reduce(max_result, 0, pool_indices, features, 'amax');
// Product aggregation
const product = torch.ones(3);
const factors = torch.tensor([2, 3, 2, 1.5]); // 4 factors to 3 positions
const factor_indices = torch.tensor([0, 1, 0, 2]);
torch.index_reduce(product, 0, factor_indices, factors, 'prod');See Also
- PyTorch torch.index_reduce()
- index_add - Sum-only reduction (equivalent to reduce='sum')
- scatter_reduce - More flexible multi-D version