torch.nn.RMSNorm
new RMSNorm(normalized_shape: number | number[], options?: RMSNormOptions)
- readonly
normalized_shape(number[]) - readonly
eps(number) - readonly
elementwise_affine(boolean) weight(Parameter | null)
Root Mean Square Layer Normalization: simplified LayerNorm without mean subtraction.
Applies RMS normalization which normalizes by RMS (root mean square) without subtracting the mean. More compute-efficient than LayerNorm with similar effectiveness. Essential for:
- Large language models (LLaMA, Falcon, others prefer RMSNorm over LayerNorm)
- Compute-efficient training (fewer operations than LayerNorm)
- Attention-based models (Transformers)
- When mean subtraction is not critical
- Modern LLMs requiring reduced computation
Why RMSNorm Over LayerNorm: LayerNorm normalizes mean and variance; RMSNorm only normalizes magnitude using RMS. This is simpler (no mean subtraction) and slightly faster. Empirically, RMSNorm performs similarly to LayerNorm on most tasks. Many modern LLMs (LLaMA 2, Falcon, etc.) use RMSNorm instead of LayerNorm for efficiency.
When to use RMSNorm:
- Building modern language models (LLMs)
- When computation efficiency is critical
- Transformer models (especially large ones)
- Performance-sensitive applications
- When you want equivalent to LayerNorm with less compute
- Pre-normalization in deep architectures
Trade-offs:
- vs LayerNorm: No mean subtraction (simpler, faster, slightly less stable)
- Computational cost: Fewer operations (no mean computation/subtraction)
- Stability: Slightly different numerical behavior but usually equivalent
- Expressiveness: Both learn weight (γ) and are highly expressive
- Gradient flow: Similar to LayerNorm, good for deep networks
- Empirical performance: Very similar to LayerNorm on most tasks
Algorithm: For input tensor with shape [..., normalized_shape]:
- Compute RMS (root mean square) across last
len(normalized_shape)dimensions RMS = √(mean(x²)) - Normalize: x_norm = x / (RMS + eps)
- Apply learned scale: y = γ * x_norm
No learned bias (β) parameter - only weight (γ) for scaling. The eps parameter (default 1e-5) prevents division by zero with small RMS values.
- LLM standard: RMSNorm preferred in modern LLMs (LLaMA, Falcon, Mistral)
- Computational efficiency: Fewer operations than LayerNorm (no mean computation)
- No bias parameter: Only weight (γ), no bias (β) like LayerNorm
- RMS-only normalization: Normalizes magnitude without centering on mean
- Batch size invariant: Works with any batch size, including batch_size=1
- No training/eval mode: Behavior identical in train() and eval() modes
- No running statistics: Unlike BatchNorm, no moving average to maintain
- Initialization: Weight initialized to 1 for identity transformation initially
- Stability equivalent: Similar numerical stability to LayerNorm in practice
- Pre-norm preferred: Usually applied before attention/FFN (pre-norm architecture)
- Empirical equivalence: Performs similarly to LayerNorm on most benchmarks
- Features in normalized_shape are normalized together (not independently)
- Currently only supports 1D normalized_shape (single trailing dimension)
- Input's last dimension must match normalized_shape
- eps parameter is critical for numerical stability with small RMS values
- No bias parameter (unlike LayerNorm) - if needed, add separate linear layer
Examples
// Simple RMSNorm layer
const rmsnorm = new torch.nn.RMSNorm(512); // Normalize 512-dim features
const x = torch.randn([16, 512]); // Batch of 16 samples, 512 features
const normalized = rmsnorm.forward(x); // Same shape [16, 512]
// Each sample's 512 features normalized by RMS// Transformer decoder block with RMSNorm (LLaMA style)
class TransformerDecoderBlock extends torch.nn.Module {
attention: torch.nn.MultiheadAttention;
rms_norm1: torch.nn.RMSNorm;
ff_linear1: torch.nn.Linear;
ff_linear2: torch.nn.Linear;
rms_norm2: torch.nn.RMSNorm;
constructor(d_model: number = 512, num_heads: number = 8) {
super();
this.attention = new torch.nn.MultiheadAttention(d_model, num_heads);
this.rms_norm1 = new torch.nn.RMSNorm(d_model); // Pre-norm attention
this.ff_linear1 = new torch.nn.Linear(d_model, d_model * 4);
this.ff_linear2 = new torch.nn.Linear(d_model * 4, d_model);
this.rms_norm2 = new torch.nn.RMSNorm(d_model); // Pre-norm FFN
}
forward(x: torch.Tensor): torch.Tensor {
// Pre-norm attention (normalize before attention)
const norm1 = this.rms_norm1.forward(x);
const attn_out = this.attention.forward(norm1, norm1, norm1);
x = x.add(attn_out); // Residual connection
// Pre-norm feed-forward (normalize before FFN)
const norm2 = this.rms_norm2.forward(x);
let ff = torch.gelu(this.ff_linear1.forward(norm2)); // GELU activation
let ff_out = this.ff_linear2.forward(ff);
x = x.add(ff_out); // Residual connection
return x;
}
}
// Usage in LLM
const model = new TransformerDecoderBlock(768, 12); // 768 dims, 12 heads
const tokens = torch.randn([batch_size, seq_len, 768]);
const output = model.forward(tokens);// LLaMA-style language model with RMSNorm
class LlamaLM extends torch.nn.Module {
embed: torch.nn.Embedding;
decoder_blocks: TransformerDecoderBlock[];
final_norm: torch.nn.RMSNorm;
output_proj: torch.nn.Linear;
constructor(vocab_size: number, d_model: number, num_layers: number) {
super();
this.embed = new torch.nn.Embedding(vocab_size, d_model);
this.decoder_blocks = [];
for (let i = 0; i < num_layers; i++) {
this.decoder_blocks.push(new TransformerDecoderBlock(d_model));
}
this.final_norm = new torch.nn.RMSNorm(d_model); // Final layer norm
this.output_proj = new torch.nn.Linear(d_model, vocab_size);
}
forward(tokens: torch.Tensor): torch.Tensor {
let x = this.embed.forward(tokens); // Token embeddings
for (const block of this.decoder_blocks) {
x = block.forward(x);
}
x = this.final_norm.forward(x); // Final normalization
const logits = this.output_proj.forward(x); // Output logits
return logits;
}
}
// LLaMA uses RMSNorm throughout for computational efficiency// Multi-dimensional input (sequence of features)
const x = torch.randn([32, 100, 512]); // [batch=32, seq=100, features=512]
const rms_norm = new torch.nn.RMSNorm(512); // Normalize 512-dim features
const normalized = rms_norm.forward(x); // [32, 100, 512]
// Each of 32*100=3200 positions gets RMS-normalized independently
// across the 512 features