torch.set_float32_matmul_precision
function set_float32_matmul_precision(precision: 'highest' | 'high' | 'medium'): voidSet the precision level for float32 matrix multiplications.
Controls accuracy vs. speed tradeoff for matrix multiplication operations using float32. Higher precision uses more computation and memory, while lower precision is faster. Important for:
- Training stability (higher = more stable)
- Inference speed (lower = faster)
- Numerical accuracy requirements
Precision Levels:
- 'highest': Maximum accuracy using high-precision intermediate computations (default)
- Slowest but most accurate
- Use for training models, especially large models where precision matters
- May require more memory for intermediate results
- 'high': Balanced accuracy and speed
- Typical choice for most applications
- Good for production inference with acceptable tolerance
- 'medium': Faster but potentially less accurate
- Useful for real-time applications with tight latency budgets
- May lose precision in edge cases or large matrices
- Best for mobile or low-power inference
Platform Notes:
- CPU: Fully respects precision setting in computations
- WebGPU: This is informational; GPU hardware may have fixed float32 precision (most GPUs use consistent precision regardless of setting)
- Default: 'highest' for maximum compatibility
- WebGPU: Often ignored as GPU uses native precision
- Per-session: Changes apply globally to all subsequent operations
- Can be changed: Can be adjusted at any time during execution
- WebGPU limitation: GPU precision may not follow setting exactly
- Accuracy: 'medium' may produce different results than 'highest'
- Large matrices: Precision differences more noticeable with large matrices
Parameters
precision'highest' | 'high' | 'medium'- One of 'highest' (default), 'high', or 'medium'
Examples
// Maximum accuracy for training
torch.set_float32_matmul_precision('highest');
const loss = model.forward(x).sum();
loss.backward(); // Stable gradients with high precision
// Balanced speed/accuracy for inference
torch.set_float32_matmul_precision('high');
const predictions = model.forward(test_data);
// Fast inference on edge device
torch.set_float32_matmul_precision('medium');
const fast_predictions = model.forward(mobile_input);
// Check current setting
const current = torch.get_float32_matmul_precision();
console.log(`Using ${current} precision for matmul`);See Also
- PyTorch torch.set_float32_matmul_precision()
- get_float32_matmul_precision - Query current precision setting