torch.nn.init.sparse_
function sparse_(tensor: Tensor, sparsity: number, options?: SparseOptions): Tensorfunction sparse_(tensor: Tensor, sparsity: number, std: number, options?: SparseOptions): TensorInitialize 2D tensor with sparse connectivity for parameter efficiency.
Sparse initialization creates weight matrices with structured zeros, reducing parameters and computation while maintaining expressiveness. Useful for:
- Large neural networks where parameter count is critical
- Mobile and embedded model deployment
- Efficient training on resource-constrained devices
- Research on sparse neural network learning
- Initialization for sparse optimization methods
- Reducing memory and computational footprint
Each column has a specified fraction of weights set to zero; non-zero weights drawn from N(0, std).
The method is described in "Deep learning via Hessian-free optimization" - Martens, J. (2010).
- Column-wise sparsity: Sparsity applied independently per column
- Non-zero pattern: Zero pattern is random and fixed after initialization
- 2D only: Only works for 2D tensors (weight matrices). Use with Linear layers
- Small std default: Default std=0.01 is small; adjust based on layer scale
- Structural sparsity: Creates structured zeros for hardware efficiency
- Parameter reduction: sparsity=0.9 reduces parameters to 10% of dense equivalent
- Gradient computation: Sparse patterns remain fixed during training (structure preserved)
- In-place operation: Modifies tensor in-place; returns the same tensor
Parameters
tensorTensor- A 2-dimensional Tensor (typically weight matrix: [output_features, input_features]). Must be exactly 2D; sparse_ is specifically designed for weight matrices
sparsitynumber- The fraction of weights in each column to set to zero (range: [0, 1]). - 0.1 (default behavior context) = 10% zeros per column (90% non-zero) - 0.5 = 50% zeros per column (dense and sparse equally) - 0.9 = 90% zeros per column (highly sparse)
optionsSparseOptionsoptional- Optional settings for sparse initialization
Returns
Tensor– The input tensor with sparse initialization (2D, with structured zeros) Algorithm: - Initialize all weights from N(0, std) - For each column, randomly select floor(rows × sparsity) elements to zero out - Ensures structural sparsity: zero pattern is fixed after initializationExamples
// Initialize a large fully connected layer sparsely
const layer = torch.nn.Linear(1000, 500);
torch.nn.init.sparse_(layer.weight, 0.5); // 50% sparsity
torch.nn.init.zeros_(layer.bias);
// Non-zero weights from N(0, 0.01^2)
const x = torch.randn([32, 1000]);
const y = layer.forward(x); // Only 50% of weights used per output// Highly sparse network for mobile deployment
const layer1 = torch.nn.Linear(2048, 1024);
const layer2 = torch.nn.Linear(1024, 512);
const layer3 = torch.nn.Linear(512, 10);
// Use high sparsity (90%+) for parameter efficiency
torch.nn.init.sparse_(layer1.weight, 0.9, { std: 0.02 });
torch.nn.init.sparse_(layer2.weight, 0.9, { std: 0.02 });
torch.nn.init.xavier_uniform_(layer3.weight); // Dense output layer
// Parameter count: 90% reduction vs dense networks
const dense_params = 2048 * 1024 + 1024 * 512 + 512 * 10;
const sparse_params = dense_params * 0.1 + 512 * 10; // Only output layer dense// Moderate sparsity for balanced efficiency
const layer = torch.nn.Linear(784, 256);
torch.nn.init.sparse_(layer.weight, 0.3, { std: 0.02 }); // 30% sparsity
torch.nn.init.zeros_(layer.bias);
// Results in ~30% fewer parameters than dense
// While maintaining sufficient expressiveness// Sparse embedding layer
class SparseEmbedding extends torch.nn.Module {
weight: torch.nn.Parameter;
constructor(num_embeddings: number, embedding_dim: number) {
super();
this.weight = torch.nn.Parameter.create(
torch.empty([num_embeddings, embedding_dim])
);
// Use sparsity for large embedding tables
torch.nn.init.sparse_(this.weight, 0.5, { std: 0.05 });
this.register_parameter('weight', this.weight);
}
}
const embed = new SparseEmbedding(10000, 512);
const indices = torch.randint(0, 10000, [32, 10]);See Also
- PyTorch torch.nn.init.sparse_()
- torch.nn.init.normal_ - Regular normal distribution (dense initialization)
- torch.nn.init.uniform_ - Uniform distribution
- torch.nn.Parameter - Learn sparse weight matrices through training