WebGPU Internals
torch.js provides several APIs that don't exist in PyTorch. These control the WebGPU execution model and can be tuned for different use cases.
These APIs are torch.js-specific and won't be found in PyTorch documentation. They exist because WebGPU has a different execution model than CUDA.
Command Batching
Unlike CUDA, WebGPU requires explicit command submission. By default, torch.js automatically batches commands together and submits them in groups for better performance.
How It Works
When you call operations like torch.add() or torch.matmul(), torch.js doesn't immediately send them to the GPU. Instead, it collects them into a batch and submits them together.
Batches are automatically flushed when:
- 64 commands accumulate (configurable)
- 16 milliseconds pass (configurable, one frame at 60fps)
- You call
sync_device(),item(), ortoArray()
Configuration
// Flush more frequently for lower latency (e.g., real-time games)
torch.webgpu.set_batch_threshold(16);
// Flush less frequently for higher throughput (e.g., batch processing)
torch.webgpu.set_batch_threshold(128);
// Control time-based flushing
torch.webgpu.set_batch_delay(16); // 60fps
torch.webgpu.set_batch_delay(33); // 30fps
torch.webgpu.set_batch_delay(0); // Disable time-based flushDebugging
For debugging, you can disable batching entirely. This makes each operation submit immediately, which can help isolate issues:
// Disable batching - each op submits immediately
torch.webgpu.disable_batching();
// Re-enable batching
torch.webgpu.enable_batching();
// Check current state
if (torch.webgpu.is_batching_enabled()) {
console.log('Batching is on');
}Explicit Batching
You can also explicitly batch operations using batch():
// Force these operations into a single batch
const result = torch.webgpu.batch(() => {
const y = torch.add(x, bias);
const z = torch.relu(y);
return z;
}); // Batch flushes hereSynchronization
To wait for all GPU operations to complete:
// Submit pending commands and wait for completion
await torch.webgpu.sync_device();
// Now safe to read results
const values = await tensor.toArray();Calling sync_device() frequently will hurt performance. Only sync when you need
the actual values (for display, logging, or decisions).
Operation Fusion
torch.js can automatically fuse sequences of operations into single GPU kernels. This reduces memory bandwidth usage and improves performance.
Built-in Fusion Patterns
// These patterns are automatically detected and fused:
const y = torch.add(x, bias);
const z = torch.relu(y); // add + relu → single fused kernel
const a = torch.mul(x, gamma);
const b = torch.add(a, beta); // mul + add → scale_shift fusion
const diff = torch.sub(pred, target);
const loss = torch.square(diff); // sub + square → squared_diff fusionConfiguration
// Enable/disable fusion
torch.webgpu.enable_fusion();
torch.webgpu.disable_fusion();
// Check if fusion is enabled
if (torch.webgpu.is_fusion_enabled()) {
console.log('Fusion is on');
}
// Tune fusion behavior
torch.webgpu.set_fusion_window(4); // Look back 4 operations
torch.webgpu.set_fusion_min_size(1024); // Only fuse tensors > 1024 elementsStatistics
const stats = torch.webgpu.get_fusion_stats();
console.log(`Fusions: ${stats.fusions}`);
console.log(`Bypasses: ${stats.bypasses}`);Custom Patterns
You can add your own fusion patterns:
torch.webgpu.add_fusion_pattern({
ops: ['add', 'mul', 'tanh'],
name: 'custom_activation',
description: 'My custom activation pattern',
});
// View all patterns
const patterns = torch.webgpu.get_fusion_patterns();
for (const p of patterns) {
console.log(`${p.name}: ${p.ops.join(' → ')}`);
}Memory Management
torch.js includes a buffer pool that reuses GPU memory allocations.
Memory Statistics
const stats = torch.webgpu.memory_stats();
console.log(`Active: ${stats.active_bytes / 1024 / 1024} MB`);
console.log(`Pooled: ${stats.pooled_bytes / 1024 / 1024} MB`);
console.log(`Peak: ${stats.peak_bytes / 1024 / 1024} MB`);
// Or get a formatted summary
console.log(torch.webgpu.memory_summary());Freeing Memory
// Free unused pooled memory
torch.webgpu.empty_cache();
// Reset peak tracking (for profiling specific sections)
torch.webgpu.reset_peak_memory_stats();Device Capabilities
Query what the GPU supports:
const caps = torch.webgpu.getCapabilities();
console.log(`Max buffer: ${caps.limits.maxBufferSize} bytes`);
console.log(`Max workgroup: ${caps.limits.maxComputeWorkgroupSizeX}`);API Reference
Batching
| Function | Description |
|---|---|
| set_batch_threshold(n) | Max commands before auto-flush (default: 64) |
| set_batch_delay(ms) | Max time before auto-flush (default: 16ms) |
| enable_batching() | Enable automatic batching |
| disable_batching() | Disable batching (immediate dispatch) |
| is_batching_enabled() | Check if batching is enabled |
| flush_batch() | Manually flush pending commands |
| get_pending_count() | Number of pending commands |
| batch(fn) | Execute fn in explicit batch scope |
| sync_device() | Wait for all GPU work to complete |
Fusion
| Function | Description |
|---|---|
| enable_fusion() | Enable automatic operation fusion |
| disable_fusion() | Disable fusion |
| is_fusion_enabled() | Check if fusion is enabled |
| set_fusion_window(n) | Operations to look back for patterns (default: 4) |
| set_fusion_min_size(n) | Min tensor size for fusion (default: 1024) |
| get_fusion_stats() | Get fusion/bypass counts |
| add_fusion_pattern(p) | Add custom fusion pattern |
| get_fusion_patterns() | List all fusion patterns |
Memory
| Function | Description |
|---|---|
| memory_stats() | Get memory usage statistics |
| memory_summary() | Get formatted memory summary |
| empty_cache() | Free unused pooled memory |
| reset_peak_memory_stats() | Reset peak memory counter |
Comparison with PyTorch/CUDA
| Feature | PyTorch (CUDA) | torch.js (WebGPU) |
|---|---|---|
| Command submission | Automatic (stream) | Automatic (batched) |
| Synchronization | torch.cuda.synchronize() | torch.webgpu.sync_device() |
| Memory stats | torch.cuda.memory_stats() | torch.webgpu.memory_stats() |
| Empty cache | torch.cuda.empty_cache() | torch.webgpu.empty_cache() |
| Batching control | N/A (stream-based) | set_batch_threshold(), batch() |
| Op fusion | torch.jit, torch.compile | Automatic + add_fusion_pattern() |
Next Steps
- Performance - General performance tips
- Profiling & Memory - Memory debugging
- Best Practices - Coding patterns