from collections.abc import Callable
import dnnl.nn as dnn
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
print('PyTorch version:', torch.__version__)PyTorch version: 2.12.0+xpu
jshn9515
2026-05-05
2026-05-05
Earlier, we looked at the Transformer Encoder. The core task of the encoder is to represent each token in the input sequence as a vector with contextual information. After multiple layers of self-attention and feed-forward networks, each position in the input sequence can already fuse information from other positions.
If we are only doing tasks such as text classification, sequence labeling, or image classification, an encoder is often enough. This is because these tasks usually only need to understand the input and then make predictions based on the input representation.
But if the task is a sequence-to-sequence generation task such as machine translation, text summarization, or image captioning, things are different. The model not only needs to understand the input, but also needs to generate the output step by step.
For example, in machine translation, the input is an English sentence:
I love deep learning.
The output might be a Chinese sentence:
我喜欢深度学习。
At this point, the model needs to do two things:
The encoder is responsible for the first thing, and the decoder is responsible for the second thing.
Therefore, the Transformer decoder can be understood as a generation-side Transformer. It not only needs to look at the content it has already generated, but also needs to read source-sequence information from the encoder output.
PyTorch version: 2.12.0+xpu
Structurally, the Transformer Decoder and encoder are very similar. Both are formed by stacking multiple blocks. Each block also contains attention, residual connections, LayerNorm, and a feed-forward network. However, the decoder has two extra key points compared with the encoder.
The first difference is that self-attention in the decoder needs a mask.
In generation tasks, the model generates tokens from left to right step by step. When it predicts the \(t\)-th token, it can only see the tokens that have already been generated before it, and it cannot peek at future tokens. Otherwise, the model is not predicting the next word; it is copying the answer. Therefore, self-attention in the decoder must add a causal mask, ensuring that each position can only attend to itself and previous positions, and cannot see future positions.
The second difference is that the decoder also needs cross-attention.
The decoder cannot only look at the part it has already generated. It also needs to refer to the encoder’s understanding of the input sequence. For example, in translation, when generating Chinese words, it needs to keep looking back at which parts of the English source sentence are most relevant. This process is completed through cross-attention.
So, a standard Transformer Decoder block usually contains three parts:
Let’s first look at self-attention inside the decoder.
In the encoder, every position can see the entire input sequence. For example, if the input sentence has 5 tokens, then the 1st token can attend to the 2nd, 3rd, 4th, and 5th tokens, and the 5th token can also attend to all the previous tokens. This is reasonable for understanding tasks. Since the input is already fully given, the model can naturally use context bidirectionally.
But the decoder is different. The decoder is used to generate sequences. When generating the first word, the later words have not been generated yet; when generating the second word, it only knows the first word; when generating the third word, it only knows the first two words. In other words, when generating position \(t\), the model can only depend on information from positions 1 to \(t\), and cannot depend on information after \(t+1\). If we do not add this restriction during training, self-attention will allow every position to see the entire target sequence.
For example, when training translation, the target sentence is:
我喜欢深度学习。
When the model predicts “喜欢”, if it can already see the later “深度学习”, isn’t it directly peeking at the answer? This is clearly unreasonable. Also, if the model uses future information during training, but those future tokens do not exist during real generation, this will cause the model’s performance at inference time to drop sharply.
Therefore, decoder self-attention must add a causal mask, also called a look-ahead mask. Its function is:
Each position can only attend to itself and the positions before it, and cannot attend to future positions.
Then how is it implemented?
Suppose our target sequence length is 4. Ordinary self-attention will compute a \(4 \times 4\) attention score matrix:
\[ S = \frac{QK^\top}{\sqrt{d_k}} \]
Here, the \(i\)-th row represents the attention scores from the \(i\)-th position to all positions.
Without a mask, each row can see all columns:
\[ \begin{array}{c|cccc} & \text{pos 1} & \text{pos 2} & \text{pos 3} & \text{pos 4} \\ \hline \text{pos 1} & \checkmark & \checkmark & \checkmark & \checkmark \\ \text{pos 2} & \checkmark & \checkmark & \checkmark & \checkmark \\ \text{pos 3} & \checkmark & \checkmark & \checkmark & \checkmark \\ \text{pos 4} & \checkmark & \checkmark & \checkmark & \checkmark \\ \end{array} \]
But in the decoder, we want it to become:
\[ \begin{array}{c|cccc} & \text{pos 1} & \text{pos 2} & \text{pos 3} & \text{pos 4} \\ \hline \text{pos 1} & \checkmark & \times & \times & \times \\ \text{pos 2} & \checkmark & \checkmark & \times & \times \\ \text{pos 3} & \checkmark & \checkmark & \checkmark & \times \\ \text{pos 4} & \checkmark & \checkmark & \checkmark & \checkmark \\ \end{array} \]
That is, the 1st position can only see the 1st position; the 2nd position can see the 1st and 2nd positions; the 3rd position can see the 1st, 2nd, and 3rd positions, and so on.
In implementation, we usually set the attention scores of future positions to a very small number, such as \(-\infty\):
\[ S_{ij} = -\infty, \quad \text{if}\, j > i \]
Then we apply softmax.
Because:
\[ \exp(-\infty) = 0 \]
So the attention weights corresponding to these future positions become 0. In this way, the model cannot take information from future tokens.
Thus, masked self-attention can be written as:
\[ \operatorname{MaskedAttention}(Q, K, V) = \operatorname{softmax} \left(\frac{QK^\top}{\sqrt{d_k}} + M \right)V \]
Here, \(M\) is the mask matrix. For positions that are allowed to be attended to, \(M_{ij}=0\); for positions that are not allowed to be attended to, \(M_{ij}=-\infty\).
As for why self-attention is needed, it is actually easy to understand. In machine translation, we need to continue generating the next word based on the target words that have already been generated. During this process, we need to know the internal dependencies in the target language. For example, the model may need to determine what the subject is, what the predicate is, and where the current sentence structure has developed. This part of the information does not come from the source sentence, but from the target sentence itself.
So decoder self-attention solves the problem:
For the current position to be generated, how should it use the already generated target-side context?
Its formula is almost the same as encoder self-attention. The only difference is the mask.
For the target sequence representation \(Y\), we have:
\[ Q = YW_Q, \quad K = YW_K, \quad V = YW_V \]
Then compute masked self-attention:
\[ H = \operatorname{MaskedAttention}(Q, K, V) \]
Here, \(Q\), \(K\), and \(V\) all come from the decoder’s current target sequence representation, so it is still self-attention.
However, masked self-attention alone is not enough.
If the decoder only looks at the tokens it has already generated, it is like an ordinary language model: it can only continue writing based on the previous context. But tasks such as machine translation and summarization are not free continuation; they need to generate output based on the input content. Therefore, the decoder must also be able to dynamically read the encoder output. This step is cross-attention.
In cross-attention, we have:
\[ Q = H_{dec} W_Q, \quad K = H_{enc} W_K, \quad V = H_{enc} W_V \]
Then compute:
\[ \operatorname{CrossAttention}(Q, K, V) = \operatorname{softmax} \left(\frac{QK^\top}{\sqrt{d_k}} \right)V \]
Here, \(H_{dec}\) is the representation obtained by the decoder after masked self-attention, and \(H_{enc}\) is the representation obtained after the encoder encodes the input sequence.
Intuitively, the decoder is using its current state to ask a question:
I am about to generate the next target word. What information do I need to find from the source sentence?
Cross-attention calculates the relevance between each decoder position and each encoder position, and then retrieves weighted information from the encoder values. This is why cross-attention connects the encoder and the decoder.
If you still remember the Bahdanau attention discussed earlier, you will find that cross-attention is very similar to early attention.
In early Bahdanau attention, every time the decoder generates a word, it takes the current decoder hidden state and matches it with all encoder hidden states. Then it performs a weighted sum over the encoder hidden states to obtain a context vector.
Corresponding this to the Transformer:
So the cross-attention in the Transformer Decoder is not a completely unfamiliar new thing. It can be seen as the modern version of early encoder-decoder attention inside the Transformer structure.
The difference is that Transformer no longer relies on RNNs to generate hidden states step by step. Instead, it uses self-attention and positional encoding to model sequence representations. At the same time, cross-attention is also written in the standard Q, K, V form, and it can be computed in parallel through multi-head attention.
Now we can put together a Transformer Decoder block. Suppose the target sequence representation input to the decoder block is \(Y\), and the encoder output is \(H_{enc}\).
First, perform masked self-attention:
\[ Y' = \operatorname{MaskedSelfAttention}(Y) \]
Then add the residual connection and LayerNorm:
\[ Y = \operatorname{LayerNorm}(Y + Y') \]
Second, perform cross-attention. At this point, the query comes from the decoder, while the key and value come from the encoder:
\[ Y' = \operatorname{CrossAttention}(Y, H_{enc}, H_{enc}) \]
Then add the residual connection and LayerNorm:
\[ Y = \operatorname{LayerNorm}(Y + Y') \]
Third, pass through the position-wise feed-forward network:
\[ Y' = \operatorname{FFN}(Y) \]
Finally, add the residual connection and LayerNorm again:
\[ Y = \operatorname{LayerNorm}(Y + Y') \]
Therefore, a decoder block can be written as:
Compared with the encoder block, the decoder block has one more cross-attention module, and its self-attention needs a causal mask.
Below, we implement a simplified Transformer Decoder block. Here, to make the structure clearer, we directly use PyTorch’s built-in nn.MultiheadAttention.
class TransformerDecoderLayer(nn.Module):
def __init__(
self,
d_model: int,
num_heads: int,
dim_feedforward: int = 2048,
activation: Callable[[Tensor], Tensor] = F.relu,
bias: bool = True,
dropout: float = 0.1,
):
super().__init__()
self.self_attn = dnn.MultiheadAttention(
d_model,
num_heads,
dropout=dropout,
bias=bias,
)
self.mha_attn = dnn.MultiheadAttention(
d_model,
num_heads,
dropout=dropout,
bias=bias,
)
self.linear1 = nn.Linear(d_model, dim_feedforward, bias=bias)
self.dropout = nn.Dropout(dropout)
self.linear2 = nn.Linear(dim_feedforward, d_model, bias=bias)
self.activation = activation
self.norm1 = nn.LayerNorm(d_model, bias=bias)
self.norm2 = nn.LayerNorm(d_model, bias=bias)
self.norm3 = nn.LayerNorm(d_model, bias=bias)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
self.dropout3 = nn.Dropout(dropout)
def forward(
self,
tgt: Tensor,
memory: Tensor,
tgt_mask: Tensor | None = None,
memory_mask: Tensor | None = None,
tgt_key_padding_mask: Tensor | None = None,
memory_key_padding_mask: Tensor | None = None,
) -> Tensor:
x = tgt + self._sa_block(
self.norm1(tgt),
tgt_mask,
tgt_key_padding_mask,
)
x = x + self._mha_block(
self.norm2(x),
memory,
memory_mask,
memory_key_padding_mask,
)
return x + self._ff_block(self.norm3(x))
def _sa_block(
self,
x: Tensor,
attn_mask: Tensor | None,
key_padding_mask: Tensor | None,
) -> Tensor:
x, _ = self.self_attn(
x, x, x,
attn_mask=attn_mask,
key_padding_mask=key_padding_mask,
need_weights=False,
) # fmt: skip
return self.dropout1(x)
def _mha_block(
self,
x: Tensor,
memory: Tensor,
attn_mask: Tensor | None,
key_padding_mask: Tensor | None,
) -> Tensor:
x, _ = self.mha_attn(
x,
memory,
memory,
attn_mask=attn_mask,
key_padding_mask=key_padding_mask,
need_weights=False,
)
return self.dropout2(x)
def _ff_block(self, x: Tensor) -> Tensor:
"""Apply the feed-forward block and dropout."""
x = self.linear1(x)
x = self.activation(x)
x = self.dropout(x)
x = self.linear2(x)
x = self.dropout3(x)
return x
x = torch.randn(2, 16, 512) # (batch_size, seq_len, d_model)
memory = torch.randn(2, 32, 512) # (batch_size, src_seq_len, d_model)
decoder_layer = TransformerDecoderLayer(d_model=512, num_heads=8)
with torch.inference_mode():
output = decoder_layer(x, memory)
print('Decoder Block output shape:', output.shape)Decoder Block output shape: torch.Size([2, 16, 512])
Here, tgt is the target sequence representation, that is, the sequence currently processed by the decoder; memory is the encoder output, that is, the contextual representation obtained after the source sequence passes through the encoder. tgt_key_padding_mask and memory_key_padding_mask are the padding masks for the target sequence and source sequence respectively.
In masked self-attention:
target -> query
target -> key
target -> value
In cross-attention:
target -> query
encoder output -> key
encoder output -> value
This exactly corresponds to what we discussed earlier: the decoder uses its own state as the query to retrieve relevant information from the encoder output.
Like the encoder, we usually stack multiple decoder blocks to form a complete Transformer Decoder.
class TransformerDecoder(nn.Module):
def __init__(
self,
d_model: int,
num_heads: int,
num_layers: int,
dim_feedforward: int = 2048,
activation: Callable[[Tensor], Tensor] = F.relu,
bias: bool = True,
dropout: float = 0.1,
):
super().__init__()
self.layers = nn.ModuleList(
[
TransformerDecoderLayer(
d_model=d_model,
num_heads=num_heads,
dim_feedforward=dim_feedforward,
activation=activation,
bias=bias,
dropout=dropout,
)
for _ in range(num_layers)
]
)
self.norm = nn.LayerNorm(d_model)
def forward(
self,
tgt: Tensor,
memory: Tensor,
tgt_mask: Tensor | None = None,
memory_mask: Tensor | None = None,
tgt_key_padding_mask: Tensor | None = None,
memory_key_padding_mask: Tensor | None = None,
) -> Tensor:
output = tgt
for layer in self.layers:
output = layer(
output,
memory,
tgt_mask=tgt_mask,
memory_mask=memory_mask,
tgt_key_padding_mask=tgt_key_padding_mask,
memory_key_padding_mask=memory_key_padding_mask,
)
output = self.norm(output)
return output
x = torch.randn(2, 16, 512) # (batch_size, tgt_seq_len, d_model)
memory = torch.randn(2, 32, 512) # (batch_size, src_seq_len, d_model)
decoder = TransformerDecoder(d_model=512, num_heads=8, num_layers=6)
with torch.inference_mode():
output = decoder(x, memory)
print('Decoder output shape:', output.shape)Decoder output shape: torch.Size([2, 16, 512])
It should be noted that the decoder input tgt is the representation of the target sequence, not the original token ids. Usually, we first convert the target token ids into vector representations through an embedding layer, and then input them into the decoder. This is the same as how the encoder input is handled.
After discussing the decoder structure, let’s look at how to generate the causal mask. Below is a simple causal mask generation function:
tensor([[0., -inf, -inf, -inf],
[0., 0., -inf, -inf],
[0., 0., 0., -inf],
[0., 0., 0., 0.]])
It will be added to the attention scores, making the softmax weights at the corresponding positions become 0.
Below is a simple example showing how to apply the causal mask to the Transformer Decoder block:
x = torch.randn(2, 16, 512) # (batch_size, tgt_seq_len, d_model)
memory = torch.randn(2, 32, 512) # (batch_size, src_seq_len, d_model)
decoder_layer = TransformerDecoderLayer(d_model=512, num_heads=8)
mask = generate_causal_mask(x.size(-2), device=x.device)
with torch.inference_mode():
output = decoder_layer(x, memory, tgt_mask=mask)
print('Decoder Block output shape:', output.shape)Decoder Block output shape: torch.Size([2, 16, 512])
As you can see, we pass the generated causal mask to the tgt_mask parameter. In this way, the mask will be automatically applied in masked self-attention, ensuring that each position can only see itself and the positions before it.
At this point, the core mechanism of the decoder is basically complete. It first uses masked self-attention to process the target sequence that has already been generated, and then reads source-sequence information from the encoder output through cross-attention. The causal mask plays a very important role here, because it prevents the decoder from seeing future tokens in advance even during training.
Next, we can look at a more practical question: how exactly is the decoder used to generate text?
During training, we usually already have the complete target sequence. For example, the target sentence is:
<BOS> 我 喜欢 深度 学习
What the model needs to predict is:
我 喜欢 深度 学习 <EOS>
Although the full target sequence is visible during training, because of the causal mask, each position can still only see the positions before it. Therefore, the model will not peek at future answers. It can only predict the next word based on the part that has already been generated.
During inference, the decoder starts from a start token and generates step by step:
<BOS>
First predict the first token:
我
Then append it back to the input:
<BOS> 我
Then predict the next token:
喜欢
This process repeats continuously until <EOS> is generated or the maximum length is reached.
So the decoder generation process is essentially autoregressive:
\[ p(y_1, y_2, \dots, y_T \mid x) = \prod_{t=1}^{T} p(y_t \mid y_{<t}, x) \]
Here, \(x\) is the source sequence, and \(y_{<t}\) is the target tokens that have already been generated.
Masked self-attention is responsible for modeling \(y_{<t}\), and cross-attention is responsible for using the input sequence \(x\).
In this section, we broke down the core structure of the Transformer Decoder.
The decoder is very similar to the encoder. Both are composed of attention, feed-forward networks, residual connections, and LayerNorm. But the decoder has two key characteristics. First, its self-attention is masked self-attention, where each position can only attend to itself and previous positions, and cannot see future tokens. Second, it contains cross-attention, which uses the decoder representation as the query to retrieve source-sequence information from the encoder output.
Therefore, the decoder takes on two tasks at the same time. On the one hand, it needs to behave like a language model and continue generating based on the already generated target-side context. On the other hand, it also needs to continuously refer to the encoder output, ensuring that the generated content corresponds to the input sequence.
At this point, both the Transformer encoder and decoder have been covered. In the next section, we can combine these components and see how the complete Encoder-Decoder Transformer performs sequence-to-sequence modeling.