torch.linalg.tensorsolve
function tensorsolve<S extends Shape, D extends DType, Dev extends DeviceType>(A: Tensor<S, D, Dev>, B: Tensor<Shape, D, Dev>, options?: TensorsolveOptions): Tensor<DynamicShape, D, Dev>function tensorsolve<S extends Shape, D extends DType, Dev extends DeviceType>(A: Tensor<S, D, Dev>, B: Tensor<Shape, D, Dev>, dims: number[], options?: TensorsolveOptions): Tensor<DynamicShape, D, Dev>Solves a multi-linear system using tensor inversion.
Computes the solution X to the tensor equation tensordot(A, X, dims) = B.
This generalizes matrix solving (torch.linalg.solve) to higher-dimensional tensors,
allowing solution of multi-linear systems and tensor equations. Essential for:
- Solving multi-linear equations and systems
- Inverting bilinear and multi-linear forms
- Tensor decomposition methods requiring system solve
- Multi-way array analysis (HOSVD, Tucker decomposition)
- Higher-order generalizations of linear least-squares
Core idea: Reshapes both A and B into compatible matrices, solves the resulting linear system using matrix inversion techniques, then reshapes the result back to tensor form. The dims parameter specifies which dimensions to contract in the tensordot operation.
How it works:
- Determine dimensions to contract based on dims parameter
- Reshape both tensors into compatible matrices
- Solve the resulting linear system: A_mat @ X_mat = B_mat
- Reshape X back to tensor form matching original dimensions
- Return solution tensor
Requirements:
- Tensors must be compatible for tensordot with specified dimensions
- Resulting matrix must be invertible (full rank, non-singular)
- Product of contracted dimensions in A must match product of corresponding dimensions in B
- Default dimensions: If dims not provided, auto-detects based on shapes
- Shape relationships: Product of contracted dims in A must equal product in B
- Invertibility: A must represent an invertible multi-linear map
- Output shape: Depends on non-contracted dimensions of A
- Numerical stability: Subject to same conditioning issues as matrix solve
- Computational cost: O((prod_contracted)³) where prod_contracted = product of contracted dimensions
- Related to tensorinv: Uses similar reshape-solve-reshape strategy as tensorinv
- Dimension compatibility: Will throw if dimension products don't match
- Singular operator: Will throw if multi-linear operator A is singular
- Default dims behavior: Automatic dimension detection may not match your intent; specify explicitly if unsure
- Numerical precision: Ill-conditioned operators produce inaccurate solutions
Parameters
ATensor<S, D, Dev>- The multi-linear operator tensor (LHS). Shape depends on dims parameter. Typically higher-dimensional tensor (3D, 4D, or more)
optionsTensorsolveOptionsoptional
Returns
Tensor<DynamicShape, D, Dev>– The solution tensor X such that tensordot(A, X, dims) ≈ B. Shape depends on input shapes and contraction dimensionsExamples
// Simple 3D tensor equation
const A = torch.randn([3, 4, 3, 4]); // Bilinear form: 3x4 input, 3x4 output
const B = torch.randn([3, 4]); // Target: 3x4
// Solve: tensordot(A, X) = B
const X = torch.linalg.tensorsolve(A, B);
console.log(X.shape); // Should be compatible shape
// Verify (approximately): tensordot(A, X, dims) ≈ B
// Due to numerical precision, won't be exactly equal// Solving with explicit dimension specification
const A = torch.randn([5, 6, 5, 6]); // 5x6 bilinear operator
const B = torch.randn([5, 6]); // 5x6 target
// Specify which dimensions to contract
// dims = [[2, 3], [0, 1]] means: contract dims 2,3 of A with dims 0,1 of B
const X = torch.linalg.tensorsolve(A, B, [[2, 3], [0, 1]]);
console.log(X.shape); // Result shape// Multi-way array solving (Tucker-like structures)
const A = torch.randn([2, 3, 4, 2, 3, 4]); // Larger multi-linear operator
const B = torch.randn([2, 3, 4]); // Right-hand side
const X = torch.linalg.tensorsolve(A, B);
// Reshapes to find X such that contracted multiplication equals B// Solving multiple right-hand sides (batch solving)
// Stack multiple B tensors and solve together
const A = torch.randn([3, 4, 3, 4]);
const B1 = torch.randn([3, 4]);
const B2 = torch.randn([3, 4]);
// Solve B1 and B2 separately (if no batch support)
const X1 = torch.linalg.tensorsolve(A, B1);
const X2 = torch.linalg.tensorsolve(A, B2);See Also
- PyTorch torch.linalg.tensorsolve
- torch.linalg.solve - Standard matrix solve (2D case)
- torch.linalg.tensorinv - Compute tensor inverse directly
- torch.tensordot - Tensor contraction operation (defines the equation)
- torch.linalg.lstsq - Least-squares solution (overdetermined systems)