torch.nn.EmbeddingBag
class EmbeddingBag extends Modulenew EmbeddingBag(num_embeddings: number, embedding_dim: number, options?: EmbeddingBagOptions)
- readonly
num_embeddings(number) - readonly
embedding_dim(number) - readonly
max_norm(number | null) - readonly
norm_type(number) - readonly
scale_grad_by_freq(boolean) - readonly
mode('sum' | 'mean' | 'max') - readonly
sparse(boolean) - readonly
include_last_offset(boolean) - readonly
padding_idx(number | null) weight(Parameter)
Embedding lookup with aggregation: efficiently combines embedding lookup + reduction.
Combines Embedding layer with aggregation (sum/mean/max) over sequences in a single operation. Highly efficient for:
- Text classification (sum word embeddings, then classify)
- Recommendation systems (aggregate item embeddings, mean pooling)
- Collaborative filtering (user/item embedding aggregation)
- Bag-of-words representations (word → embedding → mean)
- Document-level representations (all word embeddings → aggregated)
- Continuous bag-of-words (CBOW) models
Similar to Embedding but automatically aggregates embeddings within each "bag" (sequence), computing a single output vector per bag by summing, averaging, or max-pooling embeddings. More memory-efficient and faster than separate Embedding + reduction operations.
- Memory efficiency: More efficient than separate Embedding + reduction operations
- Speed advantage: Single optimized kernel faster than chained operations
- Mode selection: Sum for frequency-aware, Mean for averaging, Max for signal extraction
- Bag interpretation: "Bag" = sequence of tokens within an item (document, user history, etc.)
- Variable-length bags: Use offsets parameter for bags of different sizes
- Padding handling: Padding tokens contribute to aggregation by default (consider mask)
- Gradient flow: Gradients only flow to used embeddings (efficient sparse updates)
- 2D vs 1D: 2D input simpler (each row is bag), 1D with offsets better for variable lengths
- Per-sample weights: Optional weights can weight each embedding differently within bag
- Computational complexity: O(total_indices × embedding_dim) regardless of aggregation mode
Examples
// Text classification with EmbeddingBag
// Simpler and faster than Embedding + mean pooling
const vocab_size = 10000;
const embed_dim = 128;
const embedding_bag = new torch.nn.EmbeddingBag(vocab_size, embed_dim, { mode: 'mean' });
// Input: sequences of token IDs as bags
const token_sequences = torch.tensor(
[[101, 2054, 2003, 102], // Sentence 1: "what is"
[101, 2234, 3431, 102], // Sentence 2: "good day"
[101, 5000, 102, 0]], // Sentence 3: "hello" (padded)
{ dtype: 'int32' }
);
// Get aggregated (mean-pooled) embeddings per sentence
const sentence_embeddings = embedding_bag.forward(token_sequences); // [3, 128]
// Feed to classifier
const classifier = new torch.nn.Linear(128, 10); // 10 classes
const logits = classifier.forward(sentence_embeddings); // [3, 10]// Recommendation system: aggregating item embeddings
const num_items = 50000;
const embed_dim = 64;
const item_embedding_bag = new torch.nn.EmbeddingBag(num_items, embed_dim, { mode: 'mean' });
// User interaction history as bags
const user_histories = torch.tensor(
[[1023, 5432, 8901, 0, 0], // User 1: 3 items
[234, 567, 0, 0, 0], // User 2: 2 items (padded)
[9000, 1111, 2222, 3333, 4444]], // User 3: 5 items
{ dtype: 'int32' }
);
// Get user representations (mean of item embeddings)
const user_embeddings = item_embedding_bag.forward(user_histories); // [3, 64]
// Now can use for similarity, recommendation, etc.
const next_item_embedding = torch.randn([64]); // Candidate item
const scores = user_embeddings.matmul(next_item_embedding.unsqueeze(0).t()); // [3, 1]// Different aggregation modes
const vocab = 1000;
const dim = 50;
// Mean aggregation (most common for averaging)
const mean_bag = new torch.nn.EmbeddingBag(vocab, dim, { mode: 'mean' });
// Sum aggregation (preserves magnitude, better for count-based features)
const sum_bag = new torch.nn.EmbeddingBag(vocab, dim, { mode: 'sum' });
// Max aggregation (preserves strong signals)
const max_bag = new torch.nn.EmbeddingBag(vocab, dim, { mode: 'max' });
const bags = torch.tensor(
[[1, 2, 3],
[4, 5, 6]],
{ dtype: 'int32' }
);
const mean_out = mean_bag.forward(bags); // [2, 50] - averages per bag
const sum_out = sum_bag.forward(bags); // [2, 50] - sums per bag
const max_out = max_bag.forward(bags); // [2, 50] - max per bag
// Mode choice affects downstream learning:
// sum: good for frequency/count information
// mean: stable, interpretable as average feature
// max: extracts strongest signal per feature// Continuous Bag-of-Words (CBOW) with EmbeddingBag
const vocab_size = 10000;
const embed_dim = 100;
const embedding_bag = new torch.nn.EmbeddingBag(vocab_size, embed_dim, { mode: 'mean' });
// Context words (bag of surrounding words)
const context_words = torch.tensor(
[[95, 194, 325], // Context for first position
[194, 325, 472]], // Context for second position
{ dtype: 'int32' }
);
// Get context representation
const context_reps = embedding_bag.forward(context_words); // [2, 100]
// Predict target word from context
const word_embedding = new torch.nn.Embedding(vocab_size, embed_dim);
const context_scores = context_reps.matmul(word_embedding.weight.t()); // [2, vocab]// Using 1D input with offsets (for variable-length bags)
const vocab = 1000;
const dim = 32;
const embedding_bag = new torch.nn.EmbeddingBag(vocab, dim, { mode: 'mean' });
// Flat array of all indices
const all_indices = torch.tensor(
[1, 2, 3, 4, 5, 6, 7], // All indices from all bags
{ dtype: 'int32' }
);
// Offsets marking where each bag starts
const offsets = torch.tensor([0, 3, 7], { dtype: 'int32' }); // 2 bags: [0:3] and [3:7]
const bags = embedding_bag.forward(all_indices, { offsets }); // [2, 32]