import torch
import torch.nn.functional as F
from torch import Tensor
print('PyTorch version:', torch.__version__)PyTorch version: 2.11.0+xpu
jshn9515
2026-03-25
2026-04-04
上一节里,我们已经知道了 VAE 的基本建模思路:
到这里,VAE 的结构已经清楚了。但还有一个问题没有解决:
训练 VAE 时,我们到底在优化什么?
为什么最后会出现那个经典的损失函数:一项负责重建,一项负责 KL 散度正则化。它并不是凭经验拍脑袋写出来的,而是从一个很自然的目标一步一步推出来的。
这一节,我们就来回答这个问题。
注意,这一节会包含很多公式推导,请做好心理准备。
PyTorch version: 2.11.0+xpu
从生成模型的角度看,我们的最终目标始终是:让模型学到真实数据的分布。
如果一个训练样本是 \(x\),那么我们希望模型给它较大的概率,也就是希望最大化:
\[ p_\theta(x) \]
在训练中,通常写成最大化对数似然:
\[ \log p_\theta(x) \]
所以,VAE 的根本目标其实是最大化 \(\log p_\theta(x)\)。这点非常重要,它说明 VAE 并不是为了重建好看才设计损失函数的。它本质上仍然是一个概率生成模型,优化目标依然是数据似然。
在前一节里我们知道,
\[ \log p_\theta(x) = \log \int p(z)p_\theta(x\mid z)\,dz \]
这一步看起来很正常,但实际很难。因为这里有一个对所有 \(z\) 的积分,而 decoder 是神经网络,通常没有解析解。我们的目标是对的,但它本身不好直接计算。所以我们需要换个思路:不直接优化 \(\log p_\theta(x)\),而是找一个容易计算、又和它关系紧密的下界来优化。
这个下界,就是 ELBO(Evidence Lower Bound,证据下界)。
前面我们已经说过,真实后验
\[ p_\theta(z\mid x) \]
很难直接求。所以 VAE 引入了一个由 encoder 参数化的近似分布:
\[ q_\phi(z\mid x) \]
现在我们来做一个很关键的操作:在 \(\log p_\theta(x)\) 里乘一个 \(q_\phi(z\mid x)\),再除一个 \(q_\phi(z\mid x)\)。
因为它们相消,所以值不变:
\[ \log p_\theta(x) = \log \int q_\phi(z\mid x)\frac{p_\theta(x,z)}{q_\phi(z\mid x)}\,dz \]
把积分写成对 \(q_\phi(z\mid x)\) 的期望形式:
\[ \log p_\theta(x) = \log \mathbb{E}_{q_\phi(z\mid x)} \left[\frac{p_\theta(x,z)}{q_\phi(z\mid x)}\right] \]
到这里,ELBO 的推导入口就出现了。接下来用到一个经典工具:Jensen 不等式。
由于 \(\log\) 是凹函数,所以有:
\[ \log \mathbb{E}[Y] \ge \mathbb{E}[\log Y] \]
把这里的 \(Y\) 替换成
\[ \frac{p_\theta(x,z)}{q_\phi(z\mid x)} \]
就得到:
\[ \log p_\theta(x) = \log \mathbb{E}_{q_\phi(z\mid x)} \left[\frac{p_\theta(x,z)}{q_\phi(z\mid x)}\right] \ge \mathbb{E}_{q_\phi(z\mid x)} \left[\log \frac{p_\theta(x,z)}{q_\phi(z\mid x)}\right] \]
于是我们定义右边这一项为 ELBO:
\[ \mathcal{L}(\theta,\phi;x) = \mathbb{E}_{q_\phi(z\mid x)} \left[\log \frac{p_\theta(x,z)}{q_\phi(z\mid x)}\right] \]
于是就有:
\[ \mathcal{L}(\theta,\phi; x) \le \log p_\theta(x) \]
这就是“证据下界”这个名字的来源:
所以,VAE 并不是直接最大化 \(\log p_\theta(x)\),而是最大化它的一个可计算下界:
\[ \max \mathcal{L}(\theta,\phi; x) \]
你可能会问:优化一个下界,真的能帮助我们优化原来的目标吗?
答案是肯定的。因为:
\[ \mathcal{L}(\theta,\phi;x) \le \log p_\theta(x) \]
如果我们把这个下界不断抬高,那么至少说明模型给数据 \(x\) 的解释能力在增强。而且更妙的是,这个下界和真实目标之间的差距,其实正好可以写成一个 KL 散度。下面我们来推这个式子。
从贝叶斯公式出发:
\[ p_\theta(z\mid x)=\frac{p_\theta(x,z)}{p_\theta(x)} \]
取对数:
\[ \log p_\theta(z\mid x)=\log p_\theta(x,z)-\log p_\theta(x) \]
整理一下:
\[ \log p_\theta(x)=\log p_\theta(x,z)-\log p_\theta(z\mid x) \]
现在对两边同时取 \(q_\phi(z\mid x)\) 下的期望:
\[ \log p_\theta(x) = \mathbb{E}_{q_\phi(z\mid x)} [\log p_\theta(x,z)-\log p_\theta(z\mid x)] \]
再人为加上并减去 \(\log q_\phi(z\mid x)\):
\[ \log p_\theta(x) = \mathbb{E}_{q_\phi(z\mid x)} \left[\log p_\theta(x,z)-\log q_\phi(z\mid x)\right] + \mathbb{E}_{q_\phi(z\mid x)} \left[\log q_\phi(z\mid x)-\log p_\theta(z\mid x)\right] \]
前一项正是 ELBO:
\[ \mathcal{L}(\theta,\phi;x) = \mathbb{E}_{q_\phi(z\mid x)} \left[\log p_\theta(x,z)-\log q_\phi(z\mid x)\right] \]
后一项正是 KL 散度:
\[ D_{\mathrm{KL}}(q_\phi(z\mid x)\,\|\,p_\theta(z\mid x)) \]
所以得到一个非常重要的关系:
\[ \log p_\theta(x) = \mathcal{L}(\theta,\phi;x) + D_{\mathrm{KL}}(q_\phi(z\mid x)\,\|\,p_\theta(z\mid x)) \]
因为 KL 散度总是非负的,所以:
\[ \mathcal{L}(\theta,\phi;x) \le \log p_\theta(x) \]
这不仅说明 ELBO 是下界,还说明了:
最大化 ELBO,一方面是在提高对数似然,另一方面也是在让近似后验 \(q_\phi(z\mid x)\) 更接近真实后验 \(p_\theta(z\mid x)\)。
这就非常漂亮了。也就是说,VAE 一次训练,同时做了学生成模型和学后验推断两件事。
虽然上面的式子已经很漂亮,但训练时更常见的是另一种拆法。
从 ELBO 的定义出发:
\[ \mathcal{L}(\theta,\phi;x) = \mathbb{E}_{q_\phi(z\mid x)} \left[\log \frac{p_\theta(x,z)}{q_\phi(z\mid x)}\right] \]
把联合分布拆开:
\[ p_\theta(x,z)=p(z)p_\theta(x\mid z) \]
代入得到:
\[ \mathcal{L}(\theta,\phi;x) = \mathbb{E}_{q_\phi(z\mid x)} \left[\log p_\theta(x\mid z)+\log p(z)-\log q_\phi(z\mid x)\right] \]
把期望拆开:
\[ \mathcal{L}(\theta,\phi;x) = \mathbb{E}_{q_\phi(z\mid x)}[\log p_\theta(x\mid z)] + \mathbb{E}_{q_\phi(z\mid x)}[\log p(z)-\log q_\phi(z\mid x)] \]
后面这一项正好可以写成负的 KL 散度:
\[ \mathbb{E}_{q_\phi(z\mid x)}[\log p(z)-\log q_\phi(z\mid x)] = - D_{\mathrm{KL}}(q_\phi(z\mid x)\,\|\,p(z)) \]
所以最终得到 VAE 最经典的目标函数形式:
\[ \mathcal{L}(\theta,\phi;x) = \mathbb{E}_{q_\phi(z\mid x)}[\log p_\theta(x\mid z)] - D_{\mathrm{KL}}(q_\phi(z\mid x)\,\|\,p(z)) \]
这就是你在各种资料里最常见的 ELBO 表达式。也就是我们在前一节讲的 重建项 + KL 正则项 的那个式子。
在实际训练代码里,我们通常写成最小化损失,所以会取负号:
\[ \mathcal{J}_{\text{VAE}} = -\mathcal{L}(\theta,\phi;x) \]
于是:
\[ \mathcal{J}_{\text{VAE}} = -\mathbb{E}_{q_\phi(z\mid x)}[\log p_\theta(x\mid z)] + D_{\mathrm{KL}}(q_\phi(z\mid x)\,\|\,p(z)) \]
这里的第一项就是重建误差,第二项就是 KL 正则化。所以在很多实现里我们会看到:
\[ \text{loss} = \text{reconstruction loss} + \text{KL loss} \]
这就是 负 ELBO 的工程写法。
如果只看公式,VAE 的训练像是在同时优化两个项。但实际上,VAE 的训练更像是在做一种相互之间的拉扯。
一边,重建项要求不要丢掉和输入 \(x\) 有关的信息,因为我们得把它尽量准确地还原出来。另一边,KL 项要求不要把每个样本都藏到潜空间某个很偏僻的角落,我们得让它们整体贴近标准正态,保持规整和平滑。因此,如果我们只管重建,而不约束分布,那么潜空间就会变得很混乱,退回到普通 AE 的问题;如果我们只管靠近先验,而不保留与输入 \(x\) 有关的信息,那么 decoder 就无法重建输入了。
所以,VAE 在表达能力和潜空间规整性之间找到了一个平衡点。
这里还要补充一个很容易让人困惑的问题。
我们前面写的重建项是:
\[ \mathbb{E}_{q_\phi(z\mid x)}[\log p_\theta(x\mid z)] \]
它是一个 对数似然项。但代码里常看到的是 MSE 或 BCE,这两者怎么对应起来?
关键在于:我们如何假设 \(p_\theta(x\mid z)\) 的形式。
情况一:把输出看成高斯分布
如果假设:
\[ p_\theta(x\mid z)=\mathcal{N}(x;\hat{x},\sigma^2 I) \]
那么最大化对数似然基本等价于最小化 MSE:
\[ \|x-\hat{x}\|^2 \]
所以连续值重建时,经常使用 MSE。
情况二:把输出看成 Bernoulli 分布
如果假设像素值是 0 到 1 之间的概率,并令:
\[ p_\theta(x\mid z) \]
服从逐像素 Bernoulli 分布,那么最大化对数似然就对应 BCE:
\[ -\sum_i [x_i\log \hat{x}_i + (1-x_i)\log(1-\hat{x}_i)] \]
因此,代码里使用什么重建损失,并不是随便选的,而是在对应你对 \(p_\theta(x\mid z)\) 的概率建模假设。
VAE 最常见的设定是:
在这种情况下,KL 散度有闭式解:
\[ D_{\mathrm{KL}}(q_\phi(z\mid x)\,\|\,p(z)) = \frac{1}{2}\sum_{j=1}^d \left(\mu_j^2 + \sigma_j^2 - \log \sigma_j^2 - 1\right) \]
如果代码里输出的是 \(\text{logvar} = \log \sigma^2\),那么经常写成:
\[ D_{\mathrm{KL}} = -\frac{1}{2}\sum_{j=1}^d \left(1 + \log \sigma_j^2 - \mu_j^2 - \sigma_j^2\right) \]
这两个式子是等价的。
这就是为什么 VAE 训练起来比数学推导简单很多。因为重建项可以直接算,KL 项也有闭式公式,中间采样用重参数化技巧解决,于是整个模型可以端到端训练。
下面给一个常见的 PyTorch 风格损失写法,和前面 13.2 的代码可以直接接起来。
这是 BCE 版本:
def vae_bce_loss(x_hat: Tensor, x: Tensor, mu: Tensor, logvar: Tensor) -> Tensor:
# Reconstruction loss using binary cross-entropy
re_loss = F.binary_cross_entropy(x_hat, x, reduction='sum')
# KL divergence loss between the approximate posterior and the prior
kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
return re_loss + kl_loss如果是 MSE 版本,重建项就换成 MSE 就行了:
def vae_mse_loss(x_hat: Tensor, x: Tensor, mu: Tensor, logvar: Tensor) -> Tensor:
# Reconstruction loss using mean squared error
re_loss = F.mse_loss(x_hat, x, reduction='sum')
# KL divergence loss between the approximate posterior and the prior
kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
return re_loss + kl_loss可能你就会说了,我们推导了半天,怎么代码就这么几行?
其实这就是数学推导的魅力所在。虽然公式看起来很复杂,但它们背后的逻辑是非常清晰的。我们从一个很自然的目标出发,经过一系列合理的变换,最终得到了一个既有理论意义又实用的损失函数。这就是 ELBO 的力量,也是 VAE 设计的精妙之处。
现在,我们终于可以回答这一节开头的问题:VAE 的目标函数从哪里来?
答案是:
第一,VAE 本质上仍然想最大化数据的对数似然:
\[ \log p_\theta(x) \]
第二,由于这个量难以直接计算,我们引入近似后验 \(q_\phi(z\mid x)\),并通过 Jensen 不等式构造了一个可优化的下界:
\[ \mathcal{L}(\theta,\phi;x) = \mathbb{E}_{q_\phi(z\mid x)}[\log p_\theta(x\mid z)] - D_{\mathrm{KL}}(q_\phi(z\mid x)\,\|\,p(z)) \]
第三,训练时通常最小化负 ELBO,于是就得到熟悉的形式:重建损失 + KL 正则项。
到这里,VAE 最核心的数学已经齐了:
在下一节里,我们来看看,这个目标函数在训练中会带来什么现象?为什么 VAE 生成的图像往往更平滑,但有时会偏模糊?KL 太强或太弱时会发生什么?