12.1 Generative Adversarial Network (GAN)
In the previous chapters, we have encountered many discriminative models. Whether it is a convolutional neural network or a Transformer, most of them are doing the same type of thing: given an input, predict an output. For example, input an image and determine which class it belongs to; input a sentence and predict what the next word is.
But there is another very important type of problem in deep learning: generation.
We are no longer satisfied with judging whether this is a cat, but hope that the model can generate an image of a cat; we are no longer only recognizing the content of a piece of speech, but hope that the model can synthesize a piece of natural speech. In other words, the model not only needs to learn to distinguish data, but also needs to learn to imitate the distribution behind the data.
GAN (Generative Adversarial Network) (Goodfellow et al. 2014) is one of the most classic methods in this direction. Its idea is very intuitive: instead of training only one generative model, it trains two models that compete with each other at the same time. One model is responsible for generating samples, and the other model is responsible for judging whether the samples are real or fake. The generator tries every way to fool the discriminator, while the discriminator tries every way to identify forged samples. In this continuous adversarial process, the generator gradually learns how to generate more and more realistic data.
In this section, we first clarify the four most basic questions about GAN:
- What exactly is a generative model learning?
- What is the core idea of GAN?
- How should we understand the objective function of GAN?
- What is the training process of GAN like?
As long as these four questions are clear, later variants such as DCGAN, WGAN, and Conditional GAN are essentially improvements made on this basis.
12.1.1 What a generative model is learning
Before understanding GAN, we first need to answer a more basic question: what exactly is a generative model learning?
For a classification model, we can usually describe its goal very easily. For example, input an image \(x\) and output a class \(y\). In essence, it is learning a mapping:
\[ x \rightarrow y \]
But a generative model is different. It does not want to tell us what class this sample belongs to. Instead, it wants to answer: what kind of distribution does the real data roughly follow? Can we sample new samples from this distribution again?
For example, suppose we have many real face images. What a generative model wants to learn is not who this face is, but how face images are distributed as a whole. Once it has learned this distribution, we can sample from it and generate a new face that has never appeared before but looks very real.
So, from a more abstract perspective, what a generative model wants to learn is the data distribution.
Assume the real data comes from the distribution \(p_{data}(x)\). Then the goal of the generative model is to construct another distribution \(p_g(x)\) and make it as close as possible to the real distribution:
\[ p_g(x) \approx p_{data}(x) \]
The key here is not to memorize the training samples one by one, but to learn the statistical regularities behind the training samples. Only in this way will the content generated by the model not be simple copying, but be able to produce new and reasonable samples.
However, a practical problem immediately appears here: the real data distribution is usually very complex. For example, a natural image contains many factors such as textures, edges, shapes, lighting, and background. It is almost impossible for us to directly write this distribution as an explicit formula. Therefore, many generative models use an indirect method: first sample from a simple distribution, and then use a neural network to map it into the complex data space. The diffusion models that will be discussed later are essentially doing a similar thing.
For example, we first sample a random vector \(z\) from a Gaussian distribution:
\[ z \sim p(z) \]
Then we use a neural network \(G\) to turn it into a sample:
\[ x = G(z) \]
Here, \(z\) is usually called a latent variable. It itself does not have direct image semantics, but after being transformed by a neural network, it may be unfolded into an image, a piece of speech, or other complex data.
So, the core task of a generative model can be summarized in one sentence: starting from a simple distribution, construct a model that is as close as possible to the real data distribution through a learnable mapping. GAN is a very representative way to achieve this goal, and it is also the first high-quality generative model that was successfully trained.
12.1.2 The core idea of GAN
To understand GAN, the best way is not to look at formulas right away, but to first build an intuitive analogy. You can think of GAN as a competition between a “counterfeiter” and an “identifier.” The Generator is a counterfeiter, responsible for forging samples; while the Discriminator is an identification expert, responsible for judging whether a sample is real or fake.
So, the goals of the two networks are naturally opposite. The discriminator wants to distinguish real and fake samples, while the generator wants the samples it creates to pass as real and fool the discriminator. This is where the word “adversarial” in the name GAN comes from. It has to be said that this idea is very clever: by letting two models compete with each other, the generator is forced to continuously improve and finally learn to generate high-quality samples. But this also brings certain challenges to GAN training. Because it is not a simple optimization problem, but a dynamic game process.
Structurally, GAN is very simple:
- Sample noise \(z\) from a simple distribution, such as a Gaussian distribution;
- Use the generator to obtain a fake sample \(x_{fake} = G(z)\);
- Send both the real sample \(x\) and the fake sample \(G(z)\) into the discriminator;
- The discriminator outputs whether each of the two samples is real or fake;
- Update the two networks according to the discrimination results.
We can imagine the game between the two models:
- At the beginning, the generator is very weak, and the samples it generates are very rough;
- The discriminator can easily see through it;
- But as training continues, the generator gradually learns to disguise itself more realistically;
- The discriminator is also forced to learn more detailed judgment criteria;
- Finally, in the ideal case, the samples generated by the generator will be very close to real samples, so close that the discriminator can barely tell them apart.
So, the core idea of GAN is not to directly generate an image, but to let the generated distribution gradually approach the real distribution through the adversarial process between the generator and the discriminator.
12.1.3 The objective function of GAN
After having the intuition above, looking at the objective function of GAN becomes much more natural.
The most classic objective function of GAN is written as:
\[ \min_G \max_D V(D, G) = \mathbb{E}_{x \sim p_{data}(x)}[\log D(x)] + \mathbb{E}_{z \sim p(z)}[\log(1 - D(G(z)))] \]
This formula looks a bit complicated, but actually it is only describing what the discriminator and the generator each want to do at the same time.
12.1.3.1 The discriminator’s objective
First look at the discriminator. For a real sample \(x\), it wants the output \(D(x)\) to be as close to 1 as possible; for a fake sample \(G(z)\), it wants the output \(D(G(z))\) to be as close to 0 as possible.
Therefore, the discriminator maximizes:
\[ \mathbb{E}_{x \sim p_{data}(x)}[\log D(x)] + \mathbb{E}_{z \sim p(z)}[\log(1 - D(G(z)))] \]
Why is it written this way?
- When \(D(x)\) is closer to 1, \(\log D(x)\) is larger;
- When \(D(G(z))\) is closer to 0, \(\log(1 - D(G(z)))\) is larger.
So, this objective function is essentially encouraging the discriminator to classify real samples as real and fake samples as fake.
12.1.3.2 The generator’s objective
Now look at the generator. The generator cannot directly see the real distribution. It can only improve itself through the feedback given by the discriminator.
Starting from the original objective function, the generator wants to minimize:
\[ \mathbb{E}_{z \sim p(z)}[\log(1 - D(G(z)))] \]
This is equivalent to hoping that:
\[ D(G(z)) \rightarrow 1 \]
That is, the goal of the generator is to make the discriminator believe that the samples it generated are real.
So, GAN training is essentially a min-max problem. The discriminator wants to distinguish real from fake, while the generator wants real and fake to become indistinguishable.
12.1.3.3 Why this objective is meaningful
Intuitively speaking, if the samples generated by the generator are very different from the real data, then the discriminator can easily distinguish them; conversely, if the discriminator already has difficulty distinguishing real from fake, then it means that the generated samples are already very close to the real samples.
Theoretically, when the generator is fixed, the optimal discriminator can be written as:
\[ D^*(x)=\frac{p_{data}(x)}{p_{data}(x)+p_g(x)} \]
Here, \(p_{data}(x)\) is the real data distribution; \(p_g(x)\) is the distribution generated by the generator.
When GAN reaches the ideal equilibrium, we have:
\[ p_g(x) = p_{data}(x) \]
At this time, the discriminator can only give a judgment of around 0.5 for any sample, because real and fake are almost impossible to distinguish.
So, although the objective function of GAN looks like only “deception” and “identification,” what it is really pushing behind the scenes is: let the generated distribution gradually approach the real distribution.
12.1.3.4 A small practical supplement
In the original GAN, the generator minimizes:
\[ \log(1 - D(G(z))) \]
But in actual training, when the discriminator is very strong, this objective function can easily lead to gradients that are too weak. Therefore, a more common practice is to change it to maximizing
\[ \log D(G(z)) \]
or minimizing
\[ -\log D(G(z)) \]
This is called the non-saturating loss. It does not change the essence that “the generator wants to fool the discriminator,” but it can make the gradient signal stronger and more stable at the early stage of training.
12.1.4 The training process of GAN
After understanding the structure and the objective function, we finally look at how GAN is trained.
GAN training does not update the two networks together once at the same time. Instead, it trains them alternately:
- First train the discriminator;
- Then train the generator;
- The two keep alternating until they reach some kind of equilibrium.
Step 1: train the discriminator
In this step, we temporarily fix the generator and only update the discriminator.
The specific method is:
- Sample a batch of real samples \(x\) from the real dataset;
- Sample a batch of random vectors \(z\) from the noise distribution;
- Use the generator to produce fake samples \(G(z)\);
- Send the real samples and fake samples into the discriminator together;
- Train the discriminator in the way that “real sample labels are 1, fake sample labels are 0.”
The purpose of this step is to let the discriminator learn to better distinguish real from fake.
Step 2: train the generator
Next, we fix the discriminator and only update the generator.
At this time, we still first sample a batch of \(z\) from the noise distribution, generate fake samples \(G(z)\), and then send these fake samples into the discriminator. But there is a key point here: although the discriminator participates in the forward propagation and gives a real/fake judgment, what is actually updated is not the discriminator, but the generator.
That is, the discriminator only provides a score of “how real your generated result looks.” The generator backpropagates according to this score and updates its own parameters. From an optimization perspective, the generator is borrowing the discriminator’s gradient as its own learning signal.
The overall process of alternating optimization
Putting the two steps together, a typical GAN training loop can be written as:
GAN training loop
for each training step:
# 1. sample real data
x_real = sample_from_dataset()
# 2. sample latent noise
z = sample_noise()
# 3. generate fake data
x_fake = G(z)
# 4. update discriminator
loss_D = - [log D(x_real) + log(1 - D(x_fake))]
optimize(D, loss_D)
# 5. sample new noise
z = sample_noise()
# 6. update generator
x_fake = G(z)
loss_G = - log D(x_fake)
optimize(G, loss_G)The point most worth noticing in this pseudocode is: although the discriminator and generator are connected together, they are not updated synchronously under the same objective.
This causes GAN training to be very different from ordinary supervised learning. Ordinary models usually have only one fixed objective, while the objective of GAN is dynamically changing. Because once you update the discriminator, the opponent faced by the generator has changed; once you update the generator, the data distribution seen by the discriminator has also changed.
12.1.5 Why GAN is difficult to train
The core idea of GAN is to let the generator and discriminator compete with each other, pushing the generated distribution to gradually approach the real distribution. But precisely because of this, GAN training is often unstable.
- If the discriminator is too strong, the generator may not get enough useful gradients;
- If the discriminator is too weak, the feedback learned by the generator is not reliable;
- If the two are slightly out of balance, oscillation or even mode collapse may occur.
So, the most difficult part of GAN is not that the model structure is complex, but that the game process is difficult to balance.
However, from a higher-level perspective, the entire training process of GAN always revolves around the same thing:
By alternately training a generator and a discriminator, make the generated samples gradually approach the real samples.
As long as you grasp this point, the various GAN variants later are essentially just improvements to this basic process in different directions.