8.2 Cross-Attention:一个序列查询另一个序列

Author

jshn9515

Published

2026-04-09

Modified

2026-05-04

在上一节里,我们从 Bahdanau attention 出发,看到了注意力机制最早的一个重要动机:不要把整句输入强行压缩成一个固定长度向量,而是在生成每个目标词时,动态地去源句中寻找当前最相关的信息。这个思想可以概括成一句话:

当前时刻需要什么信息,就用当前状态去输入序列里查找什么信息。

在 Bahdanau attention 中,解码器的隐藏状态会去和编码器的所有隐藏状态进行匹配,然后根据匹配程度对编码器状态做加权求和。这个过程已经非常接近现代注意力机制中的交叉注意力(Cross-Attention)。区别在于,现代 Transformer 会用更统一、更矩阵化的方式来表达这件事:一个序列提供查询,另一个序列提供可被查询的信息。

这一节我们重点讨论 cross-attention,self-attention 会放到下一节单独展开。这样,我们可以先把一个序列查询另一个序列的结构讲清楚,再过渡到一个序列查询自己的情况。

import math

import matplotlib.pyplot as plt
import numpy as np
import scipy.stats as stats
import torch
import torch.nn as nn
from torch import Tensor

plt.rc('savefig', dpi=300, bbox='tight')
print('PyTorch version:', torch.__version__)
PyTorch version: 2.12.0+xpu

8.2.1 从软对齐到 Cross-Attention

在机器翻译任务中,我们有两个序列:

  • 源语言序列,例如英文句子;
  • 目标语言序列,例如中文句子。

传统 seq2seq 模型会先用编码器把源句压缩成一个上下文向量,然后解码器依赖这个向量逐步生成目标句子。Bahdanau attention 改变了这个过程:解码器不再只依赖一个固定的向量,而是在每一步生成时,都重新访问源句的所有隐藏状态。

换句话说,解码器每生成一个词,都会问一次:

对于我现在要生成的这个词,源句里哪些位置最重要?

这就是 cross-attention 的雏形。

所谓 cross-attention,就是让一个序列中的位置,去关注另一个序列中的位置。它和 self-attention 的区别在于,self-attention 是一个序列内部互相看,而 cross-attention 是两个序列之间发生交互。

如果用机器翻译来理解:

  • 目标序列中的某个位置负责提出问题;
  • 源序列中的每个位置负责提供候选信息;
  • 模型根据相关性,从源序列中取回对当前生成最有用的内容。

所以,cross-attention 本质上就是一种跨序列的信息检索机制

8.2.2 Query、Key、Value:动态信息检索的三要素

现代 attention 通常会用 query、key、value 来描述这个过程。虽然这三个名字看起来有些抽象,但它们背后的直觉并不复杂。我们可以把 attention 想象成一次检索:

  • Query 表示当前我们想找什么;
  • Key 表示每个位置如何参与匹配;
  • Value 表示每个位置真正提供什么内容。

从直觉上看,这很像我们平常上网搜索资料。我们在输入框里输入自己想找的内容,搜索引擎会返回一组候选结果。每条结果通常都有一个标题,以及对应的网页内容。我们会先根据自己当前的需求和这些标题,判断哪些结果更值得关注,然后再从这些结果的具体内容里提取真正有用的信息。

在 cross-attention 中,query 和 key、value 通常来自不同序列。例如,在机器翻译里,编码器输出的隐藏状态就是 key 和 value,而解码器当前的隐藏状态就是 query。解码器用 query 去匹配编码器的 key,得到一组权重,然后对编码器的 value 做加权求和,得到当前目标位置需要的上下文信息。

这其实就是 Bahdanau attention 的核心思想,只不过现代 attention 把当前解码器状态和所有编码器状态的匹配过程,统一写成了 Q、K、V 的形式,同时把 K 和 V 拆开,增强模型的表达能力。

因此,在 cross-attention 里,一个很自然的对应关系是:

  • Query 来自正在生成或正在更新的序列;
  • Key 和 value 来自被查询的序列。

也就是说,谁在提问,谁就产生 query;谁被查询,谁就产生 key 和 value。

这句话很重要。Cross-attention 并不是要求两个序列必须同等地互相看,而是有一个明确的方向:一个序列主动查询,另一个序列提供可被查询的信息。

到这里,我们已经大体熟悉了 query、key 和 value。但是,相信大家还有很多疑问:为什么要把这些角色拆开?尤其是为什么来自同一个被查询序列的 key 和 value 还要分开?我们在下一小节里来详细讨论这个问题。

8.2.3 为什么要拆成 Query、Key、Value?

上一小节里,我们把 Q、K、V 看成 attention 中的三种角色:query 表示当前需求,key 用来参与匹配,value 用来提供内容。这个说法很方便理解,但它也会带来几个自然的问题:

  1. 为什么要把输入拆成 query、key、value?
  2. 既然在 cross-attention 中 key 和 value 通常来自同一个序列,为什么 key 和 value 还要分开?
  3. 我能不能自己再加一个,比如 Q、K、V、W?

我们先看第一个问题:为什么需要 Q、K、V。

这个问题比较好解释。实际上,如果不拆成 Q、K、V,那么同一个向量就又要同时承担我想找什么、我是否值得被匹配、如果被关注应该提供什么内容这几个角色。这就又回到了我们之前那个固定上下文向量的问题:它需要同时承担多个不同的功能,模型表达能力会受限。而拆成 Q、K、V 后,模型就可以从不同角度,分别学习出查询需求、匹配特征、可传递内容,从而自己学会如何组织信息。

接下来再看第二个问题:为什么 key 和 value 要分开。

一个直观的回答是:用于匹配的信息最终取回的信息可以不是同一种信息。

在搜索引擎里,我们可能先看标题、关键词或摘要来判断一个网页是否相关;但真正需要的信息,可能藏在网页正文里。标题适合匹配,正文适合提供内容。两者相关,但并不完全相同。

Attention 里一个很重要的设计,就是把两件事分开了:

  • 第一件事:谁和我现在最相关?
  • 第二件事:如果我关注它,它能提供什么信息?

这就好比我们在图书馆找书。我们心里有一个需求,这有点像 query;我们会通过书名、目录、标签或摘要判断哪本书相关,这些线索有点像 key;而真正被我们阅读和吸收的,是书里的具体内容,这有点像 value。

如果 key 和 value 完全绑死,模型只能用同一种表示同时做匹配和内容传递;如果 key 和 value 分开,模型就可以用一种表示决定谁更相关,再用另一种表示决定相关位置应该贡献什么。也就是说:

  • 如果我们改变 key,会改变模型关注源序列的哪些位置;
  • 如果我们改变 value,在注意力权重不变的情况下,会改变模型取回的内容。

所以,attention 的注意力权重由 query 和 key 决定,但最终输出还取决于 value。把 key 和 value 分开以后,模型就可以把匹配空间和内容空间解耦开来,让它们分别学习不同的功能。

对于第三个问题:我能不能自己再加一个,比如 Q、K、V、W?

从理论上讲,当然可以。我们完全可以设计一个更复杂的 attention 机制,增加更多的角色,比如 W 代表某种辅助信息,或者引入更多的匹配维度。但是,Q、K、V 已经被证明是非常有效的设计。它们把 attention 的核心功能分成了三个角色,既有足够的表达能力,又不会过于复杂。增加更多角色可能会让模型更难训练,或者需要更多的数据来学习这些角色的分工。

不过,需要注意的是,attention 里很多具体设计,并不是从某个理论中必然推出来的。我们现在所作的解释是一种事后解释。更准确地说,它们是一组在实践中被证明有效、便于训练和实现的设计选择。也正因为如此,关于为什么不是 Q、K、V、W,打分函数能不能用其他相似度,往往并不存在标准答案。不同设计通常是在表达能力、计算效率、训练稳定性和实现复杂度之间做权衡。这也是为什么 attention 有很多变体的原因。

8.2.5 Cross-Attention 的矩阵形式

现在我们用矩阵形式写出 cross-attention。

假设有两个序列:

  • 查询序列表示为 \(X \in \mathbb{R}^{n_x \times d}\)
  • 被查询序列表示为 \(Y \in \mathbb{R}^{n_y \times d}\)

其中,\(n_x\) 是查询序列的长度,\(n_y\) 是被查询序列的长度,\(d\) 是隐藏维度。

Cross-attention 首先通过线性投影得到:

\[ Q = X W_Q, \quad K = Y W_K, \quad V = Y W_V \]

然后计算 query 和 key 之间的相似度:

\[ S = QK^\top \]

这里的 \(S \in \mathbb{R}^{n_x \times n_y}\)。它的第 \(i\) 行、第 \(j\) 列表示查询序列中的第 \(i\) 个位置对被查询序列中第 \(j\) 个位置的相关性分数。

接着,对每一行做 softmax,把分数变成权重:

\[ P = \operatorname{softmax}\left(\frac{QK^\top}{\sqrt{d_k}}\right) \]

最后,用这些权重对 value 做加权求和:

\[ O = PV \]

合在一起,就是我们非常熟悉的 scaled dot-product attention:

\[ \operatorname{Attention}(Q, K, V) = \operatorname{softmax}\left(\frac{QK^\top}{\sqrt{d_k}}\right)V \]

在 cross-attention 中,最关键的一点是 Q 来自一个序列,而 K 和 V 来自另一个序列。这也是它被称为 cross-attention 的原因。它广泛适用于机器翻译、图文匹配、问答等需要跨序列交互的任务。

下面这段代码,我们用随机张量实现一个最小版的 cross-attention。这里 xy 是两个不同序列,Q 来自 x,K 和 V 来自 y

Note

这里需要说明的是,我们的实现主要是对齐 PyTorch 的 nn.MultiheadAttention,而不是实现 attention 在维度设计上的最一般形式。

nn.MultiheadAttention 中,embed_dim 同时表示 query 的输入维度、query/key/value 投影后的总维度,以及最终 attention 输出的维度。它只额外提供 kdimvdim,用来允许 key/value 的原始输入维度与 query 不同。

不过从 attention 机制本身来看,维度可以更灵活。query、key、value 的输入维度可以不同,query/key 的匹配维度和 value 的内容维度也可以不同,最终输出维度也不一定要等于 embed_dim。只要 query 和 key 能正确计算 attention score,并且 attention weight 能和 value 对齐做加权求和,这个 attention 计算就是成立的。

class CrossAttention(nn.Module):
    def __init__(self, embed_dim: int):
        super().__init__()
        self.embed_dim = embed_dim

        self.q_proj = nn.Linear(embed_dim, embed_dim)
        self.k_proj = nn.Linear(embed_dim, embed_dim)
        self.v_proj = nn.Linear(embed_dim, embed_dim)
        self.out_proj = nn.Linear(embed_dim, embed_dim)

    def forward(
        self,
        query: Tensor,
        key: Tensor,
        value: Tensor,
        need_weights: bool = False,
    ) -> tuple[Tensor, Tensor | None]:
        # Make sure the sequence lengths of key and value match
        if key.size(-2) != value.size(-2):
            raise AssertionError(
                '`key` and `value` must have the same sequence length.'
            )

        q = self.q_proj(query)
        k = self.k_proj(key)
        v = self.v_proj(value)

        scores = q @ k.transpose(-2, -1)
        scores = scores / math.sqrt(self.embed_dim)

        attn_weights = scores.softmax(dim=-1)
        attn_output = attn_weights @ v
        output = self.out_proj(attn_output)

        if need_weights:
            return output, attn_weights

        return output, None


x = torch.randn(3, 5, 16)  # query sequence, length = 5
y = torch.randn(3, 8, 16)  # key/value sequence, length = 8
cross_attn = CrossAttention(embed_dim=16)

with torch.inference_mode():
    Q = cross_attn.q_proj(x)
    K = cross_attn.k_proj(y)
    V = cross_attn.v_proj(y)
    output, attn_weights = cross_attn(x, y, y, need_weights=True)

print('Q shape:', Q.shape)
print('K shape:', K.shape)
print('V shape:', V.shape)
print('Attention weights shape:', attn_weights.shape)
print('Output shape:', output.shape)
Q shape: torch.Size([3, 5, 16])
K shape: torch.Size([3, 8, 16])
V shape: torch.Size([3, 8, 16])
Attention weights shape: torch.Size([3, 5, 8])
Output shape: torch.Size([3, 5, 16])

可以看到,attention weights 的形状是 (batch size, query length, key/value length)。也就是说,查询序列中的每个位置,都会对被查询序列中的所有位置分配一组权重。

我们把这个权重矩阵可视化一下。

fig = plt.figure(1, figsize=(5, 3))
ax = fig.add_subplot(1, 1, 1)
im = ax.pcolormesh(attn_weights[0], cmap='Blues', vmin=0, vmax=0.4)
x_ticks = np.arange(y.size(-2))
y_ticks = np.arange(x.size(-2))
cbar_ticks = np.arange(0, 0.5, 0.1)
ax.set_xticks(x_ticks + 0.5, x_ticks)
ax.set_yticks(y_ticks + 0.5, y_ticks)
ax.invert_yaxis()
ax.set_xlabel('key/value position')
ax.set_ylabel('query position')
ax.set_title('Cross-Attention Weights')
ax.set_aspect('equal')
fig.colorbar(im, shrink=0.85, ticks=cbar_ticks)
fig.savefig('figures/ch8.2-cross-attn-weights.svg')
plt.close(fig)

图中的第 \(i\) 行第 \(j\) 列表示查询序列的第 \(i\) 个位置对被查询序列的第 \(j\) 个位置的关注程度。颜色越深表示权重越大,也就是查询位置越关注被查询位置的信息。由于这是 cross-attention,横轴和纵轴对应的是两个不同序列。

不过,这里还有一个看起来很小但其实很关键的细节:我们在计算 scores 时,并不是直接写 Q @ K.T,而是把点积的结果除以了 \(\sqrt{d_k}\)。这个细节虽然看起来不起眼,但它对模型的训练稳定性有着非常重要的影响。那么,这是为什么呢?

8.2.6 为什么要除以 \(\sqrt{d_k}\)

在 Transformer 里,最常见的形式就是我们刚刚讲的 scaled dot-product attention:

\[ \operatorname{Attention}(Q, K, V) = \operatorname{softmax}\left(\frac{QK^\top}{\sqrt{d_k}}\right)V \]

图 1:SDPA 计算示意图 (Vaswani et al. 2023, fig. 2)

它和最原始的 attention 的差异,就是把分数除以 \(\sqrt{d_k}\)

为什么要这样做?因为当特征维度 \(d_k\) 变大时,点积的数值往往也会变大。如果直接把这些较大的分数送进 softmax,那么 softmax 很容易变得特别尖锐,也就是某一个位置的概率接近 1,其他位置几乎接近 0。这样一来,梯度就可能变小,训练也会变得不稳定。所以,除以 \(\sqrt{d_k}\) 本质上是在控制分数的尺度,让 softmax 保持在一个更合适的工作区间。

我们可以用下面的代码直观看一下这个现象。我们随机生成 query 和 key,在不同维度下比较缩放前后的注意力分布熵。熵越低,说明分布越尖锐。

def average_entropy(
    d_k: int, num_queries: int = 2048, num_keys: int = 10
) -> tuple[float, float]:
    q = torch.randn(num_queries, d_k)
    k = torch.randn(num_keys, d_k)

    raw_scores = q @ k.T
    scaled_scores = raw_scores / math.sqrt(d_k)

    raw_prob = raw_scores.softmax(dim=-1)
    scaled_prob = scaled_scores.softmax(dim=-1)

    raw_entropy = stats.entropy(raw_prob, axis=-1).mean()
    scaled_entropy = stats.entropy(scaled_prob, axis=-1).mean()
    return raw_entropy, scaled_entropy


dims = [8, 32, 128, 512]
raw_entropies = []
scaled_entropies = []

for d in dims:
    raw_h, scaled_h = average_entropy(d)
    raw_entropies.append(raw_h)
    scaled_entropies.append(scaled_h)

fig = plt.figure(2, figsize=(5, 3))
ax = fig.add_subplot(1, 1, 1)
ax.plot(dims, raw_entropies, marker='o')
ax.plot(dims, scaled_entropies, marker='o')
ax.set_xscale('log', base=2)
ax.set_yticks(np.arange(0, 2.25, 0.25))
ax.set_xlabel('$d_k$')
ax.set_ylabel('Average Entropy')
ax.legend(['Unscaled', r'Scaled by $\sqrt{d_k}$'], loc='center right')
ax.set_title('Scaling Effect on Softmax Sharpness')
fig.savefig('figures/ch8.2-attn-normalize.svg')
plt.close(fig)

你看,不缩放时,随着 \(d_k\) 的增大,softmax 会越来越尖锐,熵下降得更明显;而做了缩放之后,分布会稳定很多。这就是 scaled dot-product attention 成为标准配置的重要原因。

8.2.7 一个目标位置如何查询整个源序列

为了更具体地理解 cross-attention,我们只看目标序列中的一个位置。

假设解码器现在要生成目标句子中的第 \(i\) 个词。此时它有一个隐藏表示 \(x_i\),这个表示会被投影成 query:

\[ q_i = x_i W_Q \]

源句中每个位置的编码器输出会被投影成 key 和 value:

\[ k_j = h_j W_K, \quad v_j = h_j W_V \]

其中,\(h_j\) 是源句第 \(j\) 个位置的编码器表示。

然后,当前 query 会和所有 key 计算相关性:

\[ s_{ij} = q_i \cdot k_j \]

这些分数经过 softmax 后得到注意力权重:

\[ \alpha_{ij} = \frac{\exp(s_{ij})}{\sum_l \exp(s_{il})} \]

最后,模型对所有 value 做加权求和:

\[ o_i = \sum_j \alpha_{ij} v_j \]

这个输出 \(o_i\) 就是目标位置 \(i\) 从源序列中取回来的上下文信息。

从这个角度看,cross-attention 并不是把源句压缩成一个向量,而是为目标序列中的每个位置都生成一个动态上下文。不同目标位置可以关注源句的不同部分,因此它天然适合建模翻译、图文匹配、问答等需要跨序列交互的任务。

8.2.8 Cross-Attention 在 Transformer 中的位置

在经典 Transformer 的机器翻译结构中,模型分成编码器和解码器两部分。

编码器负责处理源序列。它内部主要使用 self-attention,让源序列中的每个位置都能融合源句内部的上下文信息。

解码器负责生成目标序列。它通常包含三类子模块:

  1. Masked Self-Attention:让目标序列内部已经生成的位置互相交互,同时避免看到未来词;
  2. Cross-Attention:让目标序列去查询编码器输出的源序列表示;
  3. Feed-Forward Network:对每个位置的表示做进一步非线性变换。

其中,cross-attention 是编码器和解码器之间的信息桥梁。

如果没有 cross-attention,解码器只能依赖自己的历史生成结果,很难知道源句具体说了什么。加入 cross-attention 后,解码器在每一步生成时,都可以重新访问编码器输出,并根据当前生成状态选择最相关的源句信息。这延续了 Bahdanau attention 的核心思想:生成不是只依赖一个固定上下文,而是每一步都动态地查找输入序列。

不同的是,Transformer 把这个过程完全矩阵化了。它不再依赖循环神经网络逐步传递隐藏状态,而是通过 Q、K、V 和矩阵乘法,把跨序列的信息交互变成了一个可以高效并行计算的模块。

当然,虽然 cross-attention 最容易从机器翻译中理解,但它的应用远不止翻译。只要一个任务中存在一个序列需要从另一个序列中取信息的结构,就可以使用 cross-attention。例如:

  • 在图像字幕生成中,文本解码器可以查询图像特征;
  • 在视觉问答中,问题表示可以查询图像区域特征;
  • 在文本问答中,问题可以查询文档内容;
  • 在多模态模型中,文本 token 可以查询图像 patch,图像 patch 也可以查询文本 token。

这些任务表面上不同,但都可以抽象成同一个过程:

当前序列提出需求,另一个序列提供候选信息,attention 根据相关性完成动态检索。

因此,cross-attention 是 Transformer 连接不同信息来源的重要机制。它让模型不只是处理一个序列内部的关系,还能建模不同序列、不同模态、不同信息源之间的关系。

8.2.9 本章小结

这一节里,我们从 Bahdanau attention 的软对齐思想过渡到了现代 cross-attention。

从思想上看,cross-attention 延续了 Bahdanau attention 的关键动机:不要把所有输入信息压缩成一个固定向量,而是让模型在需要时动态访问相关信息。从实现上看,现代 Transformer 用 Q、K、V 和矩阵乘法,把这种动态访问写成了统一、高效、可并行的计算模块。

同时,我们也要记住,Q、K、V 不是从某个完美理论中唯一推导出来的结构,而是一种非常成功的工程设计。它把“提出查询”“参与匹配”“提供内容”分成了几个可学习的角色,让模型能够在训练中自己学会如何组织信息。理解这层分工就足够了,不需要把每一个设计都解释成某种必然。

不过,到目前为止,我们讨论的仍然是两个序列之间的交互。接下来,一个更自然的问题是:如果没有另一个序列,只有一个序列本身,它能不能也用同样的方式让内部不同位置互相查询?

这就是下一节要讲的内容:自注意力(Self-Attention)

References

Vaswani, Ashish, Noam Shazeer, Niki Parmar, et al. 2023. Attention Is All You Need. https://arxiv.org/abs/1706.03762.

Reuse