上一节里,我们已经把一个最基本的 Vision Transformer 串了起来:
\[
\begin{aligned}
\text{image}
& \rightarrow \text{patch embedding} \\
& \rightarrow \text{class token} \\
& \rightarrow \text{positional embedding} \\
& \rightarrow \text{Transformer Encoder} \\
& \rightarrow \text{classification head}
\end{aligned}
\]
如果目标只是图像分类,到这里似乎已经结束了。输入一张图像,模型输出类别 logits,然后用交叉熵训练即可。
但在实际应用中,ViT 的价值并不仅仅是做分类。更常见的做法是把 ViT 看成一个通用的视觉 backbone :它负责把图像编码成一组高层视觉特征,后面再接不同的任务头,用于分类、检测、分割,甚至多模态任务。
这就带来一个新的问题:
为什么 ViT 特别适合作为 backbone?如果把它当作 backbone,预训练和微调又分别在做什么?
这一节我们就从这个问题出发,讨论 ViT 作为视觉 backbone 的使用方式。
import math
from pprint import pprint
import evaluate
import IPython.display as ipy
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from datasets import load_dataset
from torch import Tensor
from transformers import Trainer, TrainingArguments
print ('PyTorch version:' , torch.__version__)
PyTorch version: 2.12.0+xpu
torch.manual_seed(42 )
torch.use_deterministic_algorithms(False )
torch.backends.cudnn.deterministic = False
torch.backends.cudnn.benchmark = True
device = torch.accelerator.current_accelerator(check_available= True )
if device is None :
device = torch.device('cpu' )
print (f'Using device: { device} ' )
11.5.1 从分类模型到视觉 Backbone
我们先回到上一节的完整 ViT 分类模型。它可以拆成两部分:
\[
\text{ViT} = \text{backbone} + \text{task head}
\]
其中,backbone 包括:
\[
\text{patch embedding}
+ \text{class token}
+ \text{positional embedding}
+ \text{Transformer Encoder}
\]
Task head 则是最后的分类头:
\[
\operatorname{Linear}: \mathbb{R}^{D} \rightarrow \mathbb{R}^{K}
\]
对于图像分类来说,分类头取出 class token 的输出表示:
\[
h_\mathrm{cls} = Z_L[:, 0, :]
\]
然后映射成类别 logits:
\[
\mathrm{logits} = h_\mathrm{cls} W + b
\]
但是,真正比较通用的部分不是这个分类头,而是前面的 ViT Encoder。分类头只和某个具体数据集的类别有关。比如 ImageNet 有 1000 类,CIFAR-10 有 10 类,医学图像数据集可能只有 2 类或 5 类。类别一变,最后的分类头就要换。
相比之下,ViT Encoder 学到的是更通用的视觉表示。它把图像转换成一组上下文相关的 token 表示:
\[
Z_L \in \mathbb{R}^{B \times (N+1) \times D}
\]
这些 token 表示可以被不同任务继续使用。因此,我们通常把 ViT 的主体部分称为 backbone,把后面针对具体任务的小模块称为 head。
直观来说:
视觉 backbone 负责看懂图像,head 负责回答具体问题。
11.5.2 Backbone 为下游任务提供了什么?
既然 ViT backbone 的输出是一个 token 序列,那么它到底给下游任务提供了什么?
假设经过 ViT Encoder 以后,输出是:
\[
Z_L = [z_\mathrm{cls}, z_1, z_2, \dots, z_N]
\]
这里有两类表示:
\(z_\mathrm{cls}\) :class token 的输出,通常被看作整张图像的全局表示;
\(z_1, z_2, \dots, z_N\) :patch token 的输出,保留了不同图像区域的局部表示。
如果任务是图像分类,我们通常只需要全局语义,所以可以取 class token:
\[
h_\mathrm{image} = z_\mathrm{cls}
\]
如果任务是语义分割,我们需要对每个空间位置做预测,那么 patch tokens 就更重要。因为每个 patch token 仍然对应图像中的一个区域。只要把 patch token 重新排列回二维网格,就可以得到类似 feature map 的表示。
假设图像被切成 \(H_p \times W_p\) 个 patch,那么 patch token 的数量就是:
\[
N = H_p \times W_p
\]
Patch token 的形状是:
\[
(B, N, D)
\]
我们可以把它还原成二维特征图:
\[
(B, N, D)
\rightarrow
(B, H_p, W_p, D)
\rightarrow
(B, D, H_p, W_p)
\]
这样,ViT backbone 的输出就可以接入各种 dense prediction 模块,用于语义分割、实例分割、目标检测等任务。
我们可以写一个简单的函数,把 patch token 还原成 feature map:
def patch_tokens_to_feature_map(
patch_tokens: Tensor,
grid_size: tuple [int , int ],
) -> Tensor:
batch_size, num_patches, embed_dim = patch_tokens.size()
grid_h, grid_w = grid_size
if num_patches != grid_h * grid_w:
raise AssertionError ('`num_patches` must be equal to `grid_h * grid_w`.' )
x = patch_tokens.reshape(batch_size, grid_h, grid_w, embed_dim)
x = x.permute(0 , 3 , 1 , 2 )
return x
测试一下:
embed_dim = 64
grid_size = (4 , 4 )
num_patches = math.prod(grid_size)
patch_tokens = torch.randn(2 , num_patches, embed_dim)
feature_map = patch_tokens_to_feature_map(patch_tokens, grid_size)
print ('Patch tokens shape:' , patch_tokens.shape)
print ('Feature map shape:' , feature_map.shape)
Patch tokens shape: torch.Size([2, 16, 64])
Feature map shape: torch.Size([2, 64, 4, 4])
可以看到,patch tokens 从 (B, N, D) 变成了 (B, D, H_p, W_p)。这就是 ViT 作为 dense prediction backbone 时非常常见的一步。
11.5.3 为什么 ViT 适合作为 Backbone?
现在我们可以更具体地回答:为什么 ViT 适合作为 backbone?
第一个原因是,ViT 的接口非常统一。无论输入是图像、文本还是其他模态,只要输入可以被表示成 token 序列,就可以送入 Transformer。对于图像来说,patch embedding 把图像翻译成视觉 token;对于文本来说,word embedding 把词翻译成文本 token。它们最后都变成了 (B, N, D) 这种统一的 token 表示,让 ViT 很容易和其他 Transformer 模块连接起来。
第二个原因是,ViT 的输出既包含全局表示,也包含局部表示。Class token 可以用于图像级任务,patch tokens 可以用于位置相关任务。因此,同一个 backbone 可以服务于不同下游任务。
第三个原因是,self-attention 很适合建模长距离关系。在 CNN 中,远距离区域的信息交互通常需要经过多层卷积逐渐传播;而在 ViT 中,任意两个 patch 在一层 self-attention 中就可以直接交互。这使得 ViT 在建模全局结构时非常自然。
不过,这也带来一个代价:标准 ViT 的视觉先验比较弱。CNN 天然带有局部性和平移等变性,而 ViT 更多依赖数据自己学出视觉结构。所以,ViT 往往更依赖大规模预训练。也就是说,ViT 适合作为 backbone,并不是因为它从一开始就比 CNN 更懂图像,而是因为它提供了一个非常通用、可扩展、易迁移的视觉表示框架。
11.5.4 预训练:先学通用视觉表示
前面我们从结构上理解了 ViT backbone。当然,我们不可能用 CIFAR-10 这种小数据集来训练一个 ViT backbone。因为 ViT 的参数量很大,训练它需要大量数据来避免过拟合。因此,在实际使用时,更常见的做法是直接加载别人已经训练好的 checkpoint。这个已经训练好的模型通常叫作预训练模型(pretrained model) 。
所谓预训练(pre-training) ,就是让模型先去见很多很多图像,提前积累一些通用的视觉经验。
比如,我们可以先让 ViT 在 ImageNet-1K 上学习区分 1000 个类别,其中包括动物、植物、家具、食物、工具等各种常见的东西。这时候,模型虽然还没有见过我们具体的花卉分类数据集,但它已经学到了很多通用的视觉线索,例如边缘、纹理、颜色组合、物体轮廓,以及主体和背景之间的关系。
在此基础上,我们只需要把这个 ViT 模型放到自己的花卉数据集上继续训练一小段时间,让它适应新的类别。由于模型已经具备了一定的视觉理解能力,这个过程通常会比从零开始训练容易很多。这个过程就叫微调(fine-tuning) 。
Hugging Face Transformers 提供了许多预训练好的 ViT checkpoint,我们可以直接加载使用。为了区分不同配置,这些模型名称通常会包含模型规模、patch 大小、输入图像分辨率、预训练数据集等关键信息。例如:
vit-base-patch16-224
我们可以拆成三部分理解:
base:表示模型规模,通常有 tiny、small、base、large 等不同大小;
patch16:表示 patch size 是 \(16 \times 16\) ,也就是每个 patch 包含 \(16 \times 16\) 个像素;
224:表示预训练时的输入分辨率是 \(224 \times 224\) 。
也就是说,vit-base-patch16-224 对应的是一个 ViT-base 模型,输入图像通常被缩放到 \(224 \times 224\) ,然后切成 \(16 \times 16\) 的 patch。
常见的 ViT 配置可以整理成下面这样:
表 1:常见 ViT 模型配置
google/vit-base-patch16-224
16
224
768
12
12
86M
最常用的 ViT-Base/16 配置
google/vit-base-patch32-224
32
224
768
12
12
86M
Patch 更大,token 更少,计算更省
google/vit-large-patch16-224
16
224
1024
24
16
307M
更大的 ViT-Large 配置,性能更好但更重
google/vit-large-patch32-384
32
384
1024
24
16
307M
更大的输入分辨率,适合细粒度任务
这里的 base 和 large 主要控制 Transformer Encoder 的宽度和深度:
base 通常使用 hidden_size=768、hidden_layers=12、attention_heads=12;
large 通常使用 hidden_size=1024、hidden_layers=24、attention_heads=16。
而 patch16 和 patch32 主要影响 token 数量。patch 越小,token 越多,空间细节保留得越多,但 self-attention 的计算量也越大。因为 self-attention 的复杂度和序列长度近似成平方关系,所以在相同输入分辨率下,patch16 会比 patch32 更细,但也更贵。
如果我们想加载一个预训练 ViT 做图像分类,可以直接使用 Hugging Face 的 Transformers 库:
from transformers import AutoImageProcessor, AutoModelForImageClassification
model_id = 'google/vit-base-patch16-224'
processor = AutoImageProcessor.from_pretrained(model_id)
model = AutoModelForImageClassification.from_pretrained(model_id)
ipy.clear_output()
print ('Image size:' , model.config.image_size)
print ('Patch size:' , model.config.patch_size)
print ('Hidden size:' , model.config.hidden_size)
print ('Number of hidden layers:' , model.config.num_hidden_layers)
print ('Number of attention heads:' , model.config.num_attention_heads)
print ('Intermediate size:' , model.config.intermediate_size)
Image size: 224
Patch size: 16
Hidden size: 768
Number of hidden layers: 12
Number of attention heads: 12
Intermediate size: 3072
这里的 processor 负责把 PIL image 或 NumPy image 转成模型需要的 pixel_values,包括图像缩放、归一化等预处理;model 则是一个带分类头的 ViT 模型。
如果我们只关心 backbone 特征,也可以加载不带具体分类头的 ViTModel:
from transformers import ViTModel
model_id = 'google/vit-base-patch16-224'
backbone = ViTModel.from_pretrained(model_id)
ipy.clear_output()
不过在下游分类任务中,更常见的是先加载 AutoModelForImageClassification,然后把分类头替换成当前数据集需要的类别数。这是因为 AutoModelForImageClassification 已经把 ViTModel 和分类头封装在一起,会直接输出分类任务需要的 logits。如果使用 ViTModel,我们还需要自己写一个分类头来接收 backbone 的输出。
11.5.5 微调:把 Backbone 迁移到具体任务
预训练完成以后,每次碰到新的任务,我们通常不会再从零开始训练一个 ViT,而是直接把预训练好的 backbone 拿过来,针对这个特定任务做微调(fine-tuning) 。
假设预训练时的类别数是 1000,而现在我们要做一个 10 分类任务。此时,backbone 可以复用,但分类头必须替换:
\[
\operatorname{Linear}(D, 1000)
\quad \rightarrow \quad
\operatorname{Linear}(D, 10)
\]
然后在新数据集上继续训练:
\[
\text{image}
\rightarrow
\text{pretrained ViT backbone}
\rightarrow
\text{new task head}
\]
这就是微调。
根据是否更新 backbone 参数,微调通常有两种做法。
第一种是 linear probing 。也就是冻结 backbone,只训练新的分类头:
\[
\theta_\mathrm{backbone} \text{ fixed},
\quad
\theta_\mathrm{head} \text{ trainable}
\]
这种方式可以测试 backbone 学到的表示是否已经足够好。如果只训练一个线性分类头就能取得不错效果,说明 backbone 的特征具有较强的可迁移性。同时,由于不用更新 backbone 参数,这种训练方式也更快,更不容易过拟合,适合下游数据比较少的情况。
第二种是 full fine-tuning 。也就是分类头和 backbone 一起训练:
\[
\theta_\mathrm{backbone}, \theta_\mathrm{head} \text{ trainable}
\]
这种方式通常效果更好,因为 backbone 可以根据下游任务进行适配。但它也更容易过拟合,尤其是在下游数据比较少的时候。
我们用 Hugging Face 的 beans 数据集做一个小实验。这个数据集是豆叶病害分类数据集,一共有 3 个类别。我们先加载一个 ImageNet 上预训练过的 ViT,然后把分类头替换成 3 类。这段代码现在看不懂也没关系,在章节末尾我们会专门介绍如何微调 ViT,以及如何把 ViT 应用到具体的下游任务上。
下面这段代码运行时间较长(XPU 需要约 10 分钟),因此默认不会在渲染文档时运行。如果你想自己运行,可以把 eval: false 改成 eval: true,或者直接复制代码到 Python 文件里运行。
from transformers import AutoImageProcessor, AutoModelForImageClassification
model_id = 'google/vit-base-patch16-224'
ds = load_dataset('beans' )
train_ds = ds['train' ].shuffle()
val_ds = ds['validation' ].shuffle()
labels = ds['train' ].features['labels' ].names
id2label = {i: label for i, label in enumerate (labels)}
label2id = {label: i for i, label in id2label.items()}
processor = AutoImageProcessor.from_pretrained(model_id)
model = AutoModelForImageClassification.from_pretrained(
model_id,
num_labels= len (labels),
id2label= id2label,
label2id= label2id,
ignore_mismatched_sizes= True ,
device_map= device,
)
ipy.clear_output()
def transform(batch: dict ) -> dict :
images = [image.convert('RGB' ) for image in batch['image' ]]
inputs = processor(images, return_tensors= 'pt' )
inputs['labels' ] = batch['labels' ]
return inputs
train_ds = train_ds.with_transform(transform)
val_ds = val_ds.with_transform(transform)
def collate_fn(images: list [dict ]) -> dict :
pixel_values = torch.stack([image['pixel_values' ] for image in images])
labels = torch.tensor([image['labels' ] for image in images])
return {'pixel_values' : pixel_values, 'labels' : labels}
def compute_metrics(eval_pred: tuple [np.ndarray, np.ndarray]) -> dict :
acc_metric = evaluate.load('accuracy' )
logits, labels = eval_pred
predictions = np.argmax(logits, axis=- 1 )
acc_value = acc_metric.compute(predictions= predictions, references= labels)
return acc_value
training_args = TrainingArguments(
output_dir= 'vit-beans-demo' ,
per_device_train_batch_size= 16 ,
per_device_eval_batch_size= 16 ,
num_train_epochs= 5 ,
learning_rate= 2e-5 ,
logging_steps= 10 ,
save_strategy= 'no' ,
remove_unused_columns= False ,
)
trainer = Trainer(
model= model,
args= training_args,
train_dataset= train_ds,
eval_dataset= val_ds,
processing_class= processor,
data_collator= collate_fn,
compute_metrics= compute_metrics,
)
before_finetuning = trainer.evaluate()
trainer.train()
after_finetuning = trainer.evaluate()
print ('Before fine-tuning:' )
pprint(before_finetuning)
print (' \n ' , end= '' )
print ('After fine-tuning:' )
pprint(after_finetuning)
Before fine-tuning:
{'epoch': 0,
'eval_accuracy': 0.40601503759398494,
'eval_loss': 1.0192097425460815,
'eval_model_preparation_time': 0.003,
'eval_runtime': 8.7041,
'eval_samples_per_second': 15.28,
'eval_steps_per_second': 1.034}
After fine-tuning:
{'epoch': 5.0,
'eval_accuracy': 0.9849624060150376,
'eval_loss': 0.020609118044376373,
'eval_model_preparation_time': 0.003,
'eval_runtime': 4.5798,
'eval_samples_per_second': 29.041,
'eval_steps_per_second': 1.965}
这个例子的对比重点在于观察同一个预训练 ViT 在下游数据集上的变化:
没有微调:预训练 backbone + 随机初始化的新分类头
微调之后:预训练 backbone + 已经适配 beans 数据集的新分类头
可以看到,只要数据量足够、训练设置合理,微调后的准确率通常会明显高于没有微调时的结果。而且,微调通常只需要少量的训练步骤就能取得不错的效果。原因也很直接:预训练 backbone 已经具备了很强的视觉理解能力,微调只是让分类头和 backbone 进一步适配当前数据集的类别。
11.5.6 Backbone 思维:不同任务如何使用 ViT
从这一节开始,我们可以用一种更工程化的方式看 ViT:
\[
\text{model} = \text{backbone} + \text{head}
\]
其中,backbone 负责把图像编码成通用的视觉表示,head 负责解决具体任务。
这个想法不只适用于 ViT,也适用于很多视觉模型。比如 CNN 中常说的 ResNet backbone 也是类似的:前面的卷积网络提取通用视觉特征,后面的分类头、检测头或分割头负责完成具体任务。
ViT 的特殊之处在于,它的 backbone 输出不是传统 CNN 中的单个 feature map,而是一组 token 表示:
\[
Z = [z_\mathrm{cls}, z_1, z_2, \dots, z_N] \in \mathbb{R}^{B \times (N+1) \times D}
\]
其中,\(z_\mathrm{cls}\) 通常用于图像级任务,而 \(z_1,\dots,z_N\) 这些 patch tokens 保留了更细粒度的空间信息。因此,下游任务需要什么样的输出,就会选择不同的 token 使用方式。
对于图像分类,任务需要一个图像级表示,所以通常使用 class token:
\[
z_\mathrm{cls} \rightarrow \text{classification head}
\]
对于语义分割,任务需要每个空间位置的类别,所以通常使用 patch tokens。我们可以先去掉 class token,再把 patch tokens 还原成二维 feature map,接上分割头:
\[
[z_1, z_2, \dots, z_N]
\rightarrow
(B, D, H_p, W_p)
\rightarrow
\text{segmentation head}
\]
对于目标检测,模型需要同时判断“哪里有物体”和“物体是什么类别”。这时,ViT backbone 可以提供 patch-level feature,检测头再基于这些特征预测边界框和类别。
对于多模态任务,ViT 输出的图像 tokens 可以作为视觉上下文,被文本 token 查询;也可以和文本 token 拼接在一起,送入更大的 multimodal Transformer:
\[
\text{image tokens} + \text{text tokens}
\rightarrow
\text{multimodal Transformer}
\]
所以,当我们说“ViT 是一个 backbone”时,重点不是最后的分类头,而是前面的 token 表示学习能力。ViT 把图像变成一组可复用的视觉 tokens,后面的任务头只是在这些 tokens 上提出不同的问题。
11.5.7 本章小结
这一节我们从完整 ViT 分类模型出发,进一步讨论了 ViT 作为视觉 backbone 的用法。
完整的 ViT 可以拆成:
\[
\text{ViT} = \text{backbone} + \text{task head}
\]
其中,backbone 负责把图像编码成 token 表示:
\[
(B, C, H, W) \rightarrow (B, N+1, D)
\]
这些 token 表示有两种常见用法:
使用 class token 表示整张图像,用于分类等图像级任务;
使用 patch tokens 表示局部区域,用于分割、检测等位置相关任务。
ViT 适合作为 backbone 的原因在于:它使用统一的 token 接口,能够建模全局关系,并且输出的 token 表示可以被不同任务头复用。不过,标准 ViT 的视觉先验较弱,因此通常需要先在大规模数据上预训练,让 backbone 学到通用视觉表示,然后再针对下游任务替换任务头,并进行微调。
到这里,我们已经讲完了最基本的 ViT 结构和使用方式。不过,正如我们前面所说,ViT 的结构虽然简单统一,但缺少 CNN 那样强的视觉先验,所以在数据量不够时,训练效果往往不如传统卷积网络。
原始 ViT 论文中的一个关键现象是:当模型在超大规模数据集上预训练时,例如 JFT-300M(ImageNet-1K 的 200 多倍),ViT 可以取得很强的效果;但如果只在 ImageNet-1K 这样规模的数据集(这个数据集其实已经非常大了)上从头训练,ViT 的优势并不明显。换句话说,ViT 很强,但它一开始更像是一个“吃数据”的模型。
这就引出了另一个问题:
如果没有 JFT-300M 这样的超大规模私有数据集,我们能不能也把 ViT 训练好?
下一节要介绍的 DeiT,就是为了解决这个问题提出的。它的目标不是改变 ViT 的基本结构,而是通过更合适的训练策略和知识蒸馏,让 ViT 也能在 ImageNet-1K 这样相对常规的数据规模上训练出很好的效果。