torch.nn.LazyBatchNorm3d
class LazyBatchNorm3d extends _LazyBatchNormLazy Batch Normalization 3D: automatically infers number of channels for 3D convolutions/volumetric data.
Extends BatchNorm3d with lazy initialization of parameters for 3D/volumetric data. Automatically infers number of channels from the first input. Essential for:
- Building 3D convolution networks (medical imaging, video, volumetric data)
- Models where 3D channel dimensions aren't known in advance
- Reducing boilerplate in 3D CNN architectures
Same behavior as LazyBatchNorm1d/2d but for 3D inputs (5D tensors: [batch, channels, depth, height, width]). Automatically determines num_features from input.shape[1] on first forward pass.
- Input format: Expects 5D input (batch, channels, depth, height, width)
- Channel inference: num_channels inferred from input.shape[1]
- 3D spatial dims: Works with any 3D spatial dimensions (adaptive)
- Memory intensive: 3D operations use more memory; may need smaller batches
- First forward pass initializes parameters (slight overhead)
- Input must be 5D (batch + channels + 3 spatial dimensions)
- All forward passes must have same channel dimension (after initialization)
- 3D convolutions are computationally expensive; use with caution
Examples
// Lazy BatchNorm for 3D convolutions (medical imaging)
const conv3d = new torch.nn.Conv3d(1, 32, 3, { padding: 1 });
const lazy_bn = new torch.nn.LazyBatchNorm3d(); // Infers 32 channels automatically
const volume = torch.randn([4, 1, 64, 64, 64]); // 3D medical scan
let x = conv3d.forward(volume); // [4, 32, 64, 64, 64]
x = lazy_bn.forward(x); // Initializes with 32 channels, then normalizes// 3D CNN for video classification
class Video3DCNN extends torch.nn.Module {
conv1: torch.nn.Conv3d;
bn1: torch.nn.LazyBatchNorm3d;
conv2: torch.nn.Conv3d;
bn2: torch.nn.LazyBatchNorm3d;
constructor() {
super();
this.conv1 = new torch.nn.Conv3d(3, 64, 3, { padding: 1 });
this.bn1 = new torch.nn.LazyBatchNorm3d();
this.conv2 = new torch.nn.Conv3d(64, 128, 3, { padding: 1 });
this.bn2 = new torch.nn.LazyBatchNorm3d();
}
forward(x: torch.Tensor): torch.Tensor {
// x: [batch, frames, height, width, 3] -> reshape to [batch, 3, frames, height, width]
x = torch.relu(this.bn1.forward(this.conv1.forward(x)));
x = torch.relu(this.bn2.forward(this.conv2.forward(x)));
return x;
}
}