12.1 生成对抗网络(GAN)
在前面的章节里,我们接触过很多判别式模型。无论是卷积神经网络,还是 Transformer,它们大多都在做同一类事情:给定输入,预测输出。例如,输入一张图片,判断它属于哪一类;输入一句话,预测下一个词是什么。
但深度学习里还有另一类非常重要的问题:生成。
我们不再满足于判断这是不是猫,而是希望模型能够生成一张猫的图片;不再只是识别一段语音的内容,而是希望模型能够合成一段自然的语音。换句话说,模型不仅要学会区分数据,还要学会模仿数据背后的分布。
GAN (Generative Adversarial Network) (Goodfellow et al. 2014) 就是这一方向中最经典的方法之一。它的想法很直观:不是只训练一个生成模型,而是同时训练两个彼此对抗的模型。一个模型负责生成样本,另一个模型负责判断样本是真是假。生成器想尽办法骗过判别器,而判别器则想尽办法识别出伪造样本。就在这种不断对抗的过程中,生成器逐渐学会如何生成越来越逼真的数据。
这一节我们先把 GAN 最基本的四个问题理清楚:
- 生成模型到底在学什么?
- GAN 的核心思想是什么?
- GAN 的目标函数怎么理解?
- GAN 的训练流程是怎样的?
只要把这四个问题想清楚,后面的 DCGAN、WGAN、Conditional GAN 等变体其实都是在这个基础上做改进。
12.1.1 生成模型在学什么
在理解 GAN 之前,我们先要回答一个更基础的问题:生成模型到底在学什么?
如果是分类模型,我们通常很容易描述它的目标。比如输入一张图片 \(x\),输出类别 \(y\),本质上是在学习一个映射:
\[ x \rightarrow y \]
但生成模型不一样。它不是要告诉我们这个样本属于什么类别,而是要回答:真实数据看起来大致服从一个什么分布?我们能不能从这个分布中再采样出新的样本?
例如,我们手里有很多真实人脸图片,生成模型希望学到的不是这张脸是谁,而是人脸图片整体上是怎样分布的。一旦学到了这个分布,我们就可以从中采样,生成一张之前从未出现过、但看起来很真实的新脸。
所以,从更抽象的角度看,生成模型想学的是 数据分布(data distribution)。
假设真实数据来自分布 \(p_{data}(x)\),那么生成模型的目标就是构造另一个分布 \(p_g(x)\),让它尽可能接近真实分布:
\[ p_g(x) \approx p_{data}(x) \]
这里的关键不在于逐个记住训练样本,而在于学会训练样本背后的统计规律。只有这样,模型生成出来的内容才不会只是简单复制,而是能够产生新的、合理的样本。
不过这里立刻会遇到一个现实问题:真实数据分布通常非常复杂。比如一张自然图像包含纹理、边缘、形状、光照、背景等大量因素,我们几乎不可能直接把这个分布写成一个显式公式。因此,很多生成模型会采用一种间接做法:先从一个简单分布中采样,再通过一个神经网络把它映射到复杂数据空间中。后面要讲的扩散模型本质上也是在做类似的事情。
例如,我们先从一个高斯分布中采样一个随机向量 \(z\):
\[ z \sim p(z) \]
然后通过一个神经网络 \(G\) 把它变成样本:
\[ x = G(z) \]
这里的 \(z\) 通常叫做 潜变量(latent variable)。它本身没有直接的图像语义,但经过神经网络变换之后,就可能被展开成一张图像、一段语音,或者其他复杂数据。
所以,生成模型的核心任务可以概括成一句话:从一个简单分布出发,通过可学习的映射,构造一个尽量接近真实数据分布的模型。GAN 正是实现这个目标的一种非常有代表性的方式,也是第一个被成功训练出的高质量生成模型。
12.1.2 GAN 的核心思想
理解 GAN,最好的办法不是一上来就看公式,而是先建立一个直观类比。你可以把 GAN 想成一场“造假者”和“鉴别者”的对抗。生成器(Generator) 是一个造假者,负责伪造样本;而 判别器(Discriminator) 是一个鉴别专家,负责判断样本是真是假。
于是,两个网络的目标就天然相反。判别器想把真假样本区分开,而生成器想让自己造出来的样本以假乱真,骗过判别器。这就是 GAN 名字里“对抗”的由来。不得不说,这种想法非常巧妙:通过让两个模型互相竞争,迫使生成器不断改进,最终学会生成高质量的样本。但这也为 GAN 的训练带来了一定挑战。因为它不是一个单纯的优化问题,而是一个动态博弈的过程。
从结构上看,GAN 非常简单:
- 从一个简单分布(如高斯分布)中采样噪声 \(z\);
- 用生成器得到假样本 \(x_{fake} = G(z)\);
- 把真实样本 \(x\) 和假样本 \(G(z)\) 都送入判别器;
- 判别器分别输出这两个样本是真的还是假的;
- 根据判别结果,更新两个网络。
我们可以想象一下两个模型之间的博弈:
- 一开始生成器很弱,生成的样本很粗糙;
- 判别器很容易看穿它;
- 但随着训练进行,生成器逐渐学会更像真的伪装;
- 判别器也被迫学会更细致的判断标准;
- 最终,在理想情况下,生成器生成出来的样本会和真实样本非常接近,以至于判别器几乎分不出来。
所以,GAN 的核心思想不是直接生成一张图像,而是通过生成器和判别器的对抗,让生成分布逐渐逼近真实分布。
12.1.3 GAN 的目标函数
有了前面的直觉之后,我们再来看 GAN 的目标函数就会自然很多。
GAN 最经典的目标函数写作:
\[ \min_G \max_D V(D, G) = \mathbb{E}_{x \sim p_{data}(x)}[\log D(x)] + \mathbb{E}_{z \sim p(z)}[\log(1 - D(G(z)))] \]
这个式子看起来有点复杂,但其实只是在同时描述判别器和生成器各自想做什么。
12.1.3.1 判别器的目标
先看判别器。它希望对真实样本 \(x\),输出的 \(D(x)\) 尽量接近 1;而对假样本 \(G(z)\),输出的 \(D(G(z))\) 尽量接近 0。
因此,判别器会最大化:
\[ \mathbb{E}_{x \sim p_{data}(x)}[\log D(x)] + \mathbb{E}_{z \sim p(z)}[\log(1 - D(G(z)))] \]
为什么这样写?
- 当 \(D(x)\) 越接近 1 时,\(\log D(x)\) 越大;
- 当 \(D(G(z))\) 越接近 0 时,\(\log(1 - D(G(z)))\) 越大。
所以,这个目标函数本质上就是在鼓励判别器把真样本判成真,把假样本判成假。
12.1.3.2 生成器的目标
再看生成器。生成器并不能直接看到真实分布,它只能通过判别器给出的反馈来改进自己。
从原始目标函数出发,生成器希望最小化:
\[ \mathbb{E}_{z \sim p(z)}[\log(1 - D(G(z)))] \]
这等价于希望:
\[ D(G(z)) \rightarrow 1 \]
也就是说,生成器的目标是让判别器相信它生成出来的样本是真的。
所以,GAN 的训练本质上是一个 极小极大(min-max)问题。判别器想把真假区分开,而生成器想让真假分不开。
12.1.3.3 为什么这个目标有意义
从直觉上说,如果生成器生成出来的样本和真实数据差别很大,那么判别器就会很容易把它们区分开;反过来,如果判别器已经很难分辨真假,那么就说明生成样本已经非常接近真实样本。
从理论上讲,在固定生成器时,最优判别器可以写成:
\[ D^*(x)=\frac{p_{data}(x)}{p_{data}(x)+p_g(x)} \]
其中,\(p_{data}(x)\) 是真实数据分布;\(p_g(x)\) 是生成器生成出来的分布。
当 GAN 达到理想平衡时,我们有:
\[ p_g(x) = p_{data}(x) \]
这时判别器对任意样本都只能给出 0.5 左右的判断,因为真假已经几乎无法区分。
所以,GAN 的目标函数虽然看起来只是“欺骗”和“识别”,但背后真正推动的是:让生成分布逐渐逼近真实分布。
12.1.3.4 一个实践上的小补充
在原始 GAN 中,生成器最小化的是:
\[ \log(1 - D(G(z))) \]
但在实际训练时,这个目标函数在判别器很强时容易导致梯度太弱。因此更常见的做法是改成最大化
\[ \log D(G(z))\]
或者最小化
\[ -\log D(G(z)) \]
这被称为 非饱和损失(non-saturating loss)。它不会改变“生成器想骗过判别器”这个本质,只是能让训练初期的梯度信号更强、更稳定。
12.1.4 GAN 的训练流程
理解了结构和目标函数之后,我们最后来看 GAN 是怎么训练的。
GAN 的训练不是一次同时把两个网络一起更新完,而是交替训练:
- 先训练判别器;
- 再训练生成器;
- 两者不断交替,直到达到某种平衡。
第一步:训练判别器
在这一步里,我们暂时固定生成器,只更新判别器。
具体做法是:
- 从真实数据集中采样一批真实样本 \(x\);
- 从噪声分布中采样一批随机向量 \(z\);
- 用生成器产生假样本 \(G(z)\);
- 把真实样本和假样本一起送入判别器;
- 用“真实样本标签为 1、假样本标签为 0”的方式训练判别器。
这一步的目的是让判别器学会更好地区分真假。
第二步:训练生成器
接着,我们固定判别器,只更新生成器。
这时候仍然先从噪声分布中采样一批 \(z\),生成假样本 \(G(z)\),再把这些假样本送入判别器。但这里有一个关键点:虽然判别器会参与前向传播,给出真假判断,可是真正更新的不是判别器,而是生成器。
也就是说,判别器只是提供一个“你生成得像不像真的”评分。生成器根据这个评分反向传播,更新自己的参数。从优化角度看,生成器是在借用判别器的梯度作为自己的学习信号。
交替优化的整体流程
把两步合起来,一个典型的 GAN 训练循环可以写成:
GAN training loop
for each training step:
# 1. sample real data
x_real = sample_from_dataset()
# 2. sample latent noise
z = sample_noise()
# 3. generate fake data
x_fake = G(z)
# 4. update discriminator
loss_D = - [log D(x_real) + log(1 - D(x_fake))]
optimize(D, loss_D)
# 5. sample new noise
z = sample_noise()
# 6. update generator
x_fake = G(z)
loss_G = - log D(x_fake)
optimize(G, loss_G)这段伪代码里最值得注意的是:判别器和生成器虽然连在一起,但并不是在同一个目标下同步更新。
这就导致 GAN 的训练和普通监督学习很不一样。普通模型通常只有一个固定目标,而 GAN 的目标是动态变化的。因为你刚更新完判别器,生成器面对的对手就变了;你刚更新完生成器,判别器看到的数据分布也变了。
12.1.5 为什么 GAN 难训练
GAN 的核心思想是让生成器和判别器互相对抗,推动生成分布逐渐逼近真实分布。但也正因为如此,GAN 的训练往往不稳定。
- 如果判别器太强,生成器可能拿不到足够有用的梯度;
- 如果判别器太弱,生成器学到的反馈又不可靠;
- 两者稍微失衡,就可能出现震荡甚至模式崩塌。
所以,GAN 最难的地方不是模型结构复杂,而是博弈过程难以平衡。
不过从更高的视角看,GAN 的整个训练流程始终围绕着同一件事:
通过交替训练一个生成器和一个判别器,让生成样本逐步逼近真实样本。
只要抓住这一点,后面的各种 GAN 变体本质上都只是对这个基本流程做不同方向的改进。