8.9 KV Cache:为什么推理时不用重复算过去

Author

jshn9515

Published

2026-05-05

Modified

2026-05-05

前面几节里,我们已经看到,Transformer decoder 在推理时通常是自回归生成的。也就是说,模型不是一次性生成完整句子,而是一步一步生成:

\[ y_1 \rightarrow y_2 \rightarrow y_3 \rightarrow \cdots \]

每一步生成新 token 时,模型都会把已经生成的 token 作为输入,再预测下一个 token。这个过程看起来很自然,但如果直接按照普通 Transformer decoder 的方式计算,会有一个很大的浪费:

每生成一个新 token,都要重新计算一遍前面所有 token 的 attention 表示。

KV cache 要解决的就是这个问题。

import math

import torch
import torch.nn as nn
from torch import Tensor

print('PyTorch version:', torch.__version__)
PyTorch version: 2.12.0+xpu

8.9.1 自回归推理中的重复计算

我们先看一个简单例子。假设模型已经生成了 4 个 token:

\[ x_1, x_2, x_3, x_4 \]

我们现在要预测第 5 个 token。

首先,decoder 会对当前序列做 masked self-attention。在某一层 self-attention 中,输入会被投影成:

\[ Q = XW_Q,\quad K = XW_K,\quad V = XW_V \]

然后计算:

\[ \operatorname{Attention}(Q, K, V) = \operatorname{softmax} \left(\frac{QK^\top}{\sqrt{d_k}} \right)V \]

如果我们不做任何优化,那么每一步推理时,都会重新计算整个前缀序列的 Q、K、V。例如,生成第 3 个 token 时,要重新计算 \(x_1, x_2\);生成第 4 个 token 时,要重新计算 \(x_1, x_2, x_3\);生成第 5 个 token 时,又要重新计算 \(x_1, x_2, x_3, x_4\)

问题在于,\(x_1, x_2, x_3, x_4\) 这些过去 token 的表示,在前面的步骤里已经算过了。推理时,过去的 token 不会再改变,所以它们对应的 key 和 value 其实也没有必要重复计算。这就是 KV Cache 的基本动机。

8.9.2 为什么缓存的是 K 和 V,而不是 Q

要理解 KV cache,关键是先看清楚自回归生成时 attention 的角色。

在生成新 token 时,模型真正关心的是:

当前新位置应该从过去哪些位置取信息?

这句话可以对应到 attention 里的三个部分:

  • 当前新 token 的 query:表示我现在想找什么;
  • 过去 token 的 key:表示过去每个位置怎么被匹配;
  • 过去 token 的 value:表示过去每个位置能提供什么内容。

当我们生成第 \(t\) 个 token 时,当前这一步只需要为最新位置计算一个新的 query:

\[ q_t = x_t W_Q \]

然后用它去和所有历史位置的 key 做匹配:

\[ q_t k_1^\top,\ q_t k_2^\top,\ \dots,\ q_t k_t^\top \]

最后根据得到的权重,对历史 value 做加权求和:

\[ \operatorname{softmax} \left(\frac{q_t K_{\le t}^\top}{\sqrt{d_k}} \right) V_{\le t} \]

这里的 \(K_{\le t}\)\(V_{\le t}\) 包含从第 1 个位置到第 \(t\) 个位置的 key 和 value。

注意一个重要事实:过去位置的 key 和 value 一旦算出来,在后续生成过程中就可以复用。

比如 \(k_1\)\(v_1\) 在生成第 2 个 token 时用过,在生成第 3 个、第 4 个、第 5 个 token 时还会继续被用到。但它们本身不会因为后面生成了新 token 而改变。因此,我们可以把过去的 key 和 value 存起来。下一步推理时,只需要计算新 token 对应的 key 和 value,再把它们追加到缓存里。

这就是 KV Cache

至于 query,通常不需要缓存。因为在自回归推理中,每一步真正要输出的是最新位置的结果。过去位置的 query 已经用来计算过过去位置的输出了,后续不会再用它们重新生成过去 token。

所以,推理时缓存的是 \(K_{\text{cache}}, V_{\text{cache}}\),而不是 \(Q_{\text{cache}}\)

8.9.3 没有 KV Cache 时发生了什么

假设现在已经生成到第 \(t\) 个 token。如果没有 KV cache,模型每一步都要把完整前缀重新送进 decoder:

\[ x_1, x_2, \dots, x_t \]

然后在每一层重新计算:

\[ Q_{\le t},\quad K_{\le t},\quad V_{\le t} \]

其中:

\[ Q_{\le t} = X_{\le t}W_Q,\quad K_{\le t} = X_{\le t}W_K,\quad V_{\le t} = X_{\le t}W_V \]

也就是说,前面已经算过的 token,并不会被直接复用,而是会在下一步生成时再次参与整层 Transformer 的前向计算。

从计算量上看,主要有两类成本。

第一类是线性层和 FFN 的计算。第 \(t\) 步输入长度是 \(t\),因此一次线性投影大致是:

\[ X_{\le t}W,\quad X_{\le t} \in \mathbb{R}^{t \times d},\quad W \in \mathbb{R}^{d \times d} \]

它的计算量约为:

\[ O(td^2) \]

生成长度为 \(T\) 的序列时,这部分会在每一步重复发生,因此累计为:

\[ \sum_{t=1}^{T} O(td^2) = O(T^2d^2) \]

第二类是 attention 里的匹配和加权计算。没有 KV cache 时,第 \(t\) 步会重新计算完整的 attention score:

\[ Q_{\le t}K_{\le t}^{\top} \]

其中:

\[ Q_{\le t} \in \mathbb{R}^{t \times d},\quad K_{\le t}^{\top} \in \mathbb{R}^{d \times t} \]

所以:

\[ Q_{\le t}K_{\le t}^{\top} \in \mathbb{R}^{t \times t} \]

这一步的计算量约为:

\[ O(t^2d) \]

如果把生成过程从第 1 步累加到第 \(T\) 步,attention score 的累计计算量大约是:

\[ \sum_{t=1}^{T} O(t^2d) = O(T^3d) \]

因此,没有 KV cache 时,问题不只是前面的 token 被重复处理,而是每一步都要把整个前缀重新计算一遍。线性层和 FFN 会产生 \(O(T^2d^2)\) 级别的累计成本,attention score 的重复计算甚至会带来 \(O(T^3d)\) 级别的累计成本。

当然,这里的分析忽略了层数、head 数、batch size、常数因子以及具体 kernel 实现。它的目的不是给出精确运行时间,而是说明一个核心事实:如果每一步生成都重新处理完整前缀,历史 token 的计算会被反复浪费,而且序列越长,这种浪费越明显。

8.9.4 有 KV Cache 时发生了什么

有了 KV cache 以后,推理时不再需要每一步都重新处理完整前缀。模型会把历史 token 在每一层 self-attention 中算出来的 key 和 value 保存下来:

\[ K_{\text{cache}},\quad V_{\text{cache}} \]

当生成到第 \(t\) 个 token 时,模型只需要对当前新 token 计算:

\[ q_t,\quad k_t,\quad v_t \]

然后把新的 key 和 value 追加到缓存里:

\[ K_{\text{cache}} = [k_1, k_2, \dots, k_t] \]

\[ V_{\text{cache}} = [v_1, v_2, \dots, v_t] \]

接着,当前 token 的 query 会和缓存中的所有 key 做匹配,并从缓存中的 value 里取回信息:

\[ \operatorname{softmax}\left( \frac{q_t K_{\text{cache}}^\top}{\sqrt{d_k}} \right)V_{\text{cache}} \]

从计算量上看,KV cache 主要改变了两件事。

第一,线性层和 FFN 不再需要重复作用在完整前缀上。第 \(t\) 步只处理当前这一个新 token,因此一次线性投影大致是:

\[ x_t W,\quad x_t \in \mathbb{R}^{1 \times d},\quad W \in \mathbb{R}^{d \times d} \]

它的计算量约为:

\[ O(d^2) \]

生成长度为 \(T\) 的序列时,累计为:

\[ \sum_{t=1}^{T} O(d^2) = O(Td^2) \]

相比没有 KV cache 时的 \(O(T^2d^2)\),这部分少了大量重复计算。

第二,attention 不再重新计算完整的 \(t \times t\) score 矩阵。第 \(t\) 步只需要计算当前 token 对所有历史 token 的这一行 attention score:

\[ q_t K_{\text{cache}}^\top \]

其中:

\[ q_t \in \mathbb{R}^{1 \times d},\quad K_{\text{cache}}^\top \in \mathbb{R}^{d \times t} \]

所以:

\[ q_t K_{\text{cache}}^\top \in \mathbb{R}^{1 \times t} \]

这一步的计算量约为:

\[ O(td) \]

生成长度为 \(T\) 的序列时,累计为:

\[ \sum_{t=1}^{T} O(td) = O(T^2d) \]

相比没有 KV cache 时每一步重新计算完整的 \(Q_{\le t}K_{\le t}^{\top}\),这部分从累计 \(O(T^3d)\) 降到了 \(O(T^2d)\)

图 1:KV Cache 示意图

因此,有 KV cache 时,模型并不是不再做 attention,而是只计算当前新 token 对历史 token 的 attention。历史 token 的 key/value 已经保存在缓存里,不需要重新计算;历史 token 自己对应的 attention 行也不需要重新计算。

可以把两种情况对比为:

表 1:有无 KV Cache 时的计算对比
计算部分 没有 KV cache 有 KV cache
线性层 / FFN / QKV 投影 \(O(T^2d^2)\) \(O(Td^2)\)
Attention score 计算 \(O(T^3d)\) \(O(T^2d)\)

从表格里可以看到,KV cache 明显减少了重复计算,但它并没有把 attention 的成本完全消除。第 \(t\) 步生成时,当前 query 仍然要和缓存中的所有历史 key 做匹配:

\[ q_t K_{\le t}^\top \]

因此,这部分计算仍然会随着上下文长度增长。KV cache 真正避免的是另一类更大的浪费:不要让历史 token 在每一步都重新经过整层 Transformer。换句话说,它把每一步都重新处理完整前缀变成每一步只处理新增 token,并读取历史缓存。

还有一个点要特别注意:

KV cache 主要用于自回归推理,而不是普通训练。

训练时,我们通常已经知道完整目标序列。因此,我们仍然可以一次性把整个目标序列输入 decoder,并用 causal mask 保证每个位置只能看到自己之前的位置。也就是说,训练时虽然不能偷看未来,但计算上仍然可以并行处理整个序列。这时候没有必要像推理那样一步一步生成,也就不需要 KV cache 来避免前缀重复计算。

推理时则不同。推理时模型一开始并不知道完整输出,它必须先生成第一个 token,再把这个 token 接回输入,继续生成第二个 token,因此推理是串行的。KV cache 不能消除这种串行依赖,它不能让第 10 个 token 在第 1 个 token 之前生成。它解决的是另一个问题:在必须一步步生成的前提下,不要重复计算已经算过的历史 key/value。

所以,KV cache 加速的是自回归推理阶段。

8.9.5 KV Cache 的代价:更快,但更占显存

KV cache 可以显著减少重复计算,但它不是免费的。因为我们把每一层、每一个 head、每一个历史 token 的 key 和 value 都存了下来,所以生成越长,cache 占用的显存就越多。

前面为了方便,我们一直把 KV cache 写成:

\[ K_{\text{cache}},\quad V_{\text{cache}} \]

但在真实模型里,缓存并不是只有一份。Transformer decoder 通常有很多层,每一层都有自己的 self-attention。每一层里的 key 和 value 都是由该层的输入表示投影得到的,因此不同层的 key 和 value 不一样。所以,KV cache 实际上要为每一层都保存一份。

如果模型有 \(L\) 层,那么缓存大致可以写成:

\[ \{(K^{(1)}, V^{(1)}), (K^{(2)}, V^{(2)}), \dots, (K^{(L)}, V^{(L)})\} \]

同时,每一层的 attention 又通常是 multi-head attention。每个 head 都有自己的 key/value 表示,所以缓存里还会包含 head 这个维度。一个常见的 cache 形状可以理解为:

(batch_size, num_heads, seq_len, head_dim)

其中,batch_size 是 batch 大小,num_heads 是注意力头数,seq_len 是已经缓存的历史长度,head_dim 是每个 head 的维度。每生成一个新 token,seq_len 这一维就会增加 1。

粗略地看,KV cache 的大小和下面这些因素成正比:

\[ \text{num\_layers} \times \text{batch\_size} \times \text{num\_heads} \times \text{seq\_len} \times \text{head\_dim} \times 2 \]

最后的 2 是因为要同时存 key 和 value。

由于:

\[ \text{num\_heads} \times \text{head\_dim} = d_\mathrm{model} \]

所以也可以直观理解为:

\[ \text{KV Cache size} \propto 2 \times \text{num\_layers} \times \text{batch\_size} \times \text{seq\_len} \times d_\mathrm{model} \]

这说明 KV cache 的显存占用会随着生成长度线性增长。上下文越长、batch 越大、模型层数越多,cache 占用就越明显。

所以,KV cache 本质上是一种典型的工程权衡:用更多显存,换更少重复计算。这也是为什么大模型推理服务里,经常会非常关注 KV cache 的管理、压缩和分页调度。

8.9.6 一个简化的 KV Cache 示例

下面我们用一个非常简化的 self-attention 例子,看看 KV cache 在代码里大概是什么样子。为了突出核心思想,这里先不考虑多头,只看单头 attention。

from dataclasses import dataclass


@dataclass
class SelfAttentionOutputWithKVCache:
    output: Tensor
    attn_weights: Tensor | None
    present_k: Tensor | None
    present_v: Tensor | None


class SelfAttentionWithKVCache(nn.Module):
    def __init__(self, embed_dim: int):
        super().__init__()
        self.embed_dim = embed_dim

        self.q_proj = nn.Linear(embed_dim, embed_dim)
        self.k_proj = nn.Linear(embed_dim, embed_dim)
        self.v_proj = nn.Linear(embed_dim, embed_dim)
        self.out_proj = nn.Linear(embed_dim, embed_dim)

    def forward(
        self,
        x: Tensor,
        past_k: Tensor | None = None,
        past_v: Tensor | None = None,
        use_cache: bool = False,
        need_weights: bool = True,
    ) -> SelfAttentionOutputWithKVCache:
        q = self.q_proj(x)
        k = self.k_proj(x)
        v = self.v_proj(x)

        if past_k is not None:
            k = torch.concat([past_k, k], dim=-2)

        if past_v is not None:
            v = torch.concat([past_v, v], dim=-2)

        scores = q @ k.transpose(-2, -1)
        scores = scores / math.sqrt(self.embed_dim)

        attn_weights = scores.softmax(dim=-1)
        attn_output = attn_weights @ v
        output = self.out_proj(attn_output)

        result = SelfAttentionOutputWithKVCache(
            output=output,
            attn_weights=attn_weights if need_weights else None,
            present_k=k if use_cache else None,
            present_v=v if use_cache else None,
        )
        return result

在推理时,每一步只输入最新 token:

d_model = 512
x = torch.randn(3, 32, d_model)
attention = SelfAttentionWithKVCache(embed_dim=d_model)

past_k = None
past_v = None
outputs = []
max_new_tokens = 10

with torch.inference_mode():
    for step in range(max_new_tokens):
        # We only input the current token, not the whole prefix
        current_x = x[:, step : step + 1, :]

        result = attention(
            x=current_x,
            past_k=past_k,
            past_v=past_v,
            use_cache=True,
        )
        outputs.append(result.output)

        past_k = result.present_k
        past_v = result.present_v

outputs = torch.concat(outputs, dim=1)
print('Output shape:', outputs.shape)
Output shape: torch.Size([3, 10, 512])

这个例子里,past_kpast_v 就是缓存。第一次调用时,cache 为空,模型计算当前 token 的 key/value,并返回它们;第二次调用时,模型只计算新 token 的 key/value,然后和过去缓存的 key/value 拼接起来。

当然,真实大模型的实现会复杂很多。它通常会包含多层 decoder、多头 attention、batch 内不同序列长度、padding 和 attention mask、预分配 cache、beam search 或 speculative decoding,以及 GPU 上更高效的 cache layout。但核心思想就是上面这几行代码:

k = torch.concat([past_k, k], dim=1)
v = torch.concat([past_v, v], dim=1)

把过去算过的 key/value 留下来,下一步继续用。

8.9.7 本章小结

这一节里,我们讨论了 Transformer decoder 推理中的一个关键优化:KV cache。自回归生成必须一步一步进行,因为下一个 token 依赖前面已经生成的 token。如果每一步都重新计算完整前缀,就会对过去 token 产生大量重复计算。

KV cache 的核心思想是:过去 token 的 key 和 value 一旦算出来,在后续生成过程中就不会改变,因此可以缓存起来。之后每生成一个新 token,只需要计算这个新 token 的 query、key 和 value,然后把新的 key/value 追加到缓存中,再让当前 query 去查询整个缓存。

需要注意的是,KV cache 主要用于推理阶段,而不是普通训练阶段。训练时目标序列已知,可以通过 causal mask 并行计算;推理时输出未知,必须自回归生成,因此 KV cache 才变得非常重要。

从整体上看,KV cache 是 Transformer 能够高效生成长文本的重要工程机制。它不改变 attention 的数学定义,也不改变模型参数,而是在推理阶段通过缓存历史 key/value,减少重复计算,用更多显存换取更快的生成速度。