torch.nn.functional.batch_norm
function batch_norm<S extends Shape, D extends DType = DType, Dev extends DeviceType = DeviceType>(input: Tensor<S, D, Dev>, running_mean: Tensor | undefined, running_var: Tensor | undefined, options?: BatchNormFunctionalOptions): Tensor<S, D, Dev>function batch_norm<S extends Shape, D extends DType = DType, Dev extends DeviceType = DeviceType>(input: Tensor<S, D, Dev>, running_mean: Tensor | undefined, running_var: Tensor | undefined, training: boolean, momentum: number, eps: number, weight: Tensor | undefined, bias: Tensor | undefined, options?: BatchNormFunctionalOptions): Tensor<S, D, Dev>Batch Normalization: normalize activations using batch statistics during training, running statistics during inference.
Normalizes inputs by subtracting batch mean and dividing by batch standard deviation (with learnable affine parameters). One of the most important techniques in modern deep learning, enabling much faster training and deeper networks. During training, uses statistics computed from current batch. During inference, uses exponential-moving-average statistics computed from training data. Essential for:
- Stabilizing training (reducing internal covariate shift)
- Enabling higher learning rates
- Allowing deeper networks without vanishing/exploding gradients
- Regularization effect reducing need for dropout
- Standard component in most modern architectures (ResNets, VGG, etc.)
Training vs Inference: During training, normalizes using batch mean/variance (reduces covariate shift). During inference, uses running statistics (pre-computed exponential moving average) for consistency. Critical to switch between modes using model.train() / model.eval().
Learnable Parameters: After normalization to (x - batch_mean) / sqrt(batch_var + eps), applies learnable affine: y = γ*(normalized_x) + β. Allows model to undo normalization if needed.
- Critical for deep learning: One of the most impactful techniques; enables training much deeper networks
- Training vs inference: MUST use training=true during training, training=false during inference. Forgetting this is a common bug causing performance degradation at inference time.
- Batch size matters: Batch norm works best with reasonably large batches (32+). Very small batches ( 8) can cause poor normalization (unreliable batch statistics).
- Running statistics: During training, running_mean/var are updated exponentially. After training completes, these contain accumulated statistics used at inference.
- Reduces internal covariate shift: Keeps activations in reasonable range, enabling higher learning rates.
- Regularization effect: Acts as implicit regularization, reducing generalization gap (like dropout).
- Normalize over batch and spatial: For CNNs, normalizes across batch dimension AND spatial dims, but independently per channel. Different from layer_norm which normalizes all features together.
- Training flag critical: Forgetting to set training=false at inference causes poor model performance. Always ensure model.eval() is called before inference (or manually set training=false).
- Small batch sizes unstable: With batch_size 8, batch statistics unreliable. Consider group_norm or layer_norm for small batches. Synchronized batch norm for multi-GPU with small per-GPU batches.
- Running statistics must persist: running_mean/var must be stored and updated during training, then used at inference. If not properly saved/loaded with checkpoint, inference will fail.
- Slightly different training/inference: Model behavior differs between training and inference due to using batch vs running statistics. This is intentional but can cause subtle bugs if forgotten.
- Momentum direction can confuse: momentum=0.1 means 10% new, 90% old (not the typical gradient sense). Higher momentum = faster adaptation to new data.
Parameters
inputTensor<S, D, Dev>- Input tensor. Shape [batch_size, channels, ...spatial_dims] for CNNs, or [batch_size, features] for fully connected. Batch dimension critical (dim 0).
running_meanTensor | undefined- (Optional) Pre-computed running mean statistics (from training). Used during inference. Shape [channels]. If undefined, uses batch mean during training (don't pass for pure batch_norm).
running_varTensor | undefined- (Optional) Pre-computed running variance statistics (from training). Used during inference. Shape [channels]. If undefined, uses batch variance during training.
optionsBatchNormFunctionalOptionsoptional- Options for the operation. See
BatchNormFunctionalOptions.
Returns
Tensor<S, D, Dev>– Normalized output tensor with same shape as input.Examples
// Simple batch normalization: normalize feature activations
const input = torch.randn(32, 64); // [batch=32, features=64]
const weight = torch.ones(64); // γ parameter (learnable scale)
const bias = torch.zeros(64); // β parameter (learnable shift)
const output = torch.nn.functional.batch_norm(
input, undefined, undefined, weight, bias, true, 0.1, 1e-5
);
// Normalized: zero mean, unit variance per feature (approximately)
// CNN batch norm: normalize per-channel activations (ResNet, VGG)
const x = torch.randn(16, 64, 32, 32); // [batch=16, channels=64, height=32, width=32]
const gamma = torch.ones(64);
const beta = torch.zeros(64);
const bn_output = torch.nn.functional.batch_norm(x, undefined, undefined, gamma, beta, true);
// Normalizes over batch (16) and spatial dims (32x32), independently per channel (64)
// Training vs inference switching
class ResNetBlock {
forward(x, training) {
x = this.conv1(x);
// During training: normalize using current batch stats
// During inference: normalize using accumulated running stats
x = torch.nn.functional.batch_norm(
x, this.running_mean, this.running_var,
this.bn_weight, this.bn_bias, training // KEY: training flag controls behavior
);
x = torch.nn.functional.relu(x);
x = this.conv2(x);
x = torch.nn.functional.batch_norm(
x, this.running_mean2, this.running_var2,
this.bn_weight2, this.bn_bias2, training
);
return x;
}
}
// With momentum: running statistics exponential moving average
// Small momentum (0.01): slow moving average (more stable, but adapts slowly to new data)
// Large momentum (0.1-0.5): faster adaptation
const running_mean = torch.zeros(64);
const running_var = torch.ones(64);
const y = torch.nn.functional.batch_norm(
x, running_mean, running_var, gamma, beta, true, 0.01, 1e-5
);
// running_mean and running_var are updated in-place (exponential moving average)
// Comparison: batch norm vs layer norm normalization
const batch_mean = torch.mean(x, [0]); // Compute mean over batch and spatial
const batch_var = torch.var(x, [0]); // Variance over batch and spatial
const normalized_bn = (x - batch_mean) / torch.sqrt(batch_var + 1e-5);
// Batch norm: normalize across batch (different for each spatial position)
// Layer norm: would normalize per sample (all features together)See Also
- PyTorch torch.nn.functional.batch_norm
- torch.nn.BatchNorm1d - Module for 1D batch norm (fully connected layers)
- torch.nn.BatchNorm2d - Module for 2D batch norm (CNNs)
- torch.nn.BatchNorm3d - Module for 3D batch norm (volumetric CNNs)
- torch.nn.functional.layer_norm - Per-sample normalization (works with small batches)
- torch.nn.functional.group_norm - Group-based normalization (middle ground)
- torch.nn.functional.instance_norm - Per-channel normalization
- torch.nn.functional.dropout - Complementary regularization technique