Skip to content

Diffusion Models

What This Is

A diffusion model learns to generate data by training a neural network to reverse a gradual noising process. You take a clean sample (an image, an audio clip, a molecule) and progressively add Gaussian noise until it is pure noise. Then you train a network to undo one step of that corruption. Generation is sampling pure noise and running the reverse process to refine it into a clean sample.

The academy's angle is not derivations. It is three things a student must get right:

  • what the forward and reverse processes actually are, and why "just predict the noise" works
  • how sampling differs from training — the same network is used differently at inference
  • where diffusion models fail in practice (classifier-free guidance tradeoffs, evaluation blind spots, slow sampling)

When You Use It

  • you need a generative model for continuous data (images, audio, latent features)
  • you want a model that trains more stably than a GAN
  • you want controllable generation (text-to-image, inpainting, super-resolution)
  • you can afford slow inference or you can invest in distillation / few-step samplers

Do Not Use It When

  • the task is discrete (text, code) — diffusion on text is still a research frontier; autoregressive decoding still dominates
  • you need fast inference on a tight budget and cannot distill
  • a plain discriminative model or a retrieval approach solves your actual problem (generation is often a false frame)

The Forward Process

Start with a clean sample x_0. Define a schedule β_1, ..., β_T of tiny noise increments. The forward process is a chain of Gaussians:

q(x_t | x_{t-1}) = N(x_t; sqrt(1 - β_t) x_{t-1},  β_t I)

A useful reparameterization: let α_t = 1 - β_t and ᾱ_t = Π_{s ≤ t} α_s. Then you can sample x_t directly from x_0 in one step:

x_t = sqrt(ᾱ_t) x_0 + sqrt(1 - ᾱ_t) ε,    ε ~ N(0, I)

This is load-bearing. It means you never simulate a long forward chain during training — you sample a single t and get the noisy version in one line.

The Reverse Process

The reverse process is a Markov chain that undoes one step:

p_θ(x_{t-1} | x_t) = N(x_{t-1}; μ_θ(x_t, t),  Σ_θ(x_t, t))

The network ε_θ(x_t, t) is trained to predict the noise ε that was added. The reverse-step mean is then determined by algebra on the forward-process mean.

The Training Loss (Simplified)

The DDPM paper derives a variational objective, but in practice the training loss collapses to the "simple" form that actually works:

L_simple = E_{t, x_0, ε} ||  ε  -  ε_θ( sqrt(ᾱ_t) x_0 + sqrt(1 - ᾱ_t) ε,  t ) ||²

In plain English: pick a random timestep t, take a clean sample, add noise of the right magnitude, and ask the network to predict the noise. The loss is a mean-squared error.

Minimal PyTorch training step:

def train_step(x0, model, betas, optimizer):
    B = x0.size(0)
    T = len(betas)
    t = torch.randint(0, T, (B,), device=x0.device)
    alphas = 1.0 - betas
    alpha_bar = torch.cumprod(alphas, dim=0)
    a_bar_t = alpha_bar[t].view(B, 1, 1, 1)

    eps = torch.randn_like(x0)
    x_t = a_bar_t.sqrt() * x0 + (1 - a_bar_t).sqrt() * eps

    eps_pred = model(x_t, t)
    loss = ((eps - eps_pred) ** 2).mean()
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()
    return loss.item()

Points that matter:

  • the same network is evaluated at every timestep, so the timestep embedding (usually sinusoidal + MLP) must be effective
  • the schedule choice (linear vs. cosine vs. sigmoid) materially affects quality, especially at the low-noise tail
  • a U-Net backbone is the historical default for images; modern work uses DiTs (transformers on patches)

Sampling (DDPM)

At inference, start from x_T ~ N(0, I) and iterate:

@torch.no_grad()
def ddpm_sample(model, shape, betas, T):
    x = torch.randn(shape, device=betas.device)
    alphas = 1.0 - betas
    alpha_bar = torch.cumprod(alphas, dim=0)
    for t in reversed(range(T)):
        a_t = alphas[t]
        a_bar_t = alpha_bar[t]
        eps_pred = model(x, torch.full((shape[0],), t, device=x.device))
        mean = (x - (1 - a_t) / (1 - a_bar_t).sqrt() * eps_pred) / a_t.sqrt()
        if t > 0:
            noise = torch.randn_like(x) * betas[t].sqrt()
            x = mean + noise
        else:
            x = mean
    return x

DDPM sampling is T forward passes — typically 1000 — which is why naive diffusion is slow.

Fast Samplers

A large chunk of real-world diffusion engineering is making sampling cheap:

  • DDIM — same trained model, deterministic sampler, usable in 20–50 steps with near-DDPM quality
  • PNDM / DPM-Solver / DPM-Solver++ — higher-order ODE solvers that get to ~20 steps without quality loss
  • Consistency models — distill a pretrained diffusion into a one- or two-step generator
  • Latent diffusion (Stable Diffusion) — run the diffusion in the compressed latent space of a pretrained autoencoder instead of pixel space; quality improves and compute drops 8–64×

Pick your sampler after training. Training and sampling are decoupled for diffusion; this is one of the reasons diffusion eats GAN territory.

Classifier-Free Guidance (CFG)

Diffusion models can be conditioned on class labels, captions, CLIP embeddings, etc. Classifier-free guidance trains the model to handle both conditional and unconditional generation (by dropping the condition with probability ~10% during training). At sample time you compute:

ε_guided = ε_uncond + w · (ε_cond - ε_uncond)

The scalar w is the guidance scale. w = 1 is plain conditional sampling; w ≈ 7 is the text-to-image default. Higher w sharpens adherence to the condition but reduces diversity and can introduce artifacts. Lower w is more diverse but weaker conditioning.

CFG is the single most important inference-time knob in text-to-image. Always report what w you used.

Evaluation Blind Spots

Diffusion evaluation is famously messy:

  • FID (Fréchet Inception Distance) — compares distributions of generated vs. real features through an Inception network. Biased, sample-size dependent, and breaks on out-of-domain data.
  • IS (Inception Score) — mostly historical; do not use alone
  • CLIP score — for text-to-image; measures caption adherence, not quality
  • human eval — expensive but the only thing that reliably catches "technically-good-FID-but-ugly" failures
  • targeted evaluation — generate samples conditioned on held-out captions and check for mode collapse and concept leakage

Two cross-cutting rules:

  • always report the guidance scale — FID vs w is a curve, not a point
  • always report the sample count — FID is sensitive to it

What To Inspect

  • loss curve — should be flat after the initial descent; large swings usually mean schedule or LR problems
  • samples at different timesteps — plot x_t for t ∈ {T, 3T/4, T/2, T/4, 0} during sampling; if the middle steps look like noise the network has not learned its middle regime
  • guidance sweep — generate the same seed at w ∈ {1, 3, 5, 7, 10}; this is the honest quality/diversity curve for text conditioning
  • mode coverage — sample N images, cluster in feature space, confirm clusters cover the classes
  • timestep embedding health — a dead timestep embedding (same outputs at every t) is a common silent bug
  • training noise schedule vs. sampler schedule — mismatches are cheap to introduce and catastrophic

Failure Pattern

The characteristic failure is looks good on FID, ugly to humans. A model scores well on FID by matching the Inception feature statistics but generates samples with visible anatomical errors, texture smears, or repeated artifacts. FID cannot see these. The fix is layered human evaluation; the lesson is FID is a hint, not a verdict.

A second failure: guidance too high. Sample quality looks crisp, but diversity collapses and the model over-commits to the caption. Dial w down before changing the model.

Common Mistakes

  • using the wrong schedule at train vs. sample time
  • forgetting to drop the condition ~10% of the time during training (breaks CFG)
  • training on pixel space for too-large images instead of using a latent autoencoder
  • benchmarking on too few samples — FID below 5000 samples is noise
  • comparing diffusion to GANs on FID alone without human eval
  • shipping a 1000-step DDPM sampler to production — distill or switch sampler
  • ignoring the timestep embedding when debugging slow convergence

Decision: Diffusion vs. GAN vs. Autoregressive vs. VAE

option when it wins trade
diffusion stable training, controllable generation, strong quality for continuous data slow sampling unless distilled
GAN extremely fast sampling, sharp outputs when training stabilizes training instability, mode collapse risk
autoregressive discrete data (text, code), exact likelihood slow serial sampling; continuous data needs extra machinery
VAE cheap, fast, useful as a feature encoder blurry samples when used as a final generator

In 2024+ practice, diffusion dominates continuous-data generation; autoregressive dominates text; VAEs survive as encoders inside latent diffusion.

Practice

  1. Train a tiny DDPM on 28×28 MNIST with a small U-Net for 20 epochs. Plot the loss curve and samples at t ∈ {T, T/2, 0}.
  2. Swap the linear schedule for a cosine schedule. Measure the effect on sample quality (eye test is fine at this scale).
  3. Implement DDIM sampling on the trained model. Verify 20-step DDIM samples are almost identical to 1000-step DDPM samples.
  4. Condition on the MNIST digit class with classifier-free guidance (drop condition with probability 0.1 during training). Sweep w ∈ {1, 3, 5, 7} and pick the best by eye.
  5. Plot FID (on a small real/fake feature extractor) vs. w. Find the elbow.
  6. Deliberately collapse the timestep embedding (return zeros). Train and observe that the loss still decreases but samples degrade — this is the "timestep embedding is dead" failure.

Runnable Example

A meaningful diffusion run needs PyTorch + GPU and some wall time, so the browser runner is not the right home. Run the training sketch above locally with pip install torch torchvision einops. For a clean reference implementation, the huggingface/diffusers library's pipelines are the easiest on-ramp; treat them as a reference for the algorithms above, not as magic.

Longer Connection

Diffusion models connect to several existing academy topics:

For the decision frame — when generation is actually the right answer — Baseline-First Task Solving is still the right first move; diffusion is a tool, not a goal.