torch.linalg.matrix_exp
function matrix_exp<S extends Shape, D extends DType, Dev extends DeviceType>(A: Tensor<S, D, Dev>): Tensor<S, D, Dev>Computes the matrix exponential e^A using scaling and squaring with Padé approximation.
Calculates the exponential of a square matrix, defined by the matrix power series: e^A = I + A + A²/2! + A³/3! + ... = ∑(A^n/n!). Computed efficiently using scaling-and-squaring and rational Padé approximation for numerical stability. Essential for:
- Solving differential equations: Computing solution to dX/dt = AX as X(t) = e^(At)X(0)
- Rotation representations: Computing rotation matrices from angular velocity matrices
- Continuous control: Discretizing continuous-time systems to discrete-time
- Lie groups: Computing group exponentials for matrix Lie groups
- Quantum mechanics: Time evolution of quantum systems under Hamiltonian evolution
Uses scaling and squaring: e^A = (e^(A/2^k))^(2^k), where e^(A/2^k) is approximated with Padé rational approximation, numerically stable even for matrices with large norms.
- Scaling and squaring: Avoids computing the series directly by scaling down A, using Padé approximation, then squaring up. More numerically stable.
- Convergence: Converges for all square matrices, including those with negative eigenvalues.
- Eigenvalues: If A has eigenvalues λ, then e^A has eigenvalues e^λ. If A diagonalizes as PDP⁻¹, then e^A = Pe^D P⁻¹.
- Commutativity: e^A e^B ≠ e^(A+B) in general (unless [A,B]=0). Matrix exponential is not multiplicative except for commuting matrices.
- Trace: trace(e^A) = e^(trace(A)) (always true). Sum of exponentials of eigenvalues equals exponential of trace.
- Numerical stability: Uses Padé approximation which has good stability, but very large matrix norms can still cause issues. Keep ||A|| reasonable if possible.
- Square matrices only: Requires input to be square (n, n). Throws error for non-square matrices.
- Computational cost: O(n³) per matrix exponential due to matrix multiplications. For large n, this can be slow. Batch processing multiple matrices together is more efficient.
Parameters
ATensor<S, D, Dev>- Square matrix of shape (n, n) or batches of shape (..., n, n)
Returns
Tensor<S, D, Dev>– Matrix exponential e^A with same shape as inputExamples
// Rotation matrix: e^(skew-symmetric) gives rotation
const omega = torch.tensor([[0, -1], [1, 0]]); // 90° rotation generator
const rot90 = torch.linalg.matrix_exp(omega);
// [[0, -1], [1, 0]] - 90° counterclockwise rotation// Solve dX/dt = AX with X(0) = I for t=1
const A = torch.tensor([[0, 1], [-1, 0]]); // Oscillator
const X = torch.linalg.matrix_exp(A); // Solution at t=1
// X evolves the oscillator by 1 unit of time// Small matrices: e^A ≈ I + A
const small = torch.tensor([[0.01, 0.02], [0.03, 0.01]]) .mul(0.01);
const exp_A = torch.linalg.matrix_exp(small);
const approx = torch.eye(2).add(small); // I + A approximation
// exp_A ≈ approx for small matrices// Larger matrix: matrix exponential handles large norms
const large = torch.tensor([[10, -3], [5, -2]]);
const exp_large = torch.linalg.matrix_exp(large);
// Numerically stable despite large matrix norm// Batch computation
const batch = torch.randn([32, 4, 4]); // 32 random 4x4 matrices
const batch_exp = torch.linalg.matrix_exp(batch); // 32 matrix exponentials// Diagonal matrices: e^diag(a,b,...) = diag(e^a, e^b, ...)
const diag = torch.tensor([[1, 0], [0, 2]], { dtype: 'float32' });
const exp_diag = torch.linalg.matrix_exp(diag);
// [[e, 0], [0, e^2]] ≈ [[2.718, 0], [0, 7.389]]See Also
- PyTorch torch.linalg.matrix_exp()
- matrix_power - Integer powers A^n more efficiently
- Tensor.matrix_exp - Tensor method form
- linalg.svd - Singular value decomposition (alternative for understanding structure)