torch.select_scatter
function select_scatter<S extends Shape, Dt extends DType, Dev extends DeviceType>(input: Tensor<S, Dt, Dev>, src: Tensor, dim: number, index: number): Tensor<S, Dt, Dev>Scatters a tensor at a single index position along a dimension.
Inverse of select(): embeds a lower-dimensional tensor at a specific index along a dimension. The source tensor should have one fewer dimension than input. Useful for:
- Element replacement: replacing single element/slice in tensor
- Index assignment: tensor[index] = value style operations
- Batch item updates: updating one item in batch
- Structural updates: updating one row/column/etc.
- Dimension reduction: src has one fewer dimension than input
- Inverse of select: Opposite operation to select()
- Single position: Only embeds at one index
- Shape checking: src.shape must match input with dim removed
- Rank mismatch: src must have exactly rank-1 dimensions
- Index bounds: index must be in range [0, input.shape[dim])
Parameters
inputTensor<S, Dt, Dev>- The input tensor
srcTensor- Source tensor to embed (rank = input.rank - 1)
dimnumber- The dimension along which to scatter
indexnumber- The single index position where to embed src
Returns
Tensor<S, Dt, Dev>– A new tensor with src embedded at input[index] along dimExamples
// Replace single row
const matrix = torch.randn(5, 3);
const new_row = torch.ones(3);
const updated = torch.select_scatter(matrix, new_row, 0, 1);
// matrix with row 1 replaced by all ones
// Update single batch element
const batch = torch.randn(32, 64, 64); // 32 images
const new_image = torch.randn(64, 64);
const updated_batch = torch.select_scatter(batch, new_image, 0, 5);
// batch with element 5 replaced
// Column replacement
const data = torch.randn(10, 5);
const new_column = torch.zeros(10);
const zeroed_col = torch.select_scatter(data, new_column, 1, 2);
// Column 2 set to zeroSee Also
- PyTorch torch.select_scatter()
- select - Extract single index (inverse operation)
- slice_scatter - Scatter at slice range
- diagonal_scatter - Scatter along diagonal
- index_copy - Copy at 1D indices