torch.Tensor.Tensor.baddbmm
Tensor.baddbmm(batch1: Tensor, batch2: Tensor, options?: AlphaBetaOptions): Tensor<DynamicShape, D, Dev>Tensor.baddbmm(batch1: Tensor, batch2: Tensor, beta: number, alpha: number, options?: AlphaBetaOptions): Tensor<DynamicShape, D, Dev>Batched matrix-matrix multiplication with addition (BLAS operation).
Applies matrix multiplication to each batch independently, then adds to self with scaling. Used for batch processing in deep learning, especially in multi-head attention and group convolutions.
Formula: self = beta * self + alpha * (batch1 @ batch2), applied per batch
Use Cases:
- Multi-head attention computations
- Batched transformations in neural networks
- Fused batch matrix operations
- Efficient batched linear algebra
- 3D tensor manipulations in CNNs/RNNs
- Batch independence: Each batch is processed independently
- Shape consistency: All batches must have compatible inner dimensions
- Efficiency: GPU-optimized for batched operations
- Inplace version: Use
baddbmm_()to modify self in-place
- Dimension requirements: All inputs must be 3D tensors
- Batch alignment: First dimension (batch) must match across all tensors
- Shape mismatch: batch1.[...K] must equal batch2.[K...]
Parameters
batch1Tensor- Batch of left matrices (shape: [B, M, K])
batch2Tensor- Batch of right matrices (shape: [B, K, N])
optionsAlphaBetaOptionsoptional
Returns
Tensor<DynamicShape, D, Dev>– New tensor with shape [B, M, N]Examples
// Basic batched operation
const C = torch.randn(4, 2, 3); // 4 batches of 2x3 matrices
const A = torch.randn(4, 2, 5); // 4 batches of 2x5 matrices
const B = torch.randn(4, 5, 3); // 4 batches of 5x3 matrices
const result = C.baddbmm(A, B); // C + (A @ B) per batch
// Multi-head attention
const batch_size = 32, num_heads = 8, seq_len = 10;
const Q = torch.randn(batch_size, num_heads, seq_len, 64); // Queries
const K = torch.randn(batch_size, num_heads, seq_len, 64); // Keys
const V = torch.randn(batch_size, num_heads, seq_len, 64); // Values
// Reshape for bmm: [B*heads, seq, dim]
const attention_weights = torch.softmax(Q.addmm(K.transpose(-1, -2)), -1);
// Scaled batched update
const base = torch.zeros(16, 10, 10);
const updates_a = torch.randn(16, 10, 5);
const updates_b = torch.randn(16, 5, 10);
const result2 = base.baddbmm(updates_a, updates_b, 0.5, 0.5); // 0.5*base + 0.5*(A @ B)
// Fused batch linear layer
const input_batch = torch.randn(32, 64, 100); // [batch, features, length]
const weight_batch = torch.randn(32, 128, 100); // Batched weights
const bias_batch = torch.zeros(32, 128, 64);
const output = bias_batch.baddbmm(weight_batch, input_batch);See Also
- PyTorch torch.baddbmm()
- baddbmm_ - In-place version modifying self
- bmm - Batched matrix multiplication without addition
- addmm - Non-batched matrix-matrix version
- matmul - General matrix multiplication (handles batching)