torch.scatter
function scatter<S extends Shape, D extends number>(input: Tensor<S>, dim: D, index: Tensor, src: Tensor): Tensor<ScatterShape<S, D>>Scatters values into a tensor at specified indices along a dimension.
Inverse of gather(): places values from src into positions in input specified by index. For each position in index, the corresponding value from src is placed at that position in the output. Useful for:
- One-hot encoding: scattering class indices as vectors
- Attention output: scattering attention values to positions
- Sparse updates: updating specific positions with new values
- Conditional assignment: setting values at computed positions
- Graph operations: updating node/edge features
- In-place vs copy: This returns new tensor; use scatter_ for in-place
- Index shape: Index must have same rank as input
- Output shape: Same as input shape
- Values from src: src must be broadcastable to index shape
- Index bounds: Indices must be valid (0 to input.shape[dim]-1)
- Rank match: Index must have same rank as input
- Size match: src must broadcast to index shape
Parameters
inputTensor<S>- The input tensor (acts as base/template)
dimD- The dimension along which to scatter
indexTensor- The indices tensor specifying where to place src values
srcTensor- The source tensor with values to scatter
Returns
Tensor<ScatterShape<S, D>>– A new tensor with src values scattered into input at positions specified by indexExamples
// One-hot encoding with scatter
const batch_size = 32;
const num_classes = 10;
const indices = torch.randint(0, num_classes, [batch_size, 1]); // Class indices
const one_hot = torch.zeros([batch_size, num_classes]);
torch.scatter(one_hot, 1, indices, torch.ones([batch_size, 1])); // [32, 10]
// Update specific positions
const data = torch.zeros(10, 10);
const indices_t = torch.tensor([[0, 2, 4], [1, 3, 5]]);
const values = torch.ones(2, 3);
const result = torch.scatter(data, 1, indices_t, values);
// Graph node updates
const nodes = torch.randn(100, 64); // 100 nodes, 64 features
const update_indices = torch.tensor([[10], [20], [30]]); // Which nodes to update
const new_values = torch.randn(3, 64); // New features
const updated = torch.scatter(nodes, 0, update_indices, new_values);See Also
- PyTorch torch.scatter()
- gather - Inverse operation (gather from input)
- scatter_add - Scatter with addition instead of replacement
- scatter_reduce - Scatter with custom reduction operation
- index_copy - Simpler 1D scattering