torch.nn.functional.rms_norm
function rms_norm(input: Tensor, normalized_shape: number[], options?: RmsNormFunctionalOptions): TensorRMS Normalization: normalizes by root mean square of activations (LLaMA, modern LLMs).
Applies RMSNorm element-wise: RMSNorm(x) = x / sqrt(mean(x²) + ε) * γ, where γ is weight. A simplified version of layer normalization that normalizes only by the RMS (root mean square) without subtracting the mean. Computationally cheaper than LayerNorm while maintaining similar effectiveness. The normalization factor is the RMS of the input, computed over specified dimensions. Essential for:
- Modern large language models (LLaMA, Alpaca, Falcon use RMSNorm)
- Efficient inference (simpler than LayerNorm, fewer operations)
- Fine-tuned transformer models (increasingly replacing LayerNorm)
- Distributed training (better numerical properties than LayerNorm)
- Memory-efficient models (simpler computation reduces memory overhead)
- Edge deployment where compute is limited
Comparison with LayerNorm:
- Mean subtraction: LayerNorm subtracts mean; RMSNorm skips this step
- Computation: RMSNorm faster (no mean calculation, fewer operations)
- Effectiveness: RMSNorm comparable to LayerNorm in practice (empirically proven)
- Memory: RMSNorm uses less memory (simpler computation graph)
- Modern trend: RMSNorm increasingly standard in new LLMs (LLaMA, Falcon)
When to use RMSNorm:
- Large language models (LLaMA-style architectures)
- When memory efficiency matters
- As LayerNorm replacement (drop-in substitution)
- Fine-tuned transformers where RMSNorm is already used
- Edge deployment or inference where compute is limited
- Training at scale where efficiency gains matter
Why RMSNorm works well: The key insight is that the mean contains redundant information when combined with the scaling parameter γ. By learning γ, the network can effectively skip mean centering without loss of expressiveness. In practice, RMSNorm achieves comparable or better perplexity than LayerNorm while being more efficient.
- Efficient: Simpler than LayerNorm (no mean subtraction)
- Modern standard: Default in LLaMA, Falcon, and recent LLMs
- Learnable weight: γ parameter trained like LayerNorm
- Empirically proven: RMSNorm ≈ LayerNorm in practice but faster
- No mean parameter: Unlike LayerNorm, no β parameter (mean centering skipped)
- Per-feature normalization: Like LayerNorm, normalized independently per sample
- RMS computation: Uses root-mean-square instead of z-score
- No bias parameter: Only weight, no additive bias (unlike LayerNorm)
- Different from LayerNorm: Not mean-centered; empirical equivalence in practice
- Pre-norm architecture: Best used with pre-norm residual connections (like LLaMA)
- Not a drop-in replacement: RMSNorm weight only; LayerNorm has weight + bias
Parameters
inputTensor- Input tensor of any shape, typically [batch, ..., feature_dims...]
normalized_shapenumber[]- Shape of the features to normalize over (last N dimensions). Example: input=[batch, seq, hidden] normalized_shape=[hidden] → normalize over hidden
optionsRmsNormFunctionalOptionsoptional
Returns
Tensor– Normalized tensor with same shape as input. Normalized by RMS over specified dimensions.Examples
// LLaMA-style normalization: RMSNorm over hidden dimension
const batch_size = 32, seq_len = 128, hidden_dim = 4096;
const x = torch.randn([batch_size, seq_len, hidden_dim]);
const weight = torch.ones([hidden_dim]); // Learnable scale
const normalized = torch.nn.functional.rms_norm(x, [hidden_dim], weight);
// Normalizes over last dimension independently for each (batch, seq) position// Transformer block using RMSNorm (LLaMA pattern)
class LLaMATransformerBlock extends torch.nn.Module {
private ln1: torch.nn.Parameter; // RMSNorm weight
private attention: torch.nn.MultiheadAttention;
private ln2: torch.nn.Parameter; // RMSNorm weight
private mlp: torch.nn.Sequential;
forward(x: torch.Tensor): torch.Tensor {
// Pre-norm pattern with RMSNorm
const x_norm = torch.nn.functional.rms_norm(x, [x.shape[-1]], this.ln1);
x = x + this.attention.forward(x_norm); // Residual
const x_norm2 = torch.nn.functional.rms_norm(x, [x.shape[-1]], this.ln2);
x = x + this.mlp.forward(x_norm2); // Residual
return x;
}
}// Comparison: RMSNorm vs LayerNorm
const x = torch.randn([32, 768]);
const weight = torch.ones([768]);
const rms_out = torch.nn.functional.rms_norm(x, [768], weight); // Simpler, faster
const layer_out = torch.nn.functional.layer_norm(x, [768], weight); // More complex
// Both produce normalized outputs; RMSNorm is faster in practice
// Modern models prefer RMSNorm for efficiency// Fine-tuning LLaMA: RMSNorm already trained, optimize with LoRA
const model = load_llama_model(); // Has RMSNorm already
const lora_layer = new LoRA(768, 8); // Adapter
const hidden = model.forward(x);
// RMSNorm already applied in model; efficient normalization throughoutSee Also
- [PyTorch torch.nn.functional.rms_norm (PyTorch 2.4+)](https://pytorch.org/docs/stable/generated/torch.nn.functional.rms_norm .html)
- torch.nn.functional.layer_norm - More complex normalization with mean + variance
- torch.nn.functional.group_norm - Groups channels for normalization
- torch.nn.functional.instance_norm - Per-instance channel normalization
- torch.nn.LayerNorm - Module version of LayerNorm (not RMSNorm currently)