torch.nn.GLUOptions
GLU (Gated Linear Unit) activation function.
GLU applies a self-gating mechanism where input is split into two parts: the first part is multiplied by the sigmoid of the second part. This acts as an adaptive gating mechanism where one part controls how much of the other part passes through. GLU is particularly effective in sequence modeling and has become increasingly popular in modern NLP models (Transformer variants, language models) as an alternative to the Feed-Forward Network layers. The key advantage is dynamic, data-dependent gating that allows the network to selectively amplify or suppress different feature representations.
Core idea: GLU(a, b) = a ⊗ sigmoid(b), where a, b are obtained by splitting input in half. The sigmoid(b) acts as a "gate" that modulates how much of a passes through. When sigmoid(b) ≈ 1, a passes through unchanged; when sigmoid(b) ≈ 0, a is suppressed. This allows learning which features are important at each point in the input.
When to use GLU:
- Sequence modeling: RNNs, LSTMs alternatives, modern transformers
- Language models: GPT variants, decoder-only models use GLU variants
- Feed-forward replacement: Replaces dense ReLU layers in Transformers (Transformer-XL)
- Computer vision: Vision Transformers and modern visual models
- Adaptive feature selection: When you need data-dependent gating per feature
GLU variants and related:
- GLU: Input split in half, compute a * sigmoid(b)
- GELU: Similar gating with GELU: a * GELU(b) (popular in modern Transformers)
- SwiGLU: Gated activation with SiLU: a * SiLU(b) (used in PaLM, other recent models)
- ReGLU: Gated with ReLU: a * ReLU(b) (alternative gating with ReLU)
Algorithm: Forward: GLU(input) = a ⊗ sigmoid(b)
- Split input in half along specified dimension: a, b = split(input, dim)
- Compute sigmoid(b), element-wise multiply with a
- Result has shape with dimension size halved
Backward: ∂GLU(input)/∂input = [sigmoid(b) + a * sigmoid(b) * (1 - sigmoid(b))] (chain rule)
- Gate gradients flow back to both a and b
- Enables learning which features gate should pass through
Definition
export interface GLUOptions {
/** Dimension on which to split the input (default: -1) */
dim?: number;
}dim(number)optional- – Dimension on which to split the input (default: -1)
Examples
// Modern Transformer: Replace FFN with GLU
class TransformerFFN extends torch.nn.Module {
private linear_gate: torch.nn.Linear;
private glu: torch.nn.GLU;
constructor(dim: number, hidden_dim: number) {
super();
// Project to 2 * hidden_dim to split in GLU
this.linear_gate = new torch.nn.Linear(dim, 2 * hidden_dim);
this.glu = new torch.nn.GLU(1); // Split along feature dimension
}
forward(x: torch.Tensor): torch.Tensor {
// x: [batch, seq_len, dim]
const gated = this.linear_gate.forward(x); // [batch, seq_len, 2*hidden_dim]
return this.glu.forward(gated); // [batch, seq_len, hidden_dim]
}
}
// GLU provides adaptive feature selection vs fixed ReLU in standard FFN// Language model with SwiGLU (practical modern variant)
// Note: GLU does sigmoid gating; SwiGLU does SiLU gating (even better)
const dim = 768, hidden_dim = 3072;
const linear1 = new torch.nn.Linear(dim, 2 * hidden_dim); // Projects up and splits
const glu = new torch.nn.GLU(1); // Gate with sigmoid
const linear2 = new torch.nn.Linear(hidden_dim, dim); // Projects back down
// In a transformer block:
let x = torch.randn([batch_size, seq_len, dim]);
let gated = linear1.forward(x); // [batch, seq_len, 2*hidden_dim]
let gated = glu.forward(gated); // [batch, seq_len, hidden_dim] - adaptive gating
let output = linear2.forward(gated); // [batch, seq_len, dim]// Comparison: GLU vs ReLU FFN behavior
const x = torch.randn([32, 10]); // Batch size 32, feature dim 10
// Standard ReLU FFN
const fc1_relu = new torch.nn.Linear(10, 40);
const relu = new torch.nn.ReLU();
const fc2_relu = new torch.nn.Linear(40, 10);
const relu_output = fc2_relu.forward(relu.forward(fc1_relu.forward(x)));
// GLU FFN (adaptive gating)
const fc1_glu = new torch.nn.Linear(10, 40); // Projects to 2*20 for split in GLU
const glu = new torch.nn.GLU(1); // Gate with sigmoid
const fc2_glu = new torch.nn.Linear(20, 10); // Input is now 20 (half of 40)
const glu_output = fc2_glu.forward(glu.forward(fc1_glu.forward(x)));
// Both have ~same param count but GLU is more expressive (gates per feature)