torch.Tensor.Tensor.scatter
Tensor.scatter<D extends number>(dim: D, index: Tensor, src: Tensor): Tensor<DynamicShape>Scatters values from src into this tensor at indices specified by index along a dimension.
Places src values into this tensor at positions indicated by index. This is the inverse of
gather(): where gather collects elements, scatter places them. Essential for:
- Accumulating updates to selected positions (attention mechanisms)
- Reversing gather operations
- Building tensors from scattered components
- Implementing custom indexing-based operations
- One-hot encoding and similar operations
- Output is copy: Result is new tensor with this tensor as base, src values scattered into it at indexed positions.
- In-place variant: Use
scatter_()to modify this tensor directly. - Accumulation variant: Use
scatter_add()to accumulate src values instead of replacing. - Index constraints: Index values along dim must be valid (0 to this.shape[dim]-1).
- Index values must be valid indices (0 to size-1 for dimension). Out-of-bounds indices throw an error.
- Index tensor must have same rank as this tensor and src.
- Duplicate indices: If index has duplicate values, later src values overwrite earlier ones.
Parameters
dimD- The dimension along which to scatter (supports negative indexing)
indexTensor- Tensor of indices indicating where to place each src value. Must have same rank as this tensor. Values specify positions in this tensor's dim.
srcTensor- Source tensor containing values to scatter. Must have same rank as this tensor.
Returns
Tensor<DynamicShape>– New tensor with shape matching this tensor, with src values scattered into itExamples
// Scatter one-hot encoded values
const output = torch.zeros(3, 4); // [3, 4]
const index = torch.tensor([[1], [0], [3]]); // [3, 1] - positions to scatter
const src = torch.ones(3, 1); // [3, 1] - values to scatter
const result = output.scatter(1, index, src); // Scatter along dim 1
// Scatter class indices back to class dimensions
const batch_size = 32;
const num_classes = 10;
const class_indices = torch.randint(0, num_classes, [batch_size]); // [32]
const ones = torch.ones(batch_size); // [32]
const one_hot_base = torch.zeros(batch_size, num_classes);
const class_indices_2d = class_indices.unsqueeze(1); // [32, 1]
const ones_2d = ones.unsqueeze(1); // [32, 1]
const one_hot = one_hot_base.scatter(1, class_indices_2d, ones_2d); // [32, 10]
// Inverse of gather
const original = torch.randn(4, 5);
const index = torch.tensor([[0, 2, 1, 3]]);
const gathered = original.gather(1, index);
const scattered_back = torch.zeros(4, 5).scatter(1, index, gathered);
// scattered_back ≈ originalSee Also
- PyTorch tensor.scatter()
- gather - Opposite operation: gather values using indices
- scatter_add - Similar but accumulates values instead of replacing
- scatter_ - In-place variant
- scatter_add_ - In-place accumulation variant