torch.baddbmm
function baddbmm<D extends DType = DType, Dev extends DeviceType = DeviceType>(input: Tensor<Shape, D, Dev>, batch1: Tensor<Shape, D, Dev>, batch2: Tensor<Shape, D, Dev>, options?: BaddbmmOptions<any>): Tensor<Shape, D, Dev>Performs fused batched matrix multiplication and addition: betainput + alpha(batch1 @ batch2).
Efficiently combines batched matrix multiplication with batched addition in a single operation.
Computes out[b] = beta * input[b] + alpha * (batch1[b] @ batch2[b]) for each batch element.
More efficient than separate bmm() and add() operations due to fused GPU kernels.
Essential for:
- Neural networks: Batch linear layers with bias (attention layers, etc.)
- Batch processing: Parallel matrix operations with accumulation
- Iterative algorithms: Accumulating batch matrix products
- Batch transformations: Applying transformations to batches of matrices
- Physical simulations: Batch state updates and force accumulation
All inputs must have matching batch dimensions. Unlike matmul, no broadcasting.
Batch dimension must match exactly: input.shape[0] == batch1.shape[0] == batch2.shape[0].
- Batch dimension: input.shape[0] must equal batch1.shape[0] and batch2.shape[0]
- Output shape: Always [B, M, N] (3D batched matrix)
- Efficiency: Fused is faster than separate bmm() + add()
- No broadcasting: Unlike matmul, batch dimensions must match exactly
- GPU optimized: Parallel computation across all batch elements
- Batch mismatch error: Will error if batch dimensions don't match
- Inner dimensions: batch1.shape[2] must equal batch2.shape[1]
- 3D only: All inputs must be exactly 3D (no batched batches)
Parameters
optionsBaddbmmOptions<any>optional- Optional parameters: -
beta: Scaling factor for input (default: 1) -alpha: Scaling factor for batch matrix products (default: 1) -out: Pre-allocated output tensor
Returns
Examples
// Batch linear layer with bias
const batch_size = 32;
const batch_input = torch.randn(batch_size, 10, 5); // [32, 10, 5]
const batch_weight = torch.randn(batch_size, 8, 5); // [32, 8, 5]
const batch_bias = torch.randn(batch_size, 10, 8); // [32, 10, 8]
const output = torch.baddbmm(batch_bias, batch_weight, batch_input.transpose(-2, -1));
// output[b] = bias[b] + weight[b] @ input[b].T
// Batch covariance update in iterative algorithm
let cov = torch.eye(10).unsqueeze(0).expand(32, 10, 10); // [32, 10, 10]
const X = torch.randn(32, 10, 100); // Batch of data [32, 10, 100]
const XXT = torch.bmm(X, X.transpose(-2, -1)); // [32, 10, 10]
cov = torch.baddbmm(cov, XXT, cov, {alpha: 0.01, beta: 0.99});
// Exponential moving average: cov = 0.99*cov + 0.01*XXT
// Multiple batch accumulations
let result = torch.zeros(16, 4, 4); // Accumulator [16, 4, 4]
const A = torch.randn(16, 4, 3);
const B = torch.randn(16, 3, 4);
const C = torch.randn(16, 4, 3);
const D = torch.randn(16, 3, 4);
result = torch.baddbmm(result, A, B, {alpha: 1, beta: 1}); // result += A @ B
result = torch.baddbmm(result, C, D, {alpha: 1, beta: 1}); // result += C @ DSee Also
- PyTorch torch.baddbmm()
- baddbmm_ - In-place version
- bmm - Batched matrix multiplication without addition
- addmm - Matrix-matrix version
- matmul - General tensor multiplication