torch.nn.functional.scaled_mm
function scaled_mm(input: Tensor, mat2: Tensor, options?: ScaledMMFunctionalOptions): TensorPerforms scaled matrix multiplication with optional quantization and mixed-precision support.
Computes the matrix multiplication of two tensors with optional per-tensor scaling and bias addition. This operation is fundamental for:
- Quantized inference: Applying scales that were learned during quantization-aware training
- Mixed-precision computation: Operating on int8 or bfloat16 inputs with float32 scales for accuracy
- Inference optimization: Reducing memory bandwidth and computation while maintaining accuracy
- Fine-grained scaling control: Independent scales for inputs, weights, and results
- Neural network inference: Efficient deployment of quantized models (INT8-FP32 operations)
- Transformer acceleration: Common in optimized attention and feedforward layers
The operation computes: output = ((input * scale_a) @ (mat2 * scale_b) + bias) * scale_result, where
scales can be either scalar values or per-channel/per-tensor tensors. This flexibility enables both
uniform and fine-grained quantization schemes.
- Quantization context: This operation is specifically designed for quantized models where inputs and weights are stored in low-precision (int8, bfloat16) but scales are kept in float32 for accuracy.
- Scale multiplication order: Scales are applied multiplicatively:
result = ((input * scale_a) @ (mat2 * scale_b) + bias) * scale_result. The order matters: scale_a dequantizes input, scale_b dequantizes weight, scale_result is final adjustment. - Broadcasting with scales: If scale_a or scale_b are tensors, they must be broadcastable with their respective matrices. For neural networks, scale_b is often shape
[output_dim]for per-channel quantization. - Bias dtype matters: For numerical stability, bias should ideally be float32 even if inputs are int8. The accumulation happens in higher precision internally, so providing a float bias ensures high precision.
- Output dtype specification: When out_dtype is specified, the result is cast after all operations. This is essential for re-quantization pipelines where you want int8 output from float32 intermediate.
- Inference optimization: This operation enables efficient INT8-FP32 inference which is 4x more memory-efficient than float32 inference while maintaining nearly identical accuracy when combined with proper quantization-aware training.
- Gradient computation: For training scenarios (QAT), gradients flow through scales but not typically through quantized values directly. The scales are what get optimized during quantization-aware training.
- Precision with int8: Intermediate accumulation may overflow if inputs have large values or scales are large. Ensure scales are properly calibrated during quantization to avoid this. Typical scales are in range [0.001, 0.1].
- Shape validation: Input must be 2D (M, K) and mat2 must be 2D (K, N). Broadcasting is only applied to scale tensors and bias, not to the main matrices. Both matrices must have matching inner dimension K.
- scale_result for requantization: If you're re-quantizing output, scale_result should typically be = 1.0. Using scale_result 1.0 can cause overflow when converting to low-precision output dtypes.
- Null vs undefined scales: Pass null (not undefined) for scales you want to skip. Undefined is treated as "not provided" while null means "explicitly no scaling", both defaulting to 1.0 but with different semantics.
Parameters
inputTensor- The input matrix of shape
(M, K). Typically an int8 or bfloat16 quantized tensor mat2Tensor- The weight matrix of shape
(K, N). Typically quantized optionsScaledMMFunctionalOptionsoptional
Returns
Tensor– Tensor of shape (M, N) containing the scaled matrix multiplication resultExamples
// Basic quantized inference: int8 inputs with learned scales
const input_quantized = torch.tensor([[10, 20], [30, 40]], { dtype: 'int8' });
const weight_quantized = torch.tensor([[5, 10], [15, 20]], { dtype: 'int8' });
const scale_a = 0.1; // Input quantization scale (0-255 -> 0-25.5)
const scale_b = 0.05; // Weight quantization scale
const result = torch.nn.functional.scaled_mm(input_quantized, weight_quantized, scale_a, scale_b);
// Result approximates float32 matmul with minimal memory// Mixed-precision inference with bias in a neural network layer
const batch_size = 32, input_dim = 768, output_dim = 3072;
const x_int8 = torch.randint(-128, 127, [batch_size, input_dim], { dtype: 'int8' });
const w_int8 = torch.randint(-128, 127, [input_dim, output_dim], { dtype: 'int8' });
const bias = torch.randn([output_dim]);
// These scales come from QAT (Quantization-Aware Training)
const x_scale = 0.008; // For INT8, typically 255/32
const w_scale = 0.01;
const output = torch.nn.functional.scaled_mm(x_int8, w_int8, x_scale, w_scale, bias);
// [32, 3072] - approximates full-precision inference with 4x memory savings// Fine-grained per-channel scaling (per-output-channel quantization)
const input = torch.tensor([[1, 2, 3], [4, 5, 6]]);
const weight = torch.tensor([[2, 3], [4, 5], [6, 7]]);
const w_scale_per_channel = torch.tensor([0.1, 0.15]); // Different scale per output channel
const result = torch.nn.functional.scaled_mm(input, weight, 1.0, w_scale_per_channel);
// Each output column has its own scaling factor// Complete quantized inference pipeline with all components
const x_q8 = torch.randint(-128, 128, [64, 512]);
const w_q8 = torch.randint(-128, 128, [512, 256]);
const bias = torch.randn([256]);
// Quantization parameters learned during training
const scale_x = 0.008;
const scale_w = 0.01;
const scale_output = 1.0; // Can be != 1.0 for symmetric requantization
const y = torch.nn.functional.scaled_mm(x_q8, w_q8, scale_x, scale_w, bias, scale_output, 'float32');
// Even if inputs are int8, result is guaranteed float32 precision// Symmetric quantization with post-dequantization scaling
const input_sym = torch.tensor([[-100, -50], [50, 100]]);
const weight_sym = torch.tensor([[-80, 80], [100, -100]]);
// Symmetric scales for -127 to 127 range
const dequant_x = torch.tensor([1.0 / 127]); // Scale to [-1, 1]
const dequant_w = torch.tensor([1.0 / 127]);
const result = torch.nn.functional.scaled_mm(input_sym, weight_sym, dequant_x, dequant_w);
// Result in typical float32 range, not -128..127// Inference with output re-quantization to int8
const x = torch.tensor([[1, 2], [3, 4]], { dtype: 'int8' });
const w = torch.tensor([[5, 6], [7, 8]], { dtype: 'int8' });
const scale_x = 0.02;
const scale_w = 0.025;
const scale_out = 1.0 / 127; // Re-quantize back to [-1, 1] range
// Chain multiple quantized operations without loss of precision
const intermediate = torch.nn.functional.scaled_mm(x, w, scale_x, scale_w, null, scale_out, 'int8');
// [2, 2] int8 tensor - ready for next layer without float32 overheadSee Also
- [PyTorch torch._scaled_mm() - Corresponding PyTorch internal operation](https://pytorch.org/docs/stable/generated/torch._scaled_mm() - Corresponding PyTorch internal operation.html)
- scaled_grouped_mm - Batched version for multi-sample scaling
- grouped_mm - General grouped matrix multiplication without per-tensor scaling
- Tensor.matmul - Standard matrix multiplication without quantization
- Tensor.to - For dtype conversion if you need to manually control output type