torch.nn.BatchNorm1d
class BatchNorm1d extends _BatchNormBatch Normalization for 1D/2D inputs: normalizes mini-batch statistics, then applies affine transform.
Applies batch normalization to normalize activations along the channel dimension, reducing internal covariate shift and allowing higher learning rates. Essential for:
- Stabilizing training of deep networks
- Accelerating convergence (acts as learning rate schedule)
- Reducing sensitivity to weight initialization
- Acting as implicit regularizer
- Processing sequential data (RNNs, Transformers, time series)
- Dense layer networks (MLPs, fully connected architectures)
During training, normalizes using mini-batch statistics and updates running statistics. During evaluation, uses accumulated running mean/variance for stable predictions. Optional affine transform (scale γ and shift β per channel) learned from data.
- Covariate shift: BN reduces "internal covariate shift" - change in distribution of inputs to layers
- Training vs eval: Must call model.train() and model.eval() for correct behavior
- Running statistics: Running mean/var accumulated during training, used during inference for stability
- Momentum: Small value (e.g., 0.1) means running stats are heavily influenced by batch; larger values smooth them
- Epsilon: Prevents division by zero when variance is very small (typically 1e-5)
- Affine optional: Set affine=false to save parameters if scale/shift is handled elsewhere
- Gradient flow: Both batch stats (indirectly) and affine parameters have gradients
- GPU consideration: Batch normalization is more effective with larger batch sizes (16)
- Initialization: γ (weight) initialized to 1, β (bias) to 0
- Position in architecture: Typically applied after linear/conv, before activation
Examples
// Simple MLP with batch normalization
class MLPWithBN extends torch.nn.Module {
fc1: torch.nn.Linear;
bn1: torch.nn.BatchNorm1d;
fc2: torch.nn.Linear;
bn2: torch.nn.BatchNorm1d;
fc3: torch.nn.Linear;
constructor() {
super();
this.fc1 = new torch.nn.Linear(784, 256);
this.bn1 = new torch.nn.BatchNorm1d(256);
this.fc2 = new torch.nn.Linear(256, 128);
this.bn2 = new torch.nn.BatchNorm1d(128);
this.fc3 = new torch.nn.Linear(128, 10);
}
forward(x: torch.Tensor): torch.Tensor {
x = x.view(x.shape[0], -1); // Flatten
x = this.fc1.forward(x); // [batch, 256]
x = this.bn1.forward(x); // Normalize channels
x = torch.nn.functional.relu(x);
x = this.fc2.forward(x); // [batch, 128]
x = this.bn2.forward(x);
x = torch.nn.functional.relu(x);
x = this.fc3.forward(x); // [batch, 10]
return x;
}
}
const model = new MLPWithBN();
model.train(); // Enable training mode
const x = torch.randn([32, 784]);
const y = model.forward(x); // Uses batch statistics for normalization// Sequence processing with BatchNorm1d
const batch_size = 32;
const seq_len = 100;
const hidden_dim = 128;
const bn = new torch.nn.BatchNorm1d(hidden_dim);
const x = torch.randn([batch_size, seq_len, hidden_dim]); // [B, L, C]
// BatchNorm1d normalizes along feature dimension (dimension 1)
// Shape [32, 128, 100] - normalizes 128 channels independently
const y = bn.forward(x); // [32, seq_len, hidden_dim]// Training vs evaluation mode
const bn = new torch.nn.BatchNorm1d(64);
// Training mode: uses batch statistics
bn.train();
const train_x = torch.randn([32, 64]);
const train_y = bn.forward(train_x); // Normalizes using batch mean/var
// Evaluation mode: uses running statistics
bn.eval();
const test_x = torch.randn([1, 64]);
const test_y = bn.forward(test_x); // Normalizes using accumulated running stats// Disabling affine transform (normalization only, no scale/shift)
const bn_no_affine = new torch.nn.BatchNorm1d(128, 1e-5, 0.1, false);
const x = torch.randn([32, 128]);
const y = bn_no_affine.forward(x); // Only normalizes, γ=1 and β=0 (fixed)// Batch normalization with different momentum
// Higher momentum: running stats follow recent batches more closely
const bn_fast = new torch.nn.BatchNorm1d(64, 1e-5, 0.5); // Momentum 0.5
const bn_slow = new torch.nn.BatchNorm1d(64, 1e-5, 0.01); // Momentum 0.01 (smoother)
// Lower momentum produces smoother running statistics (better for inference)
// Higher momentum tracks recent data distribution better (better for training)