Debugging Deep Learning Models¶
What This Is¶
Debugging deep learning is different from debugging traditional code. The model might "run" without errors but fail to learn. This topic teaches a systematic workflow for diagnosing training failures before you waste time on architecture changes or hyperparameter sweeps.
The key insight: most training failures come from a few common mistakes, and checking them in order saves hours.
When You Use It¶
- training loss doesn't decrease
- validation accuracy stays near random chance
- loss explodes or becomes NaN
- model trains on simple tasks but fails on real data
- gradient norms are suspiciously small or large
The Debugging Ladder¶
Follow this sequence when training fails:
1. Overfit a Single Batch (The Sanity Check)¶
If your model can't memorize one batch, the architecture or loss function is broken.
# Take one batch and try to get perfect loss
model.train()
inputs, labels = next(iter(train_loader))
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
for i in range(200):
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
if i % 20 == 0:
print(f"Step {i}: Loss = {loss.item():.4f}")
# Loss should drop to near zero
# If not: check model architecture, loss function, or data shapes
What to expect: Loss should drop close to zero within 100-200 steps.
If it doesn't:
- Check that outputs and labels have compatible shapes
- Verify the loss function matches the task (e.g., CrossEntropyLoss for classification)
- Inspect if the model has enough capacity (at least a few layers)
2. Check Gradient Flow¶
Gradients should flow backward through all layers. If they vanish or explode, learning fails.
# After loss.backward(), inspect gradients
def check_gradients(model):
for name, param in model.named_parameters():
if param.grad is not None:
grad_norm = param.grad.norm().item()
print(f"{name}: grad_norm = {grad_norm:.6f}")
else:
print(f"{name}: NO GRADIENT")
check_gradients(model)
What to look for: - Vanishing gradients: norms < 1e-6 → add batch norm, check activation functions, reduce depth - Exploding gradients: norms > 100 → use gradient clipping, lower learning rate - No gradient: layer is disconnected or frozen
3. Verify Data Pipeline¶
Bad data shapes or incorrect labels silently break training.
# Check one batch
inputs, labels = next(iter(train_loader))
print(f"Input shape: {inputs.shape}")
print(f"Label shape: {labels.shape}")
print(f"Input range: [{inputs.min():.3f}, {inputs.max():.3f}]")
print(f"Unique labels: {labels.unique()}")
# Visualize a sample
import matplotlib.pyplot as plt
plt.imshow(inputs[0].permute(1, 2, 0).cpu()) # For images
plt.title(f"Label: {labels[0].item()}")
plt.show()
Common issues: - Labels are one-hot but loss expects class indices (or vice versa) - Images not normalized to [0, 1] or [-1, 1] - Channels in wrong order (CHW vs HWC) - Wrong number of classes in final layer
4. Check Model Output Shape¶
The model's final layer must match the loss function's expectation.
model.eval()
with torch.no_grad():
outputs = model(inputs)
print(f"Model output shape: {outputs.shape}")
print(f"Expected shape: (batch_size, num_classes)")
# For classification, check if logits or probabilities
print(f"Output range: [{outputs.min():.3f}, {outputs.max():.3f}]")
For CrossEntropyLoss: Output should be raw logits (no softmax), shape (batch, num_classes).
For BCEWithLogitsLoss: Output should be logits, shape (batch, 1) or (batch,).
5. Monitor Learning Rate¶
If the LR is too high, loss oscillates. If too low, training is slow or stalls.
# Track LR
for epoch in range(epochs):
current_lr = optimizer.param_groups[0]['lr']
print(f"Epoch {epoch}: LR = {current_lr:.6f}")
# ... training loop ...
Rules of thumb:
- Start with 1e-3 for Adam/AdamW
- Start with 1e-1 for SGD with momentum
- Use learning rate finder if unsure
6. Visualize Training Curves¶
Plot loss and accuracy over epochs to spot patterns.
import matplotlib.pyplot as plt
plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot(train_losses, label='Train Loss')
plt.plot(val_losses, label='Val Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.subplot(1, 2, 2)
plt.plot(train_accs, label='Train Acc')
plt.plot(val_accs, label='Val Acc')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()
plt.show()
Patterns to recognize: - Flat loss: model isn't learning → check LR, data, or architecture - Diverging train/val: overfitting → add regularization - Oscillating loss: LR too high → reduce LR or add gradient clipping
Common Failure Modes¶
Loss is NaN¶
Causes: - Exploding gradients - Learning rate too high - Numerical instability in loss (e.g., log(0))
Fixes:
- Add gradient clipping: torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
- Lower learning rate by 10x
- Check for division by zero or invalid operations
Training Accuracy ≈ Random Chance¶
Causes: - Wrong loss function for the task - Labels misaligned with outputs - Model too shallow for the task - Data not preprocessed correctly
Fixes: - Overfit a single batch to rule out architecture issues - Print and inspect labels vs predictions for one batch - Check if data augmentation is too aggressive
Validation Loss Increases While Training Loss Decreases¶
Cause: Overfitting
Fixes: - Add dropout or weight decay - Use data augmentation - Reduce model capacity - Get more training data
Training is Extremely Slow¶
Causes:
- Inefficient data loading (no num_workers or pin_memory)
- Batch size too small
- Model too large for the task
Fixes:
- Set num_workers=4 and pin_memory=True in DataLoader
- Increase batch size (if GPU memory allows)
- Use mixed precision training (torch.cuda.amp)
Debugging Checklist¶
Before changing architecture or hyperparameters, verify:
- [ ] Model can overfit a single batch
- [ ] Gradients are flowing (not vanishing/exploding)
- [ ] Data shapes and labels are correct
- [ ] Loss function matches the task
- [ ] Learning rate is in a reasonable range
- [ ] Training curves are being monitored
- [ ] train() and eval() modes are set correctly
Tools for Deep Inspection¶
1. TensorBoard¶
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter('runs/experiment_1')
# Log scalars
writer.add_scalar('Loss/train', train_loss, epoch)
writer.add_scalar('Accuracy/train', train_acc, epoch)
# Log histograms of weights
for name, param in model.named_parameters():
writer.add_histogram(name, param, epoch)
writer.close()
View with: tensorboard --logdir=runs
2. Weights & Biases (wandb)¶
import wandb
wandb.init(project="debugging-demo")
wandb.watch(model, log="all") # Log gradients and parameters
for epoch in range(epochs):
# ... training ...
wandb.log({"loss": loss, "accuracy": acc, "epoch": epoch})
3. Print Intermediate Activations¶
def register_hooks(model):
activations = {}
def get_activation(name):
def hook(module, input, output):
activations[name] = output.detach()
return hook
for name, layer in model.named_modules():
if isinstance(layer, torch.nn.ReLU):
layer.register_forward_hook(get_activation(name))
return activations
activations = register_hooks(model)
outputs = model(inputs)
# Check if activations are dead (all zeros)
for name, act in activations.items():
print(f"{name}: mean = {act.mean():.4f}, std = {act.std():.4f}")
Practice¶
- Create a deliberately broken model (e.g., wrong loss function) and debug it using the ladder.
- Train a model that overfits—diagnose it by comparing train vs. validation curves.
- Implement gradient norm tracking and identify if/when gradients vanish or explode.
- Use TensorBoard to visualize weight histograms and spot dead neurons.
Next Steps¶
After mastering debugging: - Read Optimizers and Regularization to fix overfitting - Read Learning Rate Schedulers to stabilize training - Try PyTorch Training Loops for clean loop patterns
Summary¶
The debugging ladder: 1. Overfit one batch (sanity check) 2. Check gradient flow 3. Verify data pipeline 4. Check model output shapes 5. Monitor learning rate 6. Visualize training curves
Most common issues: - Wrong loss function - Bad learning rate - Data shape mismatch - Vanishing/exploding gradients
Golden rule: Fix the data and architecture before tuning hyperparameters.