14.3 DDPM 的反向去噪过程与训练目标

Author

jshn9515

Published

April 5, 2026

上一节我们已经把 DDPM 的前向过程讲清楚了。

我们知道,它先定义了一个固定的加噪链:

\[ x_0 \rightarrow x_1 \rightarrow x_2 \rightarrow \cdots \rightarrow x_T \]

并且随着步数增加,图像里的结构会逐渐被噪声淹没,最后 \(x_T\) 会接近标准高斯分布。

那么,既然我们能把图像一步步加噪到高斯噪声,那能不能再从高斯噪声一步步走回来?

这就是 反向扩散过程(reverse diffusion process) 的核心问题。

这一节,我们就来讲清楚三件事:

  1. 反向过程到底想学什么;
  2. 为什么它可以被建模成一步一步去噪;
  3. 为什么 DDPM 最后通常把训练目标写成预测噪声。
import random

import matplotlib.pyplot as plt
import torch
import torchvision.datasets as datasets
import torchvision.transforms.v2 as v2

%config InlineBackend.figure_format = 'retina'
print('PyTorch version:', torch.__version__)
PyTorch version: 2.11.0+xpu

14.3.1 如果前向能走,反向为什么不能走?

前向过程是我们自己设计的:

\[ q(x_t \mid x_{t-1}) = \mathcal{N}(x_t; \sqrt{1-\beta_t}x_{t-1}, \beta_t I) \]

也就是说,我们知道怎么从 \(x_{t-1}\) 构造出 \(x_t\)

但生成时,我们关心的是相反方向:

\[ x_T \rightarrow x_{T-1} \rightarrow x_{T-2} \rightarrow \cdots \rightarrow x_0 \]

第一次看到这里,很多人会有一个很自然的疑问:既然前向加噪这么简单,那直接把它倒过来不就行了吗?可惜事情没有这么简单。因为加噪本身是一个会丢信息的过程。

举个例子:如果你有一张清晰的猫图,给它加一点噪声,你仍然大概能看出这是一只猫;但如果我只给你一张带噪图,你并不能唯一确定它原来是哪一张清晰图。所以,如果用函数逆变换的方式来理解反向过程,就会发现它根本不满足单值性(single-valued)。也就是说,反向过程就是一对多的。一张带噪图像背后,可能对应很多种可能的清晰图像。

所以,反向过程不能简单理解成一个确定性求逆。它更合理的理解方式是:

给定当前的带噪图像 \(x_t\),下一步更干净的图像 \(x_{t-1}\) 应该服从某个条件概率分布。

我们把这个条件分布记作 \(q(x_{t-1} \mid x_t)\)。也就是说,已知当前第 \(t\) 步的含噪图像 \(x_t\),模型要给出上一步更干净样本 \(x_{t-1}\) 的概率分布。它描述的是真实扩散过程对应的反向单步分布。那么,这时候就会有一个问题:这个分布好求吗?

我们来看看贝叶斯公式:

\[ q(x_{t-1} \mid x_t) = \frac{q(x_t \mid x_{t-1}) q(x_{t-1})}{q(x_t)} \]

前向分布 \(q(x_t \mid x_{t-1})\) 是我们自己设计的,这个好办。但 \(q(x_{t-1})\)\(q(x_t)\) 呢?它们分别对应样本在第 \(t-1\) 步和第 \(t\) 步的边缘分布。也就是说:

\[ q(x_{t-1}) = \int q(x_{t-1} \mid x_0) q(x_0) dx_0, \qquad q(x_t) = \int q(x_t \mid x_0) q(x_0) dx_0 \]

你看,它们都涉及到真实数据分布 \(q(x_0)\),而真实数据分布是我们无法直接建模的(如果知道了我们还要 DDPM 干嘛)。所以,这个反向条件分布 \(q(x_{t-1} \mid x_t)\) 是一个非常复杂的分布,我们根本无法直接求出它。

那怎么办?别忘了我们有真实图像!我们能不能利用它们?

当然。虽然 \(q(x_{t-1} \mid x_t)\) 很复杂,但如果我们把分布写成这种形式:

\[ q(x_{t-1} \mid x_t, x_0) = \frac{q(x_t \mid x_{t-1}, x_0) q(x_{t-1} \mid x_0)}{q(x_t \mid x_0)} \]

由于前向过程是马尔可夫链,根据马尔可夫性,我们有:

\[ q(x_t \mid x_{t-1}, x_0) = q(x_t \mid x_{t-1}) \]

上式简化为:

\[ q(x_{t-1} \mid x_t, x_0) = \frac{q(x_t \mid x_{t-1}) q(x_{t-1} \mid x_0)}{q(x_t \mid x_0)} \]

你会发现,右边的三项我们都是知道的!前两项是我们自己设计的前向过程,最后一项 \(q(x_t \mid x_0)\) 也可以通过前向过程的递推关系求出。也就是说,虽然 \(q(x_{t-1} \mid x_t)\) 很复杂,但 \(q(x_{t-1} \mid x_t, x_0)\) 却是一个简单的分布,我们可以直接求出它的解析表达式,进而一步步地去求出反向过程的条件分布。

实际上,可以证明的是,在前向过程的定义下,反向条件分布 \(q(x_{t-1} \mid x_t, x_0)\) 是一个高斯分布:

\[ q(x_{t-1} \mid x_t, x_0) = \mathcal{N}(x_{t-1}; \tilde{\mu}_t(x_t, x_0), \tilde{\beta}_t I) \]

其中:

\[ \tilde{\mu}_t(x_t,x_0) = \frac{\sqrt{\bar{\alpha}_{t-1}}\,\beta_t}{1-\bar{\alpha}_t}\,x_0 + \frac{\sqrt{\alpha_t}\,(1-\bar{\alpha}_{t-1})}{1-\bar{\alpha}_t}\,x_t \]

\[ \tilde{\beta}_t = \frac{1-\bar{\alpha}_{t-1}}{1-\bar{\alpha}_t}\,\beta_t \]

完整的证明见 (Luo 2022, eq. 71-84)。注意,这里的结果和论文的结果有点不一样。论文里忽略了部分常数项,因此写的是“正比于”,而这里我们把常数项也写出来了。

现在我们来做一个实验。使用 MNIST 数据集,假设我们手里有原始图像 \(x_0\),我们先把一张图像按照前向公式加噪到高斯噪声,然后再从高斯噪声一步步走回去。我们来看看这个过程中图像的变化。

# Change the root path to your local directory if needed
root = 'D:/Workspaces/Python Project/datasets'
transform = v2.Compose([v2.ToImage(), v2.ToDtype(torch.float32, scale=True)])
ds = datasets.MNIST(root, train=False, download=True, transform=transform)

idx = random.randint(0, len(ds) - 1)
x0 = ds[idx][0].squeeze(0)  # shape: (28, 28)

T = 1000
betas = torch.linspace(0.0001, 0.02, steps=T)
alphas = 1.0 - betas
alpha_bars = alphas.cumprod(dim=0)

eps = torch.randn_like(x0)
xt = alpha_bars[-1].sqrt() * x0 + (1 - alpha_bars[-1]).sqrt() * eps
trajectory = [xt.clone()]

for t in range(T - 1, -1, -1):
    alpha_t = alphas[t]
    alpha_bar_t = alpha_bars[t]
    alpha_bar_prev_t = alpha_bars[t - 1] if t > 0 else torch.tensor(1.0)
    beta_t = betas[t]

    param1 = alpha_bar_prev_t.sqrt() * beta_t / (1 - alpha_bar_t)
    param2 = alpha_t.sqrt() * (1 - alpha_bar_prev_t) / (1 - alpha_bar_t)
    mean = param1 * x0 + param2 * xt
    variance = (1 - alpha_bar_prev_t) / (1 - alpha_bar_t) * beta_t

    if t > 0:
        z = torch.randn_like(x0)
        xt = mean + variance.sqrt() * z
    else:
        xt = mean

    trajectory.append(xt.clone())

# We use step=8 here for better visualization
idx = torch.logspace(3, 0, steps=8, dtype=torch.long)
trajectory = [trajectory[T - i] for i in idx - 1]

fig, axes = plt.subplots(1, len(trajectory), figsize=(8, 2))
for i, ax in enumerate(axes):
    ax.imshow(trajectory[i], cmap='gray')
    ax.axis('off')
    ax.set_title(f't={idx[i]}')

plt.tight_layout(pad=0.5)
plt.show()

你看,我们成功恢复出了原始图像!但是,这里又有一个问题:我们在恢复图像的过程中利用了原始图像 \(x_0\)。可是我们生成图像就是要生成 \(x_0\) 的啊,如果我们都知道了 \(x_0\),那我们还要生成什么?

这时候神经网络就派上用场了。我们定义了一个带参数的条件分布 \(p_\theta(x_{t-1} \mid x_t)\),让它去近似 \(q(x_{t-1} \mid x_t, x_0)\)

14.3.2 反向过程:学习 \(p_\theta(x_{t-1} \mid x_t)\)

本质上,反向过程要学的就是每一步的反向条件分布:

\[ p_\theta(x_{t-1} \mid x_t) \]

也就是说,已知当前第 \(t\) 步的含噪图像 \(x_t\),模型要给出上一步更干净样本 \(x_{t-1}\) 的概率分布。

于是,整个生成链就可以写成:

\[ p(x_{0:T}) = p(x_T) \prod_{t=1}^{T} p_\theta(x_{t-1} \mid x_t) \]

起点 \(p(x_T)\) 很简单,通常直接取标准高斯分布 \(\mathcal{N}(0, I)\)。这里的难点全都集中在每一步的反向条件分布 \(p_\theta(x_{t-1} \mid x_t)\) 上。那么,这个分布是什么样子的呢?我们又该如何去学习它呢?

在上一节里我们知道,反向条件分布 \(q(x_{t-1} \mid x_t, x_0)\) 本质上是一个高斯分布:

\[ q(x_{t-1} \mid x_t, x_0) = \mathcal{N}(x_{t-1}; \tilde{\mu}_t(x_t, x_0), \tilde{\beta}_t I) \]

所以我们可以大胆假设,我们要学的分布 \(p_\theta(x_{t-1} \mid x_t)\) 也是一个高斯分布:

\[ p_\theta(x_{t-1} \mid x_t) = \mathcal{N}(x_{t-1}; \mu_\theta(x_t, t), \Sigma_\theta(x_t, t)) \]

模型实际上要学的就是每一步的均值 \(\mu_\theta(x_t, t)\) 和协方差 \(\Sigma_\theta(x_t, t)\)

如果你了解一点 DDPM,可能你就会问了,不是说只要预测均值吗?

其实是的。我们观察协方差的表达式就会发现,它其实只是一个和时间步 \(t\) 相关的常数项:

\[ \tilde{\beta}_t = \frac{1-\bar{\alpha}_{t-1}}{1-\bar{\alpha}_t}\,\beta_t \]

也就是说,如果我们直接把协方差固定成 \(\tilde{\beta}_t I\),就已经足够了。所以,在实际训练中,我们通常只让模型去预测均值 \(\mu_\theta(x_t, t)\),而把协方差 \(\Sigma_\theta(x_t, t)\) 固定成 \(\tilde{\beta}_t I\)

14.3.3 为什么最后常写成预测噪声?

到这里,你可能会觉得,既然反向分布是高斯,模型主要学的是均值 \(\mu_\theta(x_t, t)\),那训练时直接预测这个均值不就可以了吗?

理论上当然可以。但 DDPM 在实践中通常采用一种更巧妙、也更稳定的参数化方式:

不直接预测均值,而是预测当前样本里混入的噪声 \(\epsilon\)

也就是让模型学习:

\[ \epsilon_\theta(x_t, t) \]

理由也很简单。我们知道,真实均值 \(\tilde{\mu}_t(x_t, x_0)\) 的表达式里其实是包含了原图 \(x_0\) 的:

\[ \tilde{\mu}_t(x_t, x_0) = \frac{\sqrt{\bar{\alpha}_{t-1}}\,\beta_t}{1-\bar{\alpha}_t}\,x_0 + \frac{\sqrt{\alpha_t}\,(1-\bar{\alpha}_{t-1})}{1-\bar{\alpha}_t}\,x_t \]

所以,如果我们想把这个均值预测的准,由于这个均值本身依赖 \(x_0\),模型其实就等于在间接恢复原图信息。与其直接去预测这样一个形式复杂、并且随时间步变化的均值,我们更希望把学习目标改写成一个更简单、更稳定的形式。

我们知道,前向过程有一个闭式采样公式:

\[ x_t = \sqrt{\bar{\alpha}_t}x_0 + \sqrt{1-\bar{\alpha}_t}\epsilon \]

也就是说,当前的带噪图像 \(x_t\) 是由原图 \(x_0\) 和噪声 \(\epsilon\) 混合而成的。我们可以把这个式子变形一下:

\[ \hat{x}_0 = \frac{1}{\sqrt{\bar{\alpha}_t}} (x_t - \sqrt{1-\bar{\alpha}_t}\epsilon) \]

这里的噪声 \(\epsilon\) 是训练时我们自己采样进去的,所以它是已知的。这样我们就可以把噪声作为模型的训练目标,让模型去预测它,从而间接得到原图 \(x_0\) 的一个估计,并最终得到均值 \(\mu_\theta(x_t, t)\)。同时,我们也避免了直接预测一个随时间步变化的复杂均值的麻烦。况且预测噪声和预测均值本质上是等价的。

所以,DDPM 的最终训练目标通常写成:

\[ L(\theta) = \mathbb{E}_{t, x_0, \epsilon} \left[ \| \epsilon - \epsilon_\theta(x_t, t) \|^2 \right] \]

Note

这里的推导是不严谨的。我们仅仅从直觉上说明预测噪声是合理的,并未严格说明其理论依据。如果从严格的概率建模角度来看,DDPM 的训练目标其实来自变分下界(ELBO),而常见的噪声预测损失,是在该目标基础上的一种等价或近似等价改写。这里先不展开完整推导,后面再详细说明。感兴趣的读者可以先看看 (Luo 2022, eq. 46-58, 115-130)

14.3.4 DDPM 的训练目标:一个非常简单的 MSE

在上一节里,我们把 DDPM 的训练目标写成了一个预测噪声的均方误差:

\[ L(\theta) = \mathbb{E}_{t, x_0, \epsilon} \left[ \| \epsilon - \epsilon_\theta(x_t, t) \|^2 \right] \]

其中,\(x_0\) 是从真实数据 \(p_{\text{data}}\) 里采样得到的,\(t\) 是从 1 到 \(T\) 中随机挑选的一个时间步,\(\epsilon\) 是我们自己采样的高斯噪声。这个损失函数看起来有点过于简单:整个 diffusion model,最后经常就是在做一个噪声回归的均方误差。

但是,虽然表面上只是 MSE,背后其实对应着对反向扩散过程的概率建模。也就是说,这个简单的 MSE 损失函数,其实是从一个严谨的概率模型出发,经过一系列等价或近似等价的变换,最终得到的一个非常易于优化的训练目标。与很多深度学习网络不同,它的背后是有理论支撑的。我们用一段伪代码来描述这个过程:

算法 1:DDPM 训练过程伪代码
算法 1:DDPM 训练过程伪代码 (Ho et al. 2020, alg. 1)

你看,是不是很简单?别被它迷惑了。我们在后面会详细说明这个训练目标是怎么来的,以及它和概率建模之间的关系。

14.3.5 生成阶段在做什么?

训练完成之后,模型已经学会了在不同时间步下预测噪声。于是生成时,我们就可以这样做:

  1. 先从纯高斯噪声 \(x_T \sim \mathcal{N}(0, I)\) 开始;
  2. 输入 \((x_T, T)\),模型预测噪声;
  3. 根据这个噪声估计,构造 \(x_{T-1}\)
  4. 再输入 \((x_{T-1}, T-1)\),继续去噪;
  5. 一直重复,直到得到 \(x_0\)

所以生成不是一次完成的,而是一个逐步去噪的过程。这也解释了 diffusion model 的两个典型特点:

  • 生成质量高:因为每一步都只做一个小修正;
  • 采样速度慢:因为它真的要走很多很多步。

这是生成阶段的伪代码,同样也很简单:

算法 2:DDPM 生成过程伪代码
算法 2:DDPM 生成过程伪代码 (Ho et al. 2020, alg. 2)

如果你仔细看这个伪代码就会发现,我们似乎在采样过程中也加入了噪声。也就是说,在每一步去噪的时候,我们并不是直接把 \(x_t\) 变成 \(x_{t-1}\),得到一个唯一答案,而是会在去噪的基础上再加入一些随机性。也就是说,生成过程其实也是一个随机过程。我们在后面会详细说明这个随机性是从哪里来的,以及它对生成质量和多样性的影响。

14.3.6 本章小结

到这里,我们可以把 14.1、14.2、14.3 三节串起来了。

第一步:定义前向加噪

我们人为设计一个固定过程,把真实数据一步步变成噪声:

\[ x_0 \rightarrow x_1 \rightarrow \cdots \rightarrow x_T \]

并且最后 \(x_T\) 接近标准高斯分布。

第二步:把生成问题转成反向恢复问题

既然前向可以把数据推向噪声,那么生成时就从噪声反向走回来:

\[ x_T \rightarrow x_{T-1} \rightarrow \cdots \rightarrow x_0 \]

第三步:把反向过程建模成条件高斯

每一步不是直接求逆,而是学习:

\[ p_\theta(x_{t-1} \mid x_t) \]

第四步:把训练目标变成预测噪声

利用前向过程的闭式公式,我们可以直接构造监督信号,让模型学习:

\[ \epsilon_\theta(x_t, t) \]

这就把一个复杂的生成建模问题,变成了一个可以稳定优化的噪声回归问题。

这条逻辑链,就是最基础的 DDPM 训练框架。

到此,我们终于了解了 DDPM 的训练和采样流程。但是,我们还有很多细节没有讲清楚。比如模型输入里的时间步 \(t\) 要怎么表示?为什么 U-Net 特别适合做去噪网络?采样时具体怎么从 \(x_t\) 算出 \(x_{t-1}\)?这就要到下一节,来看看 DDPM 的一些细节设计。

References

Ho, Jonathan, Ajay Jain, and Pieter Abbeel. 2020. Denoising Diffusion Probabilistic Models. https://arxiv.org/abs/2006.11239.
Luo, Calvin. 2022. Understanding Diffusion Models: A Unified Perspective. https://arxiv.org/abs/2208.11970.

Reuse