torch.optim.SparseAdam
class SparseAdam extends Optimizernew SparseAdam(params: Tensor[] | Iterable<Tensor>, options: SparseAdamOptions = {})
SparseAdam optimizer: Adam variant optimized for parameters with sparse gradients.
SparseAdam is a modification of Adam designed for parameters with sparse gradients, where only a small fraction of positions have nonzero gradients at each step. Classic examples: embedding layers in NLP where only a few words appear per batch.
Key Optimization:
- Only updates state (first and second moments) for positions with nonzero gradients
- Positions with zero gradient don't affect state (m_t, v_t unchanged)
- This is crucial for embedding layers where each batch uses different embeddings
- Dramatically reduces computation and memory access for sparse scenarios
Motivation: With standard Adam, computing m_t and v_t for all million positions in an embedding layer is wasteful when only ~1000 are actually used. SparseAdam only touches used positions, providing massive speedup for sparse scenarios.
When to use SparseAdam:
- Embedding layers (NLP, recommendation systems, graphs)
- Any layer where gradient is sparse (most positions are zero)
- Large embedding tables where sparse access is critical for performance
- Recommendation systems with massive item embeddings
- Graph neural networks with node embeddings
- Any sparse feature scenario
Trade-offs:
- Only works with sparse gradient patterns (don't use with dense layers)
- Requires manually specifying which params are sparse
- Dense layers should use standard Adam (SparseAdam adds masking overhead)
- Mixing sparse and dense parameters needs careful setup
Implementation Note: Current implementation works with dense tensors but only updates state where gradients are nonzero. For ultra-high-performance, true sparse tensor support would be better, but this handles 90% of use cases efficiently.
- Sparse gradients only: Only beneficial if gradients are actually sparse (mostly zero).
- Embedding layers: Primary use case - word embeddings, item embeddings, node embeddings.
- Performance critical: Can provide 10-100x speedup for large embedding tables.
- Don't mix: Use SparseAdam for sparse layers, Adam for dense layers.
- Default lr: 1e-3 same as Adam, usually doesn't need adjustment.
- Dense implementation: Current implementation uses dense tensors with masking.
- True sparse tensors: Would be even faster, but dense with masking sufficient.
- Batch size effects: More efficient with smaller batches (more sparsity).
- Embedding specialization: SparseAdam designed specifically for embedding patterns.
- Not for general use: Only use if you have actual sparse gradient patterns.
- Hyperparameter tuning: Same as Adam - lr, betas rarely need adjustment.
Examples
// SparseAdam for embedding layers (common in NLP)
const embedding = new torch.nn.Embedding(vocab_size, embedding_dim);
const sparse_adam = new torch.optim.SparseAdam(embedding.parameters(), { lr: 1e-3 });
for (const batch of train_loader) {
// batch.token_ids: [batch_size, seq_len] - only ~batch_size*seq_len words used
const embeddings = embedding.forward(batch.token_ids);
const loss = model.loss(embeddings, batch.y);
sparse_adam.zero_grad();
// loss.backward();
sparse_adam.step(); // Only updates embeddings that appeared
}// Recommendation system with sparse item embeddings
const item_embeddings = new torch.nn.Embedding(num_items, embedding_dim);
const user_embeddings = new torch.nn.Embedding(num_users, embedding_dim);
const sparse_adam = new torch.optim.SparseAdam([
item_embeddings.parameters(),
user_embeddings.parameters()
], { lr: 1e-3 });
// With millions of items, only ~100 appear per batch
// SparseAdam only updates those 100, not all millions// Graph neural network with node embeddings
const node_embeddings = new torch.nn.Embedding(num_nodes, embedding_dim);
const sparse_adam = new torch.optim.SparseAdam(
node_embeddings.parameters(),
{ lr: 1e-3 }
);
// In subgraph batching, only ~1000 nodes per batch even with millions total
// SparseAdam handles this efficiently// Multi-task learning with sparse task embeddings
const task_embeddings = new torch.nn.Embedding(num_tasks, task_dim);
const sparse_adam = new torch.optim.SparseAdam(
task_embeddings.parameters(),
{ lr: 2e-3, betas: [0.9, 0.999] }
);
// Different tasks in each batch mean sparse gradient patterns// Comparison: SparseAdam vs standard Adam for embeddings
const sparse_adam = new torch.optim.SparseAdam(
embedding.parameters(), { lr: 1e-3 }
);
const adam = new torch.optim.Adam(
embedding.parameters(), { lr: 1e-3 }
);
// SparseAdam: only updates used embeddings, ~100x faster
// Adam: updates all million embeddings every step, very slow
// Use SparseAdam for embedding layers, Adam for everything else