torch.scatter_reduce
function scatter_reduce<S extends Shape, D extends number>(input: Tensor<S>, dim: D, index: Tensor, src: Tensor, reduce: 'sum' | 'prod' | 'mean' | 'amax' | 'amin', options?: ScatterReduceOptions): Tensor<ScatterShape<S, D>>function scatter_reduce<S extends Shape, D extends number>(input: Tensor<S>, dim: D, index: Tensor, src: Tensor, reduce: 'sum' | 'prod' | 'mean' | 'amax' | 'amin', include_self: boolean, options?: ScatterReduceOptions): Tensor<ScatterShape<S, D>>Scatters values with custom reduction operation at specified indices.
Combines scatter() and scatter_add() with custom reduction operations. For each index position, applies the specified reduction (sum, mean, min, max, prod) to combine multiple source values. Useful for:
- Custom aggregation: mean/max pooling at positions
- Graph networks: flexible message aggregation
- Flexible histograms: reduce with operations other than sum
- Multi-source fusion: combining values with custom rules
- Robust aggregation: using min/max for outlier-resistant combining
- Reduction types: sum, mean, amax (max), amin (min), prod (product)
- include_self: If true, original input values participate in reduction
- Multiple contributors: Multiple index values can map to same position
- Accumulation: Results accumulate when same index appears multiple times
- Reduce type: Must be one of specified operations
- Type compatibility: Reduction may have numerical implications (e.g., mean vs sum)
Parameters
inputTensor<S>- The input tensor (acts as template for shape)
dimD- The dimension along which to index
indexTensor- The indices tensor specifying where to reduce src
srcTensor- The source tensor with values to reduce
reduce'sum' | 'prod' | 'mean' | 'amax' | 'amin'- Reduction operation: 'sum'|'prod'|'mean'|'amax'|'amin'
optionsScatterReduceOptionsoptional
Returns
Tensor<ScatterShape<S, D>>– A new tensor with reduced valuesExamples
// Mean pooling by position
const values = torch.randn(10, 64);
const pool_indices = torch.tensor([0, 1, 0, 2, 1, 0, 1, 2, 0, 1]); // Which positions
const output = torch.zeros(3, 64);
const mean_pooled = torch.scatter_reduce(output, 0, pool_indices, values, 'mean');
// Max pooling (robust to outliers)
const logits = torch.randn(100, 10);
const group_indices = torch.randint(0, 20, [100, 1]); // 100 logits to 20 groups
const max_logits = torch.scatter_reduce(
torch.full([20, 10], -Infinity), 0, group_indices, logits, 'amax'
);
// Product reduction
const factors = torch.randn(50);
const indices = torch.randint(0, 5, [50]); // 50 values to 5 positions
const products = torch.scatter_reduce(torch.ones(5), 0, indices, factors, 'prod');See Also
- PyTorch torch.scatter_reduce()
- scatter - Simple replacement scatter
- scatter_add - Sum-specific scatter (equivalent to reduce='sum')
- scatter_add_ - In-place scatter_add
- index_reduce - Simpler 1D scatter_reduce