torch.nn.init.trunc_normal_
function trunc_normal_(tensor: Tensor, options?: TruncNormalOptions): Tensorfunction trunc_normal_(tensor: Tensor, mean: number, std: number, a: number, b: number, options?: TruncNormalOptions): TensorFill tensor with truncated normal distribution for bounded weight initialization.
Truncated normal sampling prevents extreme initialization values by resampling any values outside the specified bounds. Useful for initializing networks where very large weights are problematic. Essential for:
- Vision Transformers (ViTs) and modern transformer models
- Networks sensitive to initialization scale
- When you want normal distribution without extreme outliers
- Fine-tuning with strict weight bounds
- Initialization where empirical bounds matter more than variance
Values are drawn from N(mean, std²) but any values outside [a, b] are clamped to bounds (simplified implementation; true truncated normal would resample).
- ViT standard: ViT uses trunc_normal_ with std=0.02 as default
- Bounds interpretation: a and b are in units of standard deviation from mean
- Default excludes: Default bounds [-2σ, +2σ] exclude ~4.6% of normal distribution
- Clamping method: Implementation uses clamping rather than true rejection sampling
- Effect on distribution: Truncation makes distribution slightly flatter near bounds
- Comparison to normal: Smoother, more bounded than regular normal_ initialization
- In-place operation: Modifies tensor in-place; returns the same tensor
Parameters
tensorTensor- An n-dimensional Tensor (typically weight matrix from a layer)
optionsTruncNormalOptionsoptional- Optional settings for truncated normal initialization
Returns
Tensor– The input tensor with truncated normal initialization Algorithm: - Sample from N(mean, std²) - Clamp values to [mean + a×std, mean + b×std] - Default bounds [-2σ, +2σ] exclude ~5% of normal distributionExamples
// Vision Transformer initialization (common use case)
const layer = torch.nn.Linear(768, 768);
torch.nn.init.trunc_normal_(layer.weight, { mean: 0.0, std: 0.02, a: -2.0, b: 2.0 });
torch.nn.init.zeros_(layer.bias);// Tight bounds to prevent extreme weights
const layer = torch.nn.Linear(512, 256);
torch.nn.init.trunc_normal_(layer.weight, { mean: 0.0, std: 0.01, a: -1.0, b: 1.0 });
// Stricter bounds: values outside [-0.01, +0.01] will be clamped// Custom mean and range
const layer = torch.nn.Linear(256, 128);
torch.nn.init.trunc_normal_(layer.weight, { mean: 0.5, std: 0.1, a: -1.0, b: 2.0 });
// Mean shifted to 0.5, bounds: [0.4, 0.7]// ViT-style transformer encoder initialization
class ViTInitializer {
static initializeTransformer(model: torch.nn.Module) {
for (const [name, module] of model.named_modules()) {
if (module instanceof torch.nn.Linear) {
torch.nn.init.trunc_normal_(
module.weight,
{
mean: 0.0, // mean
std: 0.02, // std
a: -2.0, // a (lower bound in stds)
b: 2.0 // b (upper bound in stds)
}
);
if (module.bias) {
torch.nn.init.zeros_(module.bias);
}
}
}
}
}See Also
- PyTorch torch.nn.init.trunc_normal_()
- torch.nn.init.normal_ - Regular normal distribution (unbounded)
- torch.nn.init.kaiming_normal_ - Normal initialization for ReLU networks
- torch.nn.init.xavier_normal_ - Xavier normal initialization