torch.nn.init.calculate_gain
function calculate_gain(nonlinearity: Nonlinearity, options?: CalculateGainOptions): numberfunction calculate_gain(nonlinearity: Nonlinearity, param: number, options?: CalculateGainOptions): numberCalculate the recommended gain (scaling factor) for weight initialization based on nonlinearity.
Different activation functions benefit from different weight initialization scales. Xavier and Kaiming initialization methods use gain values to scale their variance appropriately for the activation function. Essential for:
- Scaling weights correctly for Xavier/Glorot initialization
- Adjusting initialization for He/Kaiming initialization
- Ensuring signal propagation through deep networks
- Matching initialization to specific activation functions
- Training stability in deep networks
- ReLU family: ReLU and LeakyReLU use √2 (ReLU) or √(2/(1+α²)) (Leaky)
- Tanh: Uses larger gain (5/3) because tanh derivative is smaller than ReLU
- Linear layers: Linear/Conv layers use gain=1 (no activation scaling)
- SELU: Special gain 3/4 works with SELU's self-normalizing property
- Interaction with He init: Use gain=1 with Kaiming; gain factors into Xavier
- Parameter importance: For leaky_relu, matching param to training α is crucial
Parameters
nonlinearityNonlinearity- The non-linear activation function name: - 'linear', 'conv1d', 'conv2d', 'conv3d', 'conv_transpose1d/2d/3d': gain = 1 - 'sigmoid': gain = 1 - 'tanh': gain = 5/3 ≈ 1.667 - 'relu': gain = √2 ≈ 1.414 - 'leaky_relu': gain = √(2 / (1 + α²)) where α is negative_slope - 'selu': gain = 3/4 = 0.75
optionsCalculateGainOptionsoptional- Optional settings for gain calculation
Returns
number– The recommended gain value to use for weight initialization Mathematical basis: - Xavier/Glorot: targets constant variance in forward/backward pass - Gain scaling depends on activation function's derivative at zero - Different activations have different average gradient magnitudesExamples
// Getting gain for different activation functions
const linear_gain = torch.nn.init.calculate_gain('linear'); // 1.0
const relu_gain = torch.nn.init.calculate_gain('relu'); // ~1.414
const tanh_gain = torch.nn.init.calculate_gain('tanh'); // ~1.667
const leaky_gain = torch.nn.init.calculate_gain('leaky_relu', { param: 0.2 }); // ~1.38// Using gain with Xavier initialization
const layer = torch.nn.Linear(512, 256);
const gain = torch.nn.init.calculate_gain('relu'); // Get gain for ReLU
torch.nn.init.xavier_uniform_(layer.weight, { gain }); // Apply with gain
torch.nn.init.zeros_(layer.bias);// Custom layers with activation-aware initialization
class CustomLayer extends torch.nn.Module {
weight: torch.nn.Parameter;
activation: string;
constructor(in_features: number, out_features: number, activation: string) {
super();
this.activation = activation;
const gain = torch.nn.init.calculate_gain(activation);
this.weight = torch.nn.Parameter.create(
torch.empty([out_features, in_features])
);
torch.nn.init.xavier_normal_(this.weight, { gain });
}
}See Also
- PyTorch torch.nn.init.calculate_gain()
- torch.nn.init.xavier_uniform_ - Xavier initialization (uses gain)
- torch.nn.init.xavier_normal_ - Xavier initialization (uses gain)
- torch.nn.init.kaiming_uniform_ - He/Kaiming initialization
- torch.nn.init.kaiming_normal_ - He/Kaiming initialization