{ "cells": [ { "cell_type": "markdown", "id": "f3bb03f2", "metadata": { "origin_pos": 1 }, "source": [ "# Self-Attention and Positional Encoding\n", ":label:`sec_self-attention-and-positional-encoding`\n", "\n", "In deep learning, we often use CNNs or RNNs to encode sequences.\n", "Now with attention mechanisms in mind, \n", "imagine feeding a sequence of tokens \n", "into an attention mechanism\n", "such that at every step,\n", "each token has its own query, keys, and values.\n", "Here, when computing the value of a token's representation at the next layer,\n", "the token can attend (via its query vector) to any other's token \n", "(matching based on their key vectors).\n", "Using the full set of query-key compatibility scores,\n", "we can compute, for each token, a representation\n", "by building the appropriate weighted sum\n", "over the other tokens. \n", "Because every token is attending to each other token\n", "(unlike the case where decoder steps attend to encoder steps),\n", "such architectures are typically described as *self-attention* models :cite:`Lin.Feng.Santos.ea.2017,Vaswani.Shazeer.Parmar.ea.2017`, \n", "and elsewhere described as *intra-attention* model :cite:`Cheng.Dong.Lapata.2016,Parikh.Tackstrom.Das.ea.2016,Paulus.Xiong.Socher.2017`.\n", "In this section, we will discuss sequence encoding using self-attention,\n", "including using additional information for the sequence order.\n" ] }, { "cell_type": "code", "execution_count": 1, "id": "b2969e34", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:30:32.804452Z", "iopub.status.busy": "2023-08-18T19:30:32.803811Z", "iopub.status.idle": "2023-08-18T19:30:35.929844Z", "shell.execute_reply": "2023-08-18T19:30:35.926598Z" }, "origin_pos": 3, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "import math\n", "import torch\n", "from torch import nn\n", "from d2l import torch as d2l" ] }, { "cell_type": "markdown", "id": "5bb88dd8", "metadata": { "origin_pos": 6 }, "source": [ "## [**Self-Attention**]\n", "\n", "Given a sequence of input tokens\n", "$\\mathbf{x}_1, \\ldots, \\mathbf{x}_n$ where any $\\mathbf{x}_i \\in \\mathbb{R}^d$ ($1 \\leq i \\leq n$),\n", "its self-attention outputs\n", "a sequence of the same length\n", "$\\mathbf{y}_1, \\ldots, \\mathbf{y}_n$,\n", "where\n", "\n", "$$\\mathbf{y}_i = f(\\mathbf{x}_i, (\\mathbf{x}_1, \\mathbf{x}_1), \\ldots, (\\mathbf{x}_n, \\mathbf{x}_n)) \\in \\mathbb{R}^d$$\n", "\n", "according to the definition of attention pooling in\n", ":eqref:`eq_attention_pooling`.\n", "Using multi-head attention,\n", "the following code snippet\n", "computes the self-attention of a tensor\n", "with shape (batch size, number of time steps or sequence length in tokens, $d$).\n", "The output tensor has the same shape.\n" ] }, { "cell_type": "code", "execution_count": 2, "id": "13743b61", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:30:35.935527Z", "iopub.status.busy": "2023-08-18T19:30:35.934433Z", "iopub.status.idle": "2023-08-18T19:30:35.974177Z", "shell.execute_reply": "2023-08-18T19:30:35.973091Z" }, "origin_pos": 7, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "num_hiddens, num_heads = 100, 5\n", "attention = d2l.MultiHeadAttention(num_hiddens, num_heads, 0.5)\n", "batch_size, num_queries, valid_lens = 2, 4, torch.tensor([3, 2])\n", "X = torch.ones((batch_size, num_queries, num_hiddens))\n", "d2l.check_shape(attention(X, X, X, valid_lens),\n", " (batch_size, num_queries, num_hiddens))" ] }, { "cell_type": "markdown", "id": "41ae9a49", "metadata": { "origin_pos": 14 }, "source": [ "## Comparing CNNs, RNNs, and Self-Attention\n", ":label:`subsec_cnn-rnn-self-attention`\n", "\n", "Let's\n", "compare architectures for mapping\n", "a sequence of $n$ tokens\n", "to another one of equal length,\n", "where each input or output token is represented by\n", "a $d$-dimensional vector.\n", "Specifically,\n", "we will consider CNNs, RNNs, and self-attention.\n", "We will compare their\n", "computational complexity, \n", "sequential operations,\n", "and maximum path lengths.\n", "Note that sequential operations prevent parallel computation,\n", "while a shorter path between\n", "any combination of sequence positions\n", "makes it easier to learn long-range dependencies \n", "within the sequence :cite:`Hochreiter.Bengio.Frasconi.ea.2001`.\n", "\n", "\n", "![Comparing CNN (padding tokens are omitted), RNN, and self-attention architectures.](../img/cnn-rnn-self-attention.svg)\n", ":label:`fig_cnn-rnn-self-attention`\n", "\n", "\n", "\n", "Let's regard any text sequence as a \"one-dimensional image\". Similarly, one-dimensional CNNs can process local features such as $n$-grams in text.\n", "Given a sequence of length $n$,\n", "consider a convolutional layer whose kernel size is $k$,\n", "and whose numbers of input and output channels are both $d$.\n", "The computational complexity of the convolutional layer is $\\mathcal{O}(knd^2)$.\n", "As :numref:`fig_cnn-rnn-self-attention` shows,\n", "CNNs are hierarchical,\n", "so there are $\\mathcal{O}(1)$ sequential operations\n", "and the maximum path length is $\\mathcal{O}(n/k)$.\n", "For example, $\\mathbf{x}_1$ and $\\mathbf{x}_5$\n", "are within the receptive field of a two-layer CNN\n", "with kernel size 3 in :numref:`fig_cnn-rnn-self-attention`.\n", "\n", "When updating the hidden state of RNNs,\n", "multiplication of the $d \\times d$ weight matrix\n", "and the $d$-dimensional hidden state has \n", "a computational complexity of $\\mathcal{O}(d^2)$.\n", "Since the sequence length is $n$,\n", "the computational complexity of the recurrent layer\n", "is $\\mathcal{O}(nd^2)$.\n", "According to :numref:`fig_cnn-rnn-self-attention`,\n", "there are $\\mathcal{O}(n)$ sequential operations\n", "that cannot be parallelized\n", "and the maximum path length is also $\\mathcal{O}(n)$.\n", "\n", "In self-attention,\n", "the queries, keys, and values \n", "are all $n \\times d$ matrices.\n", "Consider the scaled dot product attention in\n", ":eqref:`eq_softmax_QK_V`,\n", "where an $n \\times d$ matrix is multiplied by\n", "a $d \\times n$ matrix,\n", "then the output $n \\times n$ matrix is multiplied\n", "by an $n \\times d$ matrix.\n", "As a result,\n", "the self-attention\n", "has a $\\mathcal{O}(n^2d)$ computational complexity.\n", "As we can see from :numref:`fig_cnn-rnn-self-attention`,\n", "each token is directly connected\n", "to any other token via self-attention.\n", "Therefore,\n", "computation can be parallel with $\\mathcal{O}(1)$ sequential operations\n", "and the maximum path length is also $\\mathcal{O}(1)$.\n", "\n", "All in all,\n", "both CNNs and self-attention enjoy parallel computation\n", "and self-attention has the shortest maximum path length.\n", "However, the quadratic computational complexity with respect to the sequence length\n", "makes self-attention prohibitively slow for very long sequences.\n", "\n", "\n", "\n", "\n", "\n", "## [**Positional Encoding**]\n", ":label:`subsec_positional-encoding`\n", "\n", "\n", "Unlike RNNs, which recurrently process\n", "tokens of a sequence one-by-one,\n", "self-attention ditches\n", "sequential operations in favor of \n", "parallel computation.\n", "Note that self-attention by itself\n", "does not preserve the order of the sequence. \n", "What do we do if it really matters \n", "that the model knows in which order\n", "the input sequence arrived?\n", "\n", "The dominant approach for preserving \n", "information about the order of tokens\n", "is to represent this to the model \n", "as an additional input associated \n", "with each token. \n", "These inputs are called *positional encodings*,\n", "and they can either be learned or fixed *a priori*.\n", "We now describe a simple scheme for fixed positional encodings\n", "based on sine and cosine functions :cite:`Vaswani.Shazeer.Parmar.ea.2017`.\n", "\n", "Suppose that the input representation \n", "$\\mathbf{X} \\in \\mathbb{R}^{n \\times d}$ \n", "contains the $d$-dimensional embeddings \n", "for $n$ tokens of a sequence.\n", "The positional encoding outputs\n", "$\\mathbf{X} + \\mathbf{P}$\n", "using a positional embedding matrix \n", "$\\mathbf{P} \\in \\mathbb{R}^{n \\times d}$ of the same shape,\n", "whose element on the $i^\\textrm{th}$ row \n", "and the $(2j)^\\textrm{th}$\n", "or the $(2j + 1)^\\textrm{th}$ column is\n", "\n", "$$\\begin{aligned} p_{i, 2j} &= \\sin\\left(\\frac{i}{10000^{2j/d}}\\right),\\\\p_{i, 2j+1} &= \\cos\\left(\\frac{i}{10000^{2j/d}}\\right).\\end{aligned}$$\n", ":eqlabel:`eq_positional-encoding-def`\n", "\n", "At first glance,\n", "this trigonometric function\n", "design looks weird.\n", "Before we give explanations of this design,\n", "let's first implement it in the following `PositionalEncoding` class.\n" ] }, { "cell_type": "code", "execution_count": 3, "id": "3eb1b5ef", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:30:35.979909Z", "iopub.status.busy": "2023-08-18T19:30:35.978770Z", "iopub.status.idle": "2023-08-18T19:30:35.987465Z", "shell.execute_reply": "2023-08-18T19:30:35.986155Z" }, "origin_pos": 16, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "class PositionalEncoding(nn.Module): #@save\n", " \"\"\"Positional encoding.\"\"\"\n", " def __init__(self, num_hiddens, dropout, max_len=1000):\n", " super().__init__()\n", " self.dropout = nn.Dropout(dropout)\n", " # Create a long enough P\n", " self.P = torch.zeros((1, max_len, num_hiddens))\n", " X = torch.arange(max_len, dtype=torch.float32).reshape(\n", " -1, 1) / torch.pow(10000, torch.arange(\n", " 0, num_hiddens, 2, dtype=torch.float32) / num_hiddens)\n", " self.P[:, :, 0::2] = torch.sin(X)\n", " self.P[:, :, 1::2] = torch.cos(X)\n", "\n", " def forward(self, X):\n", " X = X + self.P[:, :X.shape[1], :].to(X.device)\n", " return self.dropout(X)" ] }, { "cell_type": "markdown", "id": "17ad8db2", "metadata": { "origin_pos": 19 }, "source": [ "In the positional embedding matrix $\\mathbf{P}$,\n", "[**rows correspond to positions within a sequence\n", "and columns represent different positional encoding dimensions**].\n", "In the example below,\n", "we can see that\n", "the $6^{\\textrm{th}}$ and the $7^{\\textrm{th}}$\n", "columns of the positional embedding matrix \n", "have a higher frequency than \n", "the $8^{\\textrm{th}}$ and the $9^{\\textrm{th}}$\n", "columns.\n", "The offset between \n", "the $6^{\\textrm{th}}$ and the $7^{\\textrm{th}}$ (same for the $8^{\\textrm{th}}$ and the $9^{\\textrm{th}}$) columns\n", "is due to the alternation of sine and cosine functions.\n" ] }, { "cell_type": "code", "execution_count": 4, "id": "51320f4e", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:30:35.991251Z", "iopub.status.busy": "2023-08-18T19:30:35.990632Z", "iopub.status.idle": "2023-08-18T19:30:36.368109Z", "shell.execute_reply": "2023-08-18T19:30:36.366973Z" }, "origin_pos": 21, "tab": [ "pytorch" ] }, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-08-18T19:30:36.288792\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.7.2, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "encoding_dim, num_steps = 32, 60\n", "pos_encoding = PositionalEncoding(encoding_dim, 0)\n", "X = pos_encoding(torch.zeros((1, num_steps, encoding_dim)))\n", "P = pos_encoding.P[:, :X.shape[1], :]\n", "d2l.plot(torch.arange(num_steps), P[0, :, 6:10].T, xlabel='Row (position)',\n", " figsize=(6, 2.5), legend=[\"Col %d\" % d for d in torch.arange(6, 10)])" ] }, { "cell_type": "markdown", "id": "811eb6d1", "metadata": { "origin_pos": 24 }, "source": [ "### Absolute Positional Information\n", "\n", "To see how the monotonically decreased frequency\n", "along the encoding dimension relates to absolute positional information,\n", "let's print out [**the binary representations**] of $0, 1, \\ldots, 7$.\n", "As we can see, the lowest bit, the second-lowest bit, \n", "and the third-lowest bit alternate on every number, \n", "every two numbers, and every four numbers, respectively.\n" ] }, { "cell_type": "code", "execution_count": 5, "id": "6f42d89b", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:30:36.373921Z", "iopub.status.busy": "2023-08-18T19:30:36.373258Z", "iopub.status.idle": "2023-08-18T19:30:36.380089Z", "shell.execute_reply": "2023-08-18T19:30:36.378862Z" }, "origin_pos": 25, "tab": [ "pytorch" ] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0 in binary is 000\n", "1 in binary is 001\n", "2 in binary is 010\n", "3 in binary is 011\n", "4 in binary is 100\n", "5 in binary is 101\n", "6 in binary is 110\n", "7 in binary is 111\n" ] } ], "source": [ "for i in range(8):\n", " print(f'{i} in binary is {i:>03b}')" ] }, { "cell_type": "markdown", "id": "b617c79b", "metadata": { "origin_pos": 26 }, "source": [ "In binary representations, a higher bit \n", "has a lower frequency than a lower bit.\n", "Similarly, as demonstrated in the heat map below,\n", "[**the positional encoding decreases\n", "frequencies along the encoding dimension**]\n", "by using trigonometric functions.\n", "Since the outputs are float numbers,\n", "such continuous representations\n", "are more space-efficient\n", "than binary representations.\n" ] }, { "cell_type": "code", "execution_count": 6, "id": "c5f60f9f", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:30:36.384358Z", "iopub.status.busy": "2023-08-18T19:30:36.383531Z", "iopub.status.idle": "2023-08-18T19:30:36.858217Z", "shell.execute_reply": "2023-08-18T19:30:36.857049Z" }, "origin_pos": 28, "tab": [ "pytorch" ] }, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-08-18T19:30:36.784791\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.7.2, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "P = P[0, :, :].unsqueeze(0).unsqueeze(0)\n", "d2l.show_heatmaps(P, xlabel='Column (encoding dimension)',\n", " ylabel='Row (position)', figsize=(3.5, 4), cmap='Blues')" ] }, { "cell_type": "markdown", "id": "6cb8a898", "metadata": { "origin_pos": 31 }, "source": [ "### Relative Positional Information\n", "\n", "Besides capturing absolute positional information,\n", "the above positional encoding\n", "also allows\n", "a model to easily learn to attend by relative positions.\n", "This is because\n", "for any fixed position offset $\\delta$,\n", "the positional encoding at position $i + \\delta$\n", "can be represented by a linear projection\n", "of that at position $i$.\n", "\n", "\n", "This projection can be explained\n", "mathematically.\n", "Denoting\n", "$\\omega_j = 1/10000^{2j/d}$,\n", "any pair of $(p_{i, 2j}, p_{i, 2j+1})$ \n", "in :eqref:`eq_positional-encoding-def`\n", "can \n", "be linearly projected to $(p_{i+\\delta, 2j}, p_{i+\\delta, 2j+1})$\n", "for any fixed offset $\\delta$:\n", "\n", "$$\\begin{aligned}\n", "\\begin{bmatrix} \\cos(\\delta \\omega_j) & \\sin(\\delta \\omega_j) \\\\ -\\sin(\\delta \\omega_j) & \\cos(\\delta \\omega_j) \\\\ \\end{bmatrix}\n", "\\begin{bmatrix} p_{i, 2j} \\\\ p_{i, 2j+1} \\\\ \\end{bmatrix}\n", "=&\\begin{bmatrix} \\cos(\\delta \\omega_j) \\sin(i \\omega_j) + \\sin(\\delta \\omega_j) \\cos(i \\omega_j) \\\\ -\\sin(\\delta \\omega_j) \\sin(i \\omega_j) + \\cos(\\delta \\omega_j) \\cos(i \\omega_j) \\\\ \\end{bmatrix}\\\\\n", "=&\\begin{bmatrix} \\sin\\left((i+\\delta) \\omega_j\\right) \\\\ \\cos\\left((i+\\delta) \\omega_j\\right) \\\\ \\end{bmatrix}\\\\\n", "=& \n", "\\begin{bmatrix} p_{i+\\delta, 2j} \\\\ p_{i+\\delta, 2j+1} \\\\ \\end{bmatrix},\n", "\\end{aligned}$$\n", "\n", "where the $2\\times 2$ projection matrix does not depend on any position index $i$.\n", "\n", "## Summary\n", "\n", "In self-attention, the queries, keys, and values all come from the same place.\n", "Both CNNs and self-attention enjoy parallel computation\n", "and self-attention has the shortest maximum path length.\n", "However, the quadratic computational complexity\n", "with respect to the sequence length\n", "makes self-attention prohibitively slow\n", "for very long sequences.\n", "To use the sequence order information, \n", "we can inject absolute or relative positional information \n", "by adding positional encoding to the input representations.\n", "\n", "## Exercises\n", "\n", "1. Suppose that we design a deep architecture to represent a sequence by stacking self-attention layers with positional encoding. What could the possible issues be?\n", "1. Can you design a learnable positional encoding method?\n", "1. Can we assign different learned embeddings according to different offsets between queries and keys that are compared in self-attention? Hint: you may refer to relative position embeddings :cite:`shaw2018self,huang2018music`.\n" ] }, { "cell_type": "markdown", "id": "2cee2bcf", "metadata": { "origin_pos": 33, "tab": [ "pytorch" ] }, "source": [ "[Discussions](https://discuss.d2l.ai/t/1652)\n" ] } ], "metadata": { "language_info": { "name": "python" }, "required_libs": [] }, "nbformat": 4, "nbformat_minor": 5 }