torch.nn.functional.embedding
function embedding(input: Tensor, weight: Tensor, options?: EmbeddingFunctionalOptions): TensorGenerate embeddings by looking up indices in a weight matrix (dense vector representation table).
Performs index-based lookup in a learnable embedding matrix, converting discrete tokens/IDs into continuous dense vector representations. Essential operation for all sequence models (NLP, recommendation systems). Given a tensor of indices, returns the corresponding embedding vectors stacked together. Used extensively for:
- Word/token embeddings in NLP (BERT, GPT, transformers)
- Positional embeddings (encoding position in sequence)
- Character/subword embeddings (BPE, SentencePiece)
- Categorical feature embeddings (recommendation systems, tabular models)
- Entity embeddings (knowledge graphs, link prediction)
- Learned representations for discrete objects (items, users, tags)
How it Works: Embedding is a simple dictionary/lookup table. Given index i, return embeddings[i]. Index 0 returns first row of embedding matrix, index 1 returns second row, etc. Highly efficient on GPUs. Supports batching and multi-dimensional index tensors (looks up multiple indices simultaneously).
Padding Index: Special feature where a designated index (usually 0) always embeds to zeros. Used for padding tokens in variable-length sequences - ensures padding contributes zero to model.
- Efficient lookup table: Embedding is just a table lookup - one of the fastest operations on GPUs. Much faster than learned dense transformation (linear layer). Perfect for large vocabularies.
- Indices must be integers: Input must contain integer values (token IDs). Fractional indices are undefined.
- Padding index stays zero: If padding_idx is set, that row of embedding matrix is never updated during training. Ensures padding tokens contribute nothing to model outputs.
- Max norm normalization: Constrains embedding magnitudes, useful to prevent unbounded growth or for numerical stability. Typically used with max_norm=1 or 2 in recommendation systems.
- Gradient computation: If padding_idx is set, gradients for that embedding row are zero (not updated). All other embeddings receive gradients and are learned during training.
- Commonly learned: Embedding matrices are typically initialized randomly and learned end-to-end during training. Pre-trained embeddings (Word2Vec, GloVe, FastText) can be used to initialize and optionally frozen.
- Output shape pattern: Output shape is always input.shape + [embedding_dim]. This broadcasting works naturally for any input shape (1D, 2D, 3D, etc.).
- Memory efficient: For large vocabularies (50K+ tokens) with high embedding dimensions, embeddings are actually more memory efficient than dense layers since they skip non-existent combinations.
- Index out of bounds: Indices must be in range [0, num_embeddings-1]. Out-of-bounds indices cause runtime errors. Always validate indices match vocabulary size.
- Negative indices: Negative indices are not supported (unlike NumPy). Use absolute values or add offset.
- Type mismatch: Input should contain integer-like values. Floating-point indices are truncated (undefined behavior).
- Not differentiable w.r.t. indices: Embeddings are differentiable w.r.t. the embedding matrix (weight), but NOT w.r.t. the indices themselves (indices are discrete). Can't optimize which index to use with gradient descent.
- Large embeddings memory: With large vocab (millions) and high dims (thousands), embedding matrix can consume significant GPU memory. Consider quantization or factorization for extreme scales.
Parameters
inputTensor- Tensor of integer indices (typically 1D or 2D for batching). Each element is in range [0, num_embeddings-1]. Shape: any shape is allowed (indices broadcasted); typical shapes: [batch_size, seq_length] for NLP.
weightTensor- Pre-trained or randomly initialized embedding matrix of shape [num_embeddings, embedding_dim]. num_embeddings = size of vocabulary/dictionary; embedding_dim = desired output vector dimension (typically 64-1024).
optionsEmbeddingFunctionalOptionsoptional
Returns
Tensor– Embedded tensor where each index is replaced with corresponding embedding vector. Output shape: input.shape + [embedding_dim]. E.g., input [batch=32, seq=50] → output [32, 50, embedding_dim].Examples
// Basic NLP embedding: 5 tokens, 300-dim word vectors
const vocab_size = 5;
const embedding_dim = 300;
const embedding_weight = torch.randn(vocab_size, embedding_dim); // Learned embeddings
const input_ids = torch.tensor([1, 2, 4, 0, 3]); // Sequence of 5 tokens
const embeddings = torch.nn.functional.embedding(input_ids, embedding_weight); // [5, 300]
// Batched sequence embedding (typical in NLP)
const batch_size = 32;
const seq_length = 128;
const vocab_size = 50000; // Large vocabulary (like GPT)
const embed_dim = 768; // Hidden dimension (like BERT)
const embed_matrix = torch.randn(vocab_size, embed_dim);
const token_ids = torch.floor(torch.rand(batch_size, seq_length).mul(vocab_size)); // [32, 128] token IDs
const token_embeddings = torch.nn.functional.embedding(token_ids, embed_matrix); // [32, 128, 768]
// With padding token (padding_idx=0, so embedding index 0 stays zero)
const vocab_size = 1000;
const embed_dim = 300;
const embed_matrix = torch.randn(vocab_size, embed_dim);
const padding_idx = 0; // Index 0 reserved for padding token
const input_ids = torch.tensor([5, 10, 0, 15, 0, 20]); // Sequence with padding (0's)
const embedded = torch.nn.functional.embedding(input_ids, embed_matrix, padding_idx);
// Result: embedding[5], embedding[10], [0,0,...,0], embedding[15], [0,0,...,0], embedding[20]
// Multi-dimensional input (e.g., batch of sequences of characters)
const char_vocab = 256; // ASCII characters
const char_embed_dim = 16;
const char_embeddings = torch.randn(char_vocab, char_embed_dim);
const char_ids = torch.floor(torch.rand(4, 10, 8).mul(char_vocab)); // [batch=4, seq=10, word_length=8]
const char_embedded = torch.nn.functional.embedding(char_ids, char_embeddings); // [4, 10, 8, 16]
// Great for character-level models or subword tokenization
// Positional embeddings in transformers
const seq_length = 512;
const embed_dim = 768;
const position_embedding = torch.randn(seq_length, embed_dim); // Learned position embeddings
const positions = torch.arange(seq_length); // [0, 1, 2, ..., 511]
const pos_embed = torch.nn.functional.embedding(positions, position_embedding); // [512, 768]See Also
- PyTorch torch.nn.functional.embedding
- torch.nn.Embedding - Stateful class-based version (wraps embedding matrix)
- torch.nn.EmbeddingBag - Aggregate embeddings (mean/sum reduction)
- embedding_bag - Functional version of embedding aggregation
- gather - Generic index-based lookup (not limited to last dimension)
- scaled_dot_product_attention - Uses embeddings in attention mechanisms