torch.nn.GRUCell
new GRUCell(input_size: number, hidden_size: number, options?: RNNCellOptions)
- readonly
input_size(number) - readonly
hidden_size(number) - readonly
bias(boolean) weight_ih(Parameter)weight_hh(Parameter)bias_ih(Parameter | null)bias_hh(Parameter | null)
A gated recurrent unit (GRU) cell: single timestep recurrent unit with reset/update gates.
Like LSTM but simpler: uses reset and update gates to control information flow, without separate cell state. Fewer parameters than LSTM (3 gates vs 4) but competitive performance. Computes hidden state as: h_t = (1 - z_t) ⊙ n_t + z_t ⊙ h_{t-1}, where z_t is update gate and n_t is the new hidden content. Essential for:
- Processing sequences with fewer parameters than LSTM (3 gates vs 4)
- Learning long-range dependencies without LSTM's cell state complexity
- Faster training than LSTM (less computation per step)
- Competitive stability vs RNN, simpler than LSTM
- Mobile/embedded deployment (smaller models)
GRU offers a middle ground: simpler than LSTM (no separate cell state) but more stable than RNN (gated mechanisms). Reset gate controls what history matters, update gate blends new content with previous hidden state. Often as effective as LSTM in practice.
When to use GRUCell:
- Custom sequence models preferring simplicity over LSTM's expressiveness
- Memory-constrained applications (3 gates vs LSTM's 4)
- Building bidirectional models with fine-grained control
- Teacher forcing with GRU memory
- Variable-length sequences with masking and custom logic
- When LSTM shows no performance improvement but costs more
Trade-offs:
- vs RNNCell: GRU has gates for stability; RNN simpler but less stable
- vs LSTMCell: GRU fewer params (3 gates), simpler; LSTM more expressive with cell state
- Parameters: GRU has 3x hidden_size gates vs RNN's 1x; LSTM has 4x
- Stability: GRU much better than RNN, almost as good as LSTM empirically
- Speed: GRU faster per-step than LSTM; simpler than both
- Gradient flow: GRU gates control gradient better than RNN, nearly as well as LSTM
GRU Mechanism: Three gates: reset r_t, update z_t, new content n_t
- Reset gate: r_t = σ(W_ir @ x_t + b_ir + W_hr @ h_{t-1} + b_hr) (forget previous?)
- Update gate: z_t = σ(W_iz @ x_t + b_iz + W_hz @ h_{t-1} + b_hz) (blend with new?)
- New content: n_t = tanh(W_in @ x_t + b_in + r_t ⊙ (W_hn @ h_{t-1} + b_hn)) (candidate)
- Hidden state: h_t = (1 - z_t) ⊙ n_t + z_t ⊙ h_{t-1} (interpolate)
- Single hidden state: GRU only maintains h_t, unlike LSTM's h_t and c_t
- Reset gate mechanics: r_t controls whether to use previous hidden in new content
- Update gate mechanics: z_t controls interpolation between new and previous
- Fewer parameters: Only 3 gates vs LSTM's 4 gates (25% fewer weights)
- Initialization: Weights Kaiming uniform, biases zero-initialized
- Stateless: GRUCell has no state - you manage h_t explicitly
- Batch size: Hidden state must match input batch size
- Gradient flow: Better than RNN via gating, nearly as good as LSTM empirically
- Single state only: Unlike LSTM, no protected cell state for long dependencies
- Long sequences: Still vulnerable to vanishing gradients like RNN, just less
- Initialization critical: Poor init → gates stuck; reset bias ~0, update bias ~1 common
- Computation order: Reset gate affects new content computation - different from LSTM
- Hidden state unbounded: h_t can grow without normalization like vanilla RNN
Examples
// Process sequence with GRU cell
const gru_cell = new torch.nn.GRUCell(10, 20); // input_size=10, hidden_size=20
const x = torch.randn([32, 50, 10]); // [batch=32, seq_len=50, input_size=10]
let h = torch.zeros([32, 20]); // Initial hidden state
// Process sequence with manual loop
const outputs: torch.Tensor[] = [];
for (let t = 0; t < 50; t++) {
h = gru_cell.forward(x.select(1, t), h); // Only one hidden state needed
outputs.push(h);
}
const output = torch.stack(outputs, 1); // [batch, seq_len, hidden_size]// Stacked GRU: multiple layers manually
const layer1_gru = new torch.nn.GRUCell(10, 64);
const layer2_gru = new torch.nn.GRUCell(64, 32);
const x = torch.randn([32, 50, 10]);
let h1 = torch.zeros([32, 64]);
let h2 = torch.zeros([32, 32]);
const outputs: torch.Tensor[] = [];
for (let t = 0; t < 50; t++) {
// First layer
h1 = layer1_gru.forward(x.select(1, t), h1);
// Second layer (input is first layer output)
h2 = layer2_gru.forward(h1, h2);
outputs.push(h2);
}// Bidirectional GRU (forward and backward)
const gru_fwd = new torch.nn.GRUCell(10, 20);
const gru_bwd = new torch.nn.GRUCell(10, 20);
const x = torch.randn([32, 50, 10]);
// Forward
let h_fwd = torch.zeros([32, 20]);
const fwd_outputs: torch.Tensor[] = [];
for (let t = 0; t < 50; t++) {
h_fwd = gru_fwd.forward(x.select(1, t), h_fwd);
fwd_outputs.push(h_fwd);
}
// Backward
let h_bwd = torch.zeros([32, 20]);
const bwd_outputs: torch.Tensor[] = [];
for (let t = 49; t >= 0; t--) {
h_bwd = gru_bwd.forward(x.select(1, t), h_bwd);
bwd_outputs.unshift(h_bwd);
}
// Concatenate: [batch, seq_len, 40] (20 fwd + 20 bwd)
const bidir: torch.Tensor[] = [];
for (let t = 0; t < 50; t++) {
bidir.push(torch.cat([fwd_outputs[t], bwd_outputs[t]], -1));
}
const bidir_output = torch.stack(bidir, 1);// Sequence classification with GRU
class GRUClassifier extends torch.nn.Module {
embedding: torch.nn.Embedding;
gru: torch.nn.GRUCell;
output: torch.nn.Linear;
constructor(vocab_size: number, embed_dim: number, hidden_dim: number, num_classes: number) {
super();
this.embedding = new torch.nn.Embedding(vocab_size, embed_dim);
this.gru = new torch.nn.GRUCell(embed_dim, hidden_dim);
this.output = new torch.nn.Linear(hidden_dim, num_classes);
}
forward(token_ids: torch.Tensor): torch.Tensor {
const embedded = this.embedding.forward(token_ids); // [batch, seq_len, embed_dim]
let h = torch.zeros([embedded.shape[0], 256]);
for (let t = 0; t < embedded.shape[1]; t++) {
h = this.gru.forward(embedded.select(1, t), h);
}
// Use final hidden state for classification
return this.output.forward(h); // [batch, num_classes]
}
}// Comparing GRU with LSTM and RNN
const rnn = new torch.nn.RNNCell(10, 20);
const gru = new torch.nn.GRUCell(10, 20);
const lstm = new torch.nn.LSTMCell(10, 20);
const x = torch.randn([32, 10]);
let h_rnn = torch.zeros([32, 20]);
let h_gru = torch.zeros([32, 20]);
let h_lstm = torch.zeros([32, 20]);
let c_lstm = torch.zeros([32, 20]);
h_rnn = rnn.forward(x, h_rnn); // Only h, simpler
h_gru = gru.forward(x, h_gru); // Only h, still gated
[h_lstm, c_lstm] = lstm.step(x, [h_lstm, c_lstm]); // h and c, more expressive