Skip to content

Batch Normalization and Initialization

What This Is

Batch normalization stabilizes training by normalizing layer inputs, and initialization sets the starting point for learning. Together they determine whether a network trains smoothly or stalls before it starts.

When You Use It

  • training deep networks where gradients vanish or explode
  • speeding up convergence so you can use higher learning rates
  • stabilizing training when adding more layers
  • debugging a network that trains on one configuration but fails on another

Tooling

  • nn.BatchNorm1d and nn.BatchNorm2d for batch normalization
  • nn.LayerNorm for sequence or transformer models
  • nn.init.xavier_uniform_ and nn.init.xavier_normal_ for sigmoid/tanh activations
  • nn.init.kaiming_uniform_ and nn.init.kaiming_normal_ for ReLU activations
  • nn.init.zeros_ and nn.init.ones_ for bias and gain terms

How Batch Norm Works

Batch normalization normalizes activations across the batch dimension during training, then uses running statistics during evaluation.

import torch.nn as nn

model = nn.Sequential(
    nn.Linear(128, 64),
    nn.BatchNorm1d(64),
    nn.ReLU(),
    nn.Linear(64, 10),
)

Key behaviors:

  • in model.train() mode, it computes batch statistics and updates running mean/variance
  • in model.eval() mode, it uses the stored running statistics
  • if you forget model.eval(), validation becomes noisy because it uses batch-level statistics

When To Use Which Normalization

Layer Best For Notes
BatchNorm1d MLPs with batch dimension needs batch size > 1
BatchNorm2d CNNs normalizes per channel
LayerNorm transformers, sequence models normalizes per sample, batch-size independent
GroupNorm small batch training compromise between batch and layer norm

Initialization Matters

The default PyTorch initialization works for many architectures, but when training stalls or gradients vanish, explicit initialization can help.

Xavier (Glorot) — for sigmoid/tanh

nn.init.xavier_uniform_(layer.weight)
nn.init.zeros_(layer.bias)

Kaiming (He) — for ReLU

nn.init.kaiming_normal_(layer.weight, nonlinearity="relu")
nn.init.zeros_(layer.bias)

Custom initialization loop

def init_weights(module):
    if isinstance(module, nn.Linear):
        nn.init.kaiming_normal_(module.weight, nonlinearity="relu")
        if module.bias is not None:
            nn.init.zeros_(module.bias)

model.apply(init_weights)

Failure Pattern

Training a deep network with default initialization and no normalization, then blaming the architecture when gradients collapse after a few layers.

Another trap: using batch normalization with a batch size of 1 during training, which makes the batch statistics meaningless.

Common Mistakes

  • forgetting to switch to eval() mode, making batch norm use per-batch statistics during validation
  • using BatchNorm in a model that processes single samples at inference time without switching to eval mode
  • initializing all weights to zero, which makes every neuron learn the same thing
  • mixing Xavier init with ReLU activations, where Kaiming is more appropriate

Practice

  1. Add batch normalization to a simple MLP and compare convergence speed.
  2. Swap BatchNorm1d for LayerNorm and explain when each is preferable.
  3. Initialize a network with Kaiming init versus default and compare the first-epoch loss.
  4. Show what happens when you forget model.eval() with batch norm during validation.
  5. Apply a custom initialization function and inspect the weight distributions.

Runnable Example

This example stays local-only for now because the browser runner does not yet include PyTorch.

Longer Connection

Continue with PyTorch Training Loops for the full loop structure, and Optimizers and Regularization for regularization techniques that complement batch normalization.