torch.nn.utils.weight_norm
function weight_norm<T extends Module>(module: T, options?: WeightNormOptions): Tfunction weight_norm<T extends Module>(module: T, name: string, dim: number, options?: WeightNormOptions): TApply weight normalization to a parameter in the given module.
Weight normalization reparametrizes the weight tensor as: w = g * v / ||v||
where g is a scalar magnitude and v is the unnormalized direction. This decouples the magnitude from the direction, which can help with optimization.
Parameters
moduleT- Module containing the parameter
optionsWeightNormOptionsoptional- Optional settings for weight normalization
Returns
T– The module with weight normalization appliedExamples
const linear = new torch.nn.Linear(20, 40);
torch.nn.utils.weight_norm(linear, { name: 'weight' });
// Now linear has weight_v and weight_g parameters
// The original weight is computed as g * v / ||v||