torch.diagflat
function diagflat(input: Tensor, options?: DiagFlatOptions): Tensorfunction diagflat(input: Tensor, offset: number, options?: DiagFlatOptions): TensorCreates a 2D tensor with the flattened input on the diagonal.
Flattens the input to 1D and places it on the diagonal of a 2D output matrix, with zeros filling all off-diagonal positions. Useful for:
- Converting arbitrary-shaped tensors to diagonal matrices
- Creating block structures for independent operations
- Preparing data for matrix multiplication with diagonal matrices
- Scaling operations (diagonal matrix multiplication)
- Constructing sparse matrix representations
- Flattening: Input is always flattened before placing on diagonal
- Square output: Always returns a square matrix of size [n+|offset|, n+|offset|]
- Offset support: Positive offset places diagonal above main, negative below
- Zero-filled: All non-diagonal positions are zero
- Memory: For large flattened inputs, produces large square matrices (n² size)
- Size explosion: Output size is O(n²) where n = total elements in input
- Offset out of bounds: Very large offsets will create sparse matrices
- Different from diag(): diag() doesn't flatten; diagflat() always flattens
Parameters
inputTensor- Input tensor (any dimensions, will be flattened to 1D)
optionsDiagFlatOptionsoptional
Returns
Tensor– 2D square tensor with flattened input on the diagonal and zeros elsewhere. Shape is [n+|offset|, n+|offset|] where n = numel(input).Examples
// Simple 1D input
const x = torch.tensor([1, 2, 3]);
torch.diagflat(x);
// [[1, 0, 0],
// [0, 2, 0],
// [0, 0, 3]]
// Multi-dimensional input (gets flattened)
const x = torch.tensor([[1, 2], [3, 4]]);
torch.diagflat(x);
// [[1, 0, 0, 0],
// [0, 2, 0, 0],
// [0, 0, 3, 0],
// [0, 0, 0, 4]]
// With offset for upper diagonal
const x = torch.tensor([1, 2, 3]);
torch.diagflat(x, 1);
// [[0, 1, 0, 0],
// [0, 0, 2, 0],
// [0, 0, 0, 3],
// [0, 0, 0, 0]]
// Scaling vector for matrix multiplication
const scales = torch.tensor([2, 3, 4]);
const scaled_data = torch.matmul(torch.diagflat(scales), data);See Also
- PyTorch torch.diagflat()
- diag - Create diagonal matrix without flattening (for 1D) or extract diagonal (for 2D)
- block_diag - Create block-diagonal from multiple tensors
- flatten - Flatten tensor to 1D
- matmul - Matrix multiplication (use with diagonal matrices for scaling)