Batch Normalization and Initialization¶
What This Is¶
Batch normalization stabilizes training by normalizing layer inputs, and initialization sets the starting point for learning. Together they determine whether a network trains smoothly or stalls before it starts.
When You Use It¶
- training deep networks where gradients vanish or explode
- speeding up convergence so you can use higher learning rates
- stabilizing training when adding more layers
- debugging a network that trains on one configuration but fails on another
Tooling¶
nn.BatchNorm1dandnn.BatchNorm2dfor batch normalizationnn.LayerNormfor sequence or transformer modelsnn.init.xavier_uniform_andnn.init.xavier_normal_for sigmoid/tanh activationsnn.init.kaiming_uniform_andnn.init.kaiming_normal_for ReLU activationsnn.init.zeros_andnn.init.ones_for bias and gain terms
How Batch Norm Works¶
Batch normalization normalizes activations across the batch dimension during training, then uses running statistics during evaluation.
import torch.nn as nn
model = nn.Sequential(
nn.Linear(128, 64),
nn.BatchNorm1d(64),
nn.ReLU(),
nn.Linear(64, 10),
)
Key behaviors:
- in
model.train()mode, it computes batch statistics and updates running mean/variance - in
model.eval()mode, it uses the stored running statistics - if you forget
model.eval(), validation becomes noisy because it uses batch-level statistics
When To Use Which Normalization¶
| Layer | Best For | Notes |
|---|---|---|
BatchNorm1d |
MLPs with batch dimension | needs batch size > 1 |
BatchNorm2d |
CNNs | normalizes per channel |
LayerNorm |
transformers, sequence models | normalizes per sample, batch-size independent |
GroupNorm |
small batch training | compromise between batch and layer norm |
Initialization Matters¶
The default PyTorch initialization works for many architectures, but when training stalls or gradients vanish, explicit initialization can help.
Xavier (Glorot) — for sigmoid/tanh¶
nn.init.xavier_uniform_(layer.weight)
nn.init.zeros_(layer.bias)
Kaiming (He) — for ReLU¶
nn.init.kaiming_normal_(layer.weight, nonlinearity="relu")
nn.init.zeros_(layer.bias)
Custom initialization loop¶
def init_weights(module):
if isinstance(module, nn.Linear):
nn.init.kaiming_normal_(module.weight, nonlinearity="relu")
if module.bias is not None:
nn.init.zeros_(module.bias)
model.apply(init_weights)
Failure Pattern¶
Training a deep network with default initialization and no normalization, then blaming the architecture when gradients collapse after a few layers.
Another trap: using batch normalization with a batch size of 1 during training, which makes the batch statistics meaningless.
Common Mistakes¶
- forgetting to switch to
eval()mode, making batch norm use per-batch statistics during validation - using
BatchNormin a model that processes single samples at inference time without switching to eval mode - initializing all weights to zero, which makes every neuron learn the same thing
- mixing Xavier init with ReLU activations, where Kaiming is more appropriate
Practice¶
- Add batch normalization to a simple MLP and compare convergence speed.
- Swap
BatchNorm1dforLayerNormand explain when each is preferable. - Initialize a network with Kaiming init versus default and compare the first-epoch loss.
- Show what happens when you forget
model.eval()with batch norm during validation. - Apply a custom initialization function and inspect the weight distributions.
Runnable Example¶
Longer Connection¶
Continue with PyTorch Training Loops for the full loop structure, and Optimizers and Regularization for regularization techniques that complement batch normalization.