14.4 DDPM 的网络结构与采样过程

Author

jshn9515

Published

April 5, 2026

前面几节里,我们已经把 DDPM 最核心的想法搭起来了:

但是,我们还没有真正把这些想法落实到一个具体的网络结构和采样流程中来。比如,我们知道模型要预测噪声,但它的输入是什么?它的输出是什么?时间步 \(t\) 怎么告诉网络?为什么很多实现喜欢用 U-Net?采样时又是怎么迭代的?

这一节,我们从工程实现的角度,把 DDPM 的整体运行方式讲清楚。

import math

import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from torch import Tensor

%config InlineBackend.figure_format = 'retina'
print('PyTorch version:', torch.__version__)
PyTorch version: 2.11.0+xpu
device = torch.accelerator.current_accelerator(check_available=True)
if device is None:
    device = torch.device('cpu')
print('Using device:', device)
Using device: xpu

14.4.1 为什么网络必须知道时间步 \(t\)

我们先回忆一下训练时的任务。我们会从真实图像 \(x_0\) 出发,随机采样一个时间步 \(t\),然后按照前向过程构造出带噪图像 \(x_t\)。接着,把 \((x_t, t)\) 输入神经网络,让它预测当前这一步里的噪声:

\[ \epsilon_\theta(x_t, t) \]

所以,我们需要输入一张带噪图像 \(x_t\) 和当前时间步 \(t\),让网络输出这张图像中包含的噪声估计 \(\epsilon_\theta(x_t, t)\)。如果这个噪声估计足够准确,我们就能按照反向公式,把图像往更干净的方向推一步。

那么,假设你只把 \(x_t\) 输入给网络,而不告诉它当前是第几步,会发生什么?

答案是,网络会很困惑。因为不同时间步的图像噪声强度差别非常大。在早期时间步,图像还比较清晰,只是稍微有一点噪声;在中间时间步,图像已经很模糊,结构和噪声混在一起;在后期时间步,图像几乎就是纯噪声。这三种情况下,该去掉多少噪声,该保留多少结构是完全不同的任务。如果不告诉网络当前是第几步,它就必须仅靠输入图像本身去猜测噪声级别,这会大大增加学习难度。

所以在 DDPM 里,模型一般都写成:

\[ \epsilon_\theta(x_t, t) \]

也就是说,时间步 \(t\) 是模型输入的一部分。这是一种条件信息。它告诉网络现在噪声有多强,让网络知道当前应该做多大幅度的修正,从而让同一个网络可以处理不同噪声水平下的去噪任务。

所以 DDPM 实际上是就是一个条件去噪器(conditional denoiser),条件就是当前的时间步 \(t\)

14.4.2 时间步信息怎么送进网络?

有了时间步 \(t\),那接下来的问题就是:图像是一个张量,时间步 \(t\) 只是一个整数。怎么把这个整数变成神经网络能有效利用的信息?

最简单的方法就是把 \(t\) 归一化成一个标量,然后拼进去。但实践里,这样通常不够好。因为不同时间步之间的关系不是线性的,网络需要一种更丰富的表示,来区分不同的噪声阶段。仅用一个数字,很难表达早期、中期、后期的复杂差别。因为网络不会天然理解“100 比 10 更大”,它只能通过训练数据去学习这个关系,而这会增加学习难度。

所以很多 diffusion 模型都会使用一种和 Transformer 里位置编码很像的方法,就是把时间步 \(t\) 映射成一个高维向量,这通常叫做 时间步嵌入(time embedding)。这个向量的维度可以和图像特征的维度一样,这样就可以直接拼接在一起输入网络。

一种常见做法是使用正弦余弦形式的编码:

\[ \text{emb}(t) = [\sin(\omega_1 t), \cos(\omega_1 t), \sin(\omega_2 t), \cos(\omega_2 t), \dots] \]

然后再经过几层 MLP,把它变成适合当前网络宽度的特征向量。

这样,不用同时间步就会被映射到不同的位置,相近时间步的表示也保持一定平滑性。网络更容易学到不同噪声阶段应该采取什么策略。这和 Transformer 里的 position embedding 有一点像。在那里,我们告诉模型当前 token 在序列中的位置。而在 diffusion 里,我们告诉模型当前图像在去噪链中的阶段。

Tip

关于为什么位置编码和时间步嵌入都喜欢用正弦余弦函数,可以回顾一下 Transformer 章节里对位置编码的分析。简单来说,正弦余弦函数能让不同时间步的表示在空间上有规律地分布,相近时间步的表示也保持一定平滑性,这有助于网络学习不同噪声阶段的策略。

class SinusoidalPositionEmbeddings(nn.Module):
    def __init__(self, embedding_dim: int, max_period: int = 10000):
        """Embed time steps into a high-dimensional space using sinusoidal functions.

        Args:
            embedding_dim (int): The dimension of the output embedding vector.
            max_period (int): The maximum period for the sinusoidal functions,
                which controls the frequency of the embeddings.
        """
        super().__init__()
        self.embedding_dim = embedding_dim
        self.max_period = max_period

    def forward(self, timesteps: Tensor):
        """Embed the input time steps into a high-dimensional space.

        Args:
            timesteps (Tensor): A tensor of shape (batch_size,) containing the time
                steps to embed.

        Returns:
            emb (Tensor): A tensor of shape (batch_size, embedding_dim) containing
                the sinusoidal embeddings for the input time steps.
        """
        half_dim = self.embedding_dim // 2
        scale = -math.log(self.max_period) / (half_dim - 1)
        emb = torch.arange(half_dim, device=timesteps.device) * scale
        emb = timesteps.unsqueeze(1) * emb.exp().unsqueeze(0)
        emb = torch.concat([emb.sin(), emb.cos()], dim=-1)
        return emb
pos_emb = SinusoidalPositionEmbeddings(embedding_dim=32)

timesteps = torch.tensor([0, 10, 50, 100, 200, 500, 1000], dtype=torch.float32)
emb = pos_emb(timesteps)

fig = plt.figure(1, figsize=(8, 4))
ax = fig.add_subplot(1, 1, 1)
cbar = ax.imshow(emb, aspect='auto', vmin=-1, vmax=1)
ax.set_xlabel('Embedding Dimension')
ax.set_yticks(range(len(timesteps)))
ax.set_yticklabels([f't={int(x)}' for x in timesteps])
ax.set_title('Illustration of sinusoidal time embeddings')
fig.colorbar(cbar)
plt.show()

这张图中的每一行表示某一个时间步的编码结果,每一列表示某一个编码维度在不同时间步上的取值变化。有些列颜色变化较快,说明这些维度对应较高频率,对时间步的细微变化更敏感;有些列变化较慢,说明这些维度对应较低频率,能够提供更平滑、更粗粒度的时间信息。

因此,不同频率的编码维度共同构成了一个多尺度的时间表示,使网络既能感知时间步之间的细小差别,也能把握当前所处的整体噪声阶段。借助这种时间编码,网络就能够区分不同的时间步,并在不同噪声水平下学习采取不同的去噪策略。

14.4.3 为什么 DDPM 里最常见的是 U-Net?

到这里,我们已经知道输入输出是什么了。下一步的问题就是,去噪网络应该用什么结构?

理论上,很多网络都可以尝试。但在图像 diffusion 模型里,最经典、最常见的选择是 U-Net。因为 U-Net 非常适合做这种输入一张图,输出一张同尺寸图的任务 (Ronneberger et al. 2015)

DDPM 的输入是 \(x_t\),输出通常是一个和图像同样大小的噪声张量 \(\hat{\epsilon}\)。也就是说,模型需要对每一个像素位置给出预测,而且还要同时理解局部纹理信息和全局结构信息。U-Net 的设计正好满足了这个需求。

14.4.3.1 U-Net 的基本思想回顾

在语义分割里,我们详细讲解了 U-Net 的结构。这里我们简单回顾一下它的核心思想。

U-Net 这个名字来自它的形状。它一般分成两部分:

  • 下采样路径(encoder):不断压缩分辨率,提取更高层、更大感受野的特征;
  • 上采样路径(decoder):逐步恢复空间分辨率,把抽象特征重新变回像素级输出。

中间再加上 skip connection,把早期高分辨率特征直接传给后面的上采样层。

这样做有几个明显好处:

  1. 下采样阶段能看到更大范围的上下文,理解整体结构;
  2. 上采样阶段能恢复空间细节;
  3. skip connection 能保留浅层的局部纹理和边缘信息。

而去噪这件事,本来就同时需要看全局和看局部。全局上,我们要知道这张图大概是什么结构;在局部上,我们要知道某个像素附近的噪声该怎么修正。所以 U-Net 和 diffusion 的任务天然很契合。

14.4.3.2 时间步信息怎么融入 U-Net?

前面我们讲过,模型除了图像 \(x_t\),还必须知道当前时间步 \(t\)。那么在 U-Net 里,这个时间信息通常怎么融进去?

常见做法是:

  1. 先把 \(t\) 变成一个 time embedding 向量;
  2. 再用一个小 MLP 变换到合适维度;
  3. 把这个向量加到各层的特征中,或者用于调制各层激活。

当然,在更现代的 diffusion 模型里,除了时间 embedding,还常常会加入类别条件、文本条件,或者通过交叉注意力把外部信息融入到网络中。但在最基础的 DDPM 场景里,time embedding + U-Net 就已经是非常经典的组合了。

from ch14_utils import AttentionBlock, Block, Downsample, ResBlock, Upsample


class UNet2DModel(nn.Module):
    def __init__(
        self,
        in_channels: int = 3,
        out_channels: int = 3,
        block_out_channels: tuple[int, ...] = (64, 128, 256, 512),
        time_emb_dim: int = 256,
    ):
        """A UNet-like architecture for diffusion models that processes 2D images and
        incorporates time step embeddings.

        Args:
            in_channels (int): Number of input channels (e.g., 3 for RGB images).
            out_channels (int): Number of output channels (e.g., 3 for RGB images).
            block_out_channels (tuple[int, ...]): A tuple specifying the number of
                output channels for each block in the down and up paths.
            time_emb_dim (int): The dimension of the time step embeddings.
        """
        super().__init__()
        self.time_embedding = nn.Sequential(
            SinusoidalPositionEmbeddings(time_emb_dim),
            nn.Linear(time_emb_dim, time_emb_dim),
            nn.SiLU(),
            nn.Linear(time_emb_dim, time_emb_dim),
        )

        first_ch = block_out_channels[0]
        self.init_conv = nn.Conv2d(
            in_channels,
            first_ch,
            kernel_size=3,
            padding=1,
        )

        # Down path
        self.downs = nn.ModuleList()
        in_ch = block_out_channels[0]
        skip_channels = []

        for out_ch in block_out_channels:
            is_last_ch = out_ch == block_out_channels[-1]
            self.downs.append(
                nn.ModuleList(
                    [
                        ResBlock(in_ch, out_ch, time_emb_dim),
                        ResBlock(out_ch, out_ch, time_emb_dim),
                        AttentionBlock(out_ch),
                        Downsample(out_ch) if not is_last_ch else nn.Identity(),
                    ]
                )
            )
            in_ch = out_ch
            skip_channels.append(out_ch)

        # Middle
        last_ch = block_out_channels[-1]
        self.mid_block1 = ResBlock(last_ch, last_ch, time_emb_dim)
        self.mid_attn = AttentionBlock(last_ch)
        self.mid_block2 = ResBlock(last_ch, last_ch, time_emb_dim)

        # Up path
        self.ups = nn.ModuleList()
        in_ch = block_out_channels[-1]

        for out_ch in reversed(skip_channels):
            is_first_ch = out_ch == skip_channels[0]
            self.ups.append(
                nn.ModuleList(
                    [
                        ResBlock(in_ch + out_ch, out_ch, time_emb_dim),
                        ResBlock(out_ch, out_ch, time_emb_dim),
                        AttentionBlock(out_ch),
                        Upsample(out_ch) if not is_first_ch else nn.Identity(),
                    ]
                )
            )
            in_ch = out_ch

        self.final_block = Block(in_ch, in_ch)
        self.final_conv = nn.Conv2d(in_ch, out_channels, kernel_size=1)

    def forward(self, x: Tensor, timesteps: Tensor) -> Tensor:
        if x.size(0) != timesteps.size(0):
            raise AssertionError(
                f'Batch size of x and timesteps must match, '
                f'but got {x.size(0)} and {timesteps.size(0)}.'
            )

        t_emb = self.time_embedding(timesteps)
        x = self.init_conv(x)

        skips = []
        for block1, block2, attn, down in self.downs:  # type: ignore
            x = block1(x, t_emb)
            x = block2(x, t_emb)
            x = attn(x)
            skips.append(x)
            x = down(x)

        x = self.mid_block1(x, t_emb)
        x = self.mid_attn(x)
        x = self.mid_block2(x, t_emb)

        for block1, block2, attn, up in self.ups:  # type: ignore
            skip = skips.pop()
            x = torch.concat([x, skip], dim=1)
            x = block1(x, t_emb)
            x = block2(x, t_emb)
            x = attn(x)
            x = up(x)

        x = self.final_block(x)
        x = self.final_conv(x)
        return x
unet = UNet2DModel().to(device)

x = torch.randn(1, 3, 64, 64, device=device)
timesteps = torch.tensor([10], dtype=torch.float32, device=device)

out = unet(x, timesteps)
print(out.shape)
torch.Size([1, 3, 64, 64])

14.4.4 采样中的随机性从哪里来?

到这里,我们还有一个重要问题:每一步去噪的时候,是不是完全确定的?

在最经典的 DDPM 采样公式里,答案是:不完全是。

原因在于,反向过程一般写成一个高斯分布:

\[ p_\theta(x_{t-1}\mid x_t) = \mathcal{N}(\mu_\theta(x_t,t), \Sigma_\theta(x_t,t)) \]

这是一个概率分布。这意味着,从 \(x_t\)\(x_{t-1}\) 的过程不仅有一个均值,还会有一定随机性。所以,为了体现这种随机性,在采样时,我们往往会按照模型给出的均值方向前进的同时,再加入一点随机扰动。也就是说,每一步采样的过程是:

\[ x_{t-1} = \mu_\theta(x_t, t) + \sigma_\theta(x_t, t) \cdot z \]

如果完全没有随机性,采样路径可能会过于僵硬。适当保留一些随机扰动,则更符合原始概率模型的定义。也就是说,DDPM 的采样过程本质上也是一个随机过程,而不是一个完全确定的过程。

当然,后来也有很多变体(例如 DDIM)把 DDPM 的随机采样过程改写成了一个可以是确定性的更新过程。但在最基础的 DDPM 里,逐步采样 + 每步带一点随机性 是标准做法。

class DDPMScheduler:
    def __init__(
        self,
        num_train_timesteps: int = 1000,
        beta_start: float = 0.0001,
        beta_end: float = 0.02,
    ):
        """Scheduler for the Denoising Diffusion Probabilistic Models (DDPM) that defines
        the noise schedule and provides a method to add noise to the original samples based
        on the time steps.

        Args:
            num_train_timesteps (int): The total number of time steps used during training,
                which determines the length of the noise schedule.
            beta_start (float): The starting value of the noise variance (beta) at time step 0.
            beta_end (float): The ending value of the noise variance (beta) at the final time step.
        """
        self.num_train_timesteps = num_train_timesteps
        self.beta_start = beta_start
        self.beta_end = beta_end

        self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps)
        self.alphas = 1.0 - self.betas
        self.alphas_cumprod = self.alphas.cumprod(dim=0)

        self.num_inference_steps = num_train_timesteps
        self.timesteps = torch.arange(num_train_timesteps - 1, -1, -1, dtype=torch.long)

    def add_noise(
        self,
        original_samples: Tensor,
        noise: Tensor,
        timesteps: Tensor,
    ) -> Tensor:
        if original_samples.shape != noise.shape:
            raise AssertionError('original_samples and noise must have the same shape.')

        if timesteps.ndim != 1:
            raise AssertionError('timesteps must be a 1D tensor of shape (batch_size,)')

        self.alphas_cumprod = self.alphas_cumprod.to(original_samples.device)
        sqrt_alpha_bar = self.alphas_cumprod[timesteps].sqrt()
        sqrt_alpha_bar = sqrt_alpha_bar.view(-1, 1, 1, 1)

        sqrt_one_minus_alpha_bar = (1.0 - self.alphas_cumprod)[timesteps].sqrt()
        sqrt_one_minus_alpha_bar = sqrt_one_minus_alpha_bar.view(-1, 1, 1, 1)

        noisy_samples = (
            sqrt_alpha_bar * original_samples + sqrt_one_minus_alpha_bar * noise
        )
        return noisy_samples

    def set_timesteps(
        self,
        num_inference_steps: int,
        device: str | torch.device = 'cpu',
    ):
        if num_inference_steps > self.num_train_timesteps:
            raise AssertionError(
                f'num_inference_steps must be in the range (0, {self.num_train_timesteps}].'
            )

        self.num_inference_steps = num_inference_steps
        self.timesteps = torch.linspace(
            self.num_train_timesteps - 1,
            0,
            num_inference_steps,
            dtype=torch.long,
            device=device,
        )

    def previous_timestep(self, timestep: int) -> int:
        if self.num_inference_steps != self.num_train_timesteps:
            index = (self.timesteps == timestep).float().argmax()
            if index == len(self.timesteps) - 1:
                prev = -1
            else:
                prev = int(self.timesteps[index + 1])
        else:
            prev = timestep - 1
        return prev

    def step(self, model_output: Tensor, timestep: int, sample: Tensor) -> Tensor:
        """Perform a single reverse diffusion step to compute the previous sample given the
        model's output, the current time step, and the current sample.

        Args:
            model_output (Tensor): The output from the diffusion model, which is typically
                the predicted noise component at the current time step.
            timestep (int): The current time step in the reverse diffusion process.
            sample (Tensor): The current noisy sample at the given time step.
        """
        t = timestep
        prev_t = self.previous_timestep(t)

        alpha_t = self.alphas[t]
        alpha_bar_t = self.alphas_cumprod[t]
        beta_t = self.betas[t]

        if prev_t >= 0:
            alpha_bar_prev_t = self.alphas_cumprod[prev_t]
        else:
            alpha_bar_prev_t = torch.tensor(1.0, device=sample.device)

        pred_original_sample = (
            sample - (1 - alpha_bar_t).sqrt() * model_output
        ) / alpha_bar_t.sqrt()

        param1 = alpha_bar_prev_t.sqrt() * beta_t / (1 - alpha_bar_t)
        param2 = alpha_t.sqrt() * (1 - alpha_bar_prev_t) / (1 - alpha_bar_t)
        mean = param1 * pred_original_sample + param2 * sample

        if prev_t >= 0:
            variance = (1 - alpha_bar_prev_t) / (1 - alpha_bar_t) * beta_t
        else:
            variance = torch.tensor(0.0, device=sample.device)

        if timestep > 0:
            noise = torch.randn_like(sample)
            prev_sample = mean + variance.sqrt() * noise
        else:
            prev_sample = mean

        return prev_sample
model = UNet2DModel().to(device)
scheduler = DDPMScheduler()

num_inference_steps = 100
scheduler.set_timesteps(num_inference_steps)
x = torch.randn(1, 3, 64, 64, device=device)

model.eval()
with torch.inference_mode():
    for timestep in scheduler.timesteps:
        timestep_batch = torch.full(
            (x.size(0),),
            int(timestep),
            dtype=torch.long,
            device=device,
        )
        model_output = model(x, timestep_batch)
        x = scheduler.step(model_output, int(timestep), x)

print(x.shape)
torch.Size([1, 3, 64, 64])

14.4.5 DDPM 训练与采样的伪代码

到这里,我们可以把整个 DDPM 的实现流程用伪代码写出来。

训练过程

训练时,每次只需要随机抽一个时间步:

DDPM training
for x0 in data:
    timestep = random.randint(1, T)
    epsilon = sample_gaussian_noise()
    xt = sqrt(alpha_bar_t) * x0 + sqrt(1 - alpha_bar_t) * epsilon

    pred_epsilon = model(xt, timestep)
    loss = mse_loss(pred_epsilon, epsilon)

    update(model)

采样过程

采样时则是从纯噪声开始,一步一步往回走:

DDPM sampling
x = sample_gaussian_noise()

for timestep in range(T - 1, -1, -1):
    pred_epsilon = model(x, timestep)
    x = reverse_step(x, pred_epsilon, timestep)

return x

其实从采样过程也不难发现,为什么 DDPM 的采样通常比较慢了。因为它不是一步生成,而是要跑很多步,每一步都要做一次 U-Net 前向传播,所以整体生成成本会比较高。但是,也正是因为这个逐步的过程,DDPM 才能生成质量非常高的图像。

14.4.6 本章小结

这一节我们从实现角度看了 DDPM 的网络结构与采样过程。可以把它概括成下面几句话:

  1. DDPM 的网络并不是直接生成图像,而是在每一步预测噪声;
  2. 模型必须知道当前时间步,因此时间 embedding 是关键组成部分;
  3. U-Net 很适合这种像素级、条件式、多尺度的去噪任务;
  4. 采样时从纯噪声开始,重复调用同一个去噪网络,一步一步得到最终图像。

到这里,DDPM 的怎么做已经基本完整了。但是,到目前为止,我们还没有真正从概率建模的角度去理解它。DDPM 的训练目标只是工程上的经验技巧吗?还是说,它其实能从一个严格的概率模型目标推导出来?

下一节,我们会重新回到概率建模的视角,看看 DDPM 的目标函数到底是怎么从 ELBO 一步一步推出来的,以及为什么最后会化简成我们现在看到的噪声预测损失。

References

Ronneberger, Olaf, Philipp Fischer, and Thomas Brox. 2015. U-Net: Convolutional Networks for Biomedical Image Segmentation. https://arxiv.org/abs/1505.04597.

Reuse