torch.nn.KLDivLoss
new KLDivLoss(options?: {
reduction?: 'none' | 'batchmean' | 'sum' | 'mean';
log_target?: boolean;
})
- readonly
reduction('none' | 'batchmean' | 'sum' | 'mean') - readonly
log_target(boolean)
Kullback-Leibler (KL) Divergence loss: measures divergence between probability distributions.
Measures how much one probability distribution differs from another (asymmetric divergence). Essential for probabilistic models, distribution matching, and knowledge distillation. KL divergence is always non-negative (0 only when distributions are identical) and measures information loss.
When to use KLDivLoss:
- Knowledge distillation (matching student network to teacher network outputs)
- Variational autoencoders (matching learned distribution to prior)
- Distribution matching in generative models
- Policy optimization in reinforcement learning
- Training mixture models and probabilistic classifiers
- Any task comparing two probability distributions
Important: Input should be log-probabilities (typically from log_softmax), not raw probabilities. This provides numerical stability and matches the mathematical definition.
Trade-offs:
- Asymmetric: KL(P || Q) ≠ KL(Q || P) - direction matters (mode-seeking vs mode-covering)
- Information-theoretic: More principled than MSE for distributions
- Numerical stability: Use log-probabilities to avoid log(0)
- Zero avoidance: KL divergence is 0 only for identical distributions (unlike MSE)
Algorithm: Forward: KL(P || Q) = Σ P(x) * log(P(x) / Q(x)) = Σ P(x) * (log P(x) - log Q(x)) When input = log Q and target = P: loss = Σ P(x) * (log P(x) - input(x)) Backward: ∂loss/∂input = -P(x) (gradient toward matching input to target)
- Always use log-probabilities: Input should be from log_softmax, not softmax
- Information-theoretic: Based on information theory, more principled than MSE
- Asymmetric divergence: KL(P||Q) ≠ KL(Q||P) - order matters
- Non-negative: KL divergence ≥ 0, equals 0 only when distributions identical
- Mode-seeking: Typically rewards high probability on true distribution
- Knowledge distillation: Key technique for model compression and transfer learning
- Temperature scaling: Use temperature 1 to soften distributions for distillation
- Gradient behavior: Pushes model toward target distribution with smooth gradients
- Input must be log-probabilities (from log_softmax), not raw probabilities
- Target should sum to 1 (valid probability distribution)
Examples
// Knowledge distillation: matching student to teacher
const teacher_logits = torch.randn([32, 10]);
const student_logits = torch.randn([32, 10]);
// Temperature scaling for soft targets
const temperature = 4.0;
const teacher_probs = torch.softmax(teacher_logits.div(temperature), 1);
const student_log_probs = torch.log_softmax(student_logits.div(temperature), 1);
const kl_loss = new torch.nn.KLDivLoss();
const loss = kl_loss.forward(student_log_probs, teacher_probs);
// Student network trained to match teacher distribution// Variational Autoencoder (VAE): KL divergence in latent space
class VAE extends torch.nn.Module {
encoder: torch.nn.Linear;
mu_layer: torch.nn.Linear;
logvar_layer: torch.nn.Linear;
constructor() {
super();
this.encoder = new torch.nn.Linear(784, 256);
this.mu_layer = new torch.nn.Linear(256, 20); // Mean of latent
this.logvar_layer = new torch.nn.Linear(256, 20); // Log variance
}
forward(x: torch.Tensor): [torch.Tensor, torch.Tensor] {
const h = torch.nn.functional.relu(this.encoder.forward(x));
const mu = this.mu_layer.forward(h);
const logvar = this.logvar_layer.forward(h);
// KL divergence between learned N(mu, sigma) and standard normal N(0, 1)
const kl_div = 0.5 * torch.sum(
torch.exp(logvar).add(mu.pow(2)).sub(1).sub(logvar),
1
).mean();
return [mu, logvar];
}
}// Comparing KLDivLoss vs CrossEntropyLoss
const logits = torch.randn([32, 10]);
const targets = torch.tensor([1, 2, 5, ...]); // Integer class labels
// CrossEntropyLoss: directly from logits
const ce_loss = new torch.nn.CrossEntropyLoss();
const ce = ce_loss.forward(logits, targets);
// KLDivLoss: requires log-probabilities and one-hot targets
const log_probs = torch.log_softmax(logits, 1);
const one_hot_targets = torch.zeros([32, 10]);
for (let i = 0; i < 32; i++) {
one_hot_targets[i][targets[i]] = 1.0;
}
const kl_loss_fn = new torch.nn.KLDivLoss();
const kl = kl_loss_fn.forward(log_probs, one_hot_targets);
// CE and KL should give similar results for probability distributions// Policy distillation in reinforcement learning
const student_policy = torch.randn([batch_size, num_actions]);
const teacher_policy = torch.randn([batch_size, num_actions]);
const student_log_probs = torch.log_softmax(student_policy, 1);
const teacher_probs = torch.softmax(teacher_policy, 1);
const kl_loss = new torch.nn.KLDivLoss();
const loss = kl_loss.forward(student_log_probs, teacher_probs);
// Student policy network trained to match teacher behavior