torch.nn.functional.group_norm
function group_norm<S extends Shape, D extends DType = DType, Dev extends DeviceType = DeviceType>(input: Tensor<S, D, Dev>, num_groups: number, options?: GroupNormFunctionalOptions): Tensor<S, D, Dev>function group_norm<S extends Shape, D extends DType = DType, Dev extends DeviceType = DeviceType>(input: Tensor<S, D, Dev>, num_groups: number, weight: Tensor | undefined, bias: Tensor | undefined, eps: number, options?: GroupNormFunctionalOptions): Tensor<S, D, Dev>Group Normalization: divides channels into groups and normalizes each group independently.
Applies group normalization where channels are split into num_groups disjoint groups, and normalization is computed independently over each group. A middle ground between layer normalization (normalizes all features together) and instance normalization (one group per channel). Particularly effective for tasks with small batch sizes or variable batch sizes. Essential for:
- Small batch size training (batch norm unstable with small batches)
- Object detection and semantic segmentation (standard in ResNets)
- Video understanding (temporal batch dimension variability)
- Adversarial training (batch norm statistics problematic)
- Fine-grained image recognition with small batches
- Models where batch normalization causes instability
- Distributed training with per-GPU batch size = 1
How Group Norm works:
- Split C channels into G groups of size C/G each
- For each sample in batch: compute mean and variance per group
- Normalize each group independently (z-score)
- Apply learnable affine transform (weight and bias)
Group Norm variants (by num_groups):
- num_groups = 1: Equivalent to LayerNorm (one group = all channels)
- num_groups = C (channels): Equivalent to InstanceNorm (one group per channel)
- 1 < num_groups < C: Balance between group norm and instance norm
- Typical: num_groups = 32 for ResNet-style architectures
When to use GroupNorm:
- Small batch size (batch_size < 8)
- Variable batch sizes in training
- Replacing batch norm for stability
- Object detection/segmentation networks
- When batch statistics are unreliable
- Video/sequence models with temporal variability
Comparison with alternatives:
- Batch Norm: Normalizes over batch; requires large batch for stability
- Layer Norm: Single group over all channels; computationally efficient
- Instance Norm: One group per channel; too aggressive for some tasks
- Group Norm: Middle ground; best empirically for small batches
- Batch-size independent: Statistics computed per group, not per batch
- Small batch friendly: Works well with batch_size=1 (unlike BatchNorm)
- Flexible grouping: num_groups controls group size (balance between extremes)
- Learnable parameters: weight and bias trained like LayerNorm
- ResNet standard: Standard in modern object detection and segmentation models
- Per-sample normalization: Each sample has independent group statistics
- Channel divisibility: num_groups must divide num_channels exactly
- num_groups constraint: num_groups must divide num_channels (e.g., 256 channels → 32 groups OK)
- Minimum 2D: Requires at least 2D input (batch, channels required)
- Small groups: Very small groups (1-2 channels) may hurt performance
- Different from batch norm: Statistics computed differently; direct replacement may need tuning
Parameters
inputTensor<S, D, Dev>- Input tensor [batch, channels, spatial...]. Minimum 2D required. Example: [N, C, H, W] for images, [N, C, length] for 1D signals
num_groupsnumber- Number of groups to divide channels into. Must divide num_channels. Example: num_groups=32 for ResNet with 256 channels → 8 channels per group
optionsGroupNormFunctionalOptionsoptional
Returns
Tensor<S, D, Dev>– Normalized tensor with same shape as inputExamples
// Standard GroupNorm with 32 groups (ResNet-style)
const batch = 4, channels = 256, height = 28, width = 28;
const x = torch.randn([batch, channels, height, width]); // [4, 256, 28, 28]
const normalized = torch.nn.functional.group_norm(x, 32); // 32 groups, 8 channels/group
// Each sample has 32 independent normalizations over spatial dims + channel groups// Object detection with small batch size (GroupNorm standard)
class ResNetBackbone extends torch.nn.Module {
private layer1: torch.nn.Sequential;
private gn1: torch.nn.Parameter;
constructor() {
super();
this.layer1 = conv_block_3x3(64); // Conv layer
this.gn1 = torch.ones([64]); // GroupNorm weight for 64 channels
}
forward(x: torch.Tensor): torch.Tensor {
x = this.layer1.forward(x); // [N, 64, H, W]
x = torch.nn.functional.group_norm(x, 32, this.gn1); // 2 channels per group
return x;
}
}// Small batch training: where batch norm fails
const batch_size = 2; // Too small for reliable batch norm
const x = torch.randn([batch_size, 512, 14, 14]);
// Batch norm would be unstable; use group norm instead
const weight = torch.ones([512]);
const bias = torch.zeros([512]);
const normalized = torch.nn.functional.group_norm(x, 32, weight, bias);
// Stable normalization even with batch_size=2// Varying batch size in distributed training
const batch_per_gpu = 1; // Each GPU has batch=1 (small!)
const x = torch.randn([batch_per_gpu, 64, 32, 32]);
// GroupNorm handles this gracefully (BatchNorm would fail)
const output = torch.nn.functional.group_norm(x, 8); // 8 groups for 64 channels
// Per-group statistics computed independently per sample// Comparing num_groups effect
const x = torch.randn([4, 256, 28, 28]);
const weight = torch.ones([256]);
const g1 = torch.nn.functional.group_norm(x, 1, weight); // LayerNorm (single group)
const g32 = torch.nn.functional.group_norm(x, 32, weight); // 8 channels/group
const g256 = torch.nn.functional.group_norm(x, 256, weight); // InstanceNorm (1 channel/group)
// Different num_groups produce different normalization statistics
// g32 is typically best for small batchesSee Also
- PyTorch torch.nn.functional.group_norm
- torch.nn.functional.layer_norm - Single group (all channels together)
- torch.nn.functional.instance_norm - Multiple groups (one per channel)
- torch.nn.functional.batch_norm - Normalizes over batch (requires large batches)
- torch.nn.GroupNorm - Module version with automatic parameter management