torch.scatter_add
function scatter_add<S extends Shape, D extends number>(input: Tensor<S>, dim: D, index: Tensor, src: Tensor): Tensor<ScatterShape<S, D>>Scatters values into a tensor by adding at specified indices.
Like scatter() but instead of replacing values, accumulates (adds) src values at the positions specified by index. Useful for:
- Histogram computation: accumulating counts at specific bins
- Attention accumulation: summing attention-weighted values
- Feature aggregation: combining values from multiple sources
- Graph message passing: aggregating node/edge messages
- Reduce operations: summing values at repeated indices
If multiple source values map to the same index position, they are accumulated (summed).
- Accumulation: Values at same index are summed
- Multiple contributions: Supports many-to-one mapping
- Initialization: Input values are preserved and accumulated with src
Parameters
inputTensor<S>- The input tensor (template/base values)
dimD- The dimension along which to index
indexTensor- The indices tensor specifying where to add src values
srcTensor- The source tensor with values to add
Returns
Tensor<ScatterShape<S, D>>– A new tensor with src values accumulated into inputExamples
// Histogram computation
const data = torch.tensor([1, 2, 2, 3, 3, 3, 4]); // Binned data
const histogram = torch.zeros(5); // Bins 0-4
const indices = torch.tensor([[1], [2], [2], [3], [3], [3], [4]]); // Which bins
const ones = torch.ones(7, 1);
torch.scatter_add(histogram, 0, indices, ones); // Count occurrences
// Attention accumulation
const values = torch.randn(10, 64); // 10 values, 64 dimensions
const indices = torch.tensor([0, 1, 0, 2, 1, 0, 1, 2, 0, 1]); // Which positions
const result = torch.scatter_add(torch.zeros(3, 64), 0, indices, values);
// Sums values at each position
// Aggregation with weights
const features = torch.randn(100, 64);
const weights = torch.randn(100, 1);
const node_indices = torch.randint(0, 10, [100, 1]); // 100 features to 10 nodes
const aggregated = torch.scatter_add(torch.zeros(10, 64), 0, node_indices, features);See Also
- PyTorch torch.scatter_add()
- scatter - Replace instead of accumulate
- scatter_reduce - Custom reduction (sum, mean, max, etc.)
- index_add - Simpler 1D accumulation