torch.nn.init.kaiming_uniform_
function kaiming_uniform_(tensor: Tensor, options?: KaimingOptions): Tensorfunction kaiming_uniform_(tensor: Tensor, a: number, mode: FanMode, nonlinearity: Nonlinearity, options?: KaimingOptions): TensorFill tensor with Kaiming (He) uniform initialization for ReLU-based networks.
Kaiming initialization is optimized for ReLU and its variants. Unlike Xavier which assumes linear activations, Kaiming accounts for ReLU's dying ReLU problem by adjusting variance. Essential for:
- Deep ReLU networks (ResNets, VGGs, Modern CNNs)
- ReLU and Leaky ReLU activation functions
- Convolutional neural networks with ReLU
- Training very deep networks (50+ layers) with ReLU
- Matching modern best practices for initialization
Named after He Kaiming. Also called He initialization.
The method is described in "Delving deep into rectifiers: Surpassing human-level performance on ImageNet classification" - He, K. et al. (2015).
- Modern best practice: Kaiming initialization is standard for ReLU/Leaky ReLU networks
- ReLU specific: Designed for ReLU's rectifying property; Xavier better for sigmoid/tanh
- fan_in vs fan_out: fan_in common for early layers, fan_out for later layers
- Leaky ReLU slope: Must match the α used in training (0.01, 0.2, etc.)
- Comparison to Xavier: He uses only fan (not fan_in + fan_out), accounting for ReLU sparsity
- Dying ReLU: Helps prevent the dying ReLU problem where units become inactive
- ResNet: ResNets use kaiming_uniform_ as default initialization scheme
- In-place operation: Modifies tensor in-place; returns the same tensor
Parameters
tensorTensor- An n-dimensional Tensor (typically weight matrix from a layer)
optionsKaimingOptionsoptional- Optional settings for Kaiming initialization
Returns
Tensor– The input tensor with Kaiming uniform initialization Algorithm: - Values sampled from uniform distribution U(-bound, bound) - bound = gain × √(3 / fan) - gain = √(2 / (1 + α²)) for leaky_relu with slope α - gain = √2 for relu - fan = fan_in or fan_out (chosen by mode parameter) - For conv layers: fan includes kernel sizeExamples
// Basic He initialization for ReLU network
const layer = torch.nn.Linear(512, 256);
torch.nn.init.kaiming_uniform_(layer.weight, { a: 0, mode: 'fan_in', nonlinearity: 'relu' });
torch.nn.init.zeros_(layer.bias);
const x = torch.randn([32, 512]);
const y = torch.nn.functional.relu(layer.forward(x));// Leaky ReLU with custom slope
const layer = torch.nn.Linear(1024, 512);
const negative_slope = 0.2; // For Leaky ReLU
torch.nn.init.kaiming_uniform_(layer.weight, { a: negative_slope, mode: 'fan_in', nonlinearity: 'leaky_relu' });
torch.nn.init.zeros_(layer.bias);
const activation = torch.nn.LeakyReLU(negative_slope);
const y = activation.forward(layer.forward(x));// Modern ResNet-style initialization
const conv = torch.nn.Conv2d(3, 64, { kernel_size: 7, stride: 2, padding: 3 });
torch.nn.init.kaiming_uniform_(conv.weight, { a: 0, mode: 'fan_out', nonlinearity: 'relu' });
torch.nn.init.zeros_(conv.bias);
const bn = torch.nn.BatchNorm2d(64);
const x = torch.randn([32, 3, 224, 224]);
let y = conv.forward(x);
y = bn.forward(y);
y = torch.nn.functional.relu(y);// Deep ResNet initialization
class ResNetInitializer {
static initializeModel(model: torch.nn.Module) {
for (const [name, module] of model.named_modules()) {
if (module instanceof torch.nn.Conv2d || module instanceof torch.nn.Linear) {
if (module instanceof torch.nn.Conv2d) {
torch.nn.init.kaiming_uniform_(
module.weight,
{ a: 0, mode: 'fan_out', nonlinearity: 'relu' }
);
} else if (module instanceof torch.nn.Linear) {
torch.nn.init.kaiming_uniform_(
module.weight,
{ a: 0, mode: 'fan_in', nonlinearity: 'relu' }
);
}
if (module.bias) {
torch.nn.init.zeros_(module.bias);
}
} else if (module instanceof torch.nn.BatchNorm2d) {
torch.nn.init.ones_(module.weight);
torch.nn.init.zeros_(module.bias);
}
}
}
}See Also
- PyTorch torch.nn.init.kaiming_uniform_()
- torch.nn.init.kaiming_normal_ - Kaiming with normal distribution
- torch.nn.init.xavier_uniform_ - Xavier initialization (for sigmoid/tanh)
- torch.nn.init.xavier_normal_ - Xavier with normal distribution
- torch.nn.init.calculate_gain - Get gain for specific activation function