torch.diag_embed
function diag_embed(input: Tensor, options?: DiagEmbedOptions): Tensorfunction diag_embed(input: Tensor, offset: number, dim1: number, dim2: number, options?: DiagEmbedOptions): TensorCreates a tensor whose diagonals of certain 2D planes are filled by input.
Takes the last dimension of the input and places its elements along the diagonal of a new 2D plane, adding one extra dimension to the output. The 2D planes are defined by dimensions dim1 and dim2. This operation is the inverse of diag() for batched operations.
Commonly used for:
- Batched diagonal matrix operations
- Creating diagonal components of covariance matrices
- Batched diagonal transformations in neural networks
- Efficient representation of diagonal linear operators
- Transforming vectors to diagonal matrices in batches
- Output has one more dimension than input
- If dim1 or dim2 are not specified, they default to the last two dimensions
- The size of the 2D plane is len(input.shape[-1]) + |offset|
- When offset ≠ 0, the matrix is non-square (size = len + |offset|)
- This is the batched inverse of torch.diag() for the last two dimensions
Parameters
inputTensor- Input tensor of any shape (at least 1D). Last dimension elements go on diagonal.
optionsDiagEmbedOptionsoptional
Returns
Tensor– Tensor with shape = input.shape[:-1] + [size, size] where size = len(last_dim) + |offset|Examples
// Basic 1D to diagonal matrix
const x = torch.tensor([1, 2, 3]);
torch.diag_embed(x);
// [[1, 0, 0],
// [0, 2, 0],
// [0, 0, 3]]
// Offset diagonal (above main diagonal)
const x = torch.tensor([1, 2]);
torch.diag_embed(x, 1);
// [[0, 1, 0],
// [0, 0, 2],
// [0, 0, 0]]
// Batched diagonal matrices (2D input)
const x = torch.tensor([[1, 2], [3, 4]]); // Shape [2, 2]
const result = torch.diag_embed(x); // Shape [2, 2, 2]
// result[0] = [[1, 0], [0, 2]]
// result[1] = [[3, 0], [0, 4]]
// Specify which dimensions form the 2D plane
const x = torch.tensor([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]); // [2, 2, 2]
const result = torch.diag_embed(x, 0, 0, 1); // Diagonal in dims 0,1See Also
- PyTorch torch.diag_embed()
- diag - Extract/create diagonal from 2D matrix
- diagonal - Extract diagonal from 2D planes
- block_diag - Create block diagonal matrix from multiple tensors