import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
import dnnl.nn as dnn
import dnnl.nn.functional as dF
print('PyTorch version:', torch.__version__)PyTorch version: 2.12.0+xpu
jshn9515
2026-05-05
2026-05-05
前面几节里,我们已经分别看过 Transformer Encoder 和 Transformer Decoder。Encoder 的核心任务是把输入序列编码成一组上下文表示;decoder 的核心任务是在 causal mask 的限制下,根据已有目标前缀继续预测后面的 token。
这一节,我们不再重复拆解 encoder block 或 decoder block 的内部结构,而是把它们连接起来,从整体上看完整的 Encoder-Decoder Transformer 是如何完成序列到序列生成的。
这一节重点回答三个问题:
PyTorch version: 2.12.0+xpu
Encoder-Decoder Transformer 适合处理一类任务:给定一个输入序列,生成另一个输出序列。
比如机器翻译:
English: I love deep learning.
Chinese: 我喜欢深度学习。
或者代码生成:
Input: 打印 "Hello, World!"
Output: print("Hello, World!")
这类任务通常叫做 sequence-to-sequence,简称 seq2seq。它的输入和输出都是序列,但二者的长度不一定相同,语言、结构和表达方式也可能完全不同。
我们可以把输入序列写成:
\[ X = (x_1, x_2, \dots, x_m) \]
把目标序列写成:
\[ Y = (y_1, y_2, \dots, y_n) \]
模型要学习的是条件概率:
\[ p(Y \mid X) \]
也就是说,在给定输入序列 \(X\) 的条件下,模型要生成目标序列 \(Y\)。
不过,目标序列通常不是一次性生成出来的,而是一步一步预测下一个 token。因此,这个条件概率可以进一步拆成:
\[ p(Y \mid X) = \prod_{t=1}^{n} p(y_t \mid y_{<t}, X) \]
其中 \(y_{<t}\) 表示第 \(t\) 个 token 之前的目标前缀。
这个公式里有两个关键信息来源。第一个是输入序列 \(X\),它告诉模型“要根据什么内容生成”;第二个是目标前缀 \(y_{<t}\),它告诉模型“现在已经生成到哪里了”。
这也正好对应 Encoder-Decoder Transformer 的两个部分:
所以,Encoder-Decoder Transformer 的核心不是简单地把 encoder 和 decoder 拼在一起,而是让模型在生成每个 token 时同时利用两类信息:源序列的内容,以及目标端已经生成的前缀。
完整 Encoder-Decoder Transformer 的信息流可以写成:
图中左半部分是 encoder,右半部分是 decoder。
Encoder 侧负责处理输入序列。输入 token 先经过 token embedding 和 positional encoding,然后送入多层 encoder stack,得到一组源序列表示:
\[ H_{\mathrm{enc}} = \operatorname{Encoder}(X) \]
其中 \(H_{\mathrm{enc}}\) 可以理解为输入序列中每个位置的上下文表示。
Decoder 侧负责生成目标序列。它接收目标端已经出现的前缀 \(y_{<t}\),先通过 masked self-attention 建模目标前缀内部的关系,再通过 cross-attention 读取 encoder 输出 \(H_{\mathrm{enc}}\)。
也就是说,公式里的
\[ p(y_t \mid y_{<t}, X) \]
在模型结构中可以对应成两条信息流:
所以我们可以把完整信息流简单概括为:
Encoder 先读入源序列,得到源序列表示;decoder 一边看目标前缀,一边通过 cross-attention 读取源序列表示,然后预测下一个目标 token。
接下来,我们就来看看训练和推理时这个过程是怎么组织的。
在训练时,我们是有标准答案的。比如机器翻译任务中,输入是源语言句子,目标是人工标注好的目标语言句子。
假设目标序列是:
\[ \text{<BOS>},\ y_1,\ y_2,\ y_3,\ \text{<EOS>} \]
其中,<BOS> 表示句子开始,<EOS> 表示句子结束。
训练 decoder 时,我们通常会把目标序列右移一位作为输入:
Decoder input: <BOS> y1 y2 y3
Target label: y1 y2 y3 <EOS>
也就是说,模型看到 <BOS> 时,要预测 \(y_1\);看到 <BOS>,\(y_1\) 时,要预测 \(y_2\);看到 <BOS>,\(y_1\),\(y_2\) 时,要预测 \(y_3\)。这里的关键是,训练时喂给 decoder 的历史 token 来自真实答案,而不是模型自己生成的结果。这就叫 teacher forcing。
我们可以把 teacher forcing 想象成老师在旁边纠正学生。训练时,老师不断提供正确的历史 token,让模型始终在正确前缀下学习预测下一个 token。这样做可以让训练更稳定:即使模型一开始预测错了,后面的位置仍然能看到正确前缀,而不会因为前面一步错了,导致后面所有位置都跟着偏掉。
Teacher forcing 还有另一个重要作用:它让训练可以并行进行。乍一看,自回归建模似乎必须一步一步做:
\[ p(y_1 \mid X) \]
\[ p(y_2 \mid y_1, X) \]
\[ p(y_3 \mid y_1, y_2, X) \]
但训练时,我们已经知道完整目标序列,所以可以一次性构造右移后的 decoder input 和对应的 target label,然后让 decoder 同时输出所有位置的预测。
但是,如果我们一次性输入完整目标前缀,模型会不会偷看未来呢?答案是不会。因为 decoder 的 self-attention 里有 causal mask。
Causal mask 会保证:
位置 1 只能看到 BOS
位置 2 只能看到 BOS, y1
位置 3 只能看到 BOS, y1, y2
位置 4 只能看到 BOS, y1, y2, y3
所以,训练时有两个事实同时成立:
训练时 decoder 的输出通常会经过一个线性层映射到词表大小,并对每个位置计算交叉熵损失:
\[ \mathcal{L} = -\sum_{t=1}^{T} \log p(y_t \mid y_{<t}, X) \]
这就是 Transformer 训练高效的关键:
用 teacher forcing 提供真实前缀,用 causal mask 保证不能看未来,从而并行训练所有位置。
但是,推理时就不一样了。这时候我们没有目标答案,也就没有真实前缀可以喂给 decoder。模型只能从 <BOS> 开始,一步一步生成。
过程大致是:
Step 1 输入: <BOS>
输出: y1
Step 2 输入: <BOS>, y1
输出: y2
Step 3 输入: <BOS>, y1, y2
输出: y3
Step 4 输入: <BOS>, y1, y2, y3
输出: <EOS>
这里的 \(y_1, y_2, y_3\) 都是模型自己生成的 token。
这就是自回归生成(Autoregressive Generation)。它和训练最大的区别就是,训练时 decoder 看到的前缀来自真实答案,而推理时 decoder 看到的前缀来自模型自己。如果模型第一步生成错了,那么第二步就只能基于这个错误 token 继续生成,导致错误在后续生成过程中不断累积。这种现象通常叫做 Exposure Bias。
我们可以把训练和推理放在一起看。
| 阶段 | Decoder 输入 | 前缀来源 | 计算方式 |
|---|---|---|---|
| 训练 | 右移后的真实目标序列 | 真实答案 | 并行 |
| 推理 | 已经生成的 token | 模型自己 | 串行 |
训练和推理使用的是同一个 Encoder-Decoder Transformer,只是 decoder 的输入来源不同。
训练时,目标序列已经给定,所以可以用右移后的真实目标序列作为前缀,并行预测所有位置;推理时,真实目标序列不存在,模型只能把自己刚生成的 token 拼回输入,再继续预测下一个 token。
完整 Encoder-Decoder Transformer 里通常会涉及几种 mask。前面我们已经重点讲过 causal mask,这里只简单梳理它们各自屏蔽什么。
| Mask | 作用位置 | 主要作用 |
|---|---|---|
| Source padding mask | Encoder self-attention | 屏蔽源序列中的 padding token,防止 encoder 关注这些位置。 |
| Target padding mask | Decoder masked self-attention | 屏蔽目标序列中的 padding token,防止 decoder 关注这些位置。 |
| Source causal mask | Encoder self-attention | 通常不需要,因为 encoder 不涉及自回归生成;只有当希望 encoder 也按从左到右的方式处理输入时才会使用。 |
| Target causal mask | Decoder masked self-attention | 屏蔽目标序列中的未来位置,防止 decoder 在训练时看到未来 token。 |
| Memory key padding mask | Decoder cross-attention | 屏蔽 encoder 输出中对应源序列 padding token 的位置,防止 decoder 关注这些位置。 |
这些 mask 的名字在不同框架里可能略有不同,但核心就是两类:
Padding mask:不要看 padding。
Causal mask:不要看未来。
接下来,我们来完整实现一下 Encoder-Decoder Transformer。我们直接使用 8.6 和 8.7 中实现的 encoder block 和 decoder block,来构建一个完整的 Transformer 模型。
class Transformer(nn.Module):
def __init__(
self,
d_model: int = 512,
num_heads: int = 8,
num_encoder_layers: int = 6,
num_decoder_layers: int = 6,
dim_feedforward: int = 2048,
dropout: float = 0.1,
activation: str | Callable[[Tensor], Tensor] = F.relu,
layer_norm_eps: float = 1e-5,
norm_first: bool = False,
bias: bool = True,
):
super().__init__()
self.d_model = d_model
self.num_heads = num_heads
encoder_layer = dnn.TransformerEncoderLayer(
d_model,
num_heads,
dim_feedforward,
bias=bias,
dropout=dropout,
activation=activation,
layer_norm_eps=layer_norm_eps,
norm_first=norm_first,
)
encoder_norm = nn.LayerNorm(d_model, eps=layer_norm_eps, bias=bias)
self.encoder = dnn.TransformerEncoder(
encoder_layer,
num_encoder_layers,
norm=encoder_norm,
)
decoder_layer = dnn.TransformerDecoderLayer(
d_model,
num_heads,
dim_feedforward,
bias=bias,
dropout=dropout,
activation=activation,
layer_norm_eps=layer_norm_eps,
norm_first=norm_first,
)
decoder_norm = nn.LayerNorm(d_model, eps=layer_norm_eps, bias=bias)
self.decoder = dnn.TransformerDecoder(
decoder_layer,
num_decoder_layers,
norm=decoder_norm,
)
def forward(
self,
src: Tensor,
tgt: Tensor,
src_mask: Tensor | None = None,
tgt_mask: Tensor | None = None,
memory_mask: Tensor | None = None,
src_key_padding_mask: Tensor | None = None,
tgt_key_padding_mask: Tensor | None = None,
memory_key_padding_mask: Tensor | None = None,
src_is_causal: bool = False,
tgt_is_causal: bool = False,
memory_is_causal: bool = False,
) -> Tensor:
if src.size(-1) != self.d_model or tgt.size(-1) != self.d_model:
raise AssertionError(
'The feature number of `src` and `tgt` must be equal to `d_model`.'
)
if src.size(0) != tgt.size(0):
raise AssertionError('The batch size of `src` and `tgt` must be equal.')
memory = self.encoder(
src,
mask=src_mask,
src_key_padding_mask=src_key_padding_mask,
is_causal=src_is_causal,
)
output = self.decoder(
tgt,
memory,
tgt_mask=tgt_mask,
memory_mask=memory_mask,
tgt_key_padding_mask=tgt_key_padding_mask,
memory_key_padding_mask=memory_key_padding_mask,
tgt_is_causal=tgt_is_causal,
memory_is_causal=memory_is_causal,
)
return output
transformer = Transformer(
d_model=512,
num_heads=8,
num_encoder_layers=6,
num_decoder_layers=6,
norm_first=True, # Pre-LN Transformer
)它内部已经包含 encoder 和 decoder。调用时通常传入:
Output shape: torch.Size([2, 16, 512])
其中:
src 是 encoder 输入;tgt 是 decoder 输入;输入和输出的形状(如果默认 batch 是第一维)一般是:
src: (batch_size, src_seq_len, d_model)
tgt: (batch_size, tgt_seq_len, d_model)
out: (batch_size, tgt_seq_len, d_model)
需要注意的是,Transformer 接收的不是 token id,而是已经 embedding 后的向量。
所以完整模型通常还要包括:
带上 mask 时,常见调用形式是:
src = torch.randn(2, 32, 512)
tgt = torch.randn(2, 16, 512)
tgt_mask = dF.generate_causal_mask(tgt.size(1), device=tgt.device)
src_key_padding_mask = torch.zeros(2, 32, dtype=torch.bool)
tgt_key_padding_mask = torch.zeros(2, 16, dtype=torch.bool)
memory_key_padding_mask = torch.zeros(2, 32, dtype=torch.bool)
output = transformer(
src=src,
tgt=tgt,
tgt_mask=tgt_mask, # causal mask
src_key_padding_mask=src_key_padding_mask,
tgt_key_padding_mask=tgt_key_padding_mask,
memory_key_padding_mask=memory_key_padding_mask,
)
print('Output shape:', output.shape)Output shape: torch.Size([2, 16, 512])
下面写一个简化版 Seq2Seq Transformer。这里直接使用前面写好的 Transformer,在它的基础上加上 token embedding、positional encoding 和输出投影层。
class Seq2SeqTransformer(nn.Module):
def __init__(
self,
src_vocab_size: int,
tgt_vocab_size: int,
d_model: int = 512,
num_heads: int = 8,
num_encoder_layers: int = 6,
num_decoder_layers: int = 6,
dim_feedforward: int = 2048,
dropout: float = 0.1,
max_len: int = 5000,
):
super().__init__()
self.d_model = d_model
self.src_embedding = nn.Embedding(src_vocab_size, d_model)
self.tgt_embedding = nn.Embedding(tgt_vocab_size, d_model)
self.pos_encoding = dnn.SinusoidalPositionalEncoding(d_model, max_len)
self.transformer = Transformer(
d_model=d_model,
num_heads=num_heads,
num_encoder_layers=num_encoder_layers,
num_decoder_layers=num_decoder_layers,
dim_feedforward=dim_feedforward,
dropout=dropout,
norm_first=True, # Pre-LN Transformer
)
self.output_proj = nn.Linear(d_model, tgt_vocab_size)
def forward(
self,
src: Tensor,
tgt: Tensor,
src_mask: Tensor | None = None,
tgt_mask: Tensor | None = None,
src_key_padding_mask: Tensor | None = None,
tgt_key_padding_mask: Tensor | None = None,
memory_key_padding_mask: Tensor | None = None,
) -> Tensor:
# We scale the embeddings by sqrt(d_model) to maintain the variance
# of the input to the Transformer.
scale = math.sqrt(self.d_model)
src_emb = self.src_embedding(src) * scale
tgt_emb = self.tgt_embedding(tgt) * scale
src_emb = self.pos_encoding(src_emb)
tgt_emb = self.pos_encoding(tgt_emb)
hidden_states = self.transformer(
src=src_emb,
tgt=tgt_emb,
src_mask=src_mask,
tgt_mask=tgt_mask,
src_key_padding_mask=src_key_padding_mask,
tgt_key_padding_mask=tgt_key_padding_mask,
memory_key_padding_mask=memory_key_padding_mask,
)
logits = self.output_proj(hidden_states)
return logits
src_vocab_size = 100
tgt_vocab_size = 100
src = torch.randint(0, 100, size=(3, 32))
tgt = torch.randint(0, 100, size=(3, 16))
tgt_mask = dF.generate_causal_mask(tgt.size(1))
seq2seq = Seq2SeqTransformer(
src_vocab_size=src_vocab_size,
tgt_vocab_size=tgt_vocab_size,
d_model=512,
num_heads=8,
num_encoder_layers=6,
num_decoder_layers=6,
)
with torch.inference_mode():
output = seq2seq(src, tgt, tgt_mask=tgt_mask)
print('Output shape:', output.shape)Output shape: torch.Size([3, 16, 100])
这里的 SinusoidalPositionalEncoding 可以使用 8.5 中实现的版本。
这个模型的输入是 src 和 tgt,它们都是 token ids。模型内部会先把它们映射成 embedding,然后加上 positional encoding,再送入 Transformer。最后,Transformer 的输出会经过一个线性层映射到目标词表大小,得到每个位置的预测 logits。对于每个位置,logits 的形状是 (batch_size, tgt_vocab_size),表示模型对该位置每个 token 的预测分数。
要想训练这个模型,我们需要先构造 decoder 输入和 labels。
假设 target_ids 已经包含了 <BOS> 和 <EOS>:
bos_token_id = 0
eos_token_id = 5
pad_token_id = -100
src_sentence = ['<BOS>', '我', '爱', '深度', '学习', '<EOS>']
tgt_sentence = ['<BOS>', 'I', 'love', 'deep', 'learning', '<EOS>']
src_ids = torch.tensor([[0, 1, 2, 3, 4, 5]], dtype=torch.long)
tgt_ids = torch.tensor([[0, 1, 2, 3, 4, 5]], dtype=torch.long)那么我们可以把它切成两部分:
然后生成 causal mask:
前向传播:
Logits shape: torch.Size([1, 5, 100])
然后计算交叉熵损失:
Loss: 4.288424491882324
这里的含义是:对每个 decoder 位置,模型都要根据当前位置之前的真实 token,以及源序列表示,预测当前位置的真实 token。
推理时,我们从 <BOS> 开始,不断把模型生成的新 token 接到 decoder 输入后面。
@torch.inference_mode()
def generate(
model: Seq2SeqTransformer,
src_input_ids: Tensor,
bos_token_id: int,
eos_token_id: int,
pad_token_id: int,
max_new_tokens: int = 50,
):
model.eval()
batch_size = src_input_ids.size(0)
generated_ids = torch.full(
(batch_size, 1),
bos_token_id,
dtype=torch.long,
device=src_input_ids.device,
)
finished = torch.zeros(
batch_size,
dtype=torch.bool,
device=src_input_ids.device,
)
for _ in range(max_new_tokens):
logits = model(
src=src_input_ids,
tgt=generated_ids,
)
next_token_logits = logits[:, -1, :]
next_token_id = next_token_logits.argmax(dim=-1)
next_token_id = torch.where(
finished,
torch.full_like(next_token_id, pad_token_id),
next_token_id,
)
generated_ids = torch.concat(
[generated_ids, next_token_id.unsqueeze(1)],
dim=1,
)
finished = finished | (next_token_id == eos_token_id)
if finished.all():
break
return generated_ids
generated_ids = generate(
model=seq2seq,
src_input_ids=src_ids,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
pad_token_id=pad_token_id,
max_new_tokens=10,
)
print('Generated IDs:', generated_ids)Generated IDs: tensor([[ 0, 66, 49, 79, 39, 59, 72, 10, 64, 74, 28]])
这里用了最简单的 greedy decoding,也就是每一步都选择概率最高的 token:
\[ \hat{y}_t = \arg\max_y p(y \mid \hat{y}_{<t}, X) \]
实际中还可以使用 beam search、top-k sampling、top-p sampling 等方法。它们的区别主要在于:每一步如何从模型给出的概率分布里选择下一个 token。
不过无论选择策略怎么变,自回归生成的基本结构都是一样的:
生成一个 token,把它接回 Decoder 输入,再生成下一个 token。
这一节里,我们把 encoder 和 decoder 放到完整的 seq2seq 任务中,理解 Encoder-Decoder Transformer 是如何工作的。
从概率角度看,模型学习的是:
\[ p(Y \mid X) = \prod_{t=1}^{n} p(y_t \mid y_{<t}, X) \]
其中,\(X\) 由 encoder 处理,\(y_{<t}\) 由 decoder 的目标前缀提供。Decoder 通过 masked self-attention 读取目标前缀,通过 cross-attention 读取 encoder 输出。
训练时,我们使用 teacher forcing,把目标序列右移一位作为 decoder input,让模型在真实前缀条件下预测下一个 token。借助 causal mask,Transformer 可以并行计算所有位置的预测,同时保证每个位置不能看到未来 token。
推理时,没有真实答案可用,decoder 只能从 <BOS> 开始自回归生成:每一步预测一个 token,再把这个 token 接回输入,继续预测下一个 token。
所以,Encoder-Decoder Transformer 的完整逻辑可以概括为:
Encoder 先理解输入,decoder 再在目标前缀和输入表示的共同条件下,一步一步生成输出。
下一节,我们继续看一个和推理效率密切相关的问题:既然自回归生成必须一步一步进行,为什么实际推理时不需要每一步都重复计算过去 token 的 key 和 value?
这就引出了 KV Cache。