import math
import torch
import torch.nn as nn
from torch import Tensor
print('PyTorch version:', torch.__version__)PyTorch version: 2.12.0+xpu
jshn9515
2026-05-05
2026-05-05
In the previous sections, we have seen that the Transformer decoder usually performs autoregressive generation during inference. In other words, the model does not generate the whole sentence at once. Instead, it generates token by token:
\[ y_1 \rightarrow y_2 \rightarrow y_3 \rightarrow \cdots \]
Each time a new token is generated, the model takes the tokens that have already been generated as input, and then predicts the next token. This process looks very natural, but if we compute it directly in the ordinary Transformer decoder way, there is a large amount of waste:
Every time we generate a new token, we recompute the attention representations of all previous tokens.
KV cache is meant to solve exactly this problem.
PyTorch version: 2.12.0+xpu
Let’s first look at a simple example. Suppose the model has already generated 4 tokens:
\[ x_1, x_2, x_3, x_4 \]
Now we want to predict the 5th token.
First, the Decoder will perform masked self-attention on the current sequence. In one self-attention layer, the input is projected into:
\[ Q = XW_Q,\quad K = XW_K,\quad V = XW_V \]
Then it computes:
\[ \operatorname{Attention}(Q, K, V) = \operatorname{softmax} \left(\frac{QK^\top}{\sqrt{d_k}} \right)V \]
If we do not apply any optimization, then at every inference step, the model will recompute \(Q\), \(K\), and \(V\) for the whole prefix sequence. For example, when generating the 3rd token, it recomputes \(x_1, x_2\); when generating the 4th token, it recomputes \(x_1, x_2, x_3\); when generating the 5th token, it recomputes \(x_1, x_2, x_3, x_4\) again.
The problem is that the representations of past tokens such as \(x_1, x_2, x_3, x_4\) have already been computed in earlier steps. During inference, past tokens will not change anymore, so their corresponding keys and values also do not need to be recomputed. This is the basic motivation behind KV cache.
To understand KV cache, the key is to first see clearly what role attention plays during autoregressive generation.
When generating a new token, what the model really cares about is:
Which past positions should the current new position retrieve information from?
This sentence corresponds to the three parts of attention:
When we generate the \(t\)-th token, this step only needs to compute one new query for the latest position:
\[ q_t = x_t W_Q \]
Then it uses this query to match against the keys of all historical positions:
\[ q_t k_1^\top,\ q_t k_2^\top,\ \dots,\ q_t k_t^\top \]
Finally, according to the resulting weights, it takes a weighted sum over the historical values:
\[ \operatorname{softmax} \left(\frac{q_t K_{\le t}^\top}{\sqrt{d_k}} \right) V_{\le t} \]
Here, \(K_{\le t}\) and \(V_{\le t}\) contain the keys and values from position 1 to position \(t\).
Notice an important fact: once the keys and values of past positions have been computed, they can be reused during later generation steps.
For example, \(k_1\) and \(v_1\) were used when generating the 2nd token, and they will continue to be used when generating the 3rd, 4th, and 5th tokens. But they themselves will not change just because later tokens are generated. Therefore, we can store the past keys and values. At the next inference step, we only need to compute the key and value for the new token, and then append them to the cache.
This is KV cache.
As for the query, we usually do not need to cache it. In autoregressive inference, each step only needs to output the result for the latest position. The queries of past positions have already been used to compute the outputs of those past positions, and later they will not be used again to regenerate past tokens.
So during inference, we cache \(K_{\text{cache}}, V_{\text{cache}}\), not \(Q_{\text{cache}}\).
Suppose we have now generated up to the \(t\)-th token. Without KV cache, at every step the model has to feed the full prefix into the Decoder again:
\[ x_1, x_2, \dots, x_t \]
Then in every layer, it recomputes:
\[ Q_{\le t},\quad K_{\le t},\quad V_{\le t} \]
where:
\[ Q_{\le t} = X_{\le t}W_Q,\quad K_{\le t} = X_{\le t}W_K,\quad V_{\le t} = X_{\le t}W_V \]
That is, the tokens that have already been computed earlier are not directly reused. Instead, they will participate in the forward computation of the whole Transformer layer again during the next generation step.
From the perspective of computation, there are mainly two kinds of costs.
The first kind is the computation of linear layers and FFNs. At step \(t\), the input length is \(t\), so one linear projection is roughly:
\[ X_{\le t}W,\quad X_{\le t} \in \mathbb{R}^{t \times d},\quad W \in \mathbb{R}^{d \times d} \]
Its computation is about:
\[ O(td^2) \]
When generating a sequence of length \(T\), this part happens repeatedly at every step, so the accumulated cost is:
\[ \sum_{t=1}^{T} O(td^2) = O(T^2d^2) \]
The second kind is the matching and weighted-sum computation inside attention. Without KV cache, at step \(t\), the full attention score is recomputed:
\[ Q_{\le t}K_{\le t}^{\top} \]
where:
\[ Q_{\le t} \in \mathbb{R}^{t \times d},\quad K_{\le t}^{\top} \in \mathbb{R}^{d \times t} \]
So:
\[ Q_{\le t}K_{\le t}^{\top} \in \mathbb{R}^{t \times t} \]
The computation of this step is about:
\[ O(t^2d) \]
If we accumulate the generation process from step 1 to step \(T\), the total computation for attention scores is roughly:
\[ \sum_{t=1}^{T} O(t^2d) = O(T^3d) \]
Therefore, without KV cache, the problem is not only that previous tokens are processed repeatedly, but that every step has to recompute the whole prefix. Linear layers and FFNs produce an accumulated cost on the order of \(O(T^2d^2)\), and the repeated computation of attention scores can even bring an accumulated cost on the order of \(O(T^3d)\).
Of course, this analysis ignores the number of layers, the number of heads, batch size, constant factors, and the concrete kernel implementation. Its purpose is not to give an exact runtime, but to explain one core fact: if every generation step reprocesses the full prefix, the computation for historical tokens is repeatedly wasted, and the longer the sequence is, the more obvious this waste becomes.
With KV cache, inference no longer needs to reprocess the full prefix at every step. The model stores the keys and values computed from historical tokens in each self-attention layer:
\[ K_{\text{cache}},\quad V_{\text{cache}} \]
When generation reaches the \(t\)-th token, the model only needs to compute, for the current new token:
\[ q_t,\quad k_t,\quad v_t \]
Then it appends the new key and value to the cache:
\[ K_{\text{cache}} = [k_1, k_2, \dots, k_t] \]
\[ V_{\text{cache}} = [v_1, v_2, \dots, v_t] \]
Next, the query of the current token matches against all keys in the cache, and retrieves information from the values in the cache:
\[ \operatorname{softmax}\left( \frac{q_t K_{\text{cache}}^\top}{\sqrt{d_k}} \right)V_{\text{cache}} \]
From the perspective of computation, KV cache mainly changes two things.
First, linear layers and FFNs no longer need to be repeatedly applied to the full prefix. At step \(t\), only the current new token is processed, so one linear projection is roughly:
\[ x_t W,\quad x_t \in \mathbb{R}^{1 \times d},\quad W \in \mathbb{R}^{d \times d} \]
Its computation is about:
\[ O(d^2) \]
When generating a sequence of length \(T\), the accumulated cost is:
\[ \sum_{t=1}^{T} O(d^2) = O(Td^2) \]
Compared with \(O(T^2d^2)\) without KV cache, this removes a large amount of repeated computation.
Second, attention no longer recomputes the full \(t \times t\) score matrix. At step \(t\), it only needs to compute this one row of attention scores from the current token to all historical tokens:
\[ q_t K_{\text{cache}}^\top \]
where:
\[ q_t \in \mathbb{R}^{1 \times d},\quad K_{\text{cache}}^\top \in \mathbb{R}^{d \times t} \]
So:
\[ q_t K_{\text{cache}}^\top \in \mathbb{R}^{1 \times t} \]
The computation of this step is about:
\[ O(td) \]
When generating a sequence of length \(T\), the accumulated cost is:
\[ \sum_{t=1}^{T} O(td) = O(T^2d) \]
Compared with recomputing the full \(Q_{\le t}K_{\le t}^{\top}\) at every step without KV cache, this part drops from accumulated \(O(T^3d)\) to \(O(T^2d)\).
Therefore, with KV cache, the model does not stop doing attention. It only computes the attention from the current new token to historical tokens. The keys/values of historical tokens have already been stored in the cache and do not need to be recomputed; the attention rows corresponding to historical tokens themselves also do not need to be recomputed.
We can compare the two cases like this:
| Computation part | Without KV cache | With KV cache |
|---|---|---|
| Linear layers / FFN / QKV projection | \(O(T^2d^2)\) | \(O(Td^2)\) |
| Attention score computation | \(O(T^3d)\) | \(O(T^2d)\) |
From the table, we can see that KV cache clearly reduces repeated computation, but it does not completely remove the cost of attention. When generating at step \(t\), the current query still has to match against all historical keys in the cache:
\[ q_t K_{\le t}^\top \]
Therefore, this part of the computation still grows as the context length increases. What KV cache really avoids is another, larger kind of waste: it prevents historical tokens from going through the whole Transformer layer again at every step. In other words, it changes “reprocess the full prefix at every step” into “process only the newly added token at every step, and read from the historical cache.”
There is another point that needs special attention:
KV cache is mainly used for autoregressive inference, not ordinary training.
During training, we usually already know the full target sequence. Therefore, we can still feed the whole target sequence into the Decoder at once, and use a causal mask to make sure each position can only see positions before itself. That is, although the model cannot peek at the future during training, computation can still process the whole sequence in parallel. At this time, there is no need to generate step by step like inference, so KV cache is not needed to avoid repeated prefix computation.
Inference is different. During inference, the model does not know the full output at the beginning. It must first generate the first token, then feed this token back into the input, and continue generating the second token. Therefore, inference is serial. KV cache cannot remove this serial dependency. It cannot make the 10th token be generated before the 1st token. It solves another problem: under the condition that we must generate step by step, do not recompute historical keys/values that have already been computed.
So, KV cache accelerates the autoregressive inference stage.
KV cache can significantly reduce repeated computation, but it is not free. Because we store the keys and values for every layer, every head, and every historical token, the longer the generation is, the more GPU memory the cache occupies.
For convenience, we have been writing KV cache as:
\[ K_{\text{cache}},\quad V_{\text{cache}} \]
But in a real model, there is not just one cache. A Transformer decoder usually has many layers, and each layer has its own self-attention. The keys and values in each layer are obtained by projecting the input representations of that layer, so the keys and values are different across layers. Therefore, KV cache actually needs to store one copy for every layer.
If the model has \(L\) layers, then the cache can roughly be written as:
\[ \{(K^{(1)}, V^{(1)}), (K^{(2)}, V^{(2)}), \dots, (K^{(L)}, V^{(L)})\} \]
At the same time, attention in each layer is usually multi-head attention. Each head has its own key/value representation, so the cache also includes the head dimension. A common cache shape can be understood as:
(batch_size, num_heads, seq_len, head_dim)
Here, batch_size is the batch size, num_heads is the number of attention heads, seq_len is the historical length that has already been cached, and head_dim is the dimension of each head. Every time a new token is generated, the seq_len dimension increases by 1.
Roughly speaking, the size of KV cache is proportional to these factors:
\[ \text{num\_layers} \times \text{batch\_size} \times \text{num\_heads} \times \text{seq\_len} \times \text{head\_dim} \times 2 \]
The final 2 is because both key and value need to be stored.
Since:
\[ \text{num\_heads} \times \text{head\_dim} = d_\mathrm{model} \]
we can also intuitively understand it as:
\[ \text{KV cache size} \propto 2 \times \text{num\_layers} \times \text{batch\_size} \times \text{seq\_len} \times d_\mathrm{model} \]
This means the GPU memory usage of KV cache grows linearly with the generation length. The longer the context, the larger the batch, and the more layers the model has, the more obvious the cache usage becomes.
So KV cache is essentially a typical engineering trade-off: use more GPU memory in exchange for less repeated computation. This is also why large-model inference services often pay a lot of attention to KV cache management, compression, and paged scheduling.
Below, we use a very simplified self-attention example to see what KV cache roughly looks like in code. To highlight the core idea, we will not consider multi-head attention yet, and will only look at single-head attention.
from dataclasses import dataclass
@dataclass
class SelfAttentionOutputWithKVCache:
output: Tensor
attn_weights: Tensor | None
present_k: Tensor | None
present_v: Tensor | None
class SelfAttentionWithKVCache(nn.Module):
def __init__(self, embed_dim: int):
super().__init__()
self.embed_dim = embed_dim
self.q_proj = nn.Linear(embed_dim, embed_dim)
self.k_proj = nn.Linear(embed_dim, embed_dim)
self.v_proj = nn.Linear(embed_dim, embed_dim)
self.out_proj = nn.Linear(embed_dim, embed_dim)
def forward(
self,
x: Tensor,
past_k: Tensor | None = None,
past_v: Tensor | None = None,
use_cache: bool = False,
need_weights: bool = True,
) -> SelfAttentionOutputWithKVCache:
q = self.q_proj(x)
k = self.k_proj(x)
v = self.v_proj(x)
if past_k is not None:
k = torch.concat([past_k, k], dim=-2)
if past_v is not None:
v = torch.concat([past_v, v], dim=-2)
scores = q @ k.transpose(-2, -1)
scores = scores / math.sqrt(self.embed_dim)
attn_weights = scores.softmax(dim=-1)
attn_output = attn_weights @ v
output = self.out_proj(attn_output)
result = SelfAttentionOutputWithKVCache(
output=output,
attn_weights=attn_weights if need_weights else None,
present_k=k if use_cache else None,
present_v=v if use_cache else None,
)
return resultDuring inference, each step only inputs the latest token:
d_model = 512
x = torch.randn(3, 32, d_model)
attention = SelfAttentionWithKVCache(embed_dim=d_model)
past_k = None
past_v = None
outputs = []
max_new_tokens = 10
with torch.inference_mode():
for step in range(max_new_tokens):
# We only input the current token, not the whole prefix
current_x = x[:, step : step + 1, :]
result = attention(
x=current_x,
past_k=past_k,
past_v=past_v,
use_cache=True,
)
outputs.append(result.output)
past_k = result.present_k
past_v = result.present_v
outputs = torch.concat(outputs, dim=1)
print('Output shape:', outputs.shape)Output shape: torch.Size([3, 10, 512])
In this example, past_k and past_v are the cache. During the first call, the cache is empty. The model computes the key/value of the current token and returns them. During the second call, the model only computes the key/value of the new token, and then concatenates them with the cached key/value from the past.
Of course, the implementation in real large models is much more complex. It usually includes multiple Decoder layers, multi-head attention, different sequence lengths inside the batch, padding and attention masks, preallocated cache, beam search or speculative decoding, and a more efficient cache layout on GPU. But the core idea is exactly these few lines of code above:
Keep the key/value that were computed in the past, and continue to use them in the next step.
In this section, we discussed a key optimization in Transformer decoder inference: KV cache. Autoregressive generation must proceed step by step, because the next token depends on the tokens that have already been generated. If the full prefix is recomputed at every step, a large amount of repeated computation will be spent on past tokens.
The core idea of KV cache is: once the keys and values of past tokens have been computed, they will not change during later generation, so they can be cached. After that, every time a new token is generated, we only need to compute the query, key, and value of this new token, append the new key/value to the cache, and then let the current query attend to the whole cache.
It is important to note that KV cache is mainly used during inference, not during ordinary training. During training, the target sequence is known, so we can use a causal mask and compute in parallel. During inference, the output is unknown and must be generated autoregressively, so KV cache becomes very important.
Overall, KV cache is an important engineering mechanism that allows Transformer models to generate long text efficiently. It does not change the mathematical definition of attention, and it does not change the model parameters. Instead, during inference, it caches historical keys/values to reduce repeated computation, using more GPU memory in exchange for faster generation.