torch.diagonal_scatter
function diagonal_scatter<S extends Shape, Dt extends DType, Dev extends DeviceType>(input: Tensor<S, Dt, Dev>, src: Tensor, options?: DiagonalScatterOptions): Tensor<S, Dt, Dev>function diagonal_scatter<S extends Shape, Dt extends DType, Dev extends DeviceType>(input: Tensor<S, Dt, Dev>, src: Tensor, offset: number, dim1: number, dim2: number): Tensor<S, Dt, Dev>Scatters values along a diagonal in the tensor.
Embeds a 1D source tensor into the main diagonal (or offset diagonal) of input. The diagonal is formed by the two specified dimensions. Useful for:
- Matrix manipulation: setting diagonal elements
- Identity matrices: creating identity patterns
- Diagonal pattern creation: building diagonal tensors
- Higher-D generalization: diagonals in 3D+ tensors
- Matrix conditioning: modifying diagonal elements
- 1D source: src must be 1D tensor
- Offset range: Valid range depends on tensor dimensions
- Dimension flexibility: Any two dimensions can form the diagonal
- Broadcasting: Output shape matches input shape
- Size validation: src length must match diagonal length
- Offset bounds: Offset must be within valid range for tensor
Parameters
inputTensor<S, Dt, Dev>- The input tensor (acts as template for shape)
srcTensor- 1D tensor with values to scatter along diagonal
optionsDiagonalScatterOptionsoptional
Returns
Tensor<S, Dt, Dev>– A new tensor with src scattered along the specified diagonalExamples
// Set main diagonal
const base = torch.zeros(3, 3);
const diag_vals = torch.tensor([1, 2, 3]);
const result = torch.diagonal_scatter(base, diag_vals);
// [[1, 0, 0], [0, 2, 0], [0, 0, 3]]
// Upper diagonal (offset=1)
const upper = torch.zeros(3, 3);
const upper_vals = torch.tensor([10, 20]);
const upper_result = torch.diagonal_scatter(upper, upper_vals, 1);
// [[0, 10, 0], [0, 0, 20], [0, 0, 0]]
// Lower diagonal (offset=-1)
const lower = torch.zeros(4, 4);
const lower_vals = torch.tensor([5, 5, 5]);
const lower_result = torch.diagonal_scatter(lower, lower_vals, -1);
// 3D tensor diagonal
const cube = torch.zeros(3, 3, 3);
const cube_vals = torch.tensor([1, 2, 3]);
const cube_result = torch.diagonal_scatter(cube, cube_vals, 0, 0, 1);See Also
- PyTorch torch.diagonal_scatter()
- select_scatter - Scatter at single position
- slice_scatter - Scatter at slice range
- scatter - More flexible scattering