import dnnl.nn as dnn
import dnnl.models.vit as vit
import torch
import torch.nn as nn
from torch import Tensor
print('PyTorch version:', torch.__version__)PyTorch version: 2.12.0+xpu
jshn9515
2026-05-13
2026-05-19
上一节里,我们已经完成了 ViT 的输入构造。给定一张图像,ViT 会先把它切成 patch,经过 patch embedding 得到视觉 token,再加入 class token 和 positional embedding:
\[ (B, C, H, W) \rightarrow (B, N, D) \rightarrow (B, N+1, D) \]
这里的 \(N\) 是 patch token 的数量,额外的 1 来自 class token。
到这一步,图像已经被翻译成了 Transformer 熟悉的序列形式。但是,问题还没有结束。
Patch embedding 只是把每个局部图像块变成了一个向量。此时每个 patch token 主要表示自己对应的局部区域,并没有充分融合其他 patch 的信息。比如,一个表示猫耳朵的 patch,只知道自己附近的纹理;一个表示猫身体的 patch,也只知道自己对应的局部内容。对于图像分类来说,我们最终需要判断整张图的语义,而不是孤立地看某个 patch。
所以,ViT Encoder 要解决的问题是:
如何让这些 patch token 之间充分交换信息,并最终得到整张图像的表示?
ViT 的答案很直接:使用标准 Transformer Encoder。每一层 encoder block 由 self-attention、MLP、residual connection 和 LayerNorm 组成。Self-attention 负责让不同 patch 之间建立关系,MLP 负责对每个 token 的表示做非线性变换,残差连接和 LayerNorm 则让深层网络更容易训练。
这一节我们就从这个问题出发,逐步实现 ViT Encoder。
PyTorch version: 2.12.0+xpu
ViT Encoder 的输入是上一节得到的 token 序列:
\[ Z_0 \in \mathbb{R}^{B \times (N+1) \times D} \]
其中,\(B\) 是 batch size,\(N\) 是 patch token 数量,\(D\) 是 embedding dimension。这个序列包含了 \(N\) 个 patch token 和 1 个 class token。
Encoder 的输出仍然是一个同样长度的序列:
\[ Z_L \in \mathbb{R}^{B \times (N+1) \times D} \]
这里 \(L\) 表示 encoder block 的层数。
也就是说,ViT Encoder 并不会改变 token 数量,也不会改变每个 token 的维度:
\[ (B, N+1, D) \rightarrow (B, N+1, D) \]
它真正改变的是每个 token 的内容。输入时,每个 patch token 更偏向表示局部图像块;经过多层 self-attention 以后,每个 token 都可以融合其他 patch 的信息,变成上下文相关的视觉表示。
对于图像分类来说,我们最终通常只取 class token 的输出:
\[ h_\mathrm{cls} = Z_L[:, 0, :] \]
然后把它送入分类头:
\[ \mathrm{logits} = h_\mathrm{cls} W + b \]
所以,从整体上看,ViT Encoder 的作用可以理解为:
\[ \text{patch token sequence} \rightarrow \text{contextualized token sequence} \rightarrow \text{image representation} \]
其中,最后的 image representation 通常就是 class token 的输出。
在 Transformer 章节中,我们已经见过 encoder 和 decoder。那为什么 ViT 使用的是 Transformer Encoder,而不是 Transformer Decoder?
原因在于,图像分类不是一个自回归生成任务。对于语言生成,模型需要从左到右预测下一个 token,所以 decoder 里的 masked self-attention 必须限制当前位置只能看到过去的 token。否则,模型就会偷看到未来答案。
但图像分类不需要这种限制。一张图像里的所有 patch 都是同时给定的,模型可以让任意 patch 直接和任意其他 patch 交互。比如,左上角的 patch 可以关注右下角的 patch,中心的 patch 也可以关注边缘的 patch。这里没有未来 token 不能看的问题。
因此,ViT 使用的是 encoder-style 的双向 self-attention:每个 token 都可以关注所有 token。
这和 BERT 处理文本理解任务很相似。BERT 的输入是一整个句子,模型要理解句子中的每个 token;ViT 的输入是一整张图像的 patch 序列,模型要理解这些 patch 共同构成的视觉语义。
所以,ViT Encoder 里通常不需要 causal mask。所有 patch token 和 class token 都可以在 self-attention 中互相交互。
一个标准的 ViT Encoder Block 可以写成:
\[ H = X + \operatorname{MultiheadSelfAttention}(\operatorname{LayerNorm}(X)) \]
\[ Y = H + \operatorname{MLP}(\operatorname{LayerNorm}(H)) \]
这个形式叫做 Pre-Norm Transformer,也就是先做 layer norm,再进入 attention 或 MLP。原始 ViT 使用的就是这种结构。这和原始 Transformer 的 Post-Norm 结构不同,后者是先进入 attention 或 MLP,再做 layer norm。
从模块顺序上看,一个 encoder block 的计算流程是:
这和我们前面讲过的 Transformer Encoder 非常接近。只不过现在 token 不再是词,而是图像 patch。
直观来说,self-attention 解决的是 token 之间如何交换信息的问题,MLP 解决的是每个 token 自己如何进一步变换表示的问题。前者在序列维度上混合信息,后者在特征维度上加工信息。两者交替堆叠,就构成了 ViT Encoder 的主体。
假设输入序列是:
\[ X = [x_\mathrm{cls}, x_1, x_2, \dots, x_N] \]
其中,\(x_\mathrm{cls}\) 是 class token,\(x_i\) 是第 \(i\) 个 patch token。
在 self-attention 中,每个 token 都会生成自己的 query、key 和 value:
\[ Q = XW_Q, \quad K = XW_K, \quad V = XW_V \]
然后计算 token 两两之间的相关性:
\[ \operatorname{Attention}(Q, K, V) = \operatorname{softmax}\left(\frac{QK^\top}{\sqrt{d_k}}\right)V \]
对于 ViT 来说,attention 矩阵的形状是:
\[ (B, H, N+1, N+1) \]
其中,\(H\) 是 attention head 的数量。这个矩阵里的每一行表示某个 token 对其他所有 token 的关注程度。
这里最重要的是 class token。因为 class token 也参与 self-attention,所以它可以在每一层中从所有 patch token 读取信息:
\[ x_\mathrm{cls} \leftarrow \text{all patch tokens} \]
经过多层 encoder block 以后,class token 就逐渐变成一个全局图像表示。这就是为什么最后可以用它来做分类。
当然,patch token 之间也会互相交互。一个 patch 不再只是孤立地表示自己的局部区域,而是可以根据整张图像的上下文调整自己的表示。比如,一个局部纹理可能在不同上下文中含义不同:同样的黄色纹理,出现在动物身上可能是毛发,出现在背景里可能是沙地或戈壁。Self-attention 让模型可以利用其他 patch 来判断当前 patch 的语义。
Self-attention 负责在 token 之间交换信息,但它并不是 encoder block 的全部。每个 encoder block 里还有一个 MLP,也叫 feed-forward network。
ViT 中的 MLP 通常由两个线性层和一个激活函数组成:
\[ \operatorname{MLP}(x) = W_2 \sigma(W_1 x + b_1) + b_2 \]
其中,\(\sigma\) 通常使用 GELU 激活函数。
默认情况下,第一个线性层会把维度从 \(D\) 扩展到一个更大的隐藏维度,常见设置是 \(4D\);第二个线性层再把维度投影回 \(D\):
\[ D \rightarrow 4D \rightarrow D \]
注意,MLP 是独立作用在每个 token 上面的。也就是说,它不会直接把不同 token 混在一起,而是对每个 token 的特征维度做非线性变换。
如果把输入看成形状为 \((B, N+1, D)\) 的张量,那么 MLP 的形状变化是:
\[ (B, N+1, D) \rightarrow (B, N+1, 4D) \rightarrow (B, N+1, D) \]
序列长度不变,embedding 维度最后也不变,只是 token 表示变得更有表达能力。
我们用 PyTorch 来实现这个 MLP:
class ViTMLP(nn.Module):
def __init__(
self,
embed_dim: int,
hidden_dim: int | None = None,
dropout: float = 0.0,
):
super().__init__()
hidden_dim = hidden_dim or embed_dim * 4
self.net = nn.Sequential(
nn.Linear(embed_dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, embed_dim),
nn.Dropout(dropout),
)
def forward(self, x: Tensor) -> Tensor:
return self.net(x)测试一下:
Input shape: torch.Size([2, 17, 64])
Output shape: torch.Size([2, 17, 64])
可以看到,MLP 不改变最终输出形状。
有了 MLP 以后,我们就可以实现一个完整的 ViT Encoder Block。
这里我们直接使用第 8 章实现的 MultiheadAttention。同时,我们采用 Pre-Norm 结构,也就是先做 layer norm,再进入 attention 或 MLP。
class ViTEncoderLayer(nn.Module):
def __init__(
self,
embed_dim: int,
num_heads: int = 12,
hidden_dim: int | None = None,
dropout: float = 0.0,
attn_dropout: float = 0.0,
):
super().__init__()
hidden_dim = hidden_dim or embed_dim * 4
self.norm1 = nn.LayerNorm(embed_dim)
self.attn = dnn.MultiheadAttention(
embed_dim=embed_dim,
num_heads=num_heads,
dropout=attn_dropout,
)
self.dropout1 = nn.Dropout(dropout)
self.norm2 = nn.LayerNorm(embed_dim)
self.mlp = ViTMLP(
embed_dim=embed_dim,
hidden_dim=hidden_dim,
dropout=dropout,
)
def forward(self, x: Tensor) -> Tensor:
attn_input = self.norm1(x)
attn_output, _ = self.attn(
attn_input,
attn_input,
attn_input,
need_weights=False,
)
x = x + self.dropout1(attn_output)
x = x + self.mlp(self.norm2(x))
return xMultiheadAttention 返回两个值:第一个是 attention 输出,第二个是 attention weights。一般情况下,我们只需要 attention 输出,所以设置 need_weights=False。如果想要可视化 attention,可以设置 need_weights=True。
这段代码对应的正是前面的公式:
\[ H = X + \operatorname{MultiheadSelfAttention}(\operatorname{LayerNorm}(X)) \]
\[ Y = H + \operatorname{MLP}(\operatorname{LayerNorm}(H)) \]
测试一下形状:
Input shape: torch.Size([2, 17, 64])
Output shape: torch.Size([2, 17, 64])
可以看到,ViT Encoder Block 的输入和输出形状完全相同。这样我们就可以把多个 block 堆叠起来。
单个 encoder block 只能进行一次 self-attention 和一次 MLP 变换。为了让模型逐层构建更复杂的视觉表示,ViT 会堆叠多个 encoder block。
假设第 \(\ell\) 层的输入是 \(Z_{\ell-1}\),输出是 \(Z_\ell\),那么整个 encoder 可以写成:
\[ Z_\ell = \operatorname{EncoderBlock}_\ell(Z_{\ell-1}), \quad \ell = 1, 2, \dots, L \]
最后得到:
\[ Z_L = \operatorname{Encoder}(Z_0) \]
在实现上,我们可以用 nn.ModuleList 保存多个 block:
class ViTEncoder(nn.Module):
def __init__(
self,
embed_dim: int,
num_heads: int = 12,
num_layers: int = 12,
hidden_dim: int | None = None,
dropout: float = 0.0,
attn_dropout: float = 0.0,
):
super().__init__()
self.layers = nn.ModuleList(
[
ViTEncoderLayer(
embed_dim=embed_dim,
num_heads=num_heads,
hidden_dim=hidden_dim,
dropout=dropout,
attn_dropout=attn_dropout,
)
for _ in range(num_layers)
]
)
self.norm = nn.LayerNorm(embed_dim)
def forward(self, x: Tensor) -> Tensor:
for layer in self.layers:
x = layer(x)
x = self.norm(x)
return x最后这个 LayerNorm 是很多 ViT 实现中的常见做法。因为每个 block 内部使用的是 pre-norm,堆叠完所有 block 以后再做一次最终归一化,可以让输出表示更稳定。
测试一下:
Input shape: torch.Size([2, 17, 64])
Encoder output shape: torch.Size([2, 17, 64])
可以看到,即使堆叠多层 encoder,形状仍然保持为 (B, N+1, D)。
Encoder 输出的是完整的 token 序列:
\[ Z_L = [z_\mathrm{cls}, z_1, z_2, \dots, z_N] \]
对于分类任务,我们通常取第一个 token,也就是 class token 的输出:
\[ z_\mathrm{cls} = Z_L[:, 0, :] \]
然后接一个线性分类头:
\[ \mathrm{logits} = \operatorname{Linear}(z_\mathrm{cls}) \]
如果类别数是 \(K\),那么分类头的输出形状就是 (B, K)。
下面用 PyTorch 实现一个最简单的分类头:
测试一下:
Encoder output shape: torch.Size([2, 17, 64])
Logits shape: torch.Size([2, 10])
这样,我们就完成了从 encoder 输出到分类 logits 的转换。
当然,class token 不是唯一选择。也可以对所有 patch token 做平均池化:
\[ z_\mathrm{avg} = \frac{1}{N}\sum_{i=1}^{N} z_i \]
然后用平均后的向量做分类。很多后来的 ViT 变体也会使用这种方式。但在原始 ViT 中,最经典的做法仍然是使用 class token。
现在,我们已经有了 ViT 的三个主要部分:
我们用 PyTorch 实现一个完整的简化版 ViT:
class ViTForImageClassification(nn.Module):
def __init__(
self,
image_size: int = 224,
patch_size: int = 16,
in_channels: int = 3,
num_classes: int = 1000,
embed_dim: int = 768,
num_heads: int = 12,
num_layers: int = 12,
hidden_dim: int | None = None,
dropout: float = 0.0,
attn_dropout: float = 0.0,
):
super().__init__()
self.embedding = vit.ViTEmbedding(
image_size=image_size,
patch_size=patch_size,
in_channels=in_channels,
embed_dim=embed_dim,
dropout=dropout,
)
self.encoder = ViTEncoder(
embed_dim=embed_dim,
num_heads=num_heads,
num_layers=num_layers,
hidden_dim=hidden_dim,
dropout=dropout,
attn_dropout=attn_dropout,
)
self.head = ViTClassificationHead(
embed_dim=embed_dim,
num_classes=num_classes,
)
def forward(self, x: Tensor) -> Tensor:
x = self.embedding(x)
x = self.encoder(x)
logits = self.head(x)
return logits测试一下完整模型的输出形状:
Input shape: torch.Size([2, 3, 32, 32])
Logits shape: torch.Size([2, 10])
可以看到,ViT 的输入是 (B, C, H, W),输出是 (B, num_classes)。这就是一个最小的 ViT 分类模型。
这一节我们完成了 ViT 的核心部分:ViT Encoder。
ViT Encoder 的输入是上一节构造好的 token 序列:
\[ (B, N+1, D) \]
序列里包含 \(N\) 个 patch token 和 1 个 class token。Encoder 不改变序列长度和 embedding 维度,而是通过多层 Transformer Encoder Block 不断更新 token 表示:
\[ (B, N+1, D) \rightarrow (B, N+1, D) \]
每个 encoder block 主要包含两部分:
同时,block 内部使用 layer norm 和 residual connection 来稳定训练:
\[ H = X + \operatorname{MultiheadSelfAttention}(\operatorname{LayerNorm}(X)) \]
\[ Y = H + \operatorname{MLP}(\operatorname{LayerNorm}(H)) \]
经过多层 encoder 以后,class token 的输出表示会融合整张图像的信息。最后,分类头取出 class token:
\[ h_\mathrm{cls} = Z_L[:, 0, :] \]
并将它映射成类别 logits:
\[ (B, D) \rightarrow (B, K) \]
至此,我们已经从输入一路走到了输出,这也构成了最基本的 Vision Transformer。
不过,ViT 的意义并不只在于图像分类。在很多实际应用中,ViT 更常被看作一个通用的视觉 backbone。它把图像转换成一组高层 token 表示,后面可以接分类头、检测头、分割头,甚至多模态模型中的视觉模块。
因此,ViT 的核心价值在于学习可迁移的视觉表示。为了让这些表示足够通用,ViT 通常会先在大规模数据集上预训练,再根据具体任务进行微调。下一节,我们就来讨论为什么 ViT 特别适合作为 backbone,以及预训练和微调如何让它迁移到更多视觉任务中。