torch.nn.Embedding
new Embedding(num_embeddings: number, embedding_dim: number, options?: EmbeddingOptions)
- readonly
num_embeddings(number) - readonly
embedding_dim(number) - readonly
padding_idx(number | null) - readonly
max_norm(number | null) - readonly
norm_type(number) - readonly
scale_grad_by_freq(boolean) - readonly
sparse(boolean) weight(Parameter)
Learnable embedding lookup table: converts token indices to dense vectors.
A fundamental layer that transforms discrete token IDs into continuous vector representations. Essential for:
- Natural language processing (word embeddings for any vocabulary)
- Sequence modeling (tokens → dense representations → transformers)
- Recommendation systems (item/user IDs → embedding vectors)
- Graph neural networks (node IDs → embeddings)
- Categorical feature embeddings in deep learning
- Position embeddings in attention-based models (combined with positional encoding)
Creates a weight matrix of shape [num_embeddings, embedding_dim] where each row is the learnable embedding vector for a vocabulary item. During forward pass, looks up embeddings for given token indices, enabling end-to-end training with gradient flow back to embedding weights.
- Vocabulary size matters: Large vocabularies need large embedding matrices (memory-intensive)
- Embedding dimension selection: Typical range: 64-768 for general tasks, 768-2048 for BERT-scale
- Initialization: Random normal initialization (N(0,1)) by default
- Padding token: Set padding_idx to have zero embedding vector (important for masking)
- Frozen embeddings: Use from_pretrained with freeze=true for transfer learning
- Memory usage: Parameters = num_embeddings × embedding_dim × 4 bytes (float32)
- Sparse vs dense gradients: Sparse updates only for accessed embeddings (not implemented)
- GPU acceleration: Embedding lookup is efficient on GPU (parallel table lookup)
- Typical architecture pattern: Token IDs → Embedding → Transformer/LSTM/CNN → Output
- Combined with position encoding: Position embeddings added/concatenated to token embeddings
Examples
// Natural language processing: word embedding lookup
const vocab_size = 10000; // Dictionary with 10k words
const embed_dim = 300; // GloVe-like embedding dimension
const embed = new torch.nn.Embedding(vocab_size, embed_dim);
// Token IDs (e.g., from tokenizer)
const token_ids = torch.tensor([2, 145, 8, 9999], { dtype: 'int32' });
// [batch, seq_len] shape
const token_ids_batch = torch.tensor(
[[101, 2054, 2003, 1045, 102], // "what is i"
[101, 2234, 3431, 102, 0]], // "good day" with padding
{ dtype: 'int32' }
);
const embeddings = embed.forward(token_ids_batch); // [2, 5, 300]// Transformer-style architecture with word + position embeddings
class PositionalEmbedding extends torch.nn.Module {
word_embed: torch.nn.Embedding;
pos_embed: torch.nn.Embedding;
max_seq_len: number = 512;
constructor(vocab_size: number, embed_dim: number) {
super();
this.word_embed = new torch.nn.Embedding(vocab_size, embed_dim);
this.pos_embed = new torch.nn.Embedding(this.max_seq_len, embed_dim);
}
forward(token_ids: torch.Tensor): torch.Tensor {
const seq_len = token_ids.shape[-1];
// Create position IDs
const pos_ids = torch.arange(0, seq_len, { dtype: 'int32' });
if (token_ids.shape.length > 1) {
// Expand for batch dimension
const pos_ids_batch = pos_ids.unsqueeze(0).expand_as(token_ids);
}
// Combine word and position embeddings
const word_emb = this.word_embed.forward(token_ids); // [B, L, D]
const pos_emb = this.pos_embed.forward(pos_ids_batch); // [B, L, D]
return word_emb.add(pos_emb); // [B, L, D] - summed embeddings
}
}// Recommendation system: item embeddings
const num_items = 100000; // 100k item catalog
const embed_dim = 128; // Embedding dimension
const item_embed = new torch.nn.Embedding(num_items, embed_dim);
// User interactions as item IDs
const user_history = torch.tensor(
[[123, 456, 789, 0], // User 1's item history
[234, 567, 0, 0]], // User 2's item history (padded)
{ dtype: 'int32' }
);
// Get embeddings for items
const history_embeddings = item_embed.forward(user_history); // [2, 4, 128]
// Now can aggregate embeddings (mean, attention, etc.) for user representation
const user_reps = history_embeddings.mean(1); // [2, 128]// Using pretrained embeddings (e.g., GloVe, FastText)
const pretrained_weights = torch.randn([10000, 300]); // Pre-computed embeddings
const embed = torch.nn.Embedding.from_pretrained(pretrained_weights, { freeze: true });
// Frozen embeddings (no gradient updates)
embed.weight.requires_grad = false;
const token_ids = torch.tensor([1, 2, 3, 4], { dtype: 'int32' });
const embeddings = embed.forward(token_ids); // Uses fixed pretrained values
// Or allow fine-tuning
const embed_finetune = torch.nn.Embedding.from_pretrained(pretrained_weights, { freeze: false });
const embeddings2 = embed_finetune.forward(token_ids); // Can be updated via backprop// Padding token handling
const PAD_ID = 0;
const embed = new torch.nn.Embedding(
5000,
128,
{ padding_idx: PAD_ID } // Padding token embedding always zero
);
// Variable length sequences with padding
const sequences = torch.tensor(
[[1, 2, 3, 0, 0], // Length 3, then padded
[4, 5, 0, 0, 0], // Length 2, then padded
[6, 7, 8, 9, 10]], // Length 5, no padding
{ dtype: 'int32' }
);
const embeddings = embed.forward(sequences); // [3, 5, 128]
// Padding token embeddings are guaranteed to be zero vectors// Gradient flow through embeddings
const embed = new torch.nn.Embedding(100, 50);
const token_ids = torch.tensor([1, 2, 3], { dtype: 'int32' });
const embed_out = embed.forward(token_ids); // [3, 50]
// Gradients flow back to embedding weights
const loss = embed_out.sum(); // Dummy loss
// loss.backward(); // Gradients computed for embed.weight
// Only accessed embeddings get gradient updates (sparse gradients concept)