{
"cells": [
{
"cell_type": "markdown",
"id": "e4433f58",
"metadata": {
"origin_pos": 1
},
"source": [
"# The Bahdanau Attention Mechanism\n",
":label:`sec_seq2seq_attention`\n",
"\n",
"When we encountered machine translation in :numref:`sec_seq2seq`,\n",
"we designed an encoder--decoder architecture for sequence-to-sequence learning\n",
"based on two RNNs :cite:`Sutskever.Vinyals.Le.2014`.\n",
"Specifically, the RNN encoder transforms a variable-length sequence\n",
"into a *fixed-shape* context variable.\n",
"Then, the RNN decoder generates the output (target) sequence token by token\n",
"based on the generated tokens and the context variable.\n",
"\n",
"Recall :numref:`fig_seq2seq_details` which we repeat (:numref:`fig_s2s_attention_state`) with some additional detail. Conventionally, in an RNN all relevant information about a source sequence is translated into some internal *fixed-dimensional* state representation by the encoder. It is this very state that is used by the decoder as the complete and exclusive source of information for generating the translated sequence. In other words, the sequence-to-sequence mechanism treats the intermediate state as a sufficient statistic of whatever string might have served as input.\n",
"\n",
"\n",
":label:`fig_s2s_attention_state`\n",
"\n",
"While this is quite reasonable for short sequences, it is clear that it is infeasible for long ones, such as a book chapter or even just a very long sentence. After all, before too long there will simply not be enough \"space\" in the intermediate representation to store all that is important in the source sequence. Consequently the decoder will fail to translate long and complex sentences. One of the first to encounter this was :citet:`Graves.2013` who tried to design an RNN to generate handwritten text. Since the source text has arbitrary length they designed a differentiable attention model\n",
"to align text characters with the much longer pen trace,\n",
"where the alignment moves only in one direction. This, in turn, draws on decoding algorithms in speech recognition, e.g., hidden Markov models :cite:`rabiner1993fundamentals`.\n",
"\n",
"Inspired by the idea of learning to align,\n",
":citet:`Bahdanau.Cho.Bengio.2014` proposed a differentiable attention model\n",
"*without* the unidirectional alignment limitation.\n",
"When predicting a token,\n",
"if not all the input tokens are relevant,\n",
"the model aligns (or attends)\n",
"only to parts of the input sequence\n",
"that are deemed relevant to the current prediction. This is then used to update the current state before generating the next token. While quite innocuous in its description, this *Bahdanau attention mechanism* has arguably turned into one of the most influential ideas of the past decade in deep learning, giving rise to Transformers :cite:`Vaswani.Shazeer.Parmar.ea.2017` and many related new architectures.\n"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "64405f0b",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T19:46:13.059684Z",
"iopub.status.busy": "2023-08-18T19:46:13.058589Z",
"iopub.status.idle": "2023-08-18T19:46:16.109138Z",
"shell.execute_reply": "2023-08-18T19:46:16.108090Z"
},
"origin_pos": 3,
"tab": [
"pytorch"
]
},
"outputs": [],
"source": [
"import torch\n",
"from torch import nn\n",
"from d2l import torch as d2l"
]
},
{
"cell_type": "markdown",
"id": "2196d7ca",
"metadata": {
"origin_pos": 6
},
"source": [
"## Model\n",
"\n",
"We follow the notation introduced by the sequence-to-sequence architecture of :numref:`sec_seq2seq`, in particular :eqref:`eq_seq2seq_s_t`.\n",
"The key idea is that instead of keeping the state,\n",
"i.e., the context variable $\\mathbf{c}$ summarizing the source sentence, as fixed, we dynamically update it, as a function of both the original text (encoder hidden states $\\mathbf{h}_{t}$) and the text that was already generated (decoder hidden states $\\mathbf{s}_{t'-1}$). This yields $\\mathbf{c}_{t'}$, which is updated after any decoding time step $t'$. Suppose that the input sequence is of length $T$. In this case the context variable is the output of attention pooling:\n",
"\n",
"$$\\mathbf{c}_{t'} = \\sum_{t=1}^{T} \\alpha(\\mathbf{s}_{t' - 1}, \\mathbf{h}_{t}) \\mathbf{h}_{t}.$$\n",
"\n",
"We used $\\mathbf{s}_{t' - 1}$ as the query, and\n",
"$\\mathbf{h}_{t}$ as both the key and the value. Note that $\\mathbf{c}_{t'}$ is then used to generate the state $\\mathbf{s}_{t'}$ and to generate a new token: see :eqref:`eq_seq2seq_s_t`. In particular, the attention weight $\\alpha$ is computed as in :eqref:`eq_attn-scoring-alpha`\n",
"using the additive attention scoring function\n",
"defined by :eqref:`eq_additive-attn`.\n",
"This RNN encoder--decoder architecture\n",
"using attention is depicted in :numref:`fig_s2s_attention_details`. Note that later this model was modified so as to include the already generated tokens in the decoder as further context (i.e., the attention sum does not stop at $T$ but rather it proceeds up to $t'-1$). For instance, see :citet:`chan2015listen` for a description of this strategy, as applied to speech recognition.\n",
"\n",
"\n",
":label:`fig_s2s_attention_details`\n",
"\n",
"## Defining the Decoder with Attention\n",
"\n",
"To implement the RNN encoder--decoder with attention,\n",
"we only need to redefine the decoder (omitting the generated symbols from the attention function simplifies the design). Let's begin with [**the base interface for decoders with attention**] by defining the quite unsurprisingly named `AttentionDecoder` class.\n"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "7392fc80",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T19:46:16.115524Z",
"iopub.status.busy": "2023-08-18T19:46:16.114277Z",
"iopub.status.idle": "2023-08-18T19:46:16.121848Z",
"shell.execute_reply": "2023-08-18T19:46:16.120397Z"
},
"origin_pos": 7,
"tab": [
"pytorch"
]
},
"outputs": [],
"source": [
"class AttentionDecoder(d2l.Decoder): #@save\n",
" \"\"\"The base attention-based decoder interface.\"\"\"\n",
" def __init__(self):\n",
" super().__init__()\n",
"\n",
" @property\n",
" def attention_weights(self):\n",
" raise NotImplementedError"
]
},
{
"cell_type": "markdown",
"id": "33cd556d",
"metadata": {
"origin_pos": 8
},
"source": [
"We need to [**implement the RNN decoder**]\n",
"in the `Seq2SeqAttentionDecoder` class.\n",
"The state of the decoder is initialized with\n",
"(i) the hidden states of the last layer of the encoder at all time steps, used as keys and values for attention;\n",
"(ii) the hidden state of the encoder at all layers at the final time step, which serves to initialize the hidden state of the decoder;\n",
"and (iii) the valid length of the encoder, to exclude the padding tokens in attention pooling.\n",
"At each decoding time step, the hidden state of the final layer of the decoder, obtained at the previous time step, is used as the query of the attention mechanism.\n",
"Both the output of the attention mechanism and the input embedding are concatenated to serve as the input of the RNN decoder.\n"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "f0a3f536",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T19:46:16.128312Z",
"iopub.status.busy": "2023-08-18T19:46:16.125898Z",
"iopub.status.idle": "2023-08-18T19:46:16.142962Z",
"shell.execute_reply": "2023-08-18T19:46:16.141775Z"
},
"origin_pos": 10,
"tab": [
"pytorch"
]
},
"outputs": [],
"source": [
"class Seq2SeqAttentionDecoder(AttentionDecoder):\n",
" def __init__(self, vocab_size, embed_size, num_hiddens, num_layers,\n",
" dropout=0):\n",
" super().__init__()\n",
" self.attention = d2l.AdditiveAttention(num_hiddens, dropout)\n",
" self.embedding = nn.Embedding(vocab_size, embed_size)\n",
" self.rnn = nn.GRU(\n",
" embed_size + num_hiddens, num_hiddens, num_layers,\n",
" dropout=dropout)\n",
" self.dense = nn.LazyLinear(vocab_size)\n",
" self.apply(d2l.init_seq2seq)\n",
"\n",
" def init_state(self, enc_outputs, enc_valid_lens):\n",
" # Shape of outputs: (num_steps, batch_size, num_hiddens).\n",
" # Shape of hidden_state: (num_layers, batch_size, num_hiddens)\n",
" outputs, hidden_state = enc_outputs\n",
" return (outputs.permute(1, 0, 2), hidden_state, enc_valid_lens)\n",
"\n",
" def forward(self, X, state):\n",
" # Shape of enc_outputs: (batch_size, num_steps, num_hiddens).\n",
" # Shape of hidden_state: (num_layers, batch_size, num_hiddens)\n",
" enc_outputs, hidden_state, enc_valid_lens = state\n",
" # Shape of the output X: (num_steps, batch_size, embed_size)\n",
" X = self.embedding(X).permute(1, 0, 2)\n",
" outputs, self._attention_weights = [], []\n",
" for x in X:\n",
" # Shape of query: (batch_size, 1, num_hiddens)\n",
" query = torch.unsqueeze(hidden_state[-1], dim=1)\n",
" # Shape of context: (batch_size, 1, num_hiddens)\n",
" context = self.attention(\n",
" query, enc_outputs, enc_outputs, enc_valid_lens)\n",
" # Concatenate on the feature dimension\n",
" x = torch.cat((context, torch.unsqueeze(x, dim=1)), dim=-1)\n",
" # Reshape x as (1, batch_size, embed_size + num_hiddens)\n",
" out, hidden_state = self.rnn(x.permute(1, 0, 2), hidden_state)\n",
" outputs.append(out)\n",
" self._attention_weights.append(self.attention.attention_weights)\n",
" # After fully connected layer transformation, shape of outputs:\n",
" # (num_steps, batch_size, vocab_size)\n",
" outputs = self.dense(torch.cat(outputs, dim=0))\n",
" return outputs.permute(1, 0, 2), [enc_outputs, hidden_state,\n",
" enc_valid_lens]\n",
"\n",
" @property\n",
" def attention_weights(self):\n",
" return self._attention_weights"
]
},
{
"cell_type": "markdown",
"id": "af402335",
"metadata": {
"origin_pos": 13
},
"source": [
"In the following, we [**test the implemented\n",
"decoder**] with attention\n",
"using a minibatch of four sequences, each of which are seven time steps long.\n"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "e7d6e370",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T19:46:16.149119Z",
"iopub.status.busy": "2023-08-18T19:46:16.148126Z",
"iopub.status.idle": "2023-08-18T19:46:16.201990Z",
"shell.execute_reply": "2023-08-18T19:46:16.200456Z"
},
"origin_pos": 14,
"tab": [
"pytorch"
]
},
"outputs": [],
"source": [
"vocab_size, embed_size, num_hiddens, num_layers = 10, 8, 16, 2\n",
"batch_size, num_steps = 4, 7\n",
"encoder = d2l.Seq2SeqEncoder(vocab_size, embed_size, num_hiddens, num_layers)\n",
"decoder = Seq2SeqAttentionDecoder(vocab_size, embed_size, num_hiddens,\n",
" num_layers)\n",
"X = torch.zeros((batch_size, num_steps), dtype=torch.long)\n",
"state = decoder.init_state(encoder(X), None)\n",
"output, state = decoder(X, state)\n",
"d2l.check_shape(output, (batch_size, num_steps, vocab_size))\n",
"d2l.check_shape(state[0], (batch_size, num_steps, num_hiddens))\n",
"d2l.check_shape(state[1][0], (batch_size, num_hiddens))"
]
},
{
"cell_type": "markdown",
"id": "c6c52079",
"metadata": {
"origin_pos": 15
},
"source": [
"## [**Training**]\n",
"\n",
"Now that we specified the new decoder we can proceed analogously to :numref:`sec_seq2seq_training`:\n",
"specify the hyperparameters, instantiate\n",
"a regular encoder and a decoder with attention,\n",
"and train this model for machine translation.\n"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "a73f9cc6",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T19:46:16.207768Z",
"iopub.status.busy": "2023-08-18T19:46:16.207253Z",
"iopub.status.idle": "2023-08-18T19:46:52.077164Z",
"shell.execute_reply": "2023-08-18T19:46:52.076268Z"
},
"origin_pos": 16,
"tab": [
"pytorch"
]
},
"outputs": [
{
"data": {
"image/svg+xml": [
"\n",
"\n",
"\n"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"data = d2l.MTFraEng(batch_size=128)\n",
"embed_size, num_hiddens, num_layers, dropout = 256, 256, 2, 0.2\n",
"encoder = d2l.Seq2SeqEncoder(\n",
" len(data.src_vocab), embed_size, num_hiddens, num_layers, dropout)\n",
"decoder = Seq2SeqAttentionDecoder(\n",
" len(data.tgt_vocab), embed_size, num_hiddens, num_layers, dropout)\n",
"model = d2l.Seq2Seq(encoder, decoder, tgt_pad=data.tgt_vocab[''],\n",
" lr=0.005)\n",
"trainer = d2l.Trainer(max_epochs=30, gradient_clip_val=1, num_gpus=1)\n",
"trainer.fit(model, data)"
]
},
{
"cell_type": "markdown",
"id": "62af9112",
"metadata": {
"origin_pos": 17
},
"source": [
"After the model is trained,\n",
"we use it to [**translate a few English sentences**]\n",
"into French and compute their BLEU scores.\n"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "a22c2296",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T19:46:52.082122Z",
"iopub.status.busy": "2023-08-18T19:46:52.081526Z",
"iopub.status.idle": "2023-08-18T19:46:52.108612Z",
"shell.execute_reply": "2023-08-18T19:46:52.107700Z"
},
"origin_pos": 18,
"tab": [
"pytorch"
]
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"go . => ['va', '!'], bleu,1.000\n",
"i lost . => [\"j'ai\", 'perdu', '.'], bleu,1.000\n",
"he's calm . => ['il', 'court', '.'], bleu,0.000\n",
"i'm home . => ['je', 'suis', 'chez', 'moi', '.'], bleu,1.000\n"
]
}
],
"source": [
"engs = ['go .', 'i lost .', 'he\\'s calm .', 'i\\'m home .']\n",
"fras = ['va !', 'j\\'ai perdu .', 'il est calme .', 'je suis chez moi .']\n",
"preds, _ = model.predict_step(\n",
" data.build(engs, fras), d2l.try_gpu(), data.num_steps)\n",
"for en, fr, p in zip(engs, fras, preds):\n",
" translation = []\n",
" for token in data.tgt_vocab.to_tokens(p):\n",
" if token == '':\n",
" break\n",
" translation.append(token)\n",
" print(f'{en} => {translation}, bleu,'\n",
" f'{d2l.bleu(\" \".join(translation), fr, k=2):.3f}')"
]
},
{
"cell_type": "markdown",
"id": "f0015d7f",
"metadata": {
"origin_pos": 19
},
"source": [
"Let's [**visualize the attention weights**]\n",
"when translating the last English sentence.\n",
"We see that each query assigns non-uniform weights\n",
"over key--value pairs.\n",
"It shows that at each decoding step,\n",
"different parts of the input sequences\n",
"are selectively aggregated in the attention pooling.\n"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "5b39b45c",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T19:46:52.111885Z",
"iopub.status.busy": "2023-08-18T19:46:52.111597Z",
"iopub.status.idle": "2023-08-18T19:46:52.130667Z",
"shell.execute_reply": "2023-08-18T19:46:52.129810Z"
},
"origin_pos": 20,
"tab": [
"pytorch"
]
},
"outputs": [],
"source": [
"_, dec_attention_weights = model.predict_step(\n",
" data.build([engs[-1]], [fras[-1]]), d2l.try_gpu(), data.num_steps, True)\n",
"attention_weights = torch.cat(\n",
" [step[0][0][0] for step in dec_attention_weights], 0)\n",
"attention_weights = attention_weights.reshape((1, 1, -1, data.num_steps))"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "e9c665a8",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T19:46:52.134058Z",
"iopub.status.busy": "2023-08-18T19:46:52.133495Z",
"iopub.status.idle": "2023-08-18T19:46:52.369882Z",
"shell.execute_reply": "2023-08-18T19:46:52.368741Z"
},
"origin_pos": 22,
"tab": [
"pytorch"
]
},
"outputs": [
{
"data": {
"image/svg+xml": [
"\n",
"\n",
"\n"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# Plus one to include the end-of-sequence token\n",
"d2l.show_heatmaps(\n",
" attention_weights[:, :, :, :len(engs[-1].split()) + 1].cpu(),\n",
" xlabel='Key positions', ylabel='Query positions')"
]
},
{
"cell_type": "markdown",
"id": "ebbb775a",
"metadata": {
"origin_pos": 25
},
"source": [
"## Summary\n",
"\n",
"When predicting a token, if not all the input tokens are relevant, the RNN encoder--decoder with the Bahdanau attention mechanism selectively aggregates different parts of the input sequence. This is achieved by treating the state (context variable) as an output of additive attention pooling.\n",
"In the RNN encoder--decoder, the Bahdanau attention mechanism treats the decoder hidden state at the previous time step as the query, and the encoder hidden states at all the time steps as both the keys and values.\n",
"\n",
"\n",
"## Exercises\n",
"\n",
"1. Replace GRU with LSTM in the experiment.\n",
"1. Modify the experiment to replace the additive attention scoring function with the scaled dot-product. How does it influence the training efficiency?\n"
]
},
{
"cell_type": "markdown",
"id": "3f7db1a8",
"metadata": {
"origin_pos": 27,
"tab": [
"pytorch"
]
},
"source": [
"[Discussions](https://discuss.d2l.ai/t/1065)\n"
]
}
],
"metadata": {
"language_info": {
"name": "python"
},
"required_libs": []
},
"nbformat": 4,
"nbformat_minor": 5
}