from collections.abc import Callable
import dnnl.nn as dnn
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
print('PyTorch version:', torch.__version__)PyTorch version: 2.12.0+xpu
jshn9515
2026-05-05
2026-05-05
前面我们已经看了 Transformer Encoder。Encoder 的核心任务是:把输入序列中的每个 token 表示成一个带有上下文信息的向量。经过多层 self-attention 和 feed-forward network 以后,输入序列中的每个位置都已经能够融合来自其他位置的信息。
如果只做文本分类、序列标注、图像分类这类任务,一个 encoder 往往就够了。因为这些任务通常只需要理解输入,然后基于输入表示做预测。
但如果任务是机器翻译、文本摘要、图像描述生成这类序列到序列(sequence-to-sequence)生成任务,情况就不一样了。模型不仅要理解输入,还要一步一步生成输出。
比如,在机器翻译中,输入是英文句子:
I love deep learning.
输出可能是中文句子:
我喜欢深度学习。
这时模型需要做两件事:
Encoder 负责第一件事,decoder 负责第二件事。
因此,Transformer decoder 可以理解成一个生成端的 Transformer。它不仅要看自己已经生成的内容,还要从 encoder 的输出中读取源序列信息。
PyTorch version: 2.12.0+xpu
从结构上看,Transformer Decoder 和 encoder 很像,都是由多个 block 堆叠而成。每个 block 里也有 attention、残差连接、LayerNorm 和 feed-forward network。但是,decoder 比 encoder 多了两个关键点。
第一个区别是,decoder 里的 self-attention 需要加 mask。
在生成任务中,模型是从左到右逐步生成 token 的。当它预测第 \(t\) 个 token 时,只能看到前面已经生成的 token,而不能偷看未来的 token。否则模型就不是在预测下一个词,而是在抄答案了。因此,decoder 里的 self-attention 必须加一个 causal mask,保证每个位置只能关注自己和之前的位置,不能看到未来位置。
第二个区别是,decoder 还需要 cross-attention。
Decoder 不能只看自己已经生成的部分,它还需要参考 encoder 对输入序列的理解。比如做翻译时,在生成中文词的时候,需要不断回头看英文源句中哪些部分最相关。这个过程就是通过 cross-attention 完成的。
所以,一个标准 Transformer Decoder block 通常包含三部分:
我们先来看 decoder 里的 self-attention。
在 encoder 中,每个位置都可以看见整个输入序列。比如输入句子有 5 个 token,那么第 1 个 token 可以关注第 2、3、4、5 个 token,第 5 个 token 也可以关注前面的所有 token。这在理解任务里是合理的。因为输入已经完整给定了,模型当然可以双向地利用上下文。
但 decoder 不一样。Decoder 是用来生成序列的。生成第一个词时,后面的词还没有生成;生成第二个词时,只知道第一个词;生成第三个词时,只知道前两个词。也就是说,在生成第 \(t\) 个位置时,模型只能依赖位置 1 到 \(t\) 的信息,不能依赖 \(t+1\) 之后的信息。如果训练时不加限制,self-attention 会让每个位置都看到整个目标序列。
比如训练翻译时,目标句子是:
我喜欢深度学习。
当模型预测“喜欢”时,如果它已经可以看到后面的“深度学习”,那不就是直接偷看答案了么?这显然不合理。而且,假如模型在训练时利用了未来信息,但在真正生成时,这些未来 token 根本不存在。这会导致推理时模型的性能大幅下降。
因此,decoder 的 self-attention 必须加一个 causal mask,也叫 look-ahead mask。它的作用是:
每个位置只能关注自己和自己之前的位置,不能关注未来位置。
那么,它是怎么实现的呢?
假设我们的目标序列长度是 4。普通 self-attention 会计算一个 \(4 \times 4\) 的注意力分数矩阵:
\[ S = \frac{QK^\top}{\sqrt{d_k}} \]
其中,第 \(i\) 行表示第 \(i\) 个位置对所有位置的注意力分数。
如果不加 mask,每一行都可以看到所有列:
\[ \begin{array}{c|cccc} & \text{pos 1} & \text{pos 2} & \text{pos 3} & \text{pos 4} \\ \hline \text{pos 1} & \checkmark & \checkmark & \checkmark & \checkmark \\ \text{pos 2} & \checkmark & \checkmark & \checkmark & \checkmark \\ \text{pos 3} & \checkmark & \checkmark & \checkmark & \checkmark \\ \text{pos 4} & \checkmark & \checkmark & \checkmark & \checkmark \\ \end{array} \]
但在 decoder 中,我们希望它变成:
\[ \begin{array}{c|cccc} & \text{pos 1} & \text{pos 2} & \text{pos 3} & \text{pos 4} \\ \hline \text{pos 1} & \checkmark & \times & \times & \times \\ \text{pos 2} & \checkmark & \checkmark & \times & \times \\ \text{pos 3} & \checkmark & \checkmark & \checkmark & \times \\ \text{pos 4} & \checkmark & \checkmark & \checkmark & \checkmark \\ \end{array} \]
也就是说,第 1 个位置只能看第 1 个位置;第 2 个位置可以看第 1、2 个位置;第 3 个位置可以看第 1、2、3 个位置,以此类推。
实现时,我们通常会把未来位置的 attention score 设成一个非常小的数,比如 \(-\infty\):
\[ S_{ij} = -\infty, \quad \text{if}\, j > i \]
然后再做 softmax。
因为:
\[ \exp(-\infty) = 0 \]
所以这些未来位置对应的注意力权重就会变成 0。这样模型就无法从未来 token 中取信息。
这样,masked self-attention 的形式就可以写成:
\[ \operatorname{MaskedAttention}(Q, K, V) = \operatorname{softmax} \left(\frac{QK^\top}{\sqrt{d_k}} + M \right)V \]
其中 \(M\) 就是 mask 矩阵。对于允许关注的位置,\(M_{ij}=0\);对于不允许关注的位置,\(M_{ij}=-\infty\)。
至于为什么需要 self-attention,其实也很好理解。因为在机器翻译里,我们需要根据已经生成的目标词来继续生成下一个词。在这个过程中,我们需要知道目标语言的内部的依赖关系。比如,模型可能要判断主语是什么、谓语是什么、当前句子结构发展到哪里了。这部分信息不来自源句,而来自目标句本身。
所以 decoder self-attention 解决的是:
当前要生成的位置,应该如何利用已经生成的目标端上下文?
它和 encoder self-attention 的公式几乎一样,区别只在于 mask。
对于目标序列表示 \(Y\),有:
\[ Q = YW_Q, \quad K = YW_K, \quad V = YW_V \]
然后计算 masked self-attention:
\[ H = \operatorname{MaskedAttention}(Q, K, V) \]
这里的 \(Q\)、\(K\)、\(V\) 都来自 decoder 当前的目标序列表示,所以它仍然是 self-attention。
但是,只有 masked self-attention 还不够。
如果 fecoder 只看自己已经生成的 token,它就像一个普通语言模型,只能根据前文继续写下去。但机器翻译、摘要生成这类任务不是自由续写,而是要根据输入内容生成输出。因此,fecoder 还必须能动态读取 encoder 的输出。这一步就是 cross-attention。
在 cross-attention 中,我们有:
\[ Q = H_{dec} W_Q, \quad K = H_{enc} W_K, \quad V = H_{enc} W_V \]
然后计算:
\[ \operatorname{CrossAttention}(Q, K, V) = \operatorname{softmax} \left(\frac{QK^\top}{\sqrt{d_k}} \right)V \]
这里的 \(H_{dec}\) 是 decoder 在 masked self-attention 之后得到的表示,\(H_{enc}\) 是 encoder 对输入序列编码后得到的表示。
直观来说,Decoder 是在用自己的当前状态提出一个问题:
我现在要生成下一个目标词,需要从源句里找什么信息?
Cross-attention 会计算 decoder 每个位置和 encoder 每个位置之间的相关性,然后从 encoder 的 value 中取回加权信息。这就是为什么 cross-attention 连接了 encoder 和 decoder。
如果你还记得前面讲过的 Bahdanau-attention,就会发现 cross-attention 和早期 attention 很像。
在早期的 Bahdanau-attention 中,decoder 每生成一个词,都会拿当前 decoder 的隐藏状态去和所有 encoder 的隐藏状态做匹配,然后对 encoder 的隐藏状态做加权求和,得到一个 context vector。
对应到 Transformer 里:
所以,Transformer Decoder 里的 cross-attention 并不是一个完全陌生的新东西。它可以看成是早期 encoder-decoder attention 在 Transformer 结构中的现代版本。
不同的是,Transformer 不再依赖 RNN 逐步生成隐藏状态,而是用 self-attention 和位置编码来建模序列表示。同时,cross-attention 也被写成了标准的 Q、K、V 形式,并且可以通过 multi-head attention 并行计算。
现在我们可以把一个 Transformer Decoder block 拼起来。假设输入到 decoder block 的目标序列表示是 \(Y\),encoder 的输出是 \(H_{enc}\)。
第一步,做 masked self-attention:
\[ Y' = \operatorname{MaskedSelfAttention}(Y) \]
然后加残差连接和 LayerNorm:
\[ Y = \operatorname{LayerNorm}(Y + Y') \]
第二步,做 cross-attention。此时 query 来自 decoder,key 和 value 来自 Encoder:
\[ Y' = \operatorname{CrossAttention}(Y, H_{enc}, H_{enc}) \]
再加残差连接和 LayerNorm:
\[ Y = \operatorname{LayerNorm}(Y + Y') \]
第三步,经过 position-wise feed-forward network:
\[ Y' = \operatorname{FFN}(Y) \]
最后再加残差连接和 LayerNorm:
\[ Y = \operatorname{LayerNorm}(Y + Y') \]
因此,一个 decoder block 可以写成:
和 encoder block 相比,decoder block 多了一个 cross-attention,并且 self-attention 需要 causal mask。
下面我们实现一个简化版 Transformer Decoder block。同样,我们直接使用 8.4 节里实现的 multi-head attention 模块,来构建 decoder block。
class TransformerDecoderLayer(nn.Module):
def __init__(
self,
d_model: int,
num_heads: int,
dim_feedforward: int = 2048,
activation: Callable[[Tensor], Tensor] = F.relu,
bias: bool = True,
dropout: float = 0.1,
):
super().__init__()
self.self_attn = dnn.MultiheadAttention(
d_model,
num_heads,
dropout=dropout,
bias=bias,
)
self.mha_attn = dnn.MultiheadAttention(
d_model,
num_heads,
dropout=dropout,
bias=bias,
)
self.linear1 = nn.Linear(d_model, dim_feedforward, bias=bias)
self.dropout = nn.Dropout(dropout)
self.linear2 = nn.Linear(dim_feedforward, d_model, bias=bias)
self.activation = activation
self.norm1 = nn.LayerNorm(d_model, bias=bias)
self.norm2 = nn.LayerNorm(d_model, bias=bias)
self.norm3 = nn.LayerNorm(d_model, bias=bias)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
self.dropout3 = nn.Dropout(dropout)
def forward(
self,
tgt: Tensor,
memory: Tensor,
tgt_mask: Tensor | None = None,
memory_mask: Tensor | None = None,
tgt_key_padding_mask: Tensor | None = None,
memory_key_padding_mask: Tensor | None = None,
) -> Tensor:
x = tgt + self._sa_block(
self.norm1(tgt),
tgt_mask,
tgt_key_padding_mask,
)
x = x + self._mha_block(
self.norm2(x),
memory,
memory_mask,
memory_key_padding_mask,
)
return x + self._ff_block(self.norm3(x))
def _sa_block(
self,
x: Tensor,
attn_mask: Tensor | None,
key_padding_mask: Tensor | None,
) -> Tensor:
x, _ = self.self_attn(
x, x, x,
attn_mask=attn_mask,
key_padding_mask=key_padding_mask,
need_weights=False,
) # fmt: skip
return self.dropout1(x)
def _mha_block(
self,
x: Tensor,
memory: Tensor,
attn_mask: Tensor | None,
key_padding_mask: Tensor | None,
) -> Tensor:
x, _ = self.mha_attn(
x,
memory,
memory,
attn_mask=attn_mask,
key_padding_mask=key_padding_mask,
need_weights=False,
)
return self.dropout2(x)
def _ff_block(self, x: Tensor) -> Tensor:
"""Apply the feed-forward block and dropout."""
x = self.linear1(x)
x = self.activation(x)
x = self.dropout(x)
x = self.linear2(x)
x = self.dropout3(x)
return x
x = torch.randn(2, 16, 512) # (batch_size, seq_len, d_model)
memory = torch.randn(2, 32, 512) # (batch_size, src_seq_len, d_model)
decoder_layer = TransformerDecoderLayer(d_model=512, num_heads=8)
with torch.inference_mode():
output = decoder_layer(x, memory)
print('Decoder Block output shape:', output.shape)Decoder Block output shape: torch.Size([2, 16, 512])
这里的 tgt 是目标序列表示,也就是 decoder 当前处理的序列;memory 是 encoder 的输出,也就是源序列经过 encoder 后得到的上下文表示。tgt_key_padding_mask 和 memory_key_padding_mask 分别是目标序列和源序列的 padding mask。
在 masked self-attention 中:
target -> query
target -> key
target -> value
而在 cross-attention 中:
target -> query
encoder output -> key
encoder output -> value
这正好对应前面讲的:decoder 用自己的状态作为 query,去 encoder 输出里检索相关信息。
和 encoder 一样,我们通常会把多个 decoder block 堆叠起来,形成一个完整的 Transformer Decoder。
class TransformerDecoder(nn.Module):
def __init__(
self,
d_model: int,
num_heads: int,
num_layers: int,
dim_feedforward: int = 2048,
activation: Callable[[Tensor], Tensor] = F.relu,
bias: bool = True,
dropout: float = 0.1,
):
super().__init__()
self.layers = nn.ModuleList(
[
TransformerDecoderLayer(
d_model=d_model,
num_heads=num_heads,
dim_feedforward=dim_feedforward,
activation=activation,
bias=bias,
dropout=dropout,
)
for _ in range(num_layers)
]
)
self.norm = nn.LayerNorm(d_model)
def forward(
self,
tgt: Tensor,
memory: Tensor,
tgt_mask: Tensor | None = None,
memory_mask: Tensor | None = None,
tgt_key_padding_mask: Tensor | None = None,
memory_key_padding_mask: Tensor | None = None,
) -> Tensor:
output = tgt
for layer in self.layers:
output = layer(
output,
memory,
tgt_mask=tgt_mask,
memory_mask=memory_mask,
tgt_key_padding_mask=tgt_key_padding_mask,
memory_key_padding_mask=memory_key_padding_mask,
)
output = self.norm(output)
return output
x = torch.randn(2, 16, 512) # (batch_size, tgt_seq_len, d_model)
memory = torch.randn(2, 32, 512) # (batch_size, src_seq_len, d_model)
decoder = TransformerDecoder(d_model=512, num_heads=8, num_layers=6)
with torch.inference_mode():
output = decoder(x, memory)
print('Decoder output shape:', output.shape)Decoder output shape: torch.Size([2, 16, 512])
需要注意的是,decoder 的输入 tgt 是目标序列的表示,而不是原始的 token id。通常我们会先把目标 token id 通过一个 embedding 层转换成向量表示,然后再输入到 decoder 里。这和 encoder 的输入处理方式是一样的。
讨论完 decoder 的结构,我们再来看一下 causal mask 的生成。下面是一个简单的 causal mask 生成函数:
tensor([[0., -inf, -inf, -inf],
[0., 0., -inf, -inf],
[0., 0., 0., -inf],
[0., 0., 0., 0.]])
它会被加到 attention score 上,使得对应位置的 softmax 权重变成 0。
下面是一个简单的例子,展示了如何把 causal mask 应用到 Transformer Decoder block 中:
x = torch.randn(2, 16, 512) # (batch_size, tgt_seq_len, d_model)
memory = torch.randn(2, 32, 512) # (batch_size, src_seq_len, d_model)
decoder_layer = TransformerDecoderLayer(d_model=512, num_heads=8)
mask = generate_causal_mask(x.size(-2), device=x.device)
with torch.inference_mode():
output = decoder_layer(x, memory, tgt_mask=mask)
print('Decoder Block output shape:', output.shape)Decoder Block output shape: torch.Size([2, 16, 512])
你看,我们把生成的 causal mask 传给了 tgt_mask 参数,这样在 masked self-attention 中就会自动应用这个 mask,保证每个位置只能看到自己和之前的位置。
到这里,decoder 的核心机制就基本完整了。它先用 masked self-attention 处理已经生成的目标序列,再通过 cross-attention 从 encoder output 中读取源序列信息。其中,causal mask 的作用非常关键,因为它让 decoder 在训练时也不能提前看到未来 token。
接下来我们就可以看一个更实际的问题:decoder 到底是怎么用来生成文本的?
训练时,我们通常已经有完整的目标序列。比如目标句子是:
<BOS> 我 喜欢 深度 学习
模型要预测的是:
我 喜欢 深度 学习 <EOS>
虽然训练时目标序列完整可见,但由于 causal mask 的存在,每个位置仍然只能看到自己之前的位置。因此,模型就不会偷看未来答案,而是只能根据已经生成的部分来预测下一个词。
推理时,decoder 会从一个起始 token 开始逐步生成:
<BOS>
先预测第一个 token:
我
然后把它接回输入:
<BOS> 我
再预测下一个 token:
喜欢
这个过程不断重复,直到生成 <EOS> 或达到最大长度。
所以,decoder 的生成过程本质上是 autoregressive 的:
\[ p(y_1, y_2, \dots, y_T \mid x) = \prod_{t=1}^{T} p(y_t \mid y_{<t}, x) \]
这里的 \(x\) 是源序列,\(y_{<t}\) 是已经生成的目标 token。
Masked self-attention 负责建模 \(y_{<t}\),cross-attention 负责利用输入序列 \(x\)。
这一节里,我们把 Transformer Decoder 的核心结构拆开看了一遍。
Decoder 和 encoder 很像,都是由 attention、feed-forward network、残差连接和 LayerNorm 组成。但 decoder 有两个关键特点:第一,它的 self-attention 是 masked self-attention,每个位置只能关注自己和之前的位置,不能看到未来 token;第二,它包含 cross-attention,用 decoder 的表示作为 query,去 encoder 的输出中检索源序列信息。
因此,decoder 同时承担了两件事:一方面,它要像语言模型一样,根据已经生成的目标端上下文继续生成;另一方面,它又要不断参考 encoder 输出,保证生成内容和输入序列对应起来。
到这里,Transformer 的 encoder 和 decoder 都已经讲完了。下一节,我们就可以把这些组件合在一起,看看完整的 Encoder-Decoder Transformer 是如何完成序列到序列建模的。