torch.nn.RNNCell
new RNNCell(input_size: number, hidden_size: number, options?: RNNCellOptions)
- readonly
input_size(number) - readonly
hidden_size(number) - readonly
bias(boolean) - readonly
nonlinearity('tanh' | 'relu') weight_ih(Parameter)weight_hh(Parameter)bias_ih(Parameter | null)bias_hh(Parameter | null)
An Elman RNN cell: single timestep recurrent unit with tanh or ReLU.
Computes one recurrent step: h_t = activation(W_ih @ x_t + b_ih + W_hh @ h_{t-1} + b_hh). This is the building block for sequence processing - process one timestep at a time, maintaining hidden state across timesteps. Essential for:
- Building custom RNN architectures
- Understanding RNN mechanics at the single-step level
- Time series forecasting (one step at a time)
- Sequence-to-sequence models with fine-grained control
- Conditional generation (sampling sequentially)
Unlike the high-level RNN module which processes entire sequences, RNNCell processes a single timestep and returns the next hidden state. You manage the sequence loop manually, giving complete control over hidden state, dropout, attention, etc.
When to use RNNCell:
- Building custom sequence models with complex architectures
- Implementing attention mechanisms between RNN steps
- Variable-length sequences with masking
- Conditional generation (sampling output affects next input)
- Debugging RNN behavior step-by-step
- Teacher forcing in sequence-to-sequence models
Trade-offs:
- vs RNN: RNNCell gives fine control; RNN automates the loop
- Flexibility: RNNCell lets you insert custom logic between steps
- Complexity: Manual loop management vs automatic in RNN
- Performance: RNNCell slightly slower per-step due to loop overhead
- Activation: Both tanh and ReLU available (tanh more common)
Cell Computation: At each timestep, the Elman RNN cell computes:
- Combine input and hidden state: combined = W_ih @ x_t + b_ih + W_hh @ h_{t-1} + b_hh
- Apply activation: h_t = activation(combined)
- Return h_t for use as next h_{t-1} and as output for this step
- Stateless: RNNCell itself has no state - you manage h_t explicitly
- Single step: Process one timestep at a time; loop manually
- Initialization: Weights initialized with Kaiming uniform (scale 1/sqrt(hidden_size))
- Biases: Input and hidden biases separate to match PyTorch exactly
- Activation choice: tanh (default) more stable; relu for sparsity
- Gradient flow: tanh better for deep networks; relu can vanish differently
- No dropout: Apply torch.nn.functional.dropout between RNNCell steps
- Batch dimension: Always [batch, features]; handles batches automatically
- Vanishing/exploding gradients: Long sequences can suffer without layer norm or careful initialization
- Manual state management: You must initialize and pass hidden state correctly
- Unbounded hidden state: Unlike LSTM, RNNCell has no cell state to constrain values
- First hidden state: Must match batch size - commonly initialize to zeros([batch, hidden_size])
- Sequence dimension: You loop over sequences - index with .select(dim, t) or similar
Examples
// Process sequence one step at a time
const rnn_cell = new torch.nn.RNNCell(10, 20); // input_size=10, hidden_size=20
const x = torch.randn([32, 5, 10]); // [batch=32, seq_len=5, input_size=10]
let h = torch.zeros([32, 20]); // Initial hidden state [batch, hidden_size]
// Process sequence manually
const outputs: torch.Tensor[] = [];
for (let t = 0; t < 5; t++) {
h = rnn_cell.forward(x.select(1, t), h); // Process single timestep
outputs.push(h);
}
const output = torch.stack(outputs, 1); // [batch, seq_len, hidden_size]// Sequence-to-sequence with attention (custom logic per step)
class AttentiveRNNDecoder extends torch.nn.Module {
rnn_cell: torch.nn.RNNCell;
attention: torch.nn.Module;
output_proj: torch.nn.Linear;
constructor() {
super();
this.rnn_cell = new torch.nn.RNNCell(256, 512);
this.attention = ...; // Custom attention mechanism
this.output_proj = new torch.nn.Linear(512, 10000); // Vocab size
}
forward(encoder_output: torch.Tensor, target_seq: torch.Tensor): torch.Tensor {
let h = torch.zeros([encoder_output.shape[0], 512]);
const outputs: torch.Tensor[] = [];
for (let t = 0; t < target_seq.shape[1]; t++) {
// RNN step
h = this.rnn_cell.forward(target_seq.select(1, t), h);
// Custom attention between steps
const context = this.attention.forward(h, encoder_output);
const combined = torch.cat([h, context], -1);
// Generate output for this step
const logits = this.output_proj.forward(combined);
outputs.push(logits);
}
return torch.stack(outputs, 1);
}
}// Conditional generation (teacher forcing then free-running)
const rnn = new torch.nn.RNNCell(100, 256);
const vocab_proj = new torch.nn.Linear(256, 5000);
// Teacher forcing: use ground truth targets
let h = torch.zeros([1, 256]);
for (let t = 0; t < target_len; t++) {
h = rnn.forward(target_embeddings[t], h); // Teacher forced
}
// Free-running: use sampled predictions
const sampled: torch.Tensor[] = [];
for (let t = 0; t < 100; t++) {
h = rnn.forward(h_input, h); // Use previous output as input
const logits = vocab_proj.forward(h);
const token = torch.argmax(logits, -1);
sampled.push(token);
}// Comparing tanh vs ReLU activation
const rnn_tanh = new torch.nn.RNNCell(10, 20, true, 'tanh'); // More stable
const rnn_relu = new torch.nn.RNNCell(10, 20, true, 'relu'); // Sparser
const x = torch.randn([32, 10]);
const h = torch.randn([32, 20]);
const h_tanh = rnn_tanh.forward(x, h); // Values in [-1, 1]
const h_relu = rnn_relu.forward(x, h); // Values >= 0 (sparse)