8.8 Encoder-Decoder Transformer:把 Encoder 和 Decoder 连接起来

Author

jshn9515

Published

2026-05-05

Modified

2026-05-05

前面几节里,我们已经分别看过 Transformer Encoder 和 Transformer Decoder。Encoder 的核心任务是把输入序列编码成一组上下文表示;decoder 的核心任务是在 causal mask 的限制下,根据已有目标前缀继续预测后面的 token。

这一节,我们不再重复拆解 encoder block 或 decoder block 的内部结构,而是把它们连接起来,从整体上看完整的 Encoder-Decoder Transformer 是如何完成序列到序列生成的。

这一节重点回答三个问题:

  1. Encoder 和 decoder 在完整模型里如何连接;
  2. 训练时为什么可以使用真实前缀并行预测;
  3. 推理时为什么必须自回归生成。
import math

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
import dnnl.nn as dnn
import dnnl.nn.functional as dF

print('PyTorch version:', torch.__version__)
PyTorch version: 2.12.0+xpu

8.8.1 从 Seq2Seq 到条件自回归生成

Encoder-Decoder Transformer 适合处理一类任务:给定一个输入序列,生成另一个输出序列。

比如机器翻译:

English: I love deep learning.
Chinese: 我喜欢深度学习。

或者代码生成:

Input: 打印 "Hello, World!"
Output: print("Hello, World!")

这类任务通常叫做 sequence-to-sequence,简称 seq2seq。它的输入和输出都是序列,但二者的长度不一定相同,语言、结构和表达方式也可能完全不同。

我们可以把输入序列写成:

\[ X = (x_1, x_2, \dots, x_m) \]

把目标序列写成:

\[ Y = (y_1, y_2, \dots, y_n) \]

模型要学习的是条件概率:

\[ p(Y \mid X) \]

也就是说,在给定输入序列 \(X\) 的条件下,模型要生成目标序列 \(Y\)

不过,目标序列通常不是一次性生成出来的,而是一步一步预测下一个 token。因此,这个条件概率可以进一步拆成:

\[ p(Y \mid X) = \prod_{t=1}^{n} p(y_t \mid y_{<t}, X) \]

其中 \(y_{<t}\) 表示第 \(t\) 个 token 之前的目标前缀。

这个公式里有两个关键信息来源。第一个是输入序列 \(X\),它告诉模型“要根据什么内容生成”;第二个是目标前缀 \(y_{<t}\),它告诉模型“现在已经生成到哪里了”。

这也正好对应 Encoder-Decoder Transformer 的两个部分:

  • Encoder 负责处理输入序列 \(X\)
  • Decoder 负责根据目标前缀 \(y_{<t}\),并结合 encoder 输出,预测下一个目标 token \(y_t\)

所以,Encoder-Decoder Transformer 的核心不是简单地把 encoder 和 decoder 拼在一起,而是让模型在生成每个 token 时同时利用两类信息:源序列的内容,以及目标端已经生成的前缀

8.8.2 Encoder-Decoder Transformer 的信息流

完整 Encoder-Decoder Transformer 的信息流可以写成:

图 1:Transformer 网络结构图
图 1:Transformer 网络结构图 (Vaswani et al. 2023, fig. 1)

图中左半部分是 encoder,右半部分是 decoder。

Encoder 侧负责处理输入序列。输入 token 先经过 token embedding 和 positional encoding,然后送入多层 encoder stack,得到一组源序列表示:

\[ H_{\mathrm{enc}} = \operatorname{Encoder}(X) \]

其中 \(H_{\mathrm{enc}}\) 可以理解为输入序列中每个位置的上下文表示。

Decoder 侧负责生成目标序列。它接收目标端已经出现的前缀 \(y_{<t}\),先通过 masked self-attention 建模目标前缀内部的关系,再通过 cross-attention 读取 encoder 输出 \(H_{\mathrm{enc}}\)

也就是说,公式里的

\[ p(y_t \mid y_{<t}, X) \]

在模型结构中可以对应成两条信息流:

  • \(X\) 先经过 encoder 变成 \(H_{\mathrm{enc}}\),再通过 decoder 的 cross-attention 进入模型;
  • \(y_{<t}\) 通过 decoder 的 masked self-attention 进入模型。

所以我们可以把完整信息流简单概括为:

Encoder 先读入源序列,得到源序列表示;decoder 一边看目标前缀,一边通过 cross-attention 读取源序列表示,然后预测下一个目标 token。

接下来,我们就来看看训练和推理时这个过程是怎么组织的。

8.8.3 Teacher Forcing:训练时使用真实前缀并行预测

在训练时,我们是有标准答案的。比如机器翻译任务中,输入是源语言句子,目标是人工标注好的目标语言句子。

假设目标序列是:

\[ \text{<BOS>},\ y_1,\ y_2,\ y_3,\ \text{<EOS>} \]

其中,<BOS> 表示句子开始,<EOS> 表示句子结束。

训练 decoder 时,我们通常会把目标序列右移一位作为输入:

Decoder input:  <BOS>   y1    y2    y3
Target  label:  y1      y2    y3    <EOS>

也就是说,模型看到 <BOS> 时,要预测 \(y_1\);看到 <BOS>\(y_1\) 时,要预测 \(y_2\);看到 <BOS>\(y_1\)\(y_2\) 时,要预测 \(y_3\)。这里的关键是,训练时喂给 decoder 的历史 token 来自真实答案,而不是模型自己生成的结果。这就叫 teacher forcing

我们可以把 teacher forcing 想象成老师在旁边纠正学生。训练时,老师不断提供正确的历史 token,让模型始终在正确前缀下学习预测下一个 token。这样做可以让训练更稳定:即使模型一开始预测错了,后面的位置仍然能看到正确前缀,而不会因为前面一步错了,导致后面所有位置都跟着偏掉。

Teacher forcing 还有另一个重要作用:它让训练可以并行进行。乍一看,自回归建模似乎必须一步一步做:

\[ p(y_1 \mid X) \]

\[ p(y_2 \mid y_1, X) \]

\[ p(y_3 \mid y_1, y_2, X) \]

但训练时,我们已经知道完整目标序列,所以可以一次性构造右移后的 decoder input 和对应的 target label,然后让 decoder 同时输出所有位置的预测。

但是,如果我们一次性输入完整目标前缀,模型会不会偷看未来呢?答案是不会。因为 decoder 的 self-attention 里有 causal mask

Causal mask 会保证:

位置 1 只能看到 BOS
位置 2 只能看到 BOS, y1
位置 3 只能看到 BOS, y1, y2
位置 4 只能看到 BOS, y1, y2, y3

所以,训练时有两个事实同时成立:

  • 计算上是并行的:所有位置一起算;
  • 信息上仍然是自回归的:每个位置只能看到自己之前的 token。

训练时 decoder 的输出通常会经过一个线性层映射到词表大小,并对每个位置计算交叉熵损失:

\[ \mathcal{L} = -\sum_{t=1}^{T} \log p(y_t \mid y_{<t}, X) \]

这就是 Transformer 训练高效的关键:

用 teacher forcing 提供真实前缀,用 causal mask 保证不能看未来,从而并行训练所有位置。

8.8.4 自回归生成:推理时使用模型自己生成的前缀

但是,推理时就不一样了。这时候我们没有目标答案,也就没有真实前缀可以喂给 decoder。模型只能从 <BOS> 开始,一步一步生成。

过程大致是:

Step 1 输入: <BOS>
       输出: y1

Step 2 输入: <BOS>, y1
       输出: y2

Step 3 输入: <BOS>, y1, y2
       输出: y3

Step 4 输入: <BOS>, y1, y2, y3
       输出: <EOS>

这里的 \(y_1, y_2, y_3\) 都是模型自己生成的 token。

这就是自回归生成(Autoregressive Generation)。它和训练最大的区别就是,训练时 decoder 看到的前缀来自真实答案,而推理时 decoder 看到的前缀来自模型自己。如果模型第一步生成错了,那么第二步就只能基于这个错误 token 继续生成,导致错误在后续生成过程中不断累积。这种现象通常叫做 Exposure Bias

8.8.5 训练和推理的对比

我们可以把训练和推理放在一起看。

阶段 Decoder 输入 前缀来源 计算方式
训练 右移后的真实目标序列 真实答案 并行
推理 已经生成的 token 模型自己 串行

训练和推理使用的是同一个 Encoder-Decoder Transformer,只是 decoder 的输入来源不同。

训练时,目标序列已经给定,所以可以用右移后的真实目标序列作为前缀,并行预测所有位置;推理时,真实目标序列不存在,模型只能把自己刚生成的 token 拼回输入,再继续预测下一个 token。

8.8.6 Encoder-Decoder 中的 Mask

完整 Encoder-Decoder Transformer 里通常会涉及几种 mask。前面我们已经重点讲过 causal mask,这里只简单梳理它们各自屏蔽什么。

Mask 作用位置 主要作用
Source padding mask Encoder self-attention 屏蔽源序列中的 padding token,防止 encoder 关注这些位置。
Target padding mask Decoder masked self-attention 屏蔽目标序列中的 padding token,防止 decoder 关注这些位置。
Source causal mask Encoder self-attention 通常不需要,因为 encoder 不涉及自回归生成;只有当希望 encoder 也按从左到右的方式处理输入时才会使用。
Target causal mask Decoder masked self-attention 屏蔽目标序列中的未来位置,防止 decoder 在训练时看到未来 token。
Memory key padding mask Decoder cross-attention 屏蔽 encoder 输出中对应源序列 padding token 的位置,防止 decoder 关注这些位置。

这些 mask 的名字在不同框架里可能略有不同,但核心就是两类:

Padding mask:不要看 padding。
Causal mask:不要看未来。

8.8.7 Transformer 的 PyTorch 实现

接下来,我们来完整实现一下 Encoder-Decoder Transformer。我们直接使用 8.6 和 8.7 中实现的 encoder block 和 decoder block,来构建一个完整的 Transformer 模型。

class Transformer(nn.Module):
    def __init__(
        self,
        d_model: int = 512,
        num_heads: int = 8,
        num_encoder_layers: int = 6,
        num_decoder_layers: int = 6,
        dim_feedforward: int = 2048,
        dropout: float = 0.1,
        activation: str | Callable[[Tensor], Tensor] = F.relu,
        layer_norm_eps: float = 1e-5,
        norm_first: bool = False,
        bias: bool = True,
    ):
        super().__init__()
        self.d_model = d_model
        self.num_heads = num_heads

        encoder_layer = dnn.TransformerEncoderLayer(
            d_model,
            num_heads,
            dim_feedforward,
            bias=bias,
            dropout=dropout,
            activation=activation,
            layer_norm_eps=layer_norm_eps,
            norm_first=norm_first,
        )
        encoder_norm = nn.LayerNorm(d_model, eps=layer_norm_eps, bias=bias)
        self.encoder = dnn.TransformerEncoder(
            encoder_layer,
            num_encoder_layers,
            norm=encoder_norm,
        )

        decoder_layer = dnn.TransformerDecoderLayer(
            d_model,
            num_heads,
            dim_feedforward,
            bias=bias,
            dropout=dropout,
            activation=activation,
            layer_norm_eps=layer_norm_eps,
            norm_first=norm_first,
        )
        decoder_norm = nn.LayerNorm(d_model, eps=layer_norm_eps, bias=bias)
        self.decoder = dnn.TransformerDecoder(
            decoder_layer,
            num_decoder_layers,
            norm=decoder_norm,
        )

    def forward(
        self,
        src: Tensor,
        tgt: Tensor,
        src_mask: Tensor | None = None,
        tgt_mask: Tensor | None = None,
        memory_mask: Tensor | None = None,
        src_key_padding_mask: Tensor | None = None,
        tgt_key_padding_mask: Tensor | None = None,
        memory_key_padding_mask: Tensor | None = None,
        src_is_causal: bool = False,
        tgt_is_causal: bool = False,
        memory_is_causal: bool = False,
    ) -> Tensor:
        if src.size(-1) != self.d_model or tgt.size(-1) != self.d_model:
            raise AssertionError(
                'The feature number of `src` and `tgt` must be equal to `d_model`.'
            )
        if src.size(0) != tgt.size(0):
            raise AssertionError('The batch size of `src` and `tgt` must be equal.')

        memory = self.encoder(
            src,
            mask=src_mask,
            src_key_padding_mask=src_key_padding_mask,
            is_causal=src_is_causal,
        )
        output = self.decoder(
            tgt,
            memory,
            tgt_mask=tgt_mask,
            memory_mask=memory_mask,
            tgt_key_padding_mask=tgt_key_padding_mask,
            memory_key_padding_mask=memory_key_padding_mask,
            tgt_is_causal=tgt_is_causal,
            memory_is_causal=memory_is_causal,
        )
        return output


transformer = Transformer(
    d_model=512,
    num_heads=8,
    num_encoder_layers=6,
    num_decoder_layers=6,
    norm_first=True,  # Pre-LN Transformer
)

它内部已经包含 encoder 和 decoder。调用时通常传入:

src = torch.randn(2, 32, 512)  # (batch_size, src_seq_len, d_model)
tgt = torch.randn(2, 16, 512)  # (batch_size, tgt_seq_len, d_model)

output = transformer(src, tgt)
print('Output shape:', output.shape)
Output shape: torch.Size([2, 16, 512])

其中:

  • src 是 encoder 输入;
  • tgt 是 decoder 输入;
  • 输出是 decoder 的 hidden states。

输入和输出的形状(如果默认 batch 是第一维)一般是:

src: (batch_size, src_seq_len, d_model)
tgt: (batch_size, tgt_seq_len, d_model)
out: (batch_size, tgt_seq_len, d_model)

需要注意的是,Transformer 接收的不是 token id,而是已经 embedding 后的向量。

所以完整模型通常还要包括:

  • Source token embedding;
  • Target token embedding;
  • Positional Encoding;
  • Transformer Encoder-Decoder;
  • Output projection。

带上 mask 时,常见调用形式是:

src = torch.randn(2, 32, 512)
tgt = torch.randn(2, 16, 512)
tgt_mask = dF.generate_causal_mask(tgt.size(1), device=tgt.device)
src_key_padding_mask = torch.zeros(2, 32, dtype=torch.bool)
tgt_key_padding_mask = torch.zeros(2, 16, dtype=torch.bool)
memory_key_padding_mask = torch.zeros(2, 32, dtype=torch.bool)

output = transformer(
    src=src,
    tgt=tgt,
    tgt_mask=tgt_mask,  # causal mask
    src_key_padding_mask=src_key_padding_mask,
    tgt_key_padding_mask=tgt_key_padding_mask,
    memory_key_padding_mask=memory_key_padding_mask,
)
print('Output shape:', output.shape)
Output shape: torch.Size([2, 16, 512])

8.8.8 Seq2Seq Transformer 的 PyTorch 实现

下面写一个简化版 Seq2Seq Transformer。这里直接使用前面写好的 Transformer,在它的基础上加上 token embedding、positional encoding 和输出投影层。

class Seq2SeqTransformer(nn.Module):
    def __init__(
        self,
        src_vocab_size: int,
        tgt_vocab_size: int,
        d_model: int = 512,
        num_heads: int = 8,
        num_encoder_layers: int = 6,
        num_decoder_layers: int = 6,
        dim_feedforward: int = 2048,
        dropout: float = 0.1,
        max_len: int = 5000,
    ):
        super().__init__()
        self.d_model = d_model
        self.src_embedding = nn.Embedding(src_vocab_size, d_model)
        self.tgt_embedding = nn.Embedding(tgt_vocab_size, d_model)
        self.pos_encoding = dnn.SinusoidalPositionalEncoding(d_model, max_len)

        self.transformer = Transformer(
            d_model=d_model,
            num_heads=num_heads,
            num_encoder_layers=num_encoder_layers,
            num_decoder_layers=num_decoder_layers,
            dim_feedforward=dim_feedforward,
            dropout=dropout,
            norm_first=True,  # Pre-LN Transformer
        )
        self.output_proj = nn.Linear(d_model, tgt_vocab_size)

    def forward(
        self,
        src: Tensor,
        tgt: Tensor,
        src_mask: Tensor | None = None,
        tgt_mask: Tensor | None = None,
        src_key_padding_mask: Tensor | None = None,
        tgt_key_padding_mask: Tensor | None = None,
        memory_key_padding_mask: Tensor | None = None,
    ) -> Tensor:
        # We scale the embeddings by sqrt(d_model) to maintain the variance
        # of the input to the Transformer.
        scale = math.sqrt(self.d_model)
        src_emb = self.src_embedding(src) * scale
        tgt_emb = self.tgt_embedding(tgt) * scale

        src_emb = self.pos_encoding(src_emb)
        tgt_emb = self.pos_encoding(tgt_emb)

        hidden_states = self.transformer(
            src=src_emb,
            tgt=tgt_emb,
            src_mask=src_mask,
            tgt_mask=tgt_mask,
            src_key_padding_mask=src_key_padding_mask,
            tgt_key_padding_mask=tgt_key_padding_mask,
            memory_key_padding_mask=memory_key_padding_mask,
        )

        logits = self.output_proj(hidden_states)
        return logits


src_vocab_size = 100
tgt_vocab_size = 100
src = torch.randint(0, 100, size=(3, 32))
tgt = torch.randint(0, 100, size=(3, 16))
tgt_mask = dF.generate_causal_mask(tgt.size(1))

seq2seq = Seq2SeqTransformer(
    src_vocab_size=src_vocab_size,
    tgt_vocab_size=tgt_vocab_size,
    d_model=512,
    num_heads=8,
    num_encoder_layers=6,
    num_decoder_layers=6,
)

with torch.inference_mode():
    output = seq2seq(src, tgt, tgt_mask=tgt_mask)

print('Output shape:', output.shape)
Output shape: torch.Size([3, 16, 100])

这里的 SinusoidalPositionalEncoding 可以使用 8.5 中实现的版本。

这个模型的输入是 srctgt,它们都是 token ids。模型内部会先把它们映射成 embedding,然后加上 positional encoding,再送入 Transformer。最后,Transformer 的输出会经过一个线性层映射到目标词表大小,得到每个位置的预测 logits。对于每个位置,logits 的形状是 (batch_size, tgt_vocab_size),表示模型对该位置每个 token 的预测分数。

8.8.9 Seq2Seq Transformer 的训练流程

要想训练这个模型,我们需要先构造 decoder 输入和 labels。

假设 target_ids 已经包含了 <BOS><EOS>

bos_token_id = 0
eos_token_id = 5
pad_token_id = -100
src_sentence = ['<BOS>', '我', '爱', '深度', '学习', '<EOS>']
tgt_sentence = ['<BOS>', 'I', 'love', 'deep', 'learning', '<EOS>']
src_ids = torch.tensor([[0, 1, 2, 3, 4, 5]], dtype=torch.long)
tgt_ids = torch.tensor([[0, 1, 2, 3, 4, 5]], dtype=torch.long)

那么我们可以把它切成两部分:

# Decoder inputs: [<BOS>, y1, y2, y3]
tgt_inputs = tgt_ids[:, :-1]

# labels: [y1, y2, y3, <EOS>]
tgt_labels = tgt_ids[:, 1:]

然后生成 causal mask:

tgt_mask = dF.generate_causal_mask(
    tgt_inputs.size(1),
    device=tgt_inputs.device,
)

前向传播:

logits = seq2seq(
    src=src_ids,
    tgt=tgt_inputs,
    tgt_mask=tgt_mask,
)
print('Logits shape:', logits.shape)
Logits shape: torch.Size([1, 5, 100])

然后计算交叉熵损失:

loss = F.cross_entropy(
    logits.reshape(-1, tgt_vocab_size),
    tgt_labels.reshape(-1),
    ignore_index=pad_token_id,
)
print('Loss:', loss.item())
Loss: 4.288424491882324

这里的含义是:对每个 decoder 位置,模型都要根据当前位置之前的真实 token,以及源序列表示,预测当前位置的真实 token。

8.8.10 一个简单的自回归生成示例

推理时,我们从 <BOS> 开始,不断把模型生成的新 token 接到 decoder 输入后面。

@torch.inference_mode()
def generate(
    model: Seq2SeqTransformer,
    src_input_ids: Tensor,
    bos_token_id: int,
    eos_token_id: int,
    pad_token_id: int,
    max_new_tokens: int = 50,
):
    model.eval()

    batch_size = src_input_ids.size(0)
    generated_ids = torch.full(
        (batch_size, 1),
        bos_token_id,
        dtype=torch.long,
        device=src_input_ids.device,
    )

    finished = torch.zeros(
        batch_size,
        dtype=torch.bool,
        device=src_input_ids.device,
    )

    for _ in range(max_new_tokens):
        logits = model(
            src=src_input_ids,
            tgt=generated_ids,
        )

        next_token_logits = logits[:, -1, :]
        next_token_id = next_token_logits.argmax(dim=-1)

        next_token_id = torch.where(
            finished,
            torch.full_like(next_token_id, pad_token_id),
            next_token_id,
        )

        generated_ids = torch.concat(
            [generated_ids, next_token_id.unsqueeze(1)],
            dim=1,
        )

        finished = finished | (next_token_id == eos_token_id)
        if finished.all():
            break

    return generated_ids


generated_ids = generate(
    model=seq2seq,
    src_input_ids=src_ids,
    bos_token_id=bos_token_id,
    eos_token_id=eos_token_id,
    pad_token_id=pad_token_id,
    max_new_tokens=10,
)
print('Generated IDs:', generated_ids)
Generated IDs: tensor([[ 0, 66, 49, 79, 39, 59, 72, 10, 64, 74, 28]])

这里用了最简单的 greedy decoding,也就是每一步都选择概率最高的 token:

\[ \hat{y}_t = \arg\max_y p(y \mid \hat{y}_{<t}, X) \]

实际中还可以使用 beam search、top-k sampling、top-p sampling 等方法。它们的区别主要在于:每一步如何从模型给出的概率分布里选择下一个 token。

不过无论选择策略怎么变,自回归生成的基本结构都是一样的:

生成一个 token,把它接回 Decoder 输入,再生成下一个 token。

8.8.11 本章小结

这一节里,我们把 encoder 和 decoder 放到完整的 seq2seq 任务中,理解 Encoder-Decoder Transformer 是如何工作的。

从概率角度看,模型学习的是:

\[ p(Y \mid X) = \prod_{t=1}^{n} p(y_t \mid y_{<t}, X) \]

其中,\(X\) 由 encoder 处理,\(y_{<t}\) 由 decoder 的目标前缀提供。Decoder 通过 masked self-attention 读取目标前缀,通过 cross-attention 读取 encoder 输出。

训练时,我们使用 teacher forcing,把目标序列右移一位作为 decoder input,让模型在真实前缀条件下预测下一个 token。借助 causal mask,Transformer 可以并行计算所有位置的预测,同时保证每个位置不能看到未来 token。

推理时,没有真实答案可用,decoder 只能从 <BOS> 开始自回归生成:每一步预测一个 token,再把这个 token 接回输入,继续预测下一个 token。

所以,Encoder-Decoder Transformer 的完整逻辑可以概括为:

Encoder 先理解输入,decoder 再在目标前缀和输入表示的共同条件下,一步一步生成输出。

下一节,我们继续看一个和推理效率密切相关的问题:既然自回归生成必须一步一步进行,为什么实际推理时不需要每一步都重复计算过去 token 的 key 和 value?

这就引出了 KV Cache。

References

Vaswani, Ashish, Noam Shazeer, Niki Parmar, et al. 2023. Attention Is All You Need. https://arxiv.org/abs/1706.03762.

Reuse