torch.chain_matmul
function chain_matmul<D extends DType = DType, Dev extends DeviceType = DeviceType>(...matrices: Tensor<Shape, D, Dev>[]): Tensor<Shape, D, Dev>Efficiently multiplies a sequence of matrices, optimizing for minimum computation cost.
Performs left-to-right matrix multiplication (A @ B @ C @ ...) with automatic reordering to minimize total floating-point operations. Matrix chain multiplication can produce dramatically different costs depending on parenthesization: (A @ B) @ C vs A @ (B @ C). For example, multiplying [10×100] × [100×10] × [10×1000] can use 102,000 FLOPs or 1,001,000 FLOPs depending on order. This function finds the optimal multiplication order without changing numerical result. Essential for:
- Neural network layers: Efficiently combining weight transformations
- Geometric transformations: Composing rotation, scaling, translation matrices
- Graph algorithms: Matrix powers and path computations (A² = A @ A @ A)
- Physical simulations: State transitions and force accumulation
- Tensor networks: Contracting sequences of tensors (when all are 2D)
- Optimization algorithms: Computing product moments and accumulated gradients
Uses dynamic programming to find the optimal parenthesization that minimizes total scalar multiplications needed. For n matrices, this is O(n³) preprocessing cost with potentially huge speedup in actual multiplication (sometimes 10x-100x faster).
- Optimal ordering: Uses dynamic programming to find parenthesization minimizing scalar multiplications (the most expensive operation in matrix multiplication)
- Left-to-right evaluation: Despite optimization, numerically equivalent to left-to-right evaluation due to matrix multiplication associativity
- 2D only: All matrices must be exactly 2D (batched multiplication not supported)
- Numerical stability: Different parenthesization does NOT change numerical result (matrix multiplication is associative)
- Dimension requirements: Each consecutive pair must satisfy m[i].shape[1] === m[i+1].shape[0]
- At least 2 matrices: chain_matmul requires 2+ matrices (use for single matrix or matmul)
- 2D matrices only: Does not support batched matrices (use matmul for [..., m, k] [..., k, n])
- Not commutative: Matrix multiplication is not commutative, so order matters
- Preprocessing overhead: For small chains (5 matrices), overhead of optimization may exceed benefit - use simpler chain_matmul(a, b, c) instead of complex nested structures
Parameters
Returns
Tensor<Shape, D, Dev>– Single tensor with the result of optimized chain multiplication. Shape is [matrices[0].shape[0], matrices[n-1].shape[1]]Examples
// Simple chain of 3 matrices
const a = torch.randn(2, 3); // [2, 3]
const b = torch.randn(3, 4); // [3, 4]
const c = torch.randn(4, 5); // [4, 5]
const result = torch.chain_matmul(a, b, c); // [2, 5]
// Equivalent to: (a @ b) @ c or a @ (b @ c) but optimized// Composing geometric transformations
// Rotation, then scale, then translation
const rotation = torch.tensor([
[Math.cos(Math.PI/4), -Math.sin(Math.PI/4)],
[Math.sin(Math.PI/4), Math.cos(Math.PI/4)]
]); // [2, 2] - rotate 45°
const scale = torch.tensor([[2, 0], [0, 3]]); // [2, 2] - scale x by 2, y by 3
const points = torch.randn(2, 1000); // [2, 1000] - 1000 2D points
// Apply all transformations optimally
const transformed = torch.chain_matmul(rotation, scale, points); // [2, 1000]// Matrix power computation: compute A @ A @ A efficiently
const a = torch.randn(100, 100); // [100, 100]
const a3 = torch.chain_matmul(a, a, a); // A³ efficiently
// Better than: a @ a @ a (sequential left-to-right)// Neural network: chain of weight matrices
// Transform input through multiple linear layers
const w1 = torch.randn(512, 256); // Input [512] -> hidden [256]
const w2 = torch.randn(256, 128); // Hidden [256] -> hidden2 [128]
const w3 = torch.randn(128, 10); // Hidden2 [128] -> output [10]
const input = torch.randn(512, 32); // [512, 32] batch of inputs
// Optimal order chosen automatically
const output = torch.chain_matmul(w1, w2, w3, input); // [10, 32]// Demonstrating cost difference (when to use chain_matmul)
// Matrices: `[10×100]`, `[100×10]`, `[10×1000]`
// Naive left-to-right: `(M₁ @ M₂) @ M₃`
// = `10·100·10 + 10·10·1000 = 10,000 + 100,000 = 110,000` ops
// Optimal: `M₁ @ (M₂ @ M₃)`
// = `100·10·1000 + 10·100·1000 = 1,000,000 + 1,000,000 = 2,000,000` ops
// chain_matmul automatically finds the optimal order
const m1 = torch.randn(10, 100);
const m2 = torch.randn(100, 10);
const m3 = torch.randn(10, 1000);
const result = torch.chain_matmul(m1, m2, m3); // Automatically optimizedSee Also
- PyTorch torch.chain_matmul()
- matmul - For flexible multi-dimensional multiplication with batching
- mm - For basic 2D 2D multiplication without optimization
- bmm - For batched 3D 3D multiplication