torch.nn.TransformerEncoder
class TransformerEncoder extends Modulenew TransformerEncoder(options: TransformerEncoderOptions)
- readonly
num_layers(number) layers(ModuleList)norm(LayerNorm | null)
Stack of N Transformer encoder layers.
A TransformerEncoder applies multiple TransformerEncoderLayer sequentially to progressively transform the input sequence. Each layer receives the output of the previous layer, allowing deep feature extraction and long-range dependencies.
Commonly used in:
- Standalone encoder-only models (BERT, RoBERTa, DistilBERT)
- Encoder part of seq2seq models (machine translation)
- Vision transformers (ViT) for image understanding
- Document/text representation learning
- Feature extraction before classification/tagging
- Final layer normalization: Including norm improves output stability and is recommended. The final LayerNorm(d_model) normalizes the output from the last encoder layer.
- Attention masks: Applied to all layers. Typically don't change between layers.
- Depth scaling: Increasing num_layers improves representational capacity but also increases memory/compute. Common architectures: BERT (12 layers), GPT-2 (12 layers), GPT-3 (96 layers).
- Width vs depth: You can also increase d_model or dim_feedforward to increase capacity.
- All layers share same architecture: All num_layers have same d_model, nhead, etc. (from encoder_layer).
- Computational complexity: O(num_layers × seq_len² × d_model). Quadratic in seq_len!
- Very long sequences (1000 tokens) can cause OOM errors due to attention's O(seq_len²) complexity.
- Ensure encoder_layer.d_model is properly specified (not default 512) before stacking.
- Attention masks have shape [seq_len, seq_len] or can be broadcastable to that.
- Padding masks have shape [batch_size, seq_len] with True indicating padding positions.
Examples
// Create encoder with 6 layers (BERT-style)
const encoder_layer = new torch.nn.TransformerEncoderLayer({
d_model: 768,
nhead: 12,
dim_feedforward: 3072,
dropout: 0.1
});
const encoder = new torch.nn.TransformerEncoder({
encoder_layer,
num_layers: 6,
norm: new torch.nn.LayerNorm(768)
});
// Process a sequence
const src = torch.randn(10, 32, 768); // [seq_len, batch, d_model]
const encoded = encoder.encode(src); // same shape
// With attention mask to prevent attending to certain positions
const mask = torch.zeros(10, 10);
encoded = encoder.encode(src, { mask });
// With padding mask for variable-length sequences
const padding_mask = torch.tensor([[false, false, true, ...], ...]); // [batch, seq_len]
encoded = encoder.encode(src, { src_key_padding_mask: padding_mask });