torch.trace
Returns the trace (sum of diagonal elements) of a 2D matrix.
For 3D+ tensors, computes the trace for each matrix in the batch.
Parameters
inputTensor- Input tensor, must be at least 2D with square last two dimensions
Returns
Tensor– Scalar tensor for 2D input, or tensor with batch dimensions for 3D+ inputExamples
const A = torch.tensor([[1, 2], [3, 4]]);
torch.trace(A); // 5 (1 + 4)
const I = torch.eye(3);
torch.trace(I); // 3See Also
- PyTorch torch.trace()
- diagonal - Extract diagonal elements