11.3 Class Token 与 Positional Embedding:让序列表示整张图

Author

jshn9515

Published

2026-05-13

Modified

2026-05-13

上一节里,我们已经把图像从原始张量

\[ (B, C, H, W) \]

转换成了 patch token 序列

\[ (B, N, D) \]

其中,\(N\) 是 patch 数量,\(D\) 是每个 token 的 embedding 维度。

到这一步,图像已经可以被看成一串视觉 token,并且可以送入 Transformer Encoder。

但是,这还不是 ViT 的完整输入。我们还有两个问题没有解决。

第一个问题是:如果要做图像分类,最后应该用哪个 token 表示整张图?Patch token 本身表示的是局部图像块,而分类任务需要的是整张图像的语义表示。

第二个问题是:Transformer 本身并不知道每个 patch 来自图像的哪个位置。Patch embedding 把图像切成序列以后,如果不额外加入位置信息,模型很难区分一个 patch 是来自左上角,还是来自右下角。

所以,在把 patch token 送入 Transformer Encoder 之前,ViT 还会做两件事:

  1. 在序列最前面加入一个 class token,用来汇聚整张图的信息;
  2. 给每个 token 加上 positional embedding,用来告诉模型 token 的位置信息。

这一节我们就来讨论这两个设计。

import dnnl.models.vit as vit
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor

print('PyTorch version:', torch.__version__)
PyTorch version: 2.12.0+xpu

11.3.1 Patch Token 还不能直接表示整张图

经过 patch embedding 以后,一张图像会变成一个 patch 序列:

\[ X = [x_1, x_2, \dots, x_N] \]

其中,每个 \(x_i \in \mathbb{R}^D\) 表示一个图像 patch。

如果后面接的是 Transformer Encoder,那么输出仍然是一个序列:

\[ H = [h_1, h_2, \dots, h_N] \]

这里每个 \(h_i\) 都是对应 patch 的上下文表示。它不再只包含第 \(i\) 个 patch 自己的信息,而是通过 self-attention 融合了其他 patch 的信息。

但对于图像分类来说,我们最后需要的是一个整张图像的向量表示,而不是 \(N\) 个 patch 向量。也就是说,我们需要从 (B, N, D) 的输出中得到一个 (B, D) 的图像表示。

一个很自然的做法是对所有 patch token 做平均池化:

\[ h_\mathrm{avg} = \frac{1}{N}\sum_{i=1}^N h_i \]

这当然是可行的。事实上,很多后来的视觉 Transformer 变体也会使用平均池化或者类似的聚合方式。

但原始 ViT 采用的是另一种设计:在 patch token 前面加入一个特殊的可学习 token,叫做 class token,也常写作 [CLS]

11.3.2 Class Token:让一个 Token 代表整张图

Class token 的想法来自 BERT 这类 Transformer 模型。在文本分类中,模型会在输入序列前面加入一个特殊 token,例如 [CLS]。经过 Transformer 编码以后,这个 token 的输出表示会被用作整个序列的表示。ViT 借用了这个思路。

假设 patch token 序列是:

\[ [x_1, x_2, \dots, x_N] \]

ViT 会在它前面加一个可学习向量 \(x_\mathrm{cls}\)

\[ X = [x_\mathrm{cls}, x_1, x_2, \dots, x_N] \]

这样,序列长度就从 \(N\) 变成了 \(N + 1\)

这个 class token 一开始只是一个普通的可学习参数。它本身不对应图像中的某个 patch,也不来自输入图像。它的作用是在 Transformer Encoder 中和所有 patch token 进行 self-attention 交互。

经过多层 Transformer Encoder 以后,class token 的输出表示可以看作整张图像的信息汇聚结果:

\[ h_\mathrm{cls} = H[:, 0, :] \]

然后把这个向量送入分类头,就可以得到图像分类 logits:

\[ \mathrm{logits} = h_\mathrm{cls} W + b \]

直观来说,class token 就像是一个全局位置。它不代表某个局部图像块,而是专门用来从所有 patch token 中收集信息,最后形成整张图像的表示。

11.3.3 Class Token 的 PyTorch 实现

在实现上,class token 通常是一个形状为 (1, 1, embed_dim) 的可学习参数:

class ViTAddClassToken(nn.Module):
    def __init__(self, embed_dim: int):
        super().__init__()
        cls_token = torch.zeros(1, 1, embed_dim)
        self.cls_token = nn.Parameter(cls_token)
        nn.init.trunc_normal_(self.cls_token, std=0.02)

    def forward(self, x: Tensor) -> Tensor:
        cls_token = self.cls_token.expand(x.size(0), -1, -1)
        x = torch.concat([cls_token, x], dim=1)
        return x

这里有两个细节值得注意。

第一个细节是 cls_token 的形状是 (1, 1, D),而不是 (B, 1, D)。因为它是模型参数,不应该为每个 batch 单独存一份。真正前向传播时,我们再用 expand 把它扩展到当前 batch size:

cls_token = self.cls_token.expand(batch_size, -1, -1)

第二个细节是 expand 不会真正复制参数数据,它只是创建一个广播后的视图。这样所有样本使用的是同一个可学习 class token,但在计算图中会根据 batch 中所有样本的梯度一起更新它。

我们测试一下形状变化:

num_patches = 16
embed_dim = 64

x = torch.randn(2, num_patches, embed_dim)
cls_layer = ViTAddClassToken(embed_dim)
out = cls_layer(x)

print('Input shape:', x.shape)
print('Output shape:', out.shape)
Input shape: torch.Size([2, 16, 64])
Output shape: torch.Size([2, 17, 64])

输入形状是 (B, N, D),输出形状变成 (B, N + 1, D)。多出来的第一个 token 就是 class token。

11.3.4 Positional Embedding:让模型知道 Token 的位置

加入 class token 以后,序列已经可以表示整张图了。但我们还有另外一个问题:patch token 的位置信息在哪里?

Patch embedding 得到的是一串向量:

\[ [x_1, x_2, \dots, x_N] \]

虽然我们在展开 patch 时通常会按照从左到右、从上到下的顺序排列,但 self-attention 本身并不会天然理解这个顺序的含义。

这一点和文本 Transformer 中的问题是一样的。Self-attention 在计算时主要看 token 向量之间的相似度。如果没有额外的位置信息,它并不知道某个 token 是第一个位置还是第十个位置。

对于图像来说,位置甚至更加重要。一个 patch 在左上角和在右下角,语义可能完全不同。比如天空通常在图像上方,草地通常在图像下方;人脸中的眼睛、鼻子、嘴也有相对稳定的空间关系。如果模型完全不知道 patch 的位置,就很难利用这些空间结构。

因此,ViT 会给每个 token 加上一个可学习的位置向量:

\[ Z = X + P \]

其中,\(X \in \mathbb{R}^{B \times (N+1) \times D}\) 是加入 class token 之后的序列,\(P \in \mathbb{R}^{1 \times (N+1) \times D}\) 是 position embedding。

注意这里的位置数量是 \(N + 1\),因为除了 \(N\) 个 patch token 之外,class token 也需要一个自己的位置 embedding。

11.3.5 Positional Embedding 的 PyTorch 实现

原始 ViT 使用的是可学习的绝对位置嵌入。也就是说,模型会直接学习一个参数矩阵:

\[ P = [p_\mathrm{cls}, p_1, p_2, \dots, p_N] \]

其中 \(p_\mathrm{cls}\) 是 class token 的位置嵌入,\(p_i\) 是第 \(i\) 个 patch 位置的位置嵌入。

然后把它加到 token embedding 上:

\[ Z = X + P \]

实现起来非常简单:

class ViTPositionalEmbedding(nn.Module):
    def __init__(
        self,
        embed_dim: int,
        num_patches: int,
        use_cls_token: bool = True,
    ):
        super().__init__()
        self.use_cls_token = use_cls_token

        num_tokens = num_patches + int(use_cls_token)
        pos_embed = torch.zeros(1, num_tokens, embed_dim)
        self.pos_embed = nn.Parameter(pos_embed)
        nn.init.trunc_normal_(self.pos_embed, std=0.02)

    def forward(self, x: Tensor) -> Tensor:
        if x.size(1) != self.pos_embed.size(1):
            raise AssertionError(
                f'Expected sequence length {self.pos_embed.size(1)}, '
                f'but got {x.size(1)}.'
            )
        return x + self.pos_embed

测试一下:

num_patches = 16
embed_dim = 64

x = torch.randn(2, num_patches, embed_dim)
x = ViTAddClassToken(embed_dim)(x)
pos_embed = ViTPositionalEmbedding(
    num_patches=num_patches,
    embed_dim=embed_dim,
    use_cls_token=True,
)
out = pos_embed(x)

print('After class token:', x.shape)
print('After positional embedding:', out.shape)
After class token: torch.Size([2, 17, 64])
After positional embedding: torch.Size([2, 17, 64])

可以看到,加入 positional embedding 不会改变张量形状。它只是给每个 token 的表示加上一个和位置有关的偏移量。

也就是说:

\[ (B, N+1, D) \rightarrow (B, N+1, D) \]

维度没有变化,但 token 表示中已经包含了位置信息。

11.3.6 Class Token 和 Positional Embedding 的顺序

在 ViT 里,常见的输入构造顺序是:

\[ \text{image} \rightarrow \text{patch embedding} \rightarrow \text{append class token} \rightarrow \text{add positional embedding} \]

也就是:

\[ (B, C, H, W) \rightarrow (B, N, D) \rightarrow (B, N+1, D) \rightarrow (B, N+1, D) \]

为什么通常是先加 class token,再加 positional embedding?

因为 positional embedding 需要和最终输入序列中的每个 token 一一对应。如果先给 patch token 加位置,再拼接 class token,那么 class token 还需要单独处理自己的位置嵌入。先把 class token 拼进去,再统一加一个长度为 \(N+1\) 的 position embedding,会更直接。

下面我们把上一节的 patch embedding 和这一节讲的 class token、positional embedding 合在一起,得到一个完整的 ViT 输入模块。

class ViTEmbedding(nn.Module):
    def __init__(
        self,
        image_size: int = 224,
        patch_size: int = 16,
        in_channels: int = 3,
        embed_dim: int = 768,
        dropout: float = 0.0,
    ):
        super().__init__()
        self.patch_embed = vit.ViTConvPatchEmbedding(
            image_size=image_size,
            patch_size=patch_size,
            in_channels=in_channels,
            embed_dim=embed_dim,
        )
        self.add_cls_token = ViTAddClassToken(embed_dim)
        self.pos_embed = ViTPositionalEmbedding(
            num_patches=self.patch_embed.num_patches,
            embed_dim=embed_dim,
            use_cls_token=True,
        )
        self.dropout = nn.Dropout(dropout)

    def forward(self, x: Tensor) -> Tensor:
        x = self.patch_embed(x)
        x = self.add_cls_token(x)
        x = self.pos_embed(x)
        x = self.dropout(x)
        return x
1
这里的 Dropout 是可选的。原始 ViT 在输入模块后面没有加 dropout,但很多后来的变体会在这里加一个 dropout 层,用来增加输入的随机性,提升模型的泛化能力。

测试一下形状:

vit_embed = ViTEmbedding(
    image_size=32,
    patch_size=8,
    in_channels=3,
    embed_dim=64,
)

x = torch.randn(2, 3, 32, 32)
out = vit_embed(x)

print('ViT embedding output shape:', out.shape)
print('Number of patch tokens:', vit_embed.patch_embed.num_patches)
ViT embedding output shape: torch.Size([2, 17, 64])
Number of patch tokens: 16

这里图像大小是 \(32 \times 32\),patch size 是 \(8 \times 8\),所以每张图像有 16 个 patch token。加上 class token 后,序列长度变成 17。因此输出形状是 (B, 17, D)

11.3.7 Positional Embedding 和二维空间结构

到这里,可能大家还有一个容易疑惑的地方:

图像原本是二维网格,但 ViT 的 positional embedding 看起来是一维的。这样会不会丢失二维空间结构?

答案是不会。虽然 ViT 使用的是一维 position embedding,但这并不意味着它完全丢掉了图像的二维位置信息。因为 patch 序列是按照固定顺序从二维网格展平得到的。

比如一个 \(14 \times 14\) 的 patch 网格,会按照行优先顺序展平成 196 个 patch token。这样一来,序列中的第 \(i\) 个位置就总是对应原图中的某一个固定 patch 坐标。因此,position embedding 的形状虽然是 \((1, N+1, D)\),看起来只是在给一维序列加位置向量,但每个位置向量背后其实都对应着一个二维 patch 位置。

不过,这种可学习绝对位置嵌入也有一个限制:它和训练时的 patch 网格大小绑定得比较紧。比如模型训练时使用 \(224 \times 224\) 图像和 \(16 \times 16\) patch,那么 patch 网格是 \(14 \times 14\),位置嵌入长度是 197;如果推理时换成更高分辨率的图像,例如 \(384 \times 384\),patch 网格就变成 \(24 \times 24\),位置嵌入长度就变成 577。这时原来的 positional embedding 长度就对不上了。

那该怎么办呢?

一个常见做法是对 patch 部分的 positional embedding 做二维插值。Class token 的位置嵌入单独保留,patch positional embedding 先还原成二维网格,再插值到新的网格大小,最后重新展平成序列。

具体做法是:

  1. 从 positional embedding 中分离出 class token 的位置嵌入和 patch token 的位置嵌入。
  2. 把 patch token 的位置嵌入 reshape 成二维网格的形状 (1, old_h, old_w, D)
  3. 使用双线性插值把位置嵌入从旧的网格大小插值到新的网格大小 (1, new_h, new_w, D)
  4. 把插值后的 patch 位置嵌入 reshape 回序列形状 (1, new_h * new_w, D)
  5. 把 class token 的位置嵌入和新的 patch 位置嵌入拼接起来,得到新的 positional embedding。

下面是一个简化版实现:

def interpolate_pos_embedding(
    pos_embed: Tensor,
    old_grid_size: tuple[int, int],
    new_grid_size: tuple[int, int],
) -> Tensor:
    cls_pos_embed = pos_embed[:, :1]
    patch_pos_embed = pos_embed[:, 1:]

    old_h, old_w = old_grid_size
    new_h, new_w = new_grid_size
    embed_dim = pos_embed.size(-1)

    patch_pos_embed = patch_pos_embed.reshape(1, old_h, old_w, embed_dim)
    patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
    patch_pos_embed = F.interpolate(
        patch_pos_embed,
        size=(new_h, new_w),
        mode='bicubic',
        align_corners=False,
    )
    patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1)
    patch_pos_embed = patch_pos_embed.reshape(1, new_h * new_w, embed_dim)

    new_embed = torch.concat([cls_pos_embed, patch_pos_embed], dim=1)
    return new_embed
1
align_corners 控制插值时新旧网格的对齐方式。align_corners=True 把像素看成离散的点,并强行对齐新旧网格四个角上的点,因此角落位置的值会被保留;align_corners=False 则把像素看成一个个小方格,对齐的是整张图像的外围边界,而不是角落像素的中心。这样所有位置会按整体比例重新分布,更接近普通图像缩放的默认行为。

测试一下长度变化:

# First dimension must be 1 because it's a parameter,
# not a batch of embeddings.
pos_embed = torch.randn(1, 14 * 14 + 1, 64)
new_pos_embed = interpolate_pos_embedding(
    pos_embed,
    old_grid_size=(14, 14),
    new_grid_size=(24, 24),
)

print('Old positional embedding shape:', pos_embed.shape)
print('New positional embedding shape:', new_pos_embed.shape)
Old positional embedding shape: torch.Size([1, 197, 64])
New positional embedding shape: torch.Size([1, 577, 64])

可以看到,原来的 positional embedding 长度是 197(196 个 patch + 1 个 class token),新的 positional embedding 长度是 577(576 个 patch + 1 个 class token)。这样就可以适应不同分辨率的输入了。

11.3.8 本章小结

这一节我们补上了 ViT 输入模块中的两个关键设计:class tokenposition embedding

Patch embedding 只能把图像变成 patch token 序列:

\[ (B, C, H, W) \rightarrow (B, N, D) \]

但图像分类需要一个整张图像的表示,所以 ViT 会在序列前面加入一个可学习的 class token:

\[ (B, N, D) \rightarrow (B, N + 1, D) \]

经过 Transformer Encoder 后,class token 的输出表示会被用作整张图像的表示,并送入分类头。

同时,由于 self-attention 本身不包含位置信息,ViT 还会给每个 token 加上可学习的位置嵌入:

\[ Z = X + P \]

这样,模型不仅知道每个 token 表示哪个 patch 的内容,也能通过 positional embedding 感知这个 patch 在图像中的位置。

所以,完整的 ViT 输入模块的构建流程就是:

\[ \text{image} \rightarrow \text{patch embedding} \rightarrow \text{class token} \rightarrow \text{positional embedding} \]

到这里,我们就完成了 ViT 输入模块的构建。下一节我们将进入 ViT Encoder 的实现,利用 self-attention,把 patch token 和 class token 的信息融合在一起,得到整张图像的表示。