torch.nn.Transformer
class Transformer extends Modulenew Transformer(options: TransformerOptions = {})
- readonly
d_model(number) - readonly
nhead(number) - readonly
batch_first(boolean) encoder(TransformerEncoder)decoder(TransformerDecoder)
A complete Transformer model with encoder and decoder.
The Transformer architecture consists of an encoder-decoder stack that enables sequence-to-sequence transduction. The encoder processes the source sequence to produce context representations (memory), and the decoder generates the target sequence while attending to the encoder output.
This is the model from "Attention is All You Need" (Vaswani et al., 2017), the foundational architecture for modern large language models (BERT, GPT, T5, etc.).
Architecture:
- Encoder: Stack of encoder layers with self-attention
- Decoder: Stack of decoder layers with self-attention and cross-attention to encoder
- Uses residual connections and layer normalization throughout
Commonly used for:
- Machine translation (original use case)
- Sequence-to-sequence tasks (summarization, question answering, code generation)
- Encoder-decoder models with custom training objectives
- Research and prototyping of transformer variants
- Converting between encoder-only (BERT) and decoder-only (GPT) models
- Encoder-only mode: You can use just the encoder part for BERT-style models (set num_decoder_layers=0 if needed).
- Decoder-only mode: You can use just decoder with self-attention for GPT-style models (but that's typically implemented differently).
- Input shapes: Default is [seq_len, batch_size, d_model]. Use batch_first=true for [batch_size, seq_len, d_model].
- Embeddings not included: This class expects pre-embedded inputs. Add token/positional embeddings before feeding data.
- Post-LN vs Pre-LN: Original "Attention is All You Need" uses Post-LN (norm_first=false). Pre-LN (norm_first=true) is more stable for very deep models (12+ layers).
- Typical scaling: Depth and width trade-offs: - BERT-base: 12 layers, 768 d_model, 12 heads - BERT-large: 24 layers, 1024 d_model, 16 heads - GPT-3: 96 layers, 12288 d_model, 96 heads
- Computational complexity: O(num_layers × seq_len² × d_model). Quadratic in sequence length!
- Memory requirements: O(batch_size × seq_len × d_model) for intermediate activations. Large sequence lengths can cause OOM errors.
- Inputs must be pre-embedded (token embeddings + positional encodings). This class doesn't include embeddings.
- Sequence lengths 1024 tokens can cause memory/computation issues. Use sparse attention or hierarchical approaches for longer sequences.
- Ensure src and tgt have the same d_model dimension. Mismatches cause shape errors in attention.
- Causal masks for decoder should use -Infinity for masked positions, not 0 or 1.
- Attention masks convention: True = mask out (ignore), False = keep. This differs from some libraries!
- d_model must be divisible by nhead, otherwise attention head dimension is not an integer.
Examples
// Create default transformer (base size, like BERT-base)
const transformer = new torch.nn.Transformer({
d_model: 768,
nhead: 12,
num_encoder_layers: 12,
num_decoder_layers: 12,
dim_feedforward: 3072,
dropout: 0.1
});
// Forward pass: encode source, decode target
const src = torch.randn(10, 32, 768); // [src_seq_len, batch_size, d_model]
const tgt = torch.randn(20, 32, 768); // [tgt_seq_len, batch_size, d_model]
const output = transformer.run(src, tgt); // [20, 32, 768]
// Machine translation: encode source sentence
const src_ids = tokenizer.encode("Hello world");
const src_emb = embedding(torch.tensor(src_ids).unsqueeze(1)); // [seq_len, 1, d_model]
const memory = transformer.encoder.encode(src_emb);
// Decode target sentence (auto-regressively)
let tgt_ids = [bos_token_id];
for (let i = 0; i < max_length; i++) {
const tgt_emb = embedding(torch.tensor(tgt_ids).unsqueeze(1));
const tgt_mask = torch.nn.Transformer.generate_square_subsequent_mask(tgt_ids.length);
const out = transformer.decoder.decode(tgt_emb, memory, { tgt_mask });
const logits = output_proj(out[-1]);
const next_id = torch.argmax(logits).item();
tgt_ids.push(next_id);
if (next_id === eos_token_id) break;
}
// With attention masks for padding
const src_padding = torch.ones(10, 32);
src_padding.fill_(true); // mark padding positions
const output = transformer.run(src, tgt, { src_key_padding_mask: src_padding });
// Modern variant: Pre-LN (norm_first=true) for stability with deep models
const deep_transformer = new torch.nn.Transformer({
d_model: 1024,
nhead: 16,
num_encoder_layers: 24,
num_decoder_layers: 24,
norm_first: true // More stable for deep models (like GPT-3)
});