8.8 Encoder-Decoder Transformer: Connecting Encoder and Decoder

Author

jshn9515

Published

2026-05-05

Modified

2026-05-05

In the previous sections, we have looked at the Transformer Encoder and the Transformer Decoder separately. The core task of the Encoder is to encode the input sequence into a set of contextual representations; the core task of the Decoder is to continue predicting later tokens based on the existing target prefix under the constraint of the causal mask.

In this section, we will no longer repeatedly break down the internal structure of the Encoder Block or Decoder Block. Instead, we will connect them together and look at how the complete Encoder-Decoder Transformer performs sequence-to-sequence generation as a whole.

This section focuses on answering three questions:

  1. How the Encoder and Decoder are connected in the complete model;
  2. Why real prefixes can be used for parallel prediction during training;
  3. Why inference must use autoregressive generation.
import math

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
import dnnl.nn as dnn
import dnnl.nn.functional as dF

print('PyTorch version:', torch.__version__)
PyTorch version: 2.12.0+xpu

8.8.1 From Seq2Seq to Conditional Autoregressive Generation

Encoder-Decoder Transformer is suitable for handling a type of task: given one input sequence, generate another output sequence.

For example, machine translation:

English: I love deep learning.
Chinese: 我喜欢深度学习。

Or code generation:

Input: 打印 "Hello, World!"
Output: print("Hello, World!")

This type of task is usually called Sequence-to-Sequence, abbreviated as Seq2Seq. Both its input and output are sequences, but their lengths are not necessarily the same, and their languages, structures, and forms of expression may also be completely different.

We can write the input sequence as:

\[ X = (x_1, x_2, \dots, x_m) \]

Write the target sequence as:

\[ Y = (y_1, y_2, \dots, y_n) \]

What the model needs to learn is the conditional probability:

\[ p(Y \mid X) \]

That is, under the condition of the given input sequence \(X\), the model needs to generate the target sequence \(Y\).

However, the target sequence is usually not generated all at once, but predicted one next token at a time. Therefore, this conditional probability can be further decomposed as:

\[ p(Y \mid X) = \prod_{t=1}^{n} p(y_t \mid y_{<t}, X) \]

where \(y_{<t}\) represents the target prefix before the \(t\)-th token.

There are two key sources of information in this formula. The first is the input sequence \(X\), which tells the model “what content to generate based on”; the second is the target prefix \(y_{<t}\), which tells the model “where the generation has reached now”.

This also corresponds exactly to the two parts of the Encoder-Decoder Transformer:

  • The Encoder is responsible for processing the input sequence \(X\);
  • The Decoder is responsible for predicting the next target token \(y_t\) based on the target prefix \(y_{<t}\) and the Encoder output.

So, the core of the Encoder-Decoder Transformer is not simply putting the Encoder and Decoder together, but allowing the model to use two types of information at the same time when generating each token: the content of the source sequence, and the prefix already generated on the target side.

8.8.2 Information Flow in the Encoder-Decoder Transformer

The information flow of a complete Encoder-Decoder Transformer can be written as:

Figure 1: Transformer network architecture diagram (Vaswani et al. 2023, fig. 1)

The left half of the figure is the Encoder, and the right half is the Decoder.

The Encoder side is responsible for processing the input sequence. The input tokens first go through token embedding and positional encoding, and are then sent into a multi-layer Encoder Stack to obtain a set of source sequence representations:

\[ H_{\mathrm{enc}} = \operatorname{Encoder}(X) \]

where \(H_{\mathrm{enc}}\) can be understood as the contextual representation of each position in the input sequence.

The Decoder side is responsible for generating the target sequence. It receives the prefix \(y_{<t}\) that has already appeared on the target side, first uses masked self-attention to model the internal relationships inside the target prefix, and then reads the Encoder output \(H_{\mathrm{enc}}\) through cross-attention.

That is, the formula

\[ p(y_t \mid y_{<t}, X) \]

can correspond to two information flows in the model structure:

  • \(X\) first goes through the Encoder and becomes \(H_{\mathrm{enc}}\), and then enters the model through the Decoder’s cross-attention;
  • \(y_{<t}\) enters the model through the Decoder’s masked self-attention.

So we can summarize the complete information flow simply as:

The Encoder first reads the source sequence and obtains source sequence representations; the Decoder looks at the target prefix while reading the source sequence representations through cross-attention, and then predicts the next target token.

Next, let us look at how this process is organized during training and inference.

8.8.3 Teacher Forcing: Using Real Prefixes for Parallel Prediction During Training

During training, we have the standard answer. For example, in a machine translation task, the input is a source-language sentence, and the target is a manually annotated target-language sentence.

Suppose the target sequence is:

\[ \text{<BOS>},\ y_1,\ y_2,\ y_3,\ \text{<EOS>} \]

where <BOS> represents the beginning of the sentence, and <EOS> represents the end of the sentence.

When training the Decoder, we usually shift the target sequence one position to the right as the input:

Decoder input:  <BOS>   y1    y2    y3
Target  label:  y1      y2    y3    <EOS>

That is, when the model sees <BOS>, it needs to predict \(y_1\); when it sees <BOS>, \(y_1\), it needs to predict \(y_2\); when it sees <BOS>, \(y_1\), \(y_2\), it needs to predict \(y_3\). The key here is that the historical tokens fed to the Decoder during training come from the real answer, not from the model’s own generated results. This is called teacher forcing.

We can imagine teacher forcing as a teacher correcting the student from the side. During training, the teacher continuously provides the correct historical tokens, allowing the model to always learn to predict the next token under the correct prefix. This makes training more stable: even if the model predicts incorrectly at the beginning, later positions can still see the correct prefix, and will not all drift away just because one previous step was wrong.

Teacher forcing also has another important role: it allows training to be done in parallel. At first glance, autoregressive modeling seems to have to be done step by step:

\[ p(y_1 \mid X) \]

\[ p(y_2 \mid y_1, X) \]

\[ p(y_3 \mid y_1, y_2, X) \]

But during training, we already know the complete target sequence, so we can construct the shifted decoder input and the corresponding target label all at once, and then let the Decoder output the predictions for all positions at the same time.

But if we input the complete target prefix all at once, will the model peek at the future? The answer is no. Because there is a causal mask in the Decoder’s self-attention.

The causal mask guarantees:

Position 1 can only see BOS
Position 2 can only see BOS, y1
Position 3 can only see BOS, y1, y2
Position 4 can only see BOS, y1, y2, y3

So, during training, two facts are true at the same time:

  • Computationally, it is parallel: all positions are computed together;
  • Informationally, it is still autoregressive: each position can only see the tokens before itself.

During training, the Decoder output usually goes through a linear layer to be mapped to the vocabulary size, and cross-entropy loss is calculated for each position:

\[ \mathcal{L} = -\sum_{t=1}^{T} \log p(y_t \mid y_{<t}, X) \]

This is the key to efficient Transformer training:

Use teacher forcing to provide real prefixes, and use the causal mask to guarantee that the model cannot see the future, so all positions can be trained in parallel.

8.8.4 Autoregressive Generation: Using the Model’s Own Generated Prefix During Inference

However, inference is different. At this time, we do not have the target answer, so there is no real prefix to feed into the Decoder. The model can only start from <BOS> and generate step by step.

The process is roughly:

Step 1 input: <BOS>
       output: y1

Step 2 input: <BOS>, y1
       output: y2

Step 3 input: <BOS>, y1, y2
       output: y3

Step 4 input: <BOS>, y1, y2, y3
       output: <EOS>

Here, \(y_1, y_2, y_3\) are all tokens generated by the model itself.

This is Autoregressive Generation. The biggest difference from training is that the prefix seen by the Decoder during training comes from the real answer, while the prefix seen by the Decoder during inference comes from the model itself. If the model generates incorrectly at the first step, then the second step can only continue generating based on this wrong token, causing the error to keep accumulating during later generation. This phenomenon is usually called Exposure Bias.

8.8.5 Comparison Between Training and Inference

We can look at training and inference together.

Stage Decoder input Prefix source Computation method
Training Shifted real target sequence Real answer Parallel
Inference Already generated tokens The model itself Serial

Training and inference use the same Encoder-Decoder Transformer, but the source of the Decoder input is different.

During training, the target sequence is already given, so the shifted real target sequence can be used as the prefix, and all positions can be predicted in parallel; during inference, the real target sequence does not exist, so the model can only concatenate the token it just generated back into the input and continue predicting the next token.

8.8.6 Masks in Encoder-Decoder

A complete Encoder-Decoder Transformer usually involves several types of masks. We have already focused on the causal mask earlier, so here we only briefly sort out what each of them masks.

Mask Position of action Main role
Source padding mask Encoder self-attention Masks padding tokens in the source sequence, preventing the Encoder from attending to these positions.
Target padding mask Decoder masked self-attention Masks padding tokens in the target sequence, preventing the Decoder from attending to these positions.
Source causal mask Encoder self-attention Usually not needed, because the Encoder does not involve autoregressive generation; it is only used when we want the Encoder to also process the input from left to right.
Target causal mask Decoder masked self-attention Masks future positions in the target sequence, preventing the Decoder from seeing future tokens during training.
Memory key padding mask Decoder cross-attention Masks the positions in the Encoder output that correspond to padding tokens in the source sequence, preventing the Decoder from attending to these positions.

The names of these masks may differ slightly across frameworks, but the core is just two types:

Padding mask: do not look at padding.
Causal mask: do not look at the future.

8.8.7 nn.Transformer in PyTorch

PyTorch provides a complete Encoder-Decoder Transformer module:

class Transformer(nn.Module):
    def __init__(
        self,
        d_model: int = 512,
        num_heads: int = 8,
        num_encoder_layers: int = 6,
        num_decoder_layers: int = 6,
        dim_feedforward: int = 2048,
        dropout: float = 0.1,
        activation: str | Callable[[Tensor], Tensor] = F.relu,
        layer_norm_eps: float = 1e-5,
        norm_first: bool = False,
        bias: bool = True,
    ):
        super().__init__()
        self.d_model = d_model
        self.num_heads = num_heads

        encoder_layer = dnn.TransformerEncoderLayer(
            d_model,
            num_heads,
            dim_feedforward,
            bias=bias,
            dropout=dropout,
            activation=activation,
            layer_norm_eps=layer_norm_eps,
            norm_first=norm_first,
        )
        encoder_norm = nn.LayerNorm(d_model, eps=layer_norm_eps, bias=bias)
        self.encoder = dnn.TransformerEncoder(
            encoder_layer,
            num_encoder_layers,
            norm=encoder_norm,
        )

        decoder_layer = dnn.TransformerDecoderLayer(
            d_model,
            num_heads,
            dim_feedforward,
            bias=bias,
            dropout=dropout,
            activation=activation,
            layer_norm_eps=layer_norm_eps,
            norm_first=norm_first,
        )
        decoder_norm = nn.LayerNorm(d_model, eps=layer_norm_eps, bias=bias)
        self.decoder = dnn.TransformerDecoder(
            decoder_layer,
            num_decoder_layers,
            norm=decoder_norm,
        )

    def forward(
        self,
        src: Tensor,
        tgt: Tensor,
        src_mask: Tensor | None = None,
        tgt_mask: Tensor | None = None,
        memory_mask: Tensor | None = None,
        src_key_padding_mask: Tensor | None = None,
        tgt_key_padding_mask: Tensor | None = None,
        memory_key_padding_mask: Tensor | None = None,
        src_is_causal: bool = False,
        tgt_is_causal: bool = False,
        memory_is_causal: bool = False,
    ) -> Tensor:
        if src.size(-1) != self.d_model or tgt.size(-1) != self.d_model:
            raise AssertionError(
                'The feature number of `src` and `tgt` must be equal to `d_model`.'
            )
        if src.size(0) != tgt.size(0):
            raise AssertionError('The batch size of `src` and `tgt` must be equal.')

        memory = self.encoder(
            src,
            mask=src_mask,
            src_key_padding_mask=src_key_padding_mask,
            is_causal=src_is_causal,
        )
        output = self.decoder(
            tgt,
            memory,
            tgt_mask=tgt_mask,
            memory_mask=memory_mask,
            tgt_key_padding_mask=tgt_key_padding_mask,
            memory_key_padding_mask=memory_key_padding_mask,
            tgt_is_causal=tgt_is_causal,
            memory_is_causal=memory_is_causal,
        )
        return output


transformer = Transformer(
    d_model=512,
    num_heads=8,
    num_encoder_layers=6,
    num_decoder_layers=6,
    norm_first=True,  # Pre-LN Transformer
)

Internally, it already contains the Encoder and Decoder. When calling it, we usually pass in:

src = torch.randn(2, 32, 512)  # (batch_size, src_seq_len, d_model)
tgt = torch.randn(2, 16, 512)  # (batch_size, tgt_seq_len, d_model)

output = transformer(src, tgt)
print('Output shape:', output.shape)
Output shape: torch.Size([2, 16, 512])

where:

  • src is the Encoder input;
  • tgt is the Decoder input;
  • The output is the Decoder hidden states.

If batch_first=True is used, the shapes are usually:

src: (batch_size, src_seq_len, d_model)
tgt: (batch_size, tgt_seq_len, d_model)
out: (batch_size, tgt_seq_len, d_model)

One thing to note is that nn.Transformer receives not token ids, but vectors that have already been embedded.

So a complete model usually also includes:

  • Source token embedding;
  • Target token embedding;
  • Positional Encoding;
  • Transformer Encoder-Decoder;
  • Output projection.

When masks are included, a common calling form is:

src = torch.randn(2, 32, 512)
tgt = torch.randn(2, 16, 512)
tgt_mask = dF.generate_causal_mask(tgt.size(1), device=tgt.device)
src_key_padding_mask = torch.zeros(2, 32, dtype=torch.bool)
tgt_key_padding_mask = torch.zeros(2, 16, dtype=torch.bool)
memory_key_padding_mask = torch.zeros(2, 32, dtype=torch.bool)

output = transformer(
    src=src,
    tgt=tgt,
    tgt_mask=tgt_mask,  # causal mask
    src_key_padding_mask=src_key_padding_mask,
    tgt_key_padding_mask=tgt_key_padding_mask,
    memory_key_padding_mask=memory_key_padding_mask,
)
print('Output shape:', output.shape)
Output shape: torch.Size([2, 16, 512])

8.8.8 PyTorch Implementation of a Seq2Seq Transformer

Below, we write a simplified Seq2Seq Transformer. Here we directly use nn.Transformer, and the focus is on the complete calling process, rather than repeatedly implementing the Encoder Block and Decoder Block.

class Seq2SeqTransformer(nn.Module):
    def __init__(
        self,
        src_vocab_size: int,
        tgt_vocab_size: int,
        d_model: int = 512,
        num_heads: int = 8,
        num_encoder_layers: int = 6,
        num_decoder_layers: int = 6,
        dim_feedforward: int = 2048,
        dropout: float = 0.1,
        max_len: int = 5000,
    ):
        super().__init__()
        self.d_model = d_model
        self.src_embedding = nn.Embedding(src_vocab_size, d_model)
        self.tgt_embedding = nn.Embedding(tgt_vocab_size, d_model)
        self.pos_encoding = dnn.SinusoidalPositionalEncoding(d_model, max_len)

        self.transformer = Transformer(
            d_model=d_model,
            num_heads=num_heads,
            num_encoder_layers=num_encoder_layers,
            num_decoder_layers=num_decoder_layers,
            dim_feedforward=dim_feedforward,
            dropout=dropout,
            norm_first=True,  # Pre-LN Transformer
        )
        self.output_proj = nn.Linear(d_model, tgt_vocab_size)

    def forward(
        self,
        src: Tensor,
        tgt: Tensor,
        src_mask: Tensor | None = None,
        tgt_mask: Tensor | None = None,
        src_key_padding_mask: Tensor | None = None,
        tgt_key_padding_mask: Tensor | None = None,
        memory_key_padding_mask: Tensor | None = None,
    ) -> Tensor:
        # We scale the embeddings by sqrt(d_model) to maintain the variance
        # of the input to the Transformer.
        scale = math.sqrt(self.d_model)
        src_emb = self.src_embedding(src) * scale
        tgt_emb = self.tgt_embedding(tgt) * scale

        src_emb = self.pos_encoding(src_emb)
        tgt_emb = self.pos_encoding(tgt_emb)

        hidden_states = self.transformer(
            src=src_emb,
            tgt=tgt_emb,
            src_mask=src_mask,
            tgt_mask=tgt_mask,
            src_key_padding_mask=src_key_padding_mask,
            tgt_key_padding_mask=tgt_key_padding_mask,
            memory_key_padding_mask=memory_key_padding_mask,
        )

        logits = self.output_proj(hidden_states)
        return logits


src_vocab_size = 100
tgt_vocab_size = 100
src = torch.randint(0, 100, size=(3, 32))
tgt = torch.randint(0, 100, size=(3, 16))
tgt_mask = dF.generate_causal_mask(tgt.size(1))

seq2seq = Seq2SeqTransformer(
    src_vocab_size=src_vocab_size,
    tgt_vocab_size=tgt_vocab_size,
    d_model=512,
    num_heads=8,
    num_encoder_layers=6,
    num_decoder_layers=6,
)

with torch.inference_mode():
    output = seq2seq(src, tgt, tgt_mask=tgt_mask)

print('Output shape:', output.shape)
Output shape: torch.Size([3, 16, 100])

Here, SinusoidalPositionalEncoding can use the version implemented in 8.5.

The inputs of this model are src and tgt, both of which are token ids. Inside the model, they are first mapped into embeddings, then positional encoding is added, and then they are sent into the Transformer. Finally, the Transformer output goes through a linear layer and is mapped to the target vocabulary size, producing the prediction logits for each position. For each position, the shape of logits is (batch_size, tgt_vocab_size), representing the model’s prediction score for every token at that position.

8.8.9 Training Process of the Seq2Seq Transformer

To train this model, we first need to construct the Decoder inputs and labels.

Suppose target_ids already contains <BOS> and <EOS>:

bos_token_id = 0
eos_token_id = 5
pad_token_id = -100
src_sentence = ['<BOS>', '我', '爱', '深度', '学习', '<EOS>']
tgt_sentence = ['<BOS>', 'I', 'love', 'deep', 'learning', '<EOS>']
src_ids = torch.tensor([[0, 1, 2, 3, 4, 5]], dtype=torch.long)
tgt_ids = torch.tensor([[0, 1, 2, 3, 4, 5]], dtype=torch.long)

Then we can split it into two parts:

# Decoder inputs: [<BOS>, y1, y2, y3]
tgt_inputs = tgt_ids[:, :-1]

# labels: [y1, y2, y3, <EOS>]
tgt_labels = tgt_ids[:, 1:]

Then generate the causal mask:

tgt_mask = dF.generate_causal_mask(
    tgt_inputs.size(1),
    device=tgt_inputs.device,
)

Forward propagation:

logits = seq2seq(
    src=src_ids,
    tgt=tgt_inputs,
    tgt_mask=tgt_mask,
)
print('Logits shape:', logits.shape)
Logits shape: torch.Size([1, 5, 100])

Then compute the cross-entropy loss:

loss = F.cross_entropy(
    logits.reshape(-1, tgt_vocab_size),
    tgt_labels.reshape(-1),
    ignore_index=pad_token_id,
)
print('Loss:', loss.item())
Loss: 4.449498176574707

The meaning here is: for each Decoder position, the model needs to predict the real token at the current position based on the real tokens before the current position and the source sequence representation.

8.8.10 A Simple Autoregressive Generation Example

During inference, we start from <BOS> and continuously append the new token generated by the model to the Decoder input.

@torch.inference_mode()
def generate(
    model: Seq2SeqTransformer,
    src_input_ids: Tensor,
    bos_token_id: int,
    eos_token_id: int,
    pad_token_id: int,
    max_new_tokens: int = 50,
):
    model.eval()

    batch_size = src_input_ids.size(0)
    generated_ids = torch.full(
        (batch_size, 1),
        bos_token_id,
        dtype=torch.long,
        device=src_input_ids.device,
    )

    finished = torch.zeros(
        batch_size,
        dtype=torch.bool,
        device=src_input_ids.device,
    )

    for _ in range(max_new_tokens):
        logits = model(
            src=src_input_ids,
            tgt=generated_ids,
        )

        next_token_logits = logits[:, -1, :]
        next_token_id = next_token_logits.argmax(dim=-1)

        next_token_id = torch.where(
            finished,
            torch.full_like(next_token_id, pad_token_id),
            next_token_id,
        )

        generated_ids = torch.concat(
            [generated_ids, next_token_id.unsqueeze(1)],
            dim=1,
        )

        finished = finished | (next_token_id == eos_token_id)
        if finished.all():
            break

    return generated_ids


generated_ids = generate(
    model=seq2seq,
    src_input_ids=src_ids,
    bos_token_id=bos_token_id,
    eos_token_id=eos_token_id,
    pad_token_id=pad_token_id,
    max_new_tokens=10,
)
print('Generated IDs:', generated_ids)
Generated IDs: tensor([[ 0, 91, 78, 16, 68, 90, 79, 41, 63, 45, 18]])

Here we use the simplest greedy decoding, that is, selecting the token with the highest probability at each step:

\[ \hat{y}_t = \arg\max_y p(y \mid \hat{y}_{<t}, X) \]

In practice, methods such as beam search, top-k sampling, and top-p sampling can also be used. Their differences mainly lie in how to choose the next token from the probability distribution given by the model at each step.

However, no matter how the selection strategy changes, the basic structure of autoregressive generation is the same:

Generate one token, append it back to the Decoder input, and then generate the next token.

8.8.11 Chapter Summary

In this section, we put the Encoder and Decoder into a complete Seq2Seq task and understood how the Encoder-Decoder Transformer works.

From a probability perspective, what the model learns is:

\[ p(Y \mid X) = \prod_{t=1}^{n} p(y_t \mid y_{<t}, X) \]

where \(X\) is processed by the Encoder, and \(y_{<t}\) is provided by the target prefix of the Decoder. The Decoder reads the target prefix through masked self-attention, and reads the Encoder output through cross-attention.

During training, we use teacher forcing, shifting the target sequence one position to the right as the Decoder input, allowing the model to predict the next token under the condition of the real prefix. With the help of the causal mask, the Transformer can compute predictions for all positions in parallel while guaranteeing that each position cannot see future tokens.

During inference, there is no real answer available, so the Decoder can only start from <BOS> and generate autoregressively: each step predicts one token, then appends this token back to the input and continues predicting the next token.

So, the complete logic of the Encoder-Decoder Transformer can be summarized as:

The Encoder first understands the input, and then the Decoder generates the output step by step under the joint condition of the target prefix and the input representation.

In the next section, we will continue to look at a problem closely related to inference efficiency: since autoregressive generation must proceed step by step, why is it not necessary to recompute the key and value of past tokens at every step during actual inference?

This leads to KV Cache.

References

Vaswani, Ashish, Noam Shazeer, Niki Parmar, et al. 2023. Attention Is All You Need. https://arxiv.org/abs/1706.03762.

Reuse