11.4 ViT Encoder:让 Patch Token 之间交换信息

Author

jshn9515

Published

2026-05-13

Modified

2026-05-19

上一节里,我们已经完成了 ViT 的输入构造。给定一张图像,ViT 会先把它切成 patch,经过 patch embedding 得到视觉 token,再加入 class token 和 positional embedding:

\[ (B, C, H, W) \rightarrow (B, N, D) \rightarrow (B, N+1, D) \]

这里的 \(N\) 是 patch token 的数量,额外的 1 来自 class token。

到这一步,图像已经被翻译成了 Transformer 熟悉的序列形式。但是,问题还没有结束。

Patch embedding 只是把每个局部图像块变成了一个向量。此时每个 patch token 主要表示自己对应的局部区域,并没有充分融合其他 patch 的信息。比如,一个表示猫耳朵的 patch,只知道自己附近的纹理;一个表示猫身体的 patch,也只知道自己对应的局部内容。对于图像分类来说,我们最终需要判断整张图的语义,而不是孤立地看某个 patch。

所以,ViT Encoder 要解决的问题是:

如何让这些 patch token 之间充分交换信息,并最终得到整张图像的表示?

ViT 的答案很直接:使用标准 Transformer Encoder。每一层 encoder block 由 self-attention、MLP、residual connection 和 LayerNorm 组成。Self-attention 负责让不同 patch 之间建立关系,MLP 负责对每个 token 的表示做非线性变换,残差连接和 LayerNorm 则让深层网络更容易训练。

这一节我们就从这个问题出发,逐步实现 ViT Encoder。

import dnnl.nn as dnn
import dnnl.models.vit as vit
import torch
import torch.nn as nn
from torch import Tensor

print('PyTorch version:', torch.__version__)
PyTorch version: 2.12.0+xpu

11.4.1 ViT Encoder 的输入输出

ViT Encoder 的输入是上一节得到的 token 序列:

\[ Z_0 \in \mathbb{R}^{B \times (N+1) \times D} \]

其中,\(B\) 是 batch size,\(N\) 是 patch token 数量,\(D\) 是 embedding dimension。这个序列包含了 \(N\) 个 patch token 和 1 个 class token。

Encoder 的输出仍然是一个同样长度的序列:

\[ Z_L \in \mathbb{R}^{B \times (N+1) \times D} \]

这里 \(L\) 表示 encoder block 的层数。

也就是说,ViT Encoder 并不会改变 token 数量,也不会改变每个 token 的维度:

\[ (B, N+1, D) \rightarrow (B, N+1, D) \]

它真正改变的是每个 token 的内容。输入时,每个 patch token 更偏向表示局部图像块;经过多层 self-attention 以后,每个 token 都可以融合其他 patch 的信息,变成上下文相关的视觉表示。

对于图像分类来说,我们最终通常只取 class token 的输出:

\[ h_\mathrm{cls} = Z_L[:, 0, :] \]

然后把它送入分类头:

\[ \mathrm{logits} = h_\mathrm{cls} W + b \]

所以,从整体上看,ViT Encoder 的作用可以理解为:

\[ \text{patch token sequence} \rightarrow \text{contextualized token sequence} \rightarrow \text{image representation} \]

其中,最后的 image representation 通常就是 class token 的输出。

11.4.2 为什么这里用 Encoder,而不是 Decoder?

在 Transformer 章节中,我们已经见过 encoder 和 decoder。那为什么 ViT 使用的是 Transformer Encoder,而不是 Transformer Decoder?

原因在于,图像分类不是一个自回归生成任务。对于语言生成,模型需要从左到右预测下一个 token,所以 decoder 里的 masked self-attention 必须限制当前位置只能看到过去的 token。否则,模型就会偷看到未来答案。

但图像分类不需要这种限制。一张图像里的所有 patch 都是同时给定的,模型可以让任意 patch 直接和任意其他 patch 交互。比如,左上角的 patch 可以关注右下角的 patch,中心的 patch 也可以关注边缘的 patch。这里没有未来 token 不能看的问题。

因此,ViT 使用的是 encoder-style 的双向 self-attention:每个 token 都可以关注所有 token。

这和 BERT 处理文本理解任务很相似。BERT 的输入是一整个句子,模型要理解句子中的每个 token;ViT 的输入是一整张图像的 patch 序列,模型要理解这些 patch 共同构成的视觉语义。

所以,ViT Encoder 里通常不需要 causal mask。所有 patch token 和 class token 都可以在 self-attention 中互相交互。

11.4.3 一个 ViT Encoder Block 的结构

一个标准的 ViT Encoder Block 可以写成:

\[ H = X + \operatorname{MultiheadSelfAttention}(\operatorname{LayerNorm}(X)) \]

\[ Y = H + \operatorname{MLP}(\operatorname{LayerNorm}(H)) \]

这个形式叫做 Pre-Norm Transformer,也就是先做 layer norm,再进入 attention 或 MLP。原始 ViT 使用的就是这种结构。这和原始 Transformer 的 Post-Norm 结构不同,后者是先进入 attention 或 MLP,再做 layer norm。

从模块顺序上看,一个 encoder block 的计算流程是:

图 1:ViT Encoder 模型结构

这和我们前面讲过的 Transformer Encoder 非常接近。只不过现在 token 不再是词,而是图像 patch。

直观来说,self-attention 解决的是 token 之间如何交换信息的问题,MLP 解决的是每个 token 自己如何进一步变换表示的问题。前者在序列维度上混合信息,后者在特征维度上加工信息。两者交替堆叠,就构成了 ViT Encoder 的主体。

11.4.3.1 Self-Attention:让 Patch 之间直接交互

假设输入序列是:

\[ X = [x_\mathrm{cls}, x_1, x_2, \dots, x_N] \]

其中,\(x_\mathrm{cls}\) 是 class token,\(x_i\) 是第 \(i\) 个 patch token。

在 self-attention 中,每个 token 都会生成自己的 query、key 和 value:

\[ Q = XW_Q, \quad K = XW_K, \quad V = XW_V \]

然后计算 token 两两之间的相关性:

\[ \operatorname{Attention}(Q, K, V) = \operatorname{softmax}\left(\frac{QK^\top}{\sqrt{d_k}}\right)V \]

对于 ViT 来说,attention 矩阵的形状是:

\[ (B, H, N+1, N+1) \]

其中,\(H\) 是 attention head 的数量。这个矩阵里的每一行表示某个 token 对其他所有 token 的关注程度。

这里最重要的是 class token。因为 class token 也参与 self-attention,所以它可以在每一层中从所有 patch token 读取信息:

\[ x_\mathrm{cls} \leftarrow \text{all patch tokens} \]

经过多层 encoder block 以后,class token 就逐渐变成一个全局图像表示。这就是为什么最后可以用它来做分类。

当然,patch token 之间也会互相交互。一个 patch 不再只是孤立地表示自己的局部区域,而是可以根据整张图像的上下文调整自己的表示。比如,一个局部纹理可能在不同上下文中含义不同:同样的黄色纹理,出现在动物身上可能是毛发,出现在背景里可能是沙地或戈壁。Self-attention 让模型可以利用其他 patch 来判断当前 patch 的语义。

11.4.3.2 MLP:对每个 Token 做特征变换

Self-attention 负责在 token 之间交换信息,但它并不是 encoder block 的全部。每个 encoder block 里还有一个 MLP,也叫 feed-forward network。

ViT 中的 MLP 通常由两个线性层和一个激活函数组成:

\[ \operatorname{MLP}(x) = W_2 \sigma(W_1 x + b_1) + b_2 \]

其中,\(\sigma\) 通常使用 GELU 激活函数。

默认情况下,第一个线性层会把维度从 \(D\) 扩展到一个更大的隐藏维度,常见设置是 \(4D\);第二个线性层再把维度投影回 \(D\)

\[ D \rightarrow 4D \rightarrow D \]

注意,MLP 是独立作用在每个 token 上面的。也就是说,它不会直接把不同 token 混在一起,而是对每个 token 的特征维度做非线性变换。

如果把输入看成形状为 \((B, N+1, D)\) 的张量,那么 MLP 的形状变化是:

\[ (B, N+1, D) \rightarrow (B, N+1, 4D) \rightarrow (B, N+1, D) \]

序列长度不变,embedding 维度最后也不变,只是 token 表示变得更有表达能力。

我们用 PyTorch 来实现这个 MLP:

class ViTMLP(nn.Module):
    def __init__(
        self,
        embed_dim: int,
        hidden_dim: int | None = None,
        dropout: float = 0.0,
    ):
        super().__init__()
        hidden_dim = hidden_dim or embed_dim * 4
        self.net = nn.Sequential(
            nn.Linear(embed_dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, embed_dim),
            nn.Dropout(dropout),
        )

    def forward(self, x: Tensor) -> Tensor:
        return self.net(x)

测试一下:

batch_size = 2
num_tokens = 16 + 1
embed_dim = 64

x = torch.randn(batch_size, num_tokens, embed_dim)
mlp = ViTMLP(embed_dim)
out = mlp(x)

print('Input shape:', x.shape)
print('Output shape:', out.shape)
Input shape: torch.Size([2, 17, 64])
Output shape: torch.Size([2, 17, 64])

可以看到,MLP 不改变最终输出形状。

11.4.4 实现 ViT Encoder Block

有了 MLP 以后,我们就可以实现一个完整的 ViT Encoder Block。

这里我们直接使用第 8 章实现的 MultiheadAttention。同时,我们采用 Pre-Norm 结构,也就是先做 layer norm,再进入 attention 或 MLP。

class ViTEncoderLayer(nn.Module):
    def __init__(
        self,
        embed_dim: int,
        num_heads: int = 12,
        hidden_dim: int | None = None,
        dropout: float = 0.0,
        attn_dropout: float = 0.0,
    ):
        super().__init__()
        hidden_dim = hidden_dim or embed_dim * 4

        self.norm1 = nn.LayerNorm(embed_dim)
        self.attn = dnn.MultiheadAttention(
            embed_dim=embed_dim,
            num_heads=num_heads,
            dropout=attn_dropout,
        )
        self.dropout1 = nn.Dropout(dropout)

        self.norm2 = nn.LayerNorm(embed_dim)
        self.mlp = ViTMLP(
            embed_dim=embed_dim,
            hidden_dim=hidden_dim,
            dropout=dropout,
        )

    def forward(self, x: Tensor) -> Tensor:
        attn_input = self.norm1(x)
        attn_output, _ = self.attn(
            attn_input,
            attn_input,
            attn_input,
            need_weights=False,
        )
        x = x + self.dropout1(attn_output)
        x = x + self.mlp(self.norm2(x))
        return x
1
这里的 MultiheadAttention 返回两个值:第一个是 attention 输出,第二个是 attention weights。一般情况下,我们只需要 attention 输出,所以设置 need_weights=False。如果想要可视化 attention,可以设置 need_weights=True

这段代码对应的正是前面的公式:

\[ H = X + \operatorname{MultiheadSelfAttention}(\operatorname{LayerNorm}(X)) \]

\[ Y = H + \operatorname{MLP}(\operatorname{LayerNorm}(H)) \]

测试一下形状:

batch_size = 2
num_tokens = 16 + 1
embed_dim = 64

block = ViTEncoderLayer(
    embed_dim=embed_dim,
    hidden_dim=embed_dim * 4,
    num_heads=8,
)

x = torch.randn(batch_size, num_tokens, embed_dim)
out = block(x)

print('Input shape:', x.shape)
print('Output shape:', out.shape)
Input shape: torch.Size([2, 17, 64])
Output shape: torch.Size([2, 17, 64])

可以看到,ViT Encoder Block 的输入和输出形状完全相同。这样我们就可以把多个 block 堆叠起来。

11.4.5 堆叠多个 Encoder Block

单个 encoder block 只能进行一次 self-attention 和一次 MLP 变换。为了让模型逐层构建更复杂的视觉表示,ViT 会堆叠多个 encoder block。

假设第 \(\ell\) 层的输入是 \(Z_{\ell-1}\),输出是 \(Z_\ell\),那么整个 encoder 可以写成:

\[ Z_\ell = \operatorname{EncoderBlock}_\ell(Z_{\ell-1}), \quad \ell = 1, 2, \dots, L \]

最后得到:

\[ Z_L = \operatorname{Encoder}(Z_0) \]

在实现上,我们可以用 nn.ModuleList 保存多个 block:

class ViTEncoder(nn.Module):
    def __init__(
        self,
        embed_dim: int,
        num_heads: int = 12,
        num_layers: int = 12,
        hidden_dim: int | None = None,
        dropout: float = 0.0,
        attn_dropout: float = 0.0,
    ):
        super().__init__()
        self.layers = nn.ModuleList(
            [
                ViTEncoderLayer(
                    embed_dim=embed_dim,
                    num_heads=num_heads,
                    hidden_dim=hidden_dim,
                    dropout=dropout,
                    attn_dropout=attn_dropout,
                )
                for _ in range(num_layers)
            ]
        )
        self.norm = nn.LayerNorm(embed_dim)

    def forward(self, x: Tensor) -> Tensor:
        for layer in self.layers:
            x = layer(x)
        x = self.norm(x)
        return x

最后这个 LayerNorm 是很多 ViT 实现中的常见做法。因为每个 block 内部使用的是 pre-norm,堆叠完所有 block 以后再做一次最终归一化,可以让输出表示更稳定。

测试一下:

num_tokens = 16 + 1
embed_dim = 64

encoder = ViTEncoder(
    embed_dim=embed_dim,
    num_heads=8,
    num_layers=4,
)

x = torch.randn(2, num_tokens, embed_dim)
out = encoder(x)

print('Input shape:', x.shape)
print('Encoder output shape:', out.shape)
Input shape: torch.Size([2, 17, 64])
Encoder output shape: torch.Size([2, 17, 64])

可以看到,即使堆叠多层 encoder,形状仍然保持为 (B, N+1, D)

11.4.6 从 Encoder 输出到分类结果

Encoder 输出的是完整的 token 序列:

\[ Z_L = [z_\mathrm{cls}, z_1, z_2, \dots, z_N] \]

对于分类任务,我们通常取第一个 token,也就是 class token 的输出:

\[ z_\mathrm{cls} = Z_L[:, 0, :] \]

然后接一个线性分类头:

\[ \mathrm{logits} = \operatorname{Linear}(z_\mathrm{cls}) \]

如果类别数是 \(K\),那么分类头的输出形状就是 (B, K)

下面用 PyTorch 实现一个最简单的分类头:

class ViTClassificationHead(nn.Module):
    def __init__(self, embed_dim: int, num_classes: int):
        super().__init__()
        self.head = nn.Linear(embed_dim, num_classes)

    def forward(self, x: Tensor) -> Tensor:
        cls_token = x[:, 0]
        logits = self.head(cls_token)
        return logits

测试一下:

num_classes = 10
num_tokens = 16 + 1
embed_dim = 64

head = ViTClassificationHead(embed_dim, num_classes)

x = torch.randn(2, num_tokens, embed_dim)
logits = head(x)

print('Encoder output shape:', x.shape)
print('Logits shape:', logits.shape)
Encoder output shape: torch.Size([2, 17, 64])
Logits shape: torch.Size([2, 10])

这样,我们就完成了从 encoder 输出到分类 logits 的转换。

当然,class token 不是唯一选择。也可以对所有 patch token 做平均池化:

\[ z_\mathrm{avg} = \frac{1}{N}\sum_{i=1}^{N} z_i \]

然后用平均后的向量做分类。很多后来的 ViT 变体也会使用这种方式。但在原始 ViT 中,最经典的做法仍然是使用 class token。

11.4.7 把输入模块、Encoder 和分类头连起来

现在,我们已经有了 ViT 的三个主要部分:

  1. 输入模块:把图像转换成带有 class token 和位置编码的 token 序列;
  2. ViT Encoder:用 self-attention 和 MLP 融合 token 信息;
  3. 分类头:取 class token 输出并映射到类别 logits。

我们用 PyTorch 实现一个完整的简化版 ViT:

class ViTForImageClassification(nn.Module):
    def __init__(
        self,
        image_size: int = 224,
        patch_size: int = 16,
        in_channels: int = 3,
        num_classes: int = 1000,
        embed_dim: int = 768,
        num_heads: int = 12,
        num_layers: int = 12,
        hidden_dim: int | None = None,
        dropout: float = 0.0,
        attn_dropout: float = 0.0,
    ):
        super().__init__()
        self.embedding = vit.ViTEmbedding(
            image_size=image_size,
            patch_size=patch_size,
            in_channels=in_channels,
            embed_dim=embed_dim,
            dropout=dropout,
        )
        self.encoder = ViTEncoder(
            embed_dim=embed_dim,
            num_heads=num_heads,
            num_layers=num_layers,
            hidden_dim=hidden_dim,
            dropout=dropout,
            attn_dropout=attn_dropout,
        )
        self.head = ViTClassificationHead(
            embed_dim=embed_dim,
            num_classes=num_classes,
        )

    def forward(self, x: Tensor) -> Tensor:
        x = self.embedding(x)
        x = self.encoder(x)
        logits = self.head(x)
        return logits

测试一下完整模型的输出形状:

model = ViTForImageClassification(
    image_size=32,
    patch_size=16,
    in_channels=3,
    num_classes=10,
    embed_dim=64,
    num_heads=8,
    num_layers=4,
)

x = torch.randn(2, 3, 32, 32)
logits = model(x)

print('Input shape:', x.shape)
print('Logits shape:', logits.shape)
Input shape: torch.Size([2, 3, 32, 32])
Logits shape: torch.Size([2, 10])

可以看到,ViT 的输入是 (B, C, H, W),输出是 (B, num_classes)。这就是一个最小的 ViT 分类模型。

11.4.8 本章小结

这一节我们完成了 ViT 的核心部分:ViT Encoder

ViT Encoder 的输入是上一节构造好的 token 序列:

\[ (B, N+1, D) \]

序列里包含 \(N\) 个 patch token 和 1 个 class token。Encoder 不改变序列长度和 embedding 维度,而是通过多层 Transformer Encoder Block 不断更新 token 表示:

\[ (B, N+1, D) \rightarrow (B, N+1, D) \]

每个 encoder block 主要包含两部分:

  1. Multi-Head Self-Attention:让 patch token 和 class token 之间交换信息;
  2. MLP:对每个 token 的特征维度做非线性变换。

同时,block 内部使用 layer norm 和 residual connection 来稳定训练:

\[ H = X + \operatorname{MultiheadSelfAttention}(\operatorname{LayerNorm}(X)) \]

\[ Y = H + \operatorname{MLP}(\operatorname{LayerNorm}(H)) \]

经过多层 encoder 以后,class token 的输出表示会融合整张图像的信息。最后,分类头取出 class token:

\[ h_\mathrm{cls} = Z_L[:, 0, :] \]

并将它映射成类别 logits:

\[ (B, D) \rightarrow (B, K) \]

至此,我们已经从输入一路走到了输出,这也构成了最基本的 Vision Transformer。

不过,ViT 的意义并不只在于图像分类。在很多实际应用中,ViT 更常被看作一个通用的视觉 backbone。它把图像转换成一组高层 token 表示,后面可以接分类头、检测头、分割头,甚至多模态模型中的视觉模块。

因此,ViT 的核心价值在于学习可迁移的视觉表示。为了让这些表示足够通用,ViT 通常会先在大规模数据集上预训练,再根据具体任务进行微调。下一节,我们就来讨论为什么 ViT 特别适合作为 backbone,以及预训练和微调如何让它迁移到更多视觉任务中。