torch.nn.TransformerDecoder
class TransformerDecoder extends Modulenew TransformerDecoder(options: TransformerDecoderOptions)
- readonly
num_layers(number) layers(ModuleList)norm(LayerNorm | null)
Stack of N Transformer decoder layers.
A TransformerDecoder applies multiple TransformerDecoderLayer sequentially, with each layer processing the target sequence while attending to the encoder memory. This enables auto-regressive sequence generation and encoder-decoder sequence-to-sequence models.
Key differences from TransformerEncoder:
- Each layer has cross-attention to encoder output (memory), not just self-attention
- Usually used with causal masking on self-attention to prevent attending to future tokens
- Designed for generation tasks where tokens are produced one at a time
Commonly used in:
- Machine translation (encoder-decoder models like BERT2BERT, mBART)
- Image captioning (visual encoder + text decoder)
- Sequence-to-sequence models (speech-to-text, abstractive summarization)
- Auto-regressive text generation (with causal masking)
- Question answering (encode question, decode answer)
- Cross-attention to memory: Each decoder layer attends to the encoder output (memory). Memory shape is [src_seq_len, batch, d_model], typically different from target sequence length.
- Causal masking: When generating auto-regressively, use causal mask on self-attention to prevent attending to future tokens. During training with teacher forcing, you can optionally use it too.
- Final layer normalization: Including norm (typically LayerNorm(d_model)) stabilizes outputs. Highly recommended, especially for deep decoders.
- Target and memory shapes: Target has shape [tgt_seq_len, batch, d_model], memory has shape [src_seq_len, batch, d_model]. Sequences can be different lengths!
- Depth vs encoder: Decoder often has same depth as encoder (both 6 layers in base Transformer), but can differ. Deeper decoders (12+ layers) used in large language models.
- Computational complexity: O(num_layers × tgt_seq_len × (tgt_seq_len + src_seq_len) × d_model). Depends on both target and source lengths.
- Causal mask MUST be applied to self-attention (not cross-attention). Cross-attention can freely attend to memory.
- Memory must come from the same encoder (usually). If encoder changes, results will differ significantly.
- During inference with auto-regression, typically use growing target sequence (KV caching is not implemented).
- Ensure tgt has shape [tgt_seq_len, batch, d_model] and memory has shape [src_seq_len, batch, d_model].
- Padding masks should indicate which positions are padding (True = ignore), not which are valid.
Examples
// Create decoder with 6 layers
const decoder_layer = new torch.nn.TransformerDecoderLayer({
d_model: 512,
nhead: 8,
dim_feedforward: 2048,
dropout: 0.1
});
const decoder = new torch.nn.TransformerDecoder({
decoder_layer,
num_layers: 6,
norm: new torch.nn.LayerNorm(512)
});
// Get memory from encoder
const src = torch.randn(10, 32, 512);
const memory = encoder.encode(src); // [seq_len, batch, d_model]
// Decode from target to output
const tgt = torch.randn(20, 32, 512); // target sequence [seq_len, batch, d_model]
const output = decoder.decode(tgt, memory);
// With causal mask for auto-regressive generation
const tgt_mask = torch.nn.Transformer.generate_square_subsequent_mask(20);
const output = decoder.decode(tgt, memory, { tgt_mask });
// Full encoder-decoder pipeline
const encoded = encoder.encode(src);
const decoded = decoder.decode(tgt, encoded);
const predictions = output_projection(decoded); // project to vocab
// During inference: generate one token at a time
let generated = [bos_token_id];
for (let step = 0; step < max_length; step++) {
const tgt_emb = embedding(torch.tensor(generated));
const tgt_mask = torch.nn.Transformer.generate_square_subsequent_mask(generated.length);
const out = decoder.decode(tgt_emb, memory, { tgt_mask });
const logits = output_proj(out[-1]); // last token
const next_token = torch.argmax(logits).item();
generated.push(next_token);
if (next_token === eos_token_id) break;
}