torch.linalg.multi_dot
Efficiently multiplies two or more matrices.
Chains multiple matrix products (A @ B @ C @ ...) in a single operation. While mathematically the order doesn't matter (matrix multiplication is associative), the computational order significantly affects performance and memory usage. Evaluates left-to-right sequentially. Essential for:
- Neural networks: Composing multiple linear layers efficiently
- Matrix chain multiplication: Optimizing computational complexity
- Coordinate transformations: Composing multiple transformation matrices
- Batch operations: Multiple matrix products across batches
- Attention mechanisms: Computing attention as product of query, keys, values
Note: Current implementation uses left-to-right evaluation. For optimal performance with very long chains, consider grouping matrices manually. PyTorch's multi_dot uses dynamic programming to find optimal parenthesization (matrix chain multiplication problem).
- Left-to-right evaluation: Current implementation multiplies left to right without optimization
- Dimension compatibility: Adjacent matrices must have compatible dimensions (m n requires m's last dim = n's second-to-last)
- Batching: Batches must be broadcastable across all matrices
- Performance: For very long chains with diverse shapes, manual grouping may help
- Gradient support: Fully differentiable with respect to all input matrices
- Single tensor: Passing a single matrix returns a clone, not the original
- Empty array: Will throw error if tensors array is empty
- Dimension matching: Will error if matrices have incompatible dimensions
Parameters
tensorsTensor[]- Array of matrices. First must be 2D or batched, others must be compatible for multiplication
Returns
Tensor– Product of all matrices in sequence: tensors[0] tensors[1] ... tensors[n-1]Examples
// Simple chain of 3 matrices
const A = torch.randn(4, 3);
const B = torch.randn(3, 5);
const C = torch.randn(5, 2);
const result = torch.linalg.multi_dot([A, B, C]); // Shape [4, 2]
// Equivalent to: A.matmul(B).matmul(C)
// Neural network: composing linear layers
const x = torch.randn(32, 128); // Batch of 32 vectors
const W1 = torch.randn(128, 256);
const W2 = torch.randn(256, 128);
const W3 = torch.randn(128, 64);
const output = torch.linalg.multi_dot([x, W1, W2, W3]); // Shape [32, 64]
// Coordinate transformations: composing multiple rotations/scales
const scale = torch.tensor([[2, 0], [0, 2]]); // 2x scale
const rotate = torch.tensor([[0, -1], [1, 0]]); // 90° rotation
const translate = torch.tensor([[1, 1], [0, 1]]); // Translation
const combined = torch.linalg.multi_dot([scale, rotate, translate]);
const point = torch.tensor([1, 1]);
const transformed = combined.mv(point); // Apply combined transformation
// Quadratic form: x^T A B x (useful in optimization)
const A = torch.randn(5, 5);
const B = torch.randn(5, 5);
const x = torch.randn(5);
const quad_form = torch.linalg.multi_dot([x, A, B, x]); // Returns scalar
// Batched multi-matrix multiplication
const batch = torch.randn(32, 4, 5);
const W1 = torch.randn(32, 5, 6);
const W2 = torch.randn(32, 6, 3);
const result = torch.linalg.multi_dot([batch, W1, W2]); // Broadcast across batchSee Also
- PyTorch torch.linalg.multi_dot()
- matmul - Binary matrix multiplication
- bmm - Batched matrix multiplication
- mm - 2D matrix multiplication
- dot - Vector dot product