{
"cells": [
{
"cell_type": "markdown",
"id": "b5336169",
"metadata": {
"origin_pos": 0
},
"source": [
"# Concise Implementation of Recurrent Neural Networks\n",
":label:`sec_rnn-concise`\n",
"\n",
"Like most of our from-scratch implementations,\n",
":numref:`sec_rnn-scratch` was designed \n",
"to provide insight into how each component works.\n",
"But when you are using RNNs every day \n",
"or writing production code,\n",
"you will want to rely more on libraries\n",
"that cut down on both implementation time \n",
"(by supplying library code for common models and functions)\n",
"and computation time \n",
"(by optimizing the heck out of these library implementations).\n",
"This section will show you how to implement \n",
"the same language model more efficiently\n",
"using the high-level API provided \n",
"by your deep learning framework.\n",
"We begin, as before, by loading \n",
"*The Time Machine* dataset.\n"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "37720cdc",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T20:06:50.935241Z",
"iopub.status.busy": "2023-08-18T20:06:50.934437Z",
"iopub.status.idle": "2023-08-18T20:06:54.183049Z",
"shell.execute_reply": "2023-08-18T20:06:54.181824Z"
},
"origin_pos": 3,
"tab": [
"pytorch"
]
},
"outputs": [],
"source": [
"import torch\n",
"from torch import nn\n",
"from torch.nn import functional as F\n",
"from d2l import torch as d2l"
]
},
{
"cell_type": "markdown",
"id": "17a09762",
"metadata": {
"origin_pos": 6
},
"source": [
"## [**Defining the Model**]\n",
"\n",
"We define the following class\n",
"using the RNN implemented\n",
"by high-level APIs.\n"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "694fd386",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T20:06:54.188820Z",
"iopub.status.busy": "2023-08-18T20:06:54.188200Z",
"iopub.status.idle": "2023-08-18T20:06:54.196537Z",
"shell.execute_reply": "2023-08-18T20:06:54.195266Z"
},
"origin_pos": 10,
"tab": [
"pytorch"
]
},
"outputs": [],
"source": [
"class RNN(d2l.Module): #@save\n",
" \"\"\"The RNN model implemented with high-level APIs.\"\"\"\n",
" def __init__(self, num_inputs, num_hiddens):\n",
" super().__init__()\n",
" self.save_hyperparameters()\n",
" self.rnn = nn.RNN(num_inputs, num_hiddens)\n",
"\n",
" def forward(self, inputs, H=None):\n",
" return self.rnn(inputs, H)"
]
},
{
"cell_type": "markdown",
"id": "bdc68a58",
"metadata": {
"origin_pos": 13
},
"source": [
"Inheriting from the `RNNLMScratch` class in :numref:`sec_rnn-scratch`, \n",
"the following `RNNLM` class defines a complete RNN-based language model.\n",
"Note that we need to create a separate fully connected output layer.\n"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "3a92b933",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T20:06:54.201256Z",
"iopub.status.busy": "2023-08-18T20:06:54.200440Z",
"iopub.status.idle": "2023-08-18T20:06:54.207741Z",
"shell.execute_reply": "2023-08-18T20:06:54.206137Z"
},
"origin_pos": 14,
"tab": [
"pytorch"
]
},
"outputs": [],
"source": [
"class RNNLM(d2l.RNNLMScratch): #@save\n",
" \"\"\"The RNN-based language model implemented with high-level APIs.\"\"\"\n",
" def init_params(self):\n",
" self.linear = nn.LazyLinear(self.vocab_size)\n",
"\n",
" def output_layer(self, hiddens):\n",
" return self.linear(hiddens).swapaxes(0, 1)"
]
},
{
"cell_type": "markdown",
"id": "79498980",
"metadata": {
"origin_pos": 17
},
"source": [
"## Training and Predicting\n",
"\n",
"Before training the model, let's [**make a prediction \n",
"with a model initialized with random weights.**]\n",
"Given that we have not trained the network, \n",
"it will generate nonsensical predictions.\n"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "b4134fb6",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T20:06:54.212417Z",
"iopub.status.busy": "2023-08-18T20:06:54.211559Z",
"iopub.status.idle": "2023-08-18T20:06:55.980476Z",
"shell.execute_reply": "2023-08-18T20:06:55.979215Z"
},
"origin_pos": 18,
"tab": [
"pytorch"
]
},
"outputs": [
{
"data": {
"text/plain": [
"'it hasoadd dd dd dd dd dd '"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"data = d2l.TimeMachine(batch_size=1024, num_steps=32)\n",
"rnn = RNN(num_inputs=len(data.vocab), num_hiddens=32)\n",
"model = RNNLM(rnn, vocab_size=len(data.vocab), lr=1)\n",
"model.predict('it has', 20, data.vocab)"
]
},
{
"cell_type": "markdown",
"id": "70d7bcae",
"metadata": {
"origin_pos": 19
},
"source": [
"Next, we [**train our model, leveraging the high-level API**].\n"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "c17fc0f4",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T20:06:55.986250Z",
"iopub.status.busy": "2023-08-18T20:06:55.985457Z",
"iopub.status.idle": "2023-08-18T20:10:44.012328Z",
"shell.execute_reply": "2023-08-18T20:10:44.010882Z"
},
"origin_pos": 20,
"tab": [
"pytorch"
]
},
"outputs": [
{
"data": {
"image/svg+xml": [
"\n",
"\n",
"\n"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"trainer = d2l.Trainer(max_epochs=100, gradient_clip_val=1, num_gpus=1)\n",
"trainer.fit(model, data)"
]
},
{
"cell_type": "markdown",
"id": "7cd397c3",
"metadata": {
"origin_pos": 21
},
"source": [
"Compared with :numref:`sec_rnn-scratch`,\n",
"this model achieves comparable perplexity,\n",
"but runs faster due to the optimized implementations.\n",
"As before, we can generate predicted tokens \n",
"following the specified prefix string.\n"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "01c712d0",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T20:10:44.040685Z",
"iopub.status.busy": "2023-08-18T20:10:44.039824Z",
"iopub.status.idle": "2023-08-18T20:10:44.084788Z",
"shell.execute_reply": "2023-08-18T20:10:44.075036Z"
},
"origin_pos": 22,
"tab": [
"pytorch"
]
},
"outputs": [
{
"data": {
"text/plain": [
"'it has and the trave the t'"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model.predict('it has', 20, data.vocab, d2l.try_gpu())"
]
},
{
"cell_type": "markdown",
"id": "e34eebb6",
"metadata": {
"origin_pos": 24
},
"source": [
"## Summary\n",
"\n",
"High-level APIs in deep learning frameworks provide implementations of standard RNNs.\n",
"These libraries help you to avoid wasting time reimplementing standard models.\n",
"Moreover,\n",
"framework implementations are often highly optimized, \n",
" leading to significant (computational) performance gains \n",
" when compared with implementations from scratch.\n",
"\n",
"## Exercises\n",
"\n",
"1. Can you make the RNN model overfit using the high-level APIs?\n",
"1. Implement the autoregressive model of :numref:`sec_sequence` using an RNN.\n"
]
},
{
"cell_type": "markdown",
"id": "16726511",
"metadata": {
"origin_pos": 26,
"tab": [
"pytorch"
]
},
"source": [
"[Discussions](https://discuss.d2l.ai/t/1053)\n"
]
}
],
"metadata": {
"language_info": {
"name": "python"
},
"required_libs": []
},
"nbformat": 4,
"nbformat_minor": 5
}