import dnnl.models.vit as vit
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
print('PyTorch version:', torch.__version__)PyTorch version: 2.12.0+xpu
jshn9515
2026-05-13
2026-05-13
上一节里,我们已经把图像从原始张量
\[ (B, C, H, W) \]
转换成了 patch token 序列
\[ (B, N, D) \]
其中,\(N\) 是 patch 数量,\(D\) 是每个 token 的 embedding 维度。
到这一步,图像已经可以被看成一串视觉 token,并且可以送入 Transformer Encoder。
但是,这还不是 ViT 的完整输入。我们还有两个问题没有解决。
第一个问题是:如果要做图像分类,最后应该用哪个 token 表示整张图?Patch token 本身表示的是局部图像块,而分类任务需要的是整张图像的语义表示。
第二个问题是:Transformer 本身并不知道每个 patch 来自图像的哪个位置。Patch embedding 把图像切成序列以后,如果不额外加入位置信息,模型很难区分一个 patch 是来自左上角,还是来自右下角。
所以,在把 patch token 送入 Transformer Encoder 之前,ViT 还会做两件事:
这一节我们就来讨论这两个设计。
PyTorch version: 2.12.0+xpu
经过 patch embedding 以后,一张图像会变成一个 patch 序列:
\[ X = [x_1, x_2, \dots, x_N] \]
其中,每个 \(x_i \in \mathbb{R}^D\) 表示一个图像 patch。
如果后面接的是 Transformer Encoder,那么输出仍然是一个序列:
\[ H = [h_1, h_2, \dots, h_N] \]
这里每个 \(h_i\) 都是对应 patch 的上下文表示。它不再只包含第 \(i\) 个 patch 自己的信息,而是通过 self-attention 融合了其他 patch 的信息。
但对于图像分类来说,我们最后需要的是一个整张图像的向量表示,而不是 \(N\) 个 patch 向量。也就是说,我们需要从 (B, N, D) 的输出中得到一个 (B, D) 的图像表示。
一个很自然的做法是对所有 patch token 做平均池化:
\[ h_\mathrm{avg} = \frac{1}{N}\sum_{i=1}^N h_i \]
这当然是可行的。事实上,很多后来的视觉 Transformer 变体也会使用平均池化或者类似的聚合方式。
但原始 ViT 采用的是另一种设计:在 patch token 前面加入一个特殊的可学习 token,叫做 class token,也常写作 [CLS]。
Class token 的想法来自 BERT 这类 Transformer 模型。在文本分类中,模型会在输入序列前面加入一个特殊 token,例如 [CLS]。经过 Transformer 编码以后,这个 token 的输出表示会被用作整个序列的表示。ViT 借用了这个思路。
假设 patch token 序列是:
\[ [x_1, x_2, \dots, x_N] \]
ViT 会在它前面加一个可学习向量 \(x_\mathrm{cls}\):
\[ X = [x_\mathrm{cls}, x_1, x_2, \dots, x_N] \]
这样,序列长度就从 \(N\) 变成了 \(N + 1\)。
这个 class token 一开始只是一个普通的可学习参数。它本身不对应图像中的某个 patch,也不来自输入图像。它的作用是在 Transformer Encoder 中和所有 patch token 进行 self-attention 交互。
经过多层 Transformer Encoder 以后,class token 的输出表示可以看作整张图像的信息汇聚结果:
\[ h_\mathrm{cls} = H[:, 0, :] \]
然后把这个向量送入分类头,就可以得到图像分类 logits:
\[ \mathrm{logits} = h_\mathrm{cls} W + b \]
直观来说,class token 就像是一个全局位置。它不代表某个局部图像块,而是专门用来从所有 patch token 中收集信息,最后形成整张图像的表示。
在实现上,class token 通常是一个形状为 (1, 1, embed_dim) 的可学习参数:
class ViTAddClassToken(nn.Module):
def __init__(self, embed_dim: int):
super().__init__()
cls_token = torch.zeros(1, 1, embed_dim)
self.cls_token = nn.Parameter(cls_token)
nn.init.trunc_normal_(self.cls_token, std=0.02)
def forward(self, x: Tensor) -> Tensor:
cls_token = self.cls_token.expand(x.size(0), -1, -1)
x = torch.concat([cls_token, x], dim=1)
return x这里有两个细节值得注意。
第一个细节是 cls_token 的形状是 (1, 1, D),而不是 (B, 1, D)。因为它是模型参数,不应该为每个 batch 单独存一份。真正前向传播时,我们再用 expand 把它扩展到当前 batch size:
第二个细节是 expand 不会真正复制参数数据,它只是创建一个广播后的视图。这样所有样本使用的是同一个可学习 class token,但在计算图中会根据 batch 中所有样本的梯度一起更新它。
我们测试一下形状变化:
Input shape: torch.Size([2, 16, 64])
Output shape: torch.Size([2, 17, 64])
输入形状是 (B, N, D),输出形状变成 (B, N + 1, D)。多出来的第一个 token 就是 class token。
加入 class token 以后,序列已经可以表示整张图了。但我们还有另外一个问题:patch token 的位置信息在哪里?
Patch embedding 得到的是一串向量:
\[ [x_1, x_2, \dots, x_N] \]
虽然我们在展开 patch 时通常会按照从左到右、从上到下的顺序排列,但 self-attention 本身并不会天然理解这个顺序的含义。
这一点和文本 Transformer 中的问题是一样的。Self-attention 在计算时主要看 token 向量之间的相似度。如果没有额外的位置信息,它并不知道某个 token 是第一个位置还是第十个位置。
对于图像来说,位置甚至更加重要。一个 patch 在左上角和在右下角,语义可能完全不同。比如天空通常在图像上方,草地通常在图像下方;人脸中的眼睛、鼻子、嘴也有相对稳定的空间关系。如果模型完全不知道 patch 的位置,就很难利用这些空间结构。
因此,ViT 会给每个 token 加上一个可学习的位置向量:
\[ Z = X + P \]
其中,\(X \in \mathbb{R}^{B \times (N+1) \times D}\) 是加入 class token 之后的序列,\(P \in \mathbb{R}^{1 \times (N+1) \times D}\) 是 position embedding。
注意这里的位置数量是 \(N + 1\),因为除了 \(N\) 个 patch token 之外,class token 也需要一个自己的位置 embedding。
原始 ViT 使用的是可学习的绝对位置嵌入。也就是说,模型会直接学习一个参数矩阵:
\[ P = [p_\mathrm{cls}, p_1, p_2, \dots, p_N] \]
其中 \(p_\mathrm{cls}\) 是 class token 的位置嵌入,\(p_i\) 是第 \(i\) 个 patch 位置的位置嵌入。
然后把它加到 token embedding 上:
\[ Z = X + P \]
实现起来非常简单:
class ViTPositionalEmbedding(nn.Module):
def __init__(
self,
embed_dim: int,
num_patches: int,
use_cls_token: bool = True,
):
super().__init__()
self.use_cls_token = use_cls_token
num_tokens = num_patches + int(use_cls_token)
pos_embed = torch.zeros(1, num_tokens, embed_dim)
self.pos_embed = nn.Parameter(pos_embed)
nn.init.trunc_normal_(self.pos_embed, std=0.02)
def forward(self, x: Tensor) -> Tensor:
if x.size(1) != self.pos_embed.size(1):
raise AssertionError(
f'Expected sequence length {self.pos_embed.size(1)}, '
f'but got {x.size(1)}.'
)
return x + self.pos_embed测试一下:
num_patches = 16
embed_dim = 64
x = torch.randn(2, num_patches, embed_dim)
x = ViTAddClassToken(embed_dim)(x)
pos_embed = ViTPositionalEmbedding(
num_patches=num_patches,
embed_dim=embed_dim,
use_cls_token=True,
)
out = pos_embed(x)
print('After class token:', x.shape)
print('After positional embedding:', out.shape)After class token: torch.Size([2, 17, 64])
After positional embedding: torch.Size([2, 17, 64])
可以看到,加入 positional embedding 不会改变张量形状。它只是给每个 token 的表示加上一个和位置有关的偏移量。
也就是说:
\[ (B, N+1, D) \rightarrow (B, N+1, D) \]
维度没有变化,但 token 表示中已经包含了位置信息。
在 ViT 里,常见的输入构造顺序是:
\[ \text{image} \rightarrow \text{patch embedding} \rightarrow \text{append class token} \rightarrow \text{add positional embedding} \]
也就是:
\[ (B, C, H, W) \rightarrow (B, N, D) \rightarrow (B, N+1, D) \rightarrow (B, N+1, D) \]
为什么通常是先加 class token,再加 positional embedding?
因为 positional embedding 需要和最终输入序列中的每个 token 一一对应。如果先给 patch token 加位置,再拼接 class token,那么 class token 还需要单独处理自己的位置嵌入。先把 class token 拼进去,再统一加一个长度为 \(N+1\) 的 position embedding,会更直接。
下面我们把上一节的 patch embedding 和这一节讲的 class token、positional embedding 合在一起,得到一个完整的 ViT 输入模块。
class ViTEmbedding(nn.Module):
def __init__(
self,
image_size: int = 224,
patch_size: int = 16,
in_channels: int = 3,
embed_dim: int = 768,
dropout: float = 0.0,
):
super().__init__()
self.patch_embed = vit.ViTConvPatchEmbedding(
image_size=image_size,
patch_size=patch_size,
in_channels=in_channels,
embed_dim=embed_dim,
)
self.add_cls_token = ViTAddClassToken(embed_dim)
self.pos_embed = ViTPositionalEmbedding(
num_patches=self.patch_embed.num_patches,
embed_dim=embed_dim,
use_cls_token=True,
)
self.dropout = nn.Dropout(dropout)
def forward(self, x: Tensor) -> Tensor:
x = self.patch_embed(x)
x = self.add_cls_token(x)
x = self.pos_embed(x)
x = self.dropout(x)
return xDropout 是可选的。原始 ViT 在输入模块后面没有加 dropout,但很多后来的变体会在这里加一个 dropout 层,用来增加输入的随机性,提升模型的泛化能力。
测试一下形状:
ViT embedding output shape: torch.Size([2, 17, 64])
Number of patch tokens: 16
这里图像大小是 \(32 \times 32\),patch size 是 \(8 \times 8\),所以每张图像有 16 个 patch token。加上 class token 后,序列长度变成 17。因此输出形状是 (B, 17, D)。
到这里,可能大家还有一个容易疑惑的地方:
图像原本是二维网格,但 ViT 的 positional embedding 看起来是一维的。这样会不会丢失二维空间结构?
答案是不会。虽然 ViT 使用的是一维 position embedding,但这并不意味着它完全丢掉了图像的二维位置信息。因为 patch 序列是按照固定顺序从二维网格展平得到的。
比如一个 \(14 \times 14\) 的 patch 网格,会按照行优先顺序展平成 196 个 patch token。这样一来,序列中的第 \(i\) 个位置就总是对应原图中的某一个固定 patch 坐标。因此,position embedding 的形状虽然是 \((1, N+1, D)\),看起来只是在给一维序列加位置向量,但每个位置向量背后其实都对应着一个二维 patch 位置。
不过,这种可学习绝对位置嵌入也有一个限制:它和训练时的 patch 网格大小绑定得比较紧。比如模型训练时使用 \(224 \times 224\) 图像和 \(16 \times 16\) patch,那么 patch 网格是 \(14 \times 14\),位置嵌入长度是 197;如果推理时换成更高分辨率的图像,例如 \(384 \times 384\),patch 网格就变成 \(24 \times 24\),位置嵌入长度就变成 577。这时原来的 positional embedding 长度就对不上了。
那该怎么办呢?
一个常见做法是对 patch 部分的 positional embedding 做二维插值。Class token 的位置嵌入单独保留,patch positional embedding 先还原成二维网格,再插值到新的网格大小,最后重新展平成序列。
具体做法是:
(1, old_h, old_w, D)。(1, new_h, new_w, D)。(1, new_h * new_w, D)。下面是一个简化版实现:
def interpolate_pos_embedding(
pos_embed: Tensor,
old_grid_size: tuple[int, int],
new_grid_size: tuple[int, int],
) -> Tensor:
cls_pos_embed = pos_embed[:, :1]
patch_pos_embed = pos_embed[:, 1:]
old_h, old_w = old_grid_size
new_h, new_w = new_grid_size
embed_dim = pos_embed.size(-1)
patch_pos_embed = patch_pos_embed.reshape(1, old_h, old_w, embed_dim)
patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
patch_pos_embed = F.interpolate(
patch_pos_embed,
size=(new_h, new_w),
mode='bicubic',
align_corners=False,
)
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1)
patch_pos_embed = patch_pos_embed.reshape(1, new_h * new_w, embed_dim)
new_embed = torch.concat([cls_pos_embed, patch_pos_embed], dim=1)
return new_embedalign_corners 控制插值时新旧网格的对齐方式。align_corners=True 把像素看成离散的点,并强行对齐新旧网格四个角上的点,因此角落位置的值会被保留;align_corners=False 则把像素看成一个个小方格,对齐的是整张图像的外围边界,而不是角落像素的中心。这样所有位置会按整体比例重新分布,更接近普通图像缩放的默认行为。
测试一下长度变化:
# First dimension must be 1 because it's a parameter,
# not a batch of embeddings.
pos_embed = torch.randn(1, 14 * 14 + 1, 64)
new_pos_embed = interpolate_pos_embedding(
pos_embed,
old_grid_size=(14, 14),
new_grid_size=(24, 24),
)
print('Old positional embedding shape:', pos_embed.shape)
print('New positional embedding shape:', new_pos_embed.shape)Old positional embedding shape: torch.Size([1, 197, 64])
New positional embedding shape: torch.Size([1, 577, 64])
可以看到,原来的 positional embedding 长度是 197(196 个 patch + 1 个 class token),新的 positional embedding 长度是 577(576 个 patch + 1 个 class token)。这样就可以适应不同分辨率的输入了。
这一节我们补上了 ViT 输入模块中的两个关键设计:class token 和 position embedding。
Patch embedding 只能把图像变成 patch token 序列:
\[ (B, C, H, W) \rightarrow (B, N, D) \]
但图像分类需要一个整张图像的表示,所以 ViT 会在序列前面加入一个可学习的 class token:
\[ (B, N, D) \rightarrow (B, N + 1, D) \]
经过 Transformer Encoder 后,class token 的输出表示会被用作整张图像的表示,并送入分类头。
同时,由于 self-attention 本身不包含位置信息,ViT 还会给每个 token 加上可学习的位置嵌入:
\[ Z = X + P \]
这样,模型不仅知道每个 token 表示哪个 patch 的内容,也能通过 positional embedding 感知这个 patch 在图像中的位置。
所以,完整的 ViT 输入模块的构建流程就是:
\[ \text{image} \rightarrow \text{patch embedding} \rightarrow \text{class token} \rightarrow \text{positional embedding} \]
到这里,我们就完成了 ViT 输入模块的构建。下一节我们将进入 ViT Encoder 的实现,利用 self-attention,把 patch token 和 class token 的信息融合在一起,得到整张图像的表示。