{ "cells": [ { "cell_type": "markdown", "id": "9f09fc0f", "metadata": { "origin_pos": 1 }, "source": [ "# Object-Oriented Design for Implementation\n", ":label:`sec_oo-design`\n", "\n", "In our introduction to linear regression,\n", "we walked through various components\n", "including\n", "the data, the model, the loss function,\n", "and the optimization algorithm.\n", "Indeed,\n", "linear regression is\n", "one of the simplest machine learning models.\n", "Training it,\n", "however, uses many of the same components that other models in this book require.\n", "Therefore, \n", "before diving into the implementation details\n", "it is worth \n", "designing some of the APIs\n", "that we use throughout. \n", "Treating components in deep learning\n", "as objects,\n", "we can start by\n", "defining classes for these objects\n", "and their interactions.\n", "This object-oriented design\n", "for implementation\n", "will greatly\n", "streamline the presentation and you might even want to use it in your projects.\n", "\n", "\n", "Inspired by open-source libraries such as [PyTorch Lightning](https://www.pytorchlightning.ai/),\n", "at a high level\n", "we wish to have three classes: \n", "(i) `Module` contains models, losses, and optimization methods; \n", "(ii) `DataModule` provides data loaders for training and validation; \n", "(iii) both classes are combined using the `Trainer` class, which allows us to\n", "train models on a variety of hardware platforms. \n", "Most code in this book adapts `Module` and `DataModule`. We will touch upon the `Trainer` class only when we discuss GPUs, CPUs, parallel training, and optimization algorithms.\n" ] }, { "cell_type": "code", "execution_count": 1, "id": "4f04a51a", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:26:19.443072Z", "iopub.status.busy": "2023-08-18T19:26:19.442526Z", "iopub.status.idle": "2023-08-18T19:26:22.383415Z", "shell.execute_reply": "2023-08-18T19:26:22.382436Z" }, "origin_pos": 3, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "import time\n", "import numpy as np\n", "import torch\n", "from torch import nn\n", "from d2l import torch as d2l" ] }, { "cell_type": "markdown", "id": "c805c409", "metadata": { "origin_pos": 6 }, "source": [ "## Utilities\n", ":label:`oo-design-utilities`\n", "\n", "We need a few utilities to simplify object-oriented programming in Jupyter notebooks. One of the challenges is that class definitions tend to be fairly long blocks of code. Notebook readability demands short code fragments, interspersed with explanations, a requirement incompatible with the style of programming common for Python libraries. The first\n", "utility function allows us to register functions as methods in a class *after* the class has been created. In fact, we can do so *even after* we have created instances of the class! It allows us to split the implementation of a class into multiple code blocks.\n" ] }, { "cell_type": "code", "execution_count": 2, "id": "5403292c", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:26:22.389019Z", "iopub.status.busy": "2023-08-18T19:26:22.388325Z", "iopub.status.idle": "2023-08-18T19:26:22.393121Z", "shell.execute_reply": "2023-08-18T19:26:22.392315Z" }, "origin_pos": 7, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "def add_to_class(Class): #@save\n", " \"\"\"Register functions as methods in created class.\"\"\"\n", " def wrapper(obj):\n", " setattr(Class, obj.__name__, obj)\n", " return wrapper" ] }, { "cell_type": "markdown", "id": "7d8c1f96", "metadata": { "origin_pos": 8 }, "source": [ "Let's have a quick look at how to use it. We plan to implement a class `A` with a method `do`. Instead of having code for both `A` and `do` in the same code block, we can first declare the class `A` and create an instance `a`.\n" ] }, { "cell_type": "code", "execution_count": 3, "id": "e04e01bb", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:26:22.396585Z", "iopub.status.busy": "2023-08-18T19:26:22.395941Z", "iopub.status.idle": "2023-08-18T19:26:22.400066Z", "shell.execute_reply": "2023-08-18T19:26:22.399246Z" }, "origin_pos": 9, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "class A:\n", " def __init__(self):\n", " self.b = 1\n", "\n", "a = A()" ] }, { "cell_type": "markdown", "id": "d2db5aae", "metadata": { "origin_pos": 10 }, "source": [ "Next we define the method `do` as we normally would, but not in class `A`'s scope. Instead, we decorate this method by `add_to_class` with class `A` as its argument. In doing so, the method is able to access the member variables of `A` just as we would expect had it been included as part of `A`'s definition. Let's see what happens when we invoke it for the instance `a`.\n" ] }, { "cell_type": "code", "execution_count": 4, "id": "d2133566", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:26:22.403442Z", "iopub.status.busy": "2023-08-18T19:26:22.402774Z", "iopub.status.idle": "2023-08-18T19:26:22.407704Z", "shell.execute_reply": "2023-08-18T19:26:22.406859Z" }, "origin_pos": 11, "tab": [ "pytorch" ] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Class attribute \"b\" is 1\n" ] } ], "source": [ "@add_to_class(A)\n", "def do(self):\n", " print('Class attribute \"b\" is', self.b)\n", "\n", "a.do()" ] }, { "cell_type": "markdown", "id": "0b09869a", "metadata": { "origin_pos": 12 }, "source": [ "The second one is a utility class that saves all arguments in a class's `__init__` method as class attributes. This allows us to extend constructor call signatures implicitly without additional code.\n" ] }, { "cell_type": "code", "execution_count": 5, "id": "b2316a5a", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:26:22.410976Z", "iopub.status.busy": "2023-08-18T19:26:22.410447Z", "iopub.status.idle": "2023-08-18T19:26:22.414594Z", "shell.execute_reply": "2023-08-18T19:26:22.413768Z" }, "origin_pos": 13, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "class HyperParameters: #@save\n", " \"\"\"The base class of hyperparameters.\"\"\"\n", " def save_hyperparameters(self, ignore=[]):\n", " raise NotImplemented" ] }, { "cell_type": "markdown", "id": "f4c0374c", "metadata": { "origin_pos": 14 }, "source": [ "We defer its implementation into :numref:`sec_utils`. To use it, we define our class that inherits from `HyperParameters` and calls `save_hyperparameters` in the `__init__` method.\n" ] }, { "cell_type": "code", "execution_count": 6, "id": "e3f26d79", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:26:22.417917Z", "iopub.status.busy": "2023-08-18T19:26:22.417395Z", "iopub.status.idle": "2023-08-18T19:26:22.423055Z", "shell.execute_reply": "2023-08-18T19:26:22.422251Z" }, "origin_pos": 15, "tab": [ "pytorch" ] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "self.a = 1 self.b = 2\n", "There is no self.c = True\n" ] } ], "source": [ "# Call the fully implemented HyperParameters class saved in d2l\n", "class B(d2l.HyperParameters):\n", " def __init__(self, a, b, c):\n", " self.save_hyperparameters(ignore=['c'])\n", " print('self.a =', self.a, 'self.b =', self.b)\n", " print('There is no self.c =', not hasattr(self, 'c'))\n", "\n", "b = B(a=1, b=2, c=3)" ] }, { "cell_type": "markdown", "id": "f66f84be", "metadata": { "origin_pos": 16 }, "source": [ "The final utility allows us to plot experiment progress interactively while it is going on. In deference to the much more powerful (and complex) [TensorBoard](https://www.tensorflow.org/tensorboard) we name it `ProgressBoard`. The implementation is deferred to :numref:`sec_utils`. For now, let's simply see it in action.\n", "\n", "The `draw` method plots a point `(x, y)` in the figure, with `label` specified in the legend. The optional `every_n` smooths the line by only showing $1/n$ points in the figure. Their values are averaged from the $n$ neighbor points in the original figure.\n" ] }, { "cell_type": "code", "execution_count": 7, "id": "c7b69b94", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:26:22.426329Z", "iopub.status.busy": "2023-08-18T19:26:22.425774Z", "iopub.status.idle": "2023-08-18T19:26:22.431284Z", "shell.execute_reply": "2023-08-18T19:26:22.430467Z" }, "origin_pos": 17, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "class ProgressBoard(d2l.HyperParameters): #@save\n", " \"\"\"The board that plots data points in animation.\"\"\"\n", " def __init__(self, xlabel=None, ylabel=None, xlim=None,\n", " ylim=None, xscale='linear', yscale='linear',\n", " ls=['-', '--', '-.', ':'], colors=['C0', 'C1', 'C2', 'C3'],\n", " fig=None, axes=None, figsize=(3.5, 2.5), display=True):\n", " self.save_hyperparameters()\n", "\n", " def draw(self, x, y, label, every_n=1):\n", " raise NotImplemented" ] }, { "cell_type": "markdown", "id": "1b87a713", "metadata": { "origin_pos": 18 }, "source": [ "In the following example, we draw `sin` and `cos` with a different smoothness. If you run this code block, you will see the lines grow in animation.\n" ] }, { "cell_type": "code", "execution_count": 8, "id": "acda0f92", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:26:22.434595Z", "iopub.status.busy": "2023-08-18T19:26:22.434058Z", "iopub.status.idle": "2023-08-18T19:26:38.822758Z", "shell.execute_reply": "2023-08-18T19:26:38.821895Z" }, "origin_pos": 19, "tab": [ "pytorch" ] }, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-08-18T19:26:38.654609\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.7.2, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "board = d2l.ProgressBoard('x')\n", "for x in np.arange(0, 10, 0.1):\n", " board.draw(x, np.sin(x), 'sin', every_n=2)\n", " board.draw(x, np.cos(x), 'cos', every_n=10)" ] }, { "cell_type": "markdown", "id": "27746ca1", "metadata": { "origin_pos": 20 }, "source": [ "## Models\n", ":label:`subsec_oo-design-models`\n", "\n", "The `Module` class is the base class of all models we will implement. At the very least we need three methods. The first, `__init__`, stores the learnable parameters, the `training_step` method accepts a data batch to return the loss value, and finally, `configure_optimizers` returns the optimization method, or a list of them, that is used to update the learnable parameters. Optionally we can define `validation_step` to report the evaluation measures.\n", "Sometimes we put the code for computing the output into a separate `forward` method to make it more reusable.\n" ] }, { "cell_type": "code", "execution_count": 9, "id": "62a305b5", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:26:38.826860Z", "iopub.status.busy": "2023-08-18T19:26:38.826284Z", "iopub.status.idle": "2023-08-18T19:26:38.835993Z", "shell.execute_reply": "2023-08-18T19:26:38.835101Z" }, "origin_pos": 22, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "class Module(nn.Module, d2l.HyperParameters): #@save\n", " \"\"\"The base class of models.\"\"\"\n", " def __init__(self, plot_train_per_epoch=2, plot_valid_per_epoch=1):\n", " super().__init__()\n", " self.save_hyperparameters()\n", " self.board = ProgressBoard()\n", "\n", " def loss(self, y_hat, y):\n", " raise NotImplementedError\n", "\n", " def forward(self, X):\n", " assert hasattr(self, 'net'), 'Neural network is defined'\n", " return self.net(X)\n", "\n", " def plot(self, key, value, train):\n", " \"\"\"Plot a point in animation.\"\"\"\n", " assert hasattr(self, 'trainer'), 'Trainer is not inited'\n", " self.board.xlabel = 'epoch'\n", " if train:\n", " x = self.trainer.train_batch_idx / \\\n", " self.trainer.num_train_batches\n", " n = self.trainer.num_train_batches / \\\n", " self.plot_train_per_epoch\n", " else:\n", " x = self.trainer.epoch + 1\n", " n = self.trainer.num_val_batches / \\\n", " self.plot_valid_per_epoch\n", " self.board.draw(x, value.to(d2l.cpu()).detach().numpy(),\n", " ('train_' if train else 'val_') + key,\n", " every_n=int(n))\n", "\n", " def training_step(self, batch):\n", " l = self.loss(self(*batch[:-1]), batch[-1])\n", " self.plot('loss', l, train=True)\n", " return l\n", "\n", " def validation_step(self, batch):\n", " l = self.loss(self(*batch[:-1]), batch[-1])\n", " self.plot('loss', l, train=False)\n", "\n", " def configure_optimizers(self):\n", " raise NotImplementedError" ] }, { "cell_type": "markdown", "id": "68b2d5aa", "metadata": { "origin_pos": 25, "tab": [ "pytorch" ] }, "source": [ "You may notice that `Module` is a subclass of `nn.Module`, the base class of neural networks in PyTorch.\n", "It provides convenient features for handling neural networks. For example, if we define a `forward` method, such as `forward(self, X)`, then for an instance `a` we can invoke this method by `a(X)`. This works since it calls the `forward` method in the built-in `__call__` method. You can find more details and examples about `nn.Module` in :numref:`sec_model_construction`.\n" ] }, { "cell_type": "markdown", "id": "b263de76", "metadata": { "origin_pos": 28 }, "source": [ "## Data\n", ":label:`oo-design-data`\n", "\n", "The `DataModule` class is the base class for data. Quite frequently the `__init__` method is used to prepare the data. This includes downloading and preprocessing if needed. The `train_dataloader` returns the data loader for the training dataset. A data loader is a (Python) generator that yields a data batch each time it is used. This batch is then fed into the `training_step` method of `Module` to compute the loss. There is an optional `val_dataloader` to return the validation dataset loader. It behaves in the same manner, except that it yields data batches for the `validation_step` method in `Module`.\n" ] }, { "cell_type": "code", "execution_count": 10, "id": "14e2b695", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:26:38.839261Z", "iopub.status.busy": "2023-08-18T19:26:38.838718Z", "iopub.status.idle": "2023-08-18T19:26:38.843828Z", "shell.execute_reply": "2023-08-18T19:26:38.843014Z" }, "origin_pos": 29, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "class DataModule(d2l.HyperParameters): #@save\n", " \"\"\"The base class of data.\"\"\"\n", " def __init__(self, root='../data', num_workers=4):\n", " self.save_hyperparameters()\n", "\n", " def get_dataloader(self, train):\n", " raise NotImplementedError\n", "\n", " def train_dataloader(self):\n", " return self.get_dataloader(train=True)\n", "\n", " def val_dataloader(self):\n", " return self.get_dataloader(train=False)" ] }, { "cell_type": "markdown", "id": "80b834f8", "metadata": { "origin_pos": 30 }, "source": [ "## Training\n", ":label:`oo-design-training`\n" ] }, { "cell_type": "markdown", "id": "19097c4d", "metadata": { "origin_pos": 31, "tab": [ "pytorch" ] }, "source": [ "The `Trainer` class trains the learnable parameters in the `Module` class with data specified in `DataModule`. The key method is `fit`, which accepts two arguments: `model`, an instance of `Module`, and `data`, an instance of `DataModule`. It then iterates over the entire dataset `max_epochs` times to train the model. As before, we will defer the implementation of this method to later chapters.\n" ] }, { "cell_type": "code", "execution_count": 11, "id": "798ecd6d", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:26:38.847255Z", "iopub.status.busy": "2023-08-18T19:26:38.846633Z", "iopub.status.idle": "2023-08-18T19:26:38.853801Z", "shell.execute_reply": "2023-08-18T19:26:38.852995Z" }, "origin_pos": 33, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "class Trainer(d2l.HyperParameters): #@save\n", " \"\"\"The base class for training models with data.\"\"\"\n", " def __init__(self, max_epochs, num_gpus=0, gradient_clip_val=0):\n", " self.save_hyperparameters()\n", " assert num_gpus == 0, 'No GPU support yet'\n", "\n", " def prepare_data(self, data):\n", " self.train_dataloader = data.train_dataloader()\n", " self.val_dataloader = data.val_dataloader()\n", " self.num_train_batches = len(self.train_dataloader)\n", " self.num_val_batches = (len(self.val_dataloader)\n", " if self.val_dataloader is not None else 0)\n", "\n", " def prepare_model(self, model):\n", " model.trainer = self\n", " model.board.xlim = [0, self.max_epochs]\n", " self.model = model\n", "\n", " def fit(self, model, data):\n", " self.prepare_data(data)\n", " self.prepare_model(model)\n", " self.optim = model.configure_optimizers()\n", " self.epoch = 0\n", " self.train_batch_idx = 0\n", " self.val_batch_idx = 0\n", " for self.epoch in range(self.max_epochs):\n", " self.fit_epoch()\n", "\n", " def fit_epoch(self):\n", " raise NotImplementedError" ] }, { "cell_type": "markdown", "id": "792b845a", "metadata": { "origin_pos": 34 }, "source": [ "## Summary\n", "\n", "To highlight the object-oriented design\n", "for our future deep learning implementation,\n", "the above classes simply show how their objects \n", "store data and interact with each other.\n", "We will keep enriching implementations of these classes,\n", "such as via `@add_to_class`,\n", "in the rest of the book.\n", "Moreover,\n", "these fully implemented classes\n", "are saved in the [D2L library](https://github.com/d2l-ai/d2l-en/tree/master/d2l),\n", "a *lightweight toolkit* that makes structured modeling for deep learning easy. \n", "In particular, it facilitates reusing many components between projects without changing much at all. For instance, we can replace just the optimizer, just the model, just the dataset, etc.;\n", "this degree of modularity pays dividends throughout the book in terms of conciseness and simplicity (this is why we added it) and it can do the same for your own projects. \n", "\n", "\n", "## Exercises\n", "\n", "1. Locate full implementations of the above classes that are saved in the [D2L library](https://github.com/d2l-ai/d2l-en/tree/master/d2l). We strongly recommend that you look at the implementation in detail once you have gained some more familiarity with deep learning modeling.\n", "1. Remove the `save_hyperparameters` statement in the `B` class. Can you still print `self.a` and `self.b`? Optional: if you have dived into the full implementation of the `HyperParameters` class, can you explain why?\n" ] }, { "cell_type": "markdown", "id": "8f75e484", "metadata": { "origin_pos": 36, "tab": [ "pytorch" ] }, "source": [ "[Discussions](https://discuss.d2l.ai/t/6646)\n" ] } ], "metadata": { "language_info": { "name": "python" }, "required_libs": [] }, "nbformat": 4, "nbformat_minor": 5 }