torch.nn.LazyBatchNorm2d
class LazyBatchNorm2d extends _LazyBatchNormLazy Batch Normalization 2D: automatically infers number of channels for 2D convolutions.
Extends BatchNorm2d with lazy initialization of parameters. Number of channels is automatically inferred from the first input. Essential for:
- Building convolutional networks without knowing exact feature dimensions
- Sequential/functional API where layer sizes change dynamically
- Cleaner model code with less dimension bookkeeping
- Reducing errors from manual channel counting
Same behavior as LazyBatchNorm1d but for 2D inputs (4D tensors: [batch, channels, height, width]). Automatically determines num_features from input.shape[1] on first forward pass.
- Input format: Expects 4D input (batch, channels, height, width)
- Channel inference: num_channels inferred from input.shape[1]
- Initialization: Parameters created on first forward, not at construction
- Spatial dimensions: Works with any spatial dimensions (adaptive)
- First forward pass initializes parameters (slight overhead)
- Input must be 4D (batch + channels + 2 spatial dimensions)
- All forward passes must have same channel dimension (after initialization)
Examples
// Lazy BatchNorm for 2D convolutions
const conv = new torch.nn.Conv2d(3, 64, 3, { padding: 1 });
const lazy_bn = new torch.nn.LazyBatchNorm2d(); // Infers 64 channels automatically
const image = torch.randn([32, 3, 224, 224]);
let x = conv.forward(image); // [32, 64, 224, 224]
x = lazy_bn.forward(x); // Initializes with 64 channels, then normalizes// Sequential CNN with lazy normalization
class SimpleCNN extends torch.nn.Module {
conv1: torch.nn.Conv2d;
bn1: torch.nn.LazyBatchNorm2d;
conv2: torch.nn.Conv2d;
bn2: torch.nn.LazyBatchNorm2d;
constructor() {
super();
this.conv1 = new torch.nn.Conv2d(3, 32, 3, { padding: 1 });
this.bn1 = new torch.nn.LazyBatchNorm2d();
this.conv2 = new torch.nn.Conv2d(32, 64, 3, { padding: 1 });
this.bn2 = new torch.nn.LazyBatchNorm2d();
}
forward(x: torch.Tensor): torch.Tensor {
x = torch.relu(this.bn1.forward(this.conv1.forward(x)));
x = torch.relu(this.bn2.forward(this.conv2.forward(x)));
return x;
}
}