torch.nn.BatchNorm3d
class BatchNorm3d extends _BatchNormBatch Normalization for 3D inputs (volumetric data): normalizes feature maps for video and 3D CNNs.
Applies batch normalization to 5D volumetric tensors (NCDHW format), normalizing channels across batch and spatial (depth × height × width) dimensions. Essential for:
- Video understanding and action recognition (3D CNNs)
- Medical imaging (CT scans, MRI volumetric analysis)
- Point cloud processing (volumetric representations)
- Volumetric deep learning (3D object detection, reconstruction)
- Accelerating training of 3D networks
Similar to BatchNorm2d but operates on 5D spatial data. During training, computes statistics over mini-batch and spatial volume (batch × depth × height × width), then applies learned per-channel affine transform. During evaluation, uses accumulated running statistics.
- NCDHW format: Dimension 0=batch, 1=channel, 2=depth, 3=height, 4=width
- Spatial normalization: Normalizes across D×H×W dimensions, keeping each channel independent
- Training vs eval mode: Critical difference - training updates running stats, eval uses them
- Memory intensive: 3D operations use more memory than 2D. Reduce batch size or spatial dimensions if needed
- Computation cost: Significant overhead on 3D data. Still worth it for convergence speed
- Video processing: T (temporal) dimension treated as part of spatial volume, not as sequence
- Momentum default: 0.1 is standard but may need adjustment for small batch sizes (use 0.01-0.05)
- Architecture pattern: Conv3d → BatchNorm3d → Activation is standard for 3D CNNs
- Parameter count: 2 × num_features learned parameters (γ and β) plus 2 buffers (running mean/var)
- Computational patterns: More effective with batch size ≥ 4. Very small batches may cause training instability
- Gradient computation: Backprop through normalization is efficient on modern GPUs
Examples
// 3D CNN for video action recognition
class VideoActionRecognizer extends torch.nn.Module {
conv1: torch.nn.Conv3d;
bn1: torch.nn.BatchNorm3d;
maxpool: torch.nn.MaxPool3d;
conv2: torch.nn.Conv3d;
bn2: torch.nn.BatchNorm3d;
avgpool: torch.nn.AdaptiveAvgPool3d;
fc: torch.nn.Linear;
constructor() {
super();
this.conv1 = new torch.nn.Conv3d(3, 64, 7, { stride: 2, padding: 3 });
this.bn1 = new torch.nn.BatchNorm3d(64);
this.maxpool = new torch.nn.MaxPool3d(2);
this.conv2 = new torch.nn.Conv3d(64, 128, 3, { padding: 1 });
this.bn2 = new torch.nn.BatchNorm3d(128);
this.avgpool = new torch.nn.AdaptiveAvgPool3d([1, 1, 1]);
this.fc = new torch.nn.Linear(128, 400); // 400 action classes
}
forward(x: torch.Tensor): torch.Tensor {
// x: [B, 3, T, H, W] where T=num_frames
x = this.conv1.forward(x); // [B, 64, T/2, H/2, W/2]
x = this.bn1.forward(x); // Normalize 3D features
x = torch.nn.functional.relu(x);
x = this.maxpool.forward(x); // [B, 64, T/4, H/4, W/4]
x = this.conv2.forward(x); // [B, 128, T/4, H/4, W/4]
x = this.bn2.forward(x);
x = torch.nn.functional.relu(x);
x = this.avgpool.forward(x); // [B, 128, 1, 1, 1]
x = x.view(x.shape[0], -1); // [B, 128]
x = this.fc.forward(x); // [B, 400]
return x;
}
}
const model = new VideoActionRecognizer();
model.train();
// Video batch: [batch_size=8, channels=3, frames=32, height=224, width=224]
const video = torch.randn([8, 3, 32, 224, 224]);
const logits = model.forward(video); // [8, 400]// 3D CNN for medical imaging (CT scan analysis)
class MedicalImageAnalyzer extends torch.nn.Module {
conv1: torch.nn.Conv3d;
bn1: torch.nn.BatchNorm3d;
conv2: torch.nn.Conv3d;
bn2: torch.nn.BatchNorm3d;
fc: torch.nn.Linear;
constructor() {
super();
// Input: Single channel CT volumes
this.conv1 = new torch.nn.Conv3d(1, 32, 3, { padding: 1 });
this.bn1 = new torch.nn.BatchNorm3d(32);
this.conv2 = new torch.nn.Conv3d(32, 64, 3, { padding: 1 });
this.bn2 = new torch.nn.BatchNorm3d(64);
this.fc = new torch.nn.Linear(64 * 64 * 64 * 64, 2); // Binary classification
}
forward(x: torch.Tensor): torch.Tensor {
// x: [B, 1, D, H, W] - medical volume
x = this.conv1.forward(x);
x = this.bn1.forward(x); // Normalize across 3D spatial dimensions
x = torch.nn.functional.relu(x);
x = this.conv2.forward(x);
x = this.bn2.forward(x);
x = torch.nn.functional.relu(x);
x = x.view(x.shape[0], -1);
x = this.fc.forward(x); // Classification result
return x;
}
}// Using BatchNorm3d with different configurations
// Standard: full batch normalization with statistics tracking
const bn_standard = new torch.nn.BatchNorm3d(64);
// No affine transform: normalization only
const bn_no_affine = new torch.nn.BatchNorm3d(64, 1e-5, 0.1, false);
// No statistics tracking: useful for small batch sizes or special architectures
const bn_no_track = new torch.nn.BatchNorm3d(64, 1e-5, 0.1, true, false);
const volume = torch.randn([4, 64, 32, 32, 32]); // Volumetric data
// Training: all use batch statistics
bn_standard.train();
const y1 = bn_standard.forward(volume);
// Evaluation: only bn_standard has stable running statistics
bn_standard.eval();
const y2 = bn_standard.forward(volume);// Batch normalization with custom momentum for fine-tuning
const bn = new torch.nn.BatchNorm3d(64, 1e-5, 0.01); // Low momentum (0.01)
// Freeze batch norm for transfer learning
bn.eval(); // Use pre-trained running statistics
const volume = torch.randn([2, 64, 64, 64, 64]);
const y = bn.forward(volume); // No parameter updates, no stat changes