Skip to main content
torch.js has not been released yet.
torch.js logotorch.js logotorch.js
PlaygroundContact
Login
Documentation
IntroductionType SafetyTensor ExpressionsTensor IndexingEinsumEinopsAutogradTraining a ModelProfiling & MemoryPyTorch MigrationBest PracticesRuntimesPerformancePyTorch CompatibilityBenchmarksDType Coverage
torch.js· 2026
LegalTerms of UsePrivacy Policy
  1. docs
  2. torch.js
  3. Training a Model

Training a Model

Training a neural network in torch.js follows the same fundamental pattern as PyTorch: defining a model, selecting a loss function, and iterating with an optimizer.

Diagram of the training loop: Zero Grad, Forward, Backward, Step

1. Defining the Architecture

You can build models using nn.Module for complex logic or nn.Sequential for simple stacks of layers.

import torch, { nn } from '@torchjsorg/torch.js';

// Using nn.Sequential for a simple MLP
const model = nn.Sequential(nn.Linear(784, 128), nn.ReLU(), nn.Linear(128, 10));

2. Loss Functions and Optimizers

Loss functions are located in torch.nn.functional, and optimizers are in torch.optim.

const criterion = torch.nn.functional.cross_entropy;
const optimizer = new torch.optim.Adam(model.parameters(), { lr: 1e-3 });

3. The Training Loop

The core training loop consists of four critical steps. Skipping any of these will lead to incorrect gradients or stale weights.

StepActionPurpose
1. Zero Gradoptimizer.zero_grad()Clears previous gradients
2. Forwardmodel.forward(input)Computes predictions
3. Backwardloss.backward()Computes new gradients
4. Stepoptimizer.step()Updates model weights
async function trainStep(data: Tensor, target: Tensor) {
  // 1. Reset gradients
  optimizer.zero_grad();

  // 2. Forward pass
  const output = model.forward(data);
  const loss = criterion(output, target);

  // 3. Backward pass (Autograd builds the graph automatically)
  loss.backward();

  // 4. Update weights
  optimizer.step();

  return await loss.item(); // Read loss back to CPU for logging
}

4. Saving and Loading

Once training is complete, you can save the model's weights (the state_dict) to a file or browser storage.

// Save weights to a safetensors or .pt format
const weights = model.state_dict();
await torch.save(weights, 'my_model.safetensors');

// Load weights back into a model
const loadedWeights = await torch.load('my_model.safetensors');
model.load_state_dict(loadedWeights);

Training in the Browser: Since the browser main thread must remain responsive, consider using Spark to run your training loop in a Web Worker. See the Spark Introduction for more details.

Next Steps

  • Autograd - Deep dive into how gradients are computed.
  • Profiling & Memory - Optimizing your training performance.
  • Best Practices - Tips for efficient WebGPU training.
Previous
Autograd
Next
Profiling & Memory