torch.nn.LazyBatchNorm1d
class LazyBatchNorm1d extends _LazyBatchNormLazy Batch Normalization 1D: automatically infers number of features on first forward pass.
Extends BatchNorm1d to automatically initialize parameters based on input shape. You don't need to specify num_features upfront - it's inferred from the first batch. Essential for:
- Sequential models where layer sizes are determined dynamically
- Exploratory ML where layer sizes may change during development
- Frameworks that build layers on-the-fly during first forward pass
- Reducing boilerplate when building custom architectures
The Problem Lazy Modules Solve: Regular BatchNorm1d requires specifying num_features explicitly. With Lazy variants, you don't know the layer size until you see data. For example, after a Dynamic Linear layer with output determined at runtime, you don't know how many features to normalize. Lazy modules solve this by initializing parameters on first use.
When to use LazyBatchNorm1d:
- After layers with dynamic output size (e.g., after reshaping, after variable-size processing)
- Sequential architectures where layer sizes are determined at data-load time
- Prototyping when layer sizes aren't known in advance
- Reducing manual computation of intermediate dimensions
- Faster iteration during model design
- When you want cleaner code with less size bookkeeping
Trade-offs:
- vs regular BatchNorm1d: Slightly slower first forward (initialization overhead), then identical
- Lazy initialization: First forward pass is slower due to parameter creation
- After first pass: Behaves identically to BatchNorm1d (no further overhead)
- Memory: Same memory footprint as regular BatchNorm once initialized
- Parameter access: Parameters only exist after first forward pass (use carefully in custom code)
Algorithm: On first forward pass (lazy initialization):
- Infer num_features from input shape: num_features = input.shape[1]
- Create weight (γ), bias (β) if affine=true
- Create running_mean, running_var if track_running_stats=true
- Perform batch normalization (same as regular BatchNorm1d)
On subsequent passes:
- Identical to regular BatchNorm1d (use pre-initialized parameters)
- First forward slower: Initialization adds overhead to first forward pass
- After first pass: Behaves identically to regular BatchNorm1d
- Parameter inference: num_features inferred from input.shape[1]
- Initialization delay: Parameters created on first forward, not at construction
- Gradient accumulation: Running statistics updated after initialization like regular BatchNorm
- Train/eval modes: Same behavior as regular BatchNorm once initialized
- Weight sharing: Parameters only shared within this module after initialization
- First forward pass initializes parameters (slight overhead)
- Don't access weight/bias before first forward (they don't exist yet)
- Input must be 2D or 3D (batch + features [+ sequence])
- All forward passes must have same feature dimension (after initialization)
Examples
// Don't need to know input size upfront
const linear = new torch.nn.Linear(10, 0); // Dummy, will be set correctly
const lazy_bn = new torch.nn.LazyBatchNorm1d(); // No num_features needed!
// First forward: initializes with num_features=64 from input
const x = torch.randn([32, 64]);
const output = lazy_bn.forward(x); // Initializes, then normalizes
// Second forward: uses initialized parameters
const x2 = torch.randn([16, 64]);
const output2 = lazy_bn.forward(x2); // Fast, uses initialized parameters// Useful in Sequential with dynamic layer sizes
class DynamicNetwork extends torch.nn.Module {
fc1: torch.nn.Linear;
bn1: torch.nn.LazyBatchNorm1d; // Don't know size after fc1
fc2: torch.nn.Linear;
bn2: torch.nn.LazyBatchNorm1d; // Don't know size after fc2
constructor(input_size: number, hidden_size: number, output_size: number) {
super();
this.fc1 = new torch.nn.Linear(input_size, hidden_size);
this.bn1 = new torch.nn.LazyBatchNorm1d(); // Will be hidden_size
this.fc2 = new torch.nn.Linear(hidden_size, output_size);
this.bn2 = new torch.nn.LazyBatchNorm1d(); // Will be output_size
}
forward(x: torch.Tensor): torch.Tensor {
x = this.fc1.forward(x);
x = this.bn1.forward(x); // Initializes with 128 features
x = torch.relu(x);
x = this.fc2.forward(x);
x = this.bn2.forward(x); // Initializes with 64 features
return x;
}
}