{
"cells": [
{
"cell_type": "markdown",
"id": "201a0938",
"metadata": {
"origin_pos": 0
},
"source": [
"# Gated Recurrent Units (GRU)\n",
":label:`sec_gru`\n",
"\n",
"\n",
"As RNNs and particularly the LSTM architecture (:numref:`sec_lstm`)\n",
"rapidly gained popularity during the 2010s,\n",
"a number of researchers began to experiment \n",
"with simplified architectures in hopes \n",
"of retaining the key idea of incorporating\n",
"an internal state and multiplicative gating mechanisms\n",
"but with the aim of speeding up computation.\n",
"The gated recurrent unit (GRU) :cite:`Cho.Van-Merrienboer.Bahdanau.ea.2014` \n",
"offered a streamlined version of the LSTM memory cell\n",
"that often achieves comparable performance\n",
"but with the advantage of being faster \n",
"to compute :cite:`Chung.Gulcehre.Cho.ea.2014`.\n"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "6851ec0b",
"metadata": {
"attributes": {
"classes": [],
"id": "",
"n": "7"
},
"execution": {
"iopub.execute_input": "2023-08-18T19:50:04.809302Z",
"iopub.status.busy": "2023-08-18T19:50:04.808778Z",
"iopub.status.idle": "2023-08-18T19:50:07.808414Z",
"shell.execute_reply": "2023-08-18T19:50:07.807417Z"
},
"origin_pos": 3,
"tab": [
"pytorch"
]
},
"outputs": [],
"source": [
"import torch\n",
"from torch import nn\n",
"from d2l import torch as d2l"
]
},
{
"cell_type": "markdown",
"id": "9d895702",
"metadata": {
"origin_pos": 6
},
"source": [
"## Reset Gate and Update Gate\n",
"\n",
"Here, the LSTM's three gates are replaced by two:\n",
"the *reset gate* and the *update gate*.\n",
"As with LSTMs, these gates are given sigmoid activations,\n",
"forcing their values to lie in the interval $(0, 1)$.\n",
"Intuitively, the reset gate controls how much of the previous state \n",
"we might still want to remember.\n",
"Likewise, an update gate would allow us to control \n",
"how much of the new state is just a copy of the old one.\n",
":numref:`fig_gru_1` illustrates the inputs for both\n",
"the reset and update gates in a GRU, \n",
"given the input of the current time step\n",
"and the hidden state of the previous time step.\n",
"The outputs of the gates are given \n",
"by two fully connected layers\n",
"with a sigmoid activation function.\n",
"\n",
"\n",
":label:`fig_gru_1`\n",
"\n",
"Mathematically, for a given time step $t$,\n",
"suppose that the input is a minibatch\n",
"$\\mathbf{X}_t \\in \\mathbb{R}^{n \\times d}$ \n",
"(number of examples $=n$; number of inputs $=d$)\n",
"and the hidden state of the previous time step \n",
"is $\\mathbf{H}_{t-1} \\in \\mathbb{R}^{n \\times h}$ \n",
"(number of hidden units $=h$). \n",
"Then the reset gate $\\mathbf{R}_t \\in \\mathbb{R}^{n \\times h}$ \n",
"and update gate $\\mathbf{Z}_t \\in \\mathbb{R}^{n \\times h}$ are computed as follows:\n",
"\n",
"$$\n",
"\\begin{aligned}\n",
"\\mathbf{R}_t = \\sigma(\\mathbf{X}_t \\mathbf{W}_{\\textrm{xr}} + \\mathbf{H}_{t-1} \\mathbf{W}_{\\textrm{hr}} + \\mathbf{b}_\\textrm{r}),\\\\\n",
"\\mathbf{Z}_t = \\sigma(\\mathbf{X}_t \\mathbf{W}_{\\textrm{xz}} + \\mathbf{H}_{t-1} \\mathbf{W}_{\\textrm{hz}} + \\mathbf{b}_\\textrm{z}),\n",
"\\end{aligned}\n",
"$$\n",
"\n",
"where $\\mathbf{W}_{\\textrm{xr}}, \\mathbf{W}_{\\textrm{xz}} \\in \\mathbb{R}^{d \\times h}$ \n",
"and $\\mathbf{W}_{\\textrm{hr}}, \\mathbf{W}_{\\textrm{hz}} \\in \\mathbb{R}^{h \\times h}$ \n",
"are weight parameters and $\\mathbf{b}_\\textrm{r}, \\mathbf{b}_\\textrm{z} \\in \\mathbb{R}^{1 \\times h}$ \n",
"are bias parameters.\n",
"\n",
"\n",
"## Candidate Hidden State\n",
"\n",
"Next, we integrate the reset gate $\\mathbf{R}_t$ \n",
"with the regular updating mechanism\n",
"in :eqref:`rnn_h_with_state`,\n",
"leading to the following\n",
"*candidate hidden state*\n",
"$\\tilde{\\mathbf{H}}_t \\in \\mathbb{R}^{n \\times h}$ at time step $t$:\n",
"\n",
"$$\\tilde{\\mathbf{H}}_t = \\tanh(\\mathbf{X}_t \\mathbf{W}_{\\textrm{xh}} + \\left(\\mathbf{R}_t \\odot \\mathbf{H}_{t-1}\\right) \\mathbf{W}_{\\textrm{hh}} + \\mathbf{b}_\\textrm{h}),$$\n",
":eqlabel:`gru_tilde_H`\n",
"\n",
"where $\\mathbf{W}_{\\textrm{xh}} \\in \\mathbb{R}^{d \\times h}$ and $\\mathbf{W}_{\\textrm{hh}} \\in \\mathbb{R}^{h \\times h}$\n",
"are weight parameters,\n",
"$\\mathbf{b}_\\textrm{h} \\in \\mathbb{R}^{1 \\times h}$\n",
"is the bias,\n",
"and the symbol $\\odot$ is the Hadamard (elementwise) product operator.\n",
"Here we use a tanh activation function.\n",
"\n",
"The result is a *candidate*, since we still need \n",
"to incorporate the action of the update gate.\n",
"Comparing with :eqref:`rnn_h_with_state`,\n",
"the influence of the previous states\n",
"can now be reduced with the\n",
"elementwise multiplication of\n",
"$\\mathbf{R}_t$ and $\\mathbf{H}_{t-1}$\n",
"in :eqref:`gru_tilde_H`.\n",
"Whenever the entries in the reset gate $\\mathbf{R}_t$ are close to 1, \n",
"we recover a vanilla RNN such as that in :eqref:`rnn_h_with_state`.\n",
"For all entries of the reset gate $\\mathbf{R}_t$ that are close to 0, \n",
"the candidate hidden state is the result of an MLP with $\\mathbf{X}_t$ as input. \n",
"Any pre-existing hidden state is thus *reset* to defaults.\n",
"\n",
":numref:`fig_gru_2` illustrates the computational flow after applying the reset gate.\n",
"\n",
"\n",
":label:`fig_gru_2`\n",
"\n",
"\n",
"## Hidden State\n",
"\n",
"Finally, we need to incorporate the effect of the update gate $\\mathbf{Z}_t$.\n",
"This determines the extent to which the new hidden state $\\mathbf{H}_t \\in \\mathbb{R}^{n \\times h}$ \n",
"matches the old state $\\mathbf{H}_{t-1}$ compared with how much \n",
"it resembles the new candidate state $\\tilde{\\mathbf{H}}_t$.\n",
"The update gate $\\mathbf{Z}_t$ can be used for this purpose, \n",
"simply by taking elementwise convex combinations \n",
"of $\\mathbf{H}_{t-1}$ and $\\tilde{\\mathbf{H}}_t$.\n",
"This leads to the final update equation for the GRU:\n",
"\n",
"$$\\mathbf{H}_t = \\mathbf{Z}_t \\odot \\mathbf{H}_{t-1} + (1 - \\mathbf{Z}_t) \\odot \\tilde{\\mathbf{H}}_t.$$\n",
"\n",
"\n",
"Whenever the update gate $\\mathbf{Z}_t$ is close to 1,\n",
"we simply retain the old state. \n",
"In this case the information from $\\mathbf{X}_t$ is ignored, \n",
"effectively skipping time step $t$ in the dependency chain. \n",
"By contrast, whenever $\\mathbf{Z}_t$ is close to 0,\n",
"the new latent state $\\mathbf{H}_t$ approaches the candidate latent state $\\tilde{\\mathbf{H}}_t$. \n",
":numref:`fig_gru_3` shows the computational flow after the update gate is in action.\n",
"\n",
"\n",
":label:`fig_gru_3`\n",
"\n",
"\n",
"In summary, GRUs have the following two distinguishing features:\n",
"\n",
"* Reset gates help capture short-term dependencies in sequences.\n",
"* Update gates help capture long-term dependencies in sequences.\n",
"\n",
"## Implementation from Scratch\n",
"\n",
"To gain a better understanding of the GRU model, let's implement it from scratch.\n",
"\n",
"### (**Initializing Model Parameters**)\n",
"\n",
"The first step is to initialize the model parameters.\n",
"We draw the weights from a Gaussian distribution\n",
"with standard deviation to be `sigma` and set the bias to 0. \n",
"The hyperparameter `num_hiddens` defines the number of hidden units.\n",
"We instantiate all weights and biases relating to the update gate, \n",
"the reset gate, and the candidate hidden state.\n"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "1f2fcd5e",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T19:50:07.813979Z",
"iopub.status.busy": "2023-08-18T19:50:07.813174Z",
"iopub.status.idle": "2023-08-18T19:50:07.819841Z",
"shell.execute_reply": "2023-08-18T19:50:07.818739Z"
},
"origin_pos": 7,
"tab": [
"pytorch"
]
},
"outputs": [],
"source": [
"class GRUScratch(d2l.Module):\n",
" def __init__(self, num_inputs, num_hiddens, sigma=0.01):\n",
" super().__init__()\n",
" self.save_hyperparameters()\n",
"\n",
" init_weight = lambda *shape: nn.Parameter(torch.randn(*shape) * sigma)\n",
" triple = lambda: (init_weight(num_inputs, num_hiddens),\n",
" init_weight(num_hiddens, num_hiddens),\n",
" nn.Parameter(torch.zeros(num_hiddens)))\n",
" self.W_xz, self.W_hz, self.b_z = triple() # Update gate\n",
" self.W_xr, self.W_hr, self.b_r = triple() # Reset gate\n",
" self.W_xh, self.W_hh, self.b_h = triple() # Candidate hidden state"
]
},
{
"cell_type": "markdown",
"id": "8b2f43a9",
"metadata": {
"origin_pos": 9
},
"source": [
"### Defining the Model\n",
"\n",
"Now we are ready to [**define the GRU forward computation**].\n",
"Its structure is the same as that of the basic RNN cell, \n",
"except that the update equations are more complex.\n"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "78b86a43",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T19:50:07.824621Z",
"iopub.status.busy": "2023-08-18T19:50:07.823909Z",
"iopub.status.idle": "2023-08-18T19:50:07.830603Z",
"shell.execute_reply": "2023-08-18T19:50:07.829486Z"
},
"origin_pos": 10,
"tab": [
"pytorch"
]
},
"outputs": [],
"source": [
"@d2l.add_to_class(GRUScratch)\n",
"def forward(self, inputs, H=None):\n",
" if H is None:\n",
" # Initial state with shape: (batch_size, num_hiddens)\n",
" H = torch.zeros((inputs.shape[1], self.num_hiddens),\n",
" device=inputs.device)\n",
" outputs = []\n",
" for X in inputs:\n",
" Z = torch.sigmoid(torch.matmul(X, self.W_xz) +\n",
" torch.matmul(H, self.W_hz) + self.b_z)\n",
" R = torch.sigmoid(torch.matmul(X, self.W_xr) +\n",
" torch.matmul(H, self.W_hr) + self.b_r)\n",
" H_tilde = torch.tanh(torch.matmul(X, self.W_xh) +\n",
" torch.matmul(R * H, self.W_hh) + self.b_h)\n",
" H = Z * H + (1 - Z) * H_tilde\n",
" outputs.append(H)\n",
" return outputs, H"
]
},
{
"cell_type": "markdown",
"id": "25f1402e",
"metadata": {
"origin_pos": 12
},
"source": [
"### Training\n",
"\n",
"[**Training**] a language model on *The Time Machine* dataset\n",
"works in exactly the same manner as in :numref:`sec_rnn-scratch`.\n"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "ecd79fad",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T19:50:07.835201Z",
"iopub.status.busy": "2023-08-18T19:50:07.834646Z",
"iopub.status.idle": "2023-08-18T19:51:44.215275Z",
"shell.execute_reply": "2023-08-18T19:51:44.214117Z"
},
"origin_pos": 13,
"tab": [
"pytorch"
]
},
"outputs": [
{
"data": {
"image/svg+xml": [
"\n",
"\n",
"\n"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"data = d2l.TimeMachine(batch_size=1024, num_steps=32)\n",
"gru = GRUScratch(num_inputs=len(data.vocab), num_hiddens=32)\n",
"model = d2l.RNNLMScratch(gru, vocab_size=len(data.vocab), lr=4)\n",
"trainer = d2l.Trainer(max_epochs=50, gradient_clip_val=1, num_gpus=1)\n",
"trainer.fit(model, data)"
]
},
{
"cell_type": "markdown",
"id": "04b6222c",
"metadata": {
"origin_pos": 14
},
"source": [
"## [**Concise Implementation**]\n",
"\n",
"In high-level APIs, we can directly instantiate a GRU model.\n",
"This encapsulates all the configuration detail that we made explicit above.\n"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "4b6caa68",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T19:51:44.228345Z",
"iopub.status.busy": "2023-08-18T19:51:44.227717Z",
"iopub.status.idle": "2023-08-18T19:51:44.233105Z",
"shell.execute_reply": "2023-08-18T19:51:44.232084Z"
},
"origin_pos": 15,
"tab": [
"pytorch"
]
},
"outputs": [],
"source": [
"class GRU(d2l.RNN):\n",
" def __init__(self, num_inputs, num_hiddens):\n",
" d2l.Module.__init__(self)\n",
" self.save_hyperparameters()\n",
" self.rnn = nn.GRU(num_inputs, num_hiddens)"
]
},
{
"cell_type": "markdown",
"id": "8cb2801e",
"metadata": {
"origin_pos": 17
},
"source": [
"The code is significantly faster in training as it uses compiled operators \n",
"rather than Python.\n"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "66c56966",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T19:51:44.237345Z",
"iopub.status.busy": "2023-08-18T19:51:44.237065Z",
"iopub.status.idle": "2023-08-18T19:52:51.996558Z",
"shell.execute_reply": "2023-08-18T19:52:51.995714Z"
},
"origin_pos": 18,
"tab": [
"pytorch"
]
},
"outputs": [
{
"data": {
"image/svg+xml": [
"\n",
"\n",
"\n"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"gru = GRU(num_inputs=len(data.vocab), num_hiddens=32)\n",
"model = d2l.RNNLM(gru, vocab_size=len(data.vocab), lr=4)\n",
"trainer.fit(model, data)"
]
},
{
"cell_type": "markdown",
"id": "6da99e11",
"metadata": {
"origin_pos": 19
},
"source": [
"After training, we print out the perplexity on the training set\n",
"and the predicted sequence following the provided prefix.\n"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "33f8aee3",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T19:52:52.004246Z",
"iopub.status.busy": "2023-08-18T19:52:52.003659Z",
"iopub.status.idle": "2023-08-18T19:52:52.029661Z",
"shell.execute_reply": "2023-08-18T19:52:52.028855Z"
},
"origin_pos": 20,
"tab": [
"pytorch"
]
},
"outputs": [
{
"data": {
"text/plain": [
"'it has so it and the time '"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model.predict('it has', 20, data.vocab, d2l.try_gpu())"
]
},
{
"cell_type": "markdown",
"id": "b28f4509",
"metadata": {
"origin_pos": 23
},
"source": [
"## Summary\n",
"\n",
"Compared with LSTMs, GRUs achieve similar performance but tend to be lighter computationally.\n",
"Generally, compared with simple RNNs, gated RNNS, just like LSTMs and GRUs,\n",
"can better capture dependencies for sequences with large time step distances.\n",
"GRUs contain basic RNNs as their extreme case whenever the reset gate is switched on. \n",
"They can also skip subsequences by turning on the update gate.\n",
"\n",
"\n",
"## Exercises\n",
"\n",
"1. Assume that we only want to use the input at time step $t'$ to predict the output at time step $t > t'$. What are the best values for the reset and update gates for each time step?\n",
"1. Adjust the hyperparameters and analyze their influence on running time, perplexity, and the output sequence.\n",
"1. Compare runtime, perplexity, and the output strings for `rnn.RNN` and `rnn.GRU` implementations with each other.\n",
"1. What happens if you implement only parts of a GRU, e.g., with only a reset gate or only an update gate?\n"
]
},
{
"cell_type": "markdown",
"id": "17cec523",
"metadata": {
"origin_pos": 25,
"tab": [
"pytorch"
]
},
"source": [
"[Discussions](https://discuss.d2l.ai/t/1056)\n"
]
}
],
"metadata": {
"language_info": {
"name": "python"
},
"required_libs": []
},
"nbformat": 4,
"nbformat_minor": 5
}