{ "cells": [ { "cell_type": "markdown", "id": "565b35c4", "metadata": { "origin_pos": 0 }, "source": [ "# Long Short-Term Memory (LSTM)\n", ":label:`sec_lstm`\n", "\n", "\n", "Shortly after the first Elman-style RNNs were trained using backpropagation \n", ":cite:`elman1990finding`, the problems of learning long-term dependencies\n", "(owing to vanishing and exploding gradients)\n", "became salient, with Bengio and Hochreiter \n", "discussing the problem\n", ":cite:`bengio1994learning,Hochreiter.Bengio.Frasconi.ea.2001`.\n", "Hochreiter had articulated this problem as early \n", "as 1991 in his Master's thesis, although the results \n", "were not widely known because the thesis was written in German.\n", "While gradient clipping helps with exploding gradients, \n", "handling vanishing gradients appears \n", "to require a more elaborate solution. \n", "One of the first and most successful techniques \n", "for addressing vanishing gradients \n", "came in the form of the long short-term memory (LSTM) model \n", "due to :citet:`Hochreiter.Schmidhuber.1997`. \n", "LSTMs resemble standard recurrent neural networks \n", "but here each ordinary recurrent node\n", "is replaced by a *memory cell*.\n", "Each memory cell contains an *internal state*,\n", "i.e., a node with a self-connected recurrent edge of fixed weight 1,\n", "ensuring that the gradient can pass across many time steps \n", "without vanishing or exploding.\n", "\n", "The term \"long short-term memory\" comes from the following intuition.\n", "Simple recurrent neural networks \n", "have *long-term memory* in the form of weights.\n", "The weights change slowly during training, \n", "encoding general knowledge about the data.\n", "They also have *short-term memory*\n", "in the form of ephemeral activations,\n", "which pass from each node to successive nodes.\n", "The LSTM model introduces an intermediate type of storage via the memory cell.\n", "A memory cell is a composite unit, \n", "built from simpler nodes \n", "in a specific connectivity pattern,\n", "with the novel inclusion of multiplicative nodes.\n" ] }, { "cell_type": "code", "execution_count": 1, "id": "af24541f", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:53:39.252438Z", "iopub.status.busy": "2023-08-18T19:53:39.251775Z", "iopub.status.idle": "2023-08-18T19:53:42.244563Z", "shell.execute_reply": "2023-08-18T19:53:42.243546Z" }, "origin_pos": 3, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "import torch\n", "from torch import nn\n", "from d2l import torch as d2l" ] }, { "cell_type": "markdown", "id": "4547196c", "metadata": { "origin_pos": 6 }, "source": [ "## Gated Memory Cell\n", "\n", "Each memory cell is equipped with an *internal state*\n", "and a number of multiplicative gates that determine whether\n", "(i) a given input should impact the internal state (the *input gate*),\n", "(ii) the internal state should be flushed to $0$ (the *forget gate*),\n", "and (iii) the internal state of a given neuron \n", "should be allowed to impact the cell's output (the *output* gate). \n", "\n", "\n", "### Gated Hidden State\n", "\n", "The key distinction between vanilla RNNs and LSTMs\n", "is that the latter support gating of the hidden state.\n", "This means that we have dedicated mechanisms for\n", "when a hidden state should be *updated* and\n", "also for when it should be *reset*.\n", "These mechanisms are learned and they address the concerns listed above.\n", "For instance, if the first token is of great importance\n", "we will learn not to update the hidden state after the first observation.\n", "Likewise, we will learn to skip irrelevant temporary observations.\n", "Last, we will learn to reset the latent state whenever needed.\n", "We discuss this in detail below.\n", "\n", "### Input Gate, Forget Gate, and Output Gate\n", "\n", "The data feeding into the LSTM gates are\n", "the input at the current time step and\n", "the hidden state of the previous time step,\n", "as illustrated in :numref:`fig_lstm_0`.\n", "Three fully connected layers with sigmoid activation functions\n", "compute the values of the input, forget, and output gates.\n", "As a result of the sigmoid activation,\n", "all values of the three gates\n", "are in the range of $(0, 1)$.\n", "Additionally, we require an *input node*,\n", "typically computed with a *tanh* activation function. \n", "Intuitively, the *input gate* determines how much\n", "of the input node's value should be added \n", "to the current memory cell internal state.\n", "The *forget gate* determines whether to keep\n", "the current value of the memory or flush it. \n", "And the *output gate* determines whether \n", "the memory cell should influence the output\n", "at the current time step. \n", "\n", "\n", "![Computing the input gate, the forget gate, and the output gate in an LSTM model.](../img/lstm-0.svg)\n", ":label:`fig_lstm_0`\n", "\n", "Mathematically, suppose that there are $h$ hidden units, \n", "the batch size is $n$, and the number of inputs is $d$.\n", "Thus, the input is $\\mathbf{X}_t \\in \\mathbb{R}^{n \\times d}$ \n", "and the hidden state of the previous time step \n", "is $\\mathbf{H}_{t-1} \\in \\mathbb{R}^{n \\times h}$. \n", "Correspondingly, the gates at time step $t$\n", "are defined as follows: the input gate is $\\mathbf{I}_t \\in \\mathbb{R}^{n \\times h}$, \n", "the forget gate is $\\mathbf{F}_t \\in \\mathbb{R}^{n \\times h}$, \n", "and the output gate is $\\mathbf{O}_t \\in \\mathbb{R}^{n \\times h}$. \n", "They are calculated as follows:\n", "\n", "$$\n", "\\begin{aligned}\n", "\\mathbf{I}_t &= \\sigma(\\mathbf{X}_t \\mathbf{W}_{\\textrm{xi}} + \\mathbf{H}_{t-1} \\mathbf{W}_{\\textrm{hi}} + \\mathbf{b}_\\textrm{i}),\\\\\n", "\\mathbf{F}_t &= \\sigma(\\mathbf{X}_t \\mathbf{W}_{\\textrm{xf}} + \\mathbf{H}_{t-1} \\mathbf{W}_{\\textrm{hf}} + \\mathbf{b}_\\textrm{f}),\\\\\n", "\\mathbf{O}_t &= \\sigma(\\mathbf{X}_t \\mathbf{W}_{\\textrm{xo}} + \\mathbf{H}_{t-1} \\mathbf{W}_{\\textrm{ho}} + \\mathbf{b}_\\textrm{o}),\n", "\\end{aligned}\n", "$$\n", "\n", "where $\\mathbf{W}_{\\textrm{xi}}, \\mathbf{W}_{\\textrm{xf}}, \\mathbf{W}_{\\textrm{xo}} \\in \\mathbb{R}^{d \\times h}$ and $\\mathbf{W}_{\\textrm{hi}}, \\mathbf{W}_{\\textrm{hf}}, \\mathbf{W}_{\\textrm{ho}} \\in \\mathbb{R}^{h \\times h}$ are weight parameters \n", "and $\\mathbf{b}_\\textrm{i}, \\mathbf{b}_\\textrm{f}, \\mathbf{b}_\\textrm{o} \\in \\mathbb{R}^{1 \\times h}$ are bias parameters.\n", "Note that broadcasting \n", "(see :numref:`subsec_broadcasting`)\n", "is triggered during the summation.\n", "We use sigmoid functions \n", "(as introduced in :numref:`sec_mlp`) \n", "to map the input values to the interval $(0, 1)$.\n", "\n", "\n", "### Input Node\n", "\n", "Next we design the memory cell. \n", "Since we have not specified the action of the various gates yet, \n", "we first introduce the *input node* \n", "$\\tilde{\\mathbf{C}}_t \\in \\mathbb{R}^{n \\times h}$.\n", "Its computation is similar to that of the three gates described above, \n", "but uses a $\\tanh$ function with a value range for $(-1, 1)$ as the activation function. \n", "This leads to the following equation at time step $t$:\n", "\n", "$$\\tilde{\\mathbf{C}}_t = \\textrm{tanh}(\\mathbf{X}_t \\mathbf{W}_{\\textrm{xc}} + \\mathbf{H}_{t-1} \\mathbf{W}_{\\textrm{hc}} + \\mathbf{b}_\\textrm{c}),$$\n", "\n", "where $\\mathbf{W}_{\\textrm{xc}} \\in \\mathbb{R}^{d \\times h}$ and $\\mathbf{W}_{\\textrm{hc}} \\in \\mathbb{R}^{h \\times h}$ are weight parameters and $\\mathbf{b}_\\textrm{c} \\in \\mathbb{R}^{1 \\times h}$ is a bias parameter.\n", "\n", "A quick illustration of the input node is shown in :numref:`fig_lstm_1`.\n", "\n", "![Computing the input node in an LSTM model.](../img/lstm-1.svg)\n", ":label:`fig_lstm_1`\n", "\n", "\n", "### Memory Cell Internal State\n", "\n", "In LSTMs, the input gate $\\mathbf{I}_t$ governs \n", "how much we take new data into account via $\\tilde{\\mathbf{C}}_t$ \n", "and the forget gate $\\mathbf{F}_t$ addresses \n", "how much of the old cell internal state $\\mathbf{C}_{t-1} \\in \\mathbb{R}^{n \\times h}$ we retain. \n", "Using the Hadamard (elementwise) product operator $\\odot$\n", "we arrive at the following update equation:\n", "\n", "$$\\mathbf{C}_t = \\mathbf{F}_t \\odot \\mathbf{C}_{t-1} + \\mathbf{I}_t \\odot \\tilde{\\mathbf{C}}_t.$$\n", "\n", "If the forget gate is always 1 and the input gate is always 0, \n", "the memory cell internal state $\\mathbf{C}_{t-1}$\n", "will remain constant forever, \n", "passing unchanged to each subsequent time step.\n", "However, input gates and forget gates\n", "give the model the flexibility of being able to learn \n", "when to keep this value unchanged\n", "and when to perturb it in response \n", "to subsequent inputs. \n", "In practice, this design alleviates the vanishing gradient problem,\n", "resulting in models that are much easier to train,\n", "especially when facing datasets with long sequence lengths. \n", "\n", "We thus arrive at the flow diagram in :numref:`fig_lstm_2`.\n", "\n", "![Computing the memory cell internal state in an LSTM model.](../img/lstm-2.svg)\n", "\n", ":label:`fig_lstm_2`\n", "\n", "\n", "### Hidden State\n", "\n", "Last, we need to define how to compute the output\n", "of the memory cell, i.e., the hidden state $\\mathbf{H}_t \\in \\mathbb{R}^{n \\times h}$, as seen by other layers. \n", "This is where the output gate comes into play.\n", "In LSTMs, we first apply $\\tanh$ to the memory cell internal state\n", "and then apply another point-wise multiplication,\n", "this time with the output gate.\n", "This ensures that the values of $\\mathbf{H}_t$ \n", "are always in the interval $(-1, 1)$:\n", "\n", "$$\\mathbf{H}_t = \\mathbf{O}_t \\odot \\tanh(\\mathbf{C}_t).$$\n", "\n", "\n", "Whenever the output gate is close to 1, \n", "we allow the memory cell internal state to impact the subsequent layers uninhibited,\n", "whereas for output gate values close to 0,\n", "we prevent the current memory from impacting other layers of the network\n", "at the current time step. \n", "Note that a memory cell can accrue information \n", "across many time steps without impacting the rest of the network\n", "(as long as the output gate takes values close to 0),\n", "and then suddenly impact the network at a subsequent time step\n", "as soon as the output gate flips from values close to 0\n", "to values close to 1. :numref:`fig_lstm_3` has a graphical illustration of the data flow.\n", "\n", "![Computing the hidden state in an LSTM model.](../img/lstm-3.svg)\n", ":label:`fig_lstm_3`\n", "\n", "\n", "\n", "## Implementation from Scratch\n", "\n", "Now let's implement an LSTM from scratch.\n", "As same as the experiments in :numref:`sec_rnn-scratch`,\n", "we first load *The Time Machine* dataset.\n", "\n", "### [**Initializing Model Parameters**]\n", "\n", "Next, we need to define and initialize the model parameters. \n", "As previously, the hyperparameter `num_hiddens` \n", "dictates the number of hidden units.\n", "We initialize weights following a Gaussian distribution\n", "with 0.01 standard deviation, \n", "and we set the biases to 0.\n" ] }, { "cell_type": "code", "execution_count": 2, "id": "344044a5", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:53:42.249641Z", "iopub.status.busy": "2023-08-18T19:53:42.248681Z", "iopub.status.idle": "2023-08-18T19:53:42.259080Z", "shell.execute_reply": "2023-08-18T19:53:42.257966Z" }, "origin_pos": 7, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "class LSTMScratch(d2l.Module):\n", " def __init__(self, num_inputs, num_hiddens, sigma=0.01):\n", " super().__init__()\n", " self.save_hyperparameters()\n", "\n", " init_weight = lambda *shape: nn.Parameter(torch.randn(*shape) * sigma)\n", " triple = lambda: (init_weight(num_inputs, num_hiddens),\n", " init_weight(num_hiddens, num_hiddens),\n", " nn.Parameter(torch.zeros(num_hiddens)))\n", " self.W_xi, self.W_hi, self.b_i = triple() # Input gate\n", " self.W_xf, self.W_hf, self.b_f = triple() # Forget gate\n", " self.W_xo, self.W_ho, self.b_o = triple() # Output gate\n", " self.W_xc, self.W_hc, self.b_c = triple() # Input node" ] }, { "cell_type": "markdown", "id": "7be7fb35", "metadata": { "origin_pos": 9, "tab": [ "pytorch" ] }, "source": [ "[**The actual model**] is defined as described above,\n", "consisting of three gates and an input node. \n", "Note that only the hidden state is passed to the output layer.\n" ] }, { "cell_type": "code", "execution_count": 3, "id": "3284d4fa", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:53:42.263023Z", "iopub.status.busy": "2023-08-18T19:53:42.262354Z", "iopub.status.idle": "2023-08-18T19:53:42.269844Z", "shell.execute_reply": "2023-08-18T19:53:42.269034Z" }, "origin_pos": 11, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "@d2l.add_to_class(LSTMScratch)\n", "def forward(self, inputs, H_C=None):\n", " if H_C is None:\n", " # Initial state with shape: (batch_size, num_hiddens)\n", " H = torch.zeros((inputs.shape[1], self.num_hiddens),\n", " device=inputs.device)\n", " C = torch.zeros((inputs.shape[1], self.num_hiddens),\n", " device=inputs.device)\n", " else:\n", " H, C = H_C\n", " outputs = []\n", " for X in inputs:\n", " I = torch.sigmoid(torch.matmul(X, self.W_xi) +\n", " torch.matmul(H, self.W_hi) + self.b_i)\n", " F = torch.sigmoid(torch.matmul(X, self.W_xf) +\n", " torch.matmul(H, self.W_hf) + self.b_f)\n", " O = torch.sigmoid(torch.matmul(X, self.W_xo) +\n", " torch.matmul(H, self.W_ho) + self.b_o)\n", " C_tilde = torch.tanh(torch.matmul(X, self.W_xc) +\n", " torch.matmul(H, self.W_hc) + self.b_c)\n", " C = F * C + I * C_tilde\n", " H = O * torch.tanh(C)\n", " outputs.append(H)\n", " return outputs, (H, C)" ] }, { "cell_type": "markdown", "id": "8c83e2d3", "metadata": { "origin_pos": 13 }, "source": [ "### [**Training**] and Prediction\n", "\n", "Let's train an LSTM model by instantiating the `RNNLMScratch` class from :numref:`sec_rnn-scratch`.\n" ] }, { "cell_type": "code", "execution_count": 4, "id": "3c605094", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:53:42.273652Z", "iopub.status.busy": "2023-08-18T19:53:42.273097Z", "iopub.status.idle": "2023-08-18T19:55:28.400186Z", "shell.execute_reply": "2023-08-18T19:55:28.399180Z" }, "origin_pos": 14, "tab": [ "pytorch" ] }, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-08-18T19:55:28.211147\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.7.2, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "data = d2l.TimeMachine(batch_size=1024, num_steps=32)\n", "lstm = LSTMScratch(num_inputs=len(data.vocab), num_hiddens=32)\n", "model = d2l.RNNLMScratch(lstm, vocab_size=len(data.vocab), lr=4)\n", "trainer = d2l.Trainer(max_epochs=50, gradient_clip_val=1, num_gpus=1)\n", "trainer.fit(model, data)" ] }, { "cell_type": "markdown", "id": "1ca6b8b1", "metadata": { "origin_pos": 15 }, "source": [ "## [**Concise Implementation**]\n", "\n", "Using high-level APIs,\n", "we can directly instantiate an LSTM model.\n", "This encapsulates all the configuration details \n", "that we made explicit above. \n", "The code is significantly faster as it uses \n", "compiled operators rather than Python\n", "for many details that we spelled out before.\n" ] }, { "cell_type": "code", "execution_count": 5, "id": "915b335f", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:55:28.408950Z", "iopub.status.busy": "2023-08-18T19:55:28.408640Z", "iopub.status.idle": "2023-08-18T19:55:28.413951Z", "shell.execute_reply": "2023-08-18T19:55:28.413093Z" }, "origin_pos": 17, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "class LSTM(d2l.RNN):\n", " def __init__(self, num_inputs, num_hiddens):\n", " d2l.Module.__init__(self)\n", " self.save_hyperparameters()\n", " self.rnn = nn.LSTM(num_inputs, num_hiddens)\n", "\n", " def forward(self, inputs, H_C=None):\n", " return self.rnn(inputs, H_C)" ] }, { "cell_type": "code", "execution_count": 6, "id": "af084fc2", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:55:28.417830Z", "iopub.status.busy": "2023-08-18T19:55:28.417111Z", "iopub.status.idle": "2023-08-18T19:56:41.586045Z", "shell.execute_reply": "2023-08-18T19:56:41.585125Z" }, "origin_pos": 20, "tab": [ "pytorch" ] }, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-08-18T19:56:41.375391\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.7.2, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "lstm = LSTM(num_inputs=len(data.vocab), num_hiddens=32)\n", "model = d2l.RNNLM(lstm, vocab_size=len(data.vocab), lr=4)\n", "trainer.fit(model, data)" ] }, { "cell_type": "code", "execution_count": 7, "id": "40007e27", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:56:41.599247Z", "iopub.status.busy": "2023-08-18T19:56:41.598314Z", "iopub.status.idle": "2023-08-18T19:56:41.638040Z", "shell.execute_reply": "2023-08-18T19:56:41.636677Z" }, "origin_pos": 21, "tab": [ "pytorch" ] }, "outputs": [ { "data": { "text/plain": [ "'it has a the time travelly'" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model.predict('it has', 20, data.vocab, d2l.try_gpu())" ] }, { "cell_type": "markdown", "id": "d4ede917", "metadata": { "origin_pos": 24 }, "source": [ "LSTMs are the prototypical latent variable autoregressive model with nontrivial state control.\n", "Many variants thereof have been proposed over the years, e.g., multiple layers, residual connections, different types of regularization. However, training LSTMs and other sequence models (such as GRUs) is quite costly because of the long range dependency of the sequence.\n", "Later we will encounter alternative models such as Transformers that can be used in some cases.\n", "\n", "\n", "## Summary\n", "\n", "While LSTMs were published in 1997, \n", "they rose to great prominence \n", "with some victories in prediction competitions in the mid-2000s,\n", "and became the dominant models for sequence learning from 2011 \n", "until the rise of Transformer models, starting in 2017.\n", "Even Tranformers owe some of their key ideas \n", "to architecture design innovations introduced by the LSTM.\n", "\n", "\n", "LSTMs have three types of gates: \n", "input gates, forget gates, and output gates \n", "that control the flow of information.\n", "The hidden layer output of LSTM includes the hidden state and the memory cell internal state. \n", "Only the hidden state is passed into the output layer while \n", "the memory cell internal state remains entirely internal.\n", "LSTMs can alleviate vanishing and exploding gradients.\n", "\n", "\n", "\n", "## Exercises\n", "\n", "1. Adjust the hyperparameters and analyze their influence on running time, perplexity, and the output sequence.\n", "1. How would you need to change the model to generate proper words rather than just sequences of characters?\n", "1. Compare the computational cost for GRUs, LSTMs, and regular RNNs for a given hidden dimension. Pay special attention to the training and inference cost.\n", "1. Since the candidate memory cell ensures that the value range is between $-1$ and $1$ by using the $\\tanh$ function, why does the hidden state need to use the $\\tanh$ function again to ensure that the output value range is between $-1$ and $1$?\n", "1. Implement an LSTM model for time series prediction rather than character sequence prediction.\n" ] }, { "cell_type": "markdown", "id": "3dba411a", "metadata": { "origin_pos": 26, "tab": [ "pytorch" ] }, "source": [ "[Discussions](https://discuss.d2l.ai/t/1057)\n" ] } ], "metadata": { "language_info": { "name": "python" }, "required_libs": [] }, "nbformat": 4, "nbformat_minor": 5 }