torch.nn.BatchNorm2d
class BatchNorm2d extends _BatchNormBatch Normalization for 2D inputs (images): normalizes channels, accelerates training in CNNs.
Applies batch normalization to 4D image tensors (NCHW format), normalizing feature maps independently per channel. Fundamental building block of modern CNNs that:
- Stabilizes training and allows higher learning rates
- Reduces sensitivity to weight initialization
- Enables training of very deep networks
- Acts as implicit regularizer
- Accelerates convergence dramatically
During training, computes mean/variance over mini-batch and spatial dimensions (batch × height × width), normalizes, then applies learned per-channel scale γ and shift β. During evaluation, uses accumulated running statistics for consistent predictions.
- NCHW format: Operates on channels (dimension 1), normalizes across batch and spatial dimensions
- Training vs eval mode: Behavior changes significantly - train() updates running stats, eval() uses them
- Running statistics: Accumulated during training, crucial for stable inference. Don't forget model.eval()!
- Momentum: Typical value is 0.1. Smaller = smoother running stats (better for inference), larger = responsive to current batch
- Batch size sensitivity: More effective with larger batches (16). Small batches (1-4) can destabilize training
- Affine transform: Per-channel learnable scale γ and shift β. Disable (affine=false) only if normalization is handled elsewhere
- Epsilon: Prevents division by zero. 1e-5 is standard; increase if instability occurs
- Architecture position: Conv → BatchNorm → Activation is the standard pattern for CNNs
- Computational cost: Adds ~15-20% overhead but enables much faster convergence with higher learning rates
- Parameter count: num_features * 2 parameters (γ and β) plus 2 buffers (running mean/var)
- Gradient flow: Gradients flow through normalization operation efficiently on GPU
Examples
// ResNet-style CNN architecture with BatchNorm2d
class ResNetBlock extends torch.nn.Module {
conv1: torch.nn.Conv2d;
bn1: torch.nn.BatchNorm2d;
conv2: torch.nn.Conv2d;
bn2: torch.nn.BatchNorm2d;
constructor(in_channels: number, out_channels: number) {
super();
this.conv1 = new torch.nn.Conv2d(in_channels, out_channels, 3, { padding: 1 });
this.bn1 = new torch.nn.BatchNorm2d(out_channels);
this.conv2 = new torch.nn.Conv2d(out_channels, out_channels, 3, { padding: 1 });
this.bn2 = new torch.nn.BatchNorm2d(out_channels);
}
forward(x: torch.Tensor): torch.Tensor {
const residual = x;
x = this.conv1.forward(x); // [B, out_channels, H, W]
x = this.bn1.forward(x); // Normalize channels
x = torch.nn.functional.relu(x);
x = this.conv2.forward(x);
x = this.bn2.forward(x);
x = x.add(residual); // Skip connection
x = torch.nn.functional.relu(x);
return x;
}
}
const model = new ResNetBlock(64, 64);
model.train();
const x = torch.randn([32, 64, 224, 224]); // ImageNet-style batch
const y = model.forward(x);// VGG-style deep network with BatchNorm2d
class VGGBlock extends torch.nn.Module {
layers: (torch.nn.Conv2d | torch.nn.BatchNorm2d | torch.nn.ReLU)[] = [];
constructor(in_channels: number, out_channels: number, num_convs: number) {
super();
for (let i = 0; i < num_convs; i++) {
const cin = i === 0 ? in_channels : out_channels;
this.layers.push(new torch.nn.Conv2d(cin, out_channels, 3, { padding: 1 }));
this.layers.push(new torch.nn.BatchNorm2d(out_channels));
this.layers.push(new torch.nn.ReLU());
}
}
forward(x: torch.Tensor): torch.Tensor {
for (const layer of this.layers) {
if (layer instanceof torch.nn.ReLU) {
x = layer.forward(x);
} else if (layer instanceof torch.nn.BatchNorm2d) {
x = layer.forward(x);
} else {
x = layer.forward(x);
}
}
return x;
}
}// Image classification pipeline with training/eval modes
class ImageClassifier extends torch.nn.Module {
conv1: torch.nn.Conv2d;
bn1: torch.nn.BatchNorm2d;
maxpool: torch.nn.MaxPool2d;
fc: torch.nn.Linear;
constructor() {
super();
this.conv1 = new torch.nn.Conv2d(3, 64, 7, { stride: 2, padding: 3 });
this.bn1 = new torch.nn.BatchNorm2d(64);
this.maxpool = new torch.nn.MaxPool2d(2);
this.fc = new torch.nn.Linear(64 * 112 * 112, 1000);
}
forward(x: torch.Tensor): torch.Tensor {
x = this.conv1.forward(x); // [B, 64, 112, 112]
x = this.bn1.forward(x); // Normalize: crucial for training
x = torch.nn.functional.relu(x);
x = this.maxpool.forward(x); // [B, 64, 56, 56]
x = x.view(x.shape[0], -1); // Flatten
x = this.fc.forward(x); // [B, 1000]
return x;
}
}
const model = new ImageClassifier();
// Training: normalizes using batch statistics
model.train();
const train_batch = torch.randn([32, 3, 224, 224]);
const train_logits = model.forward(train_batch); // Uses batch mean/var
// Evaluation: normalizes using accumulated running statistics
model.eval();
const test_image = torch.randn([1, 3, 224, 224]);
const test_logits = model.forward(test_image); // Uses running mean/var// Fine-tuning with frozen BatchNorm (common in transfer learning)
const bn = new torch.nn.BatchNorm2d(64);
// Option 1: Keep eval mode (don't update running stats)
bn.eval();
const x = torch.randn([32, 64, 56, 56]);
const y = bn.forward(x); // Uses frozen running stats
// Option 2: track_running_stats=false (no running stats to update)
const bn_no_track = new torch.nn.BatchNorm2d(64, 1e-5, 0.1, true, false);
bn_no_track.train();
const y2 = bn_no_track.forward(x); // Always uses batch stats// Analyzing running statistics and parameters
const bn = new torch.nn.BatchNorm2d(64);
// Access learned scale (γ) and shift (β)
const gamma = bn.weight; // [64] - per-channel scale
const beta = bn.bias; // [64] - per-channel shift
// Access accumulated running statistics
const running_mean = bn.running_mean; // [64]
const running_var = bn.running_var; // [64]
// Reset running statistics and parameters
bn.reset_parameters(); // Re-initialize γ, β and running stats