{ "cells": [ { "cell_type": "markdown", "id": "ff2b6513", "metadata": { "origin_pos": 1 }, "source": [ "# Utility Functions and Classes\n", ":label:`sec_utils`\n", "\n", "\n", "This section contains the implementations of utility functions and classes used in this book.\n" ] }, { "cell_type": "code", "execution_count": 1, "id": "3d3fe5bb", "metadata": { "attributes": { "classes": [], "id": "", "n": "1" }, "execution": { "iopub.execute_input": "2023-08-18T19:28:10.024196Z", "iopub.status.busy": "2023-08-18T19:28:10.023668Z", "iopub.status.idle": "2023-08-18T19:28:12.807027Z", "shell.execute_reply": "2023-08-18T19:28:12.805588Z" }, "origin_pos": 3, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "import collections\n", "import inspect\n", "from IPython import display\n", "from torch import nn\n", "from d2l import torch as d2l" ] }, { "cell_type": "markdown", "id": "1296db3f", "metadata": { "origin_pos": 6 }, "source": [ "Hyperparameters.\n" ] }, { "cell_type": "code", "execution_count": 2, "id": "98bb0ea8", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:28:12.814437Z", "iopub.status.busy": "2023-08-18T19:28:12.813801Z", "iopub.status.idle": "2023-08-18T19:28:12.820512Z", "shell.execute_reply": "2023-08-18T19:28:12.819348Z" }, "origin_pos": 7, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "@d2l.add_to_class(d2l.HyperParameters) #@save\n", "def save_hyperparameters(self, ignore=[]):\n", " \"\"\"Save function arguments into class attributes.\"\"\"\n", " frame = inspect.currentframe().f_back\n", " _, _, _, local_vars = inspect.getargvalues(frame)\n", " self.hparams = {k:v for k, v in local_vars.items()\n", " if k not in set(ignore+['self']) and not k.startswith('_')}\n", " for k, v in self.hparams.items():\n", " setattr(self, k, v)" ] }, { "cell_type": "markdown", "id": "baddd421", "metadata": { "origin_pos": 8 }, "source": [ "Progress bar.\n" ] }, { "cell_type": "code", "execution_count": 3, "id": "885db4c2", "metadata": { "attributes": { "classes": [], "id": "", "n": "22" }, "execution": { "iopub.execute_input": "2023-08-18T19:28:12.824277Z", "iopub.status.busy": "2023-08-18T19:28:12.823984Z", "iopub.status.idle": "2023-08-18T19:28:12.835713Z", "shell.execute_reply": "2023-08-18T19:28:12.834666Z" }, "origin_pos": 9, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "@d2l.add_to_class(d2l.ProgressBoard) #@save\n", "def draw(self, x, y, label, every_n=1):\n", " Point = collections.namedtuple('Point', ['x', 'y'])\n", " if not hasattr(self, 'raw_points'):\n", " self.raw_points = collections.OrderedDict()\n", " self.data = collections.OrderedDict()\n", " if label not in self.raw_points:\n", " self.raw_points[label] = []\n", " self.data[label] = []\n", " points = self.raw_points[label]\n", " line = self.data[label]\n", " points.append(Point(x, y))\n", " if len(points) != every_n:\n", " return\n", " mean = lambda x: sum(x) / len(x)\n", " line.append(Point(mean([p.x for p in points]),\n", " mean([p.y for p in points])))\n", " points.clear()\n", " if not self.display:\n", " return\n", " d2l.use_svg_display()\n", " if self.fig is None:\n", " self.fig = d2l.plt.figure(figsize=self.figsize)\n", " plt_lines, labels = [], []\n", " for (k, v), ls, color in zip(self.data.items(), self.ls, self.colors):\n", " plt_lines.append(d2l.plt.plot([p.x for p in v], [p.y for p in v],\n", " linestyle=ls, color=color)[0])\n", " labels.append(k)\n", " axes = self.axes if self.axes else d2l.plt.gca()\n", " if self.xlim: axes.set_xlim(self.xlim)\n", " if self.ylim: axes.set_ylim(self.ylim)\n", " if not self.xlabel: self.xlabel = self.x\n", " axes.set_xlabel(self.xlabel)\n", " axes.set_ylabel(self.ylabel)\n", " axes.set_xscale(self.xscale)\n", " axes.set_yscale(self.yscale)\n", " axes.legend(plt_lines, labels)\n", " display.display(self.fig)\n", " display.clear_output(wait=True)" ] }, { "cell_type": "markdown", "id": "463a57b9", "metadata": { "origin_pos": 10 }, "source": [ "Add FrozenLake enviroment\n" ] }, { "cell_type": "code", "execution_count": 4, "id": "bd8e6dad", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:28:12.840420Z", "iopub.status.busy": "2023-08-18T19:28:12.839831Z", "iopub.status.idle": "2023-08-18T19:28:12.846955Z", "shell.execute_reply": "2023-08-18T19:28:12.846019Z" }, "origin_pos": 11, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "def frozen_lake(seed): #@save\n", " # See https://www.gymlibrary.dev/environments/toy_text/frozen_lake/ to learn more about this env\n", " # How to process env.P.items is adpated from https://sites.google.com/view/deep-rl-bootcamp/labs\n", " import gym\n", "\n", " env = gym.make('FrozenLake-v1', is_slippery=False)\n", " env.seed(seed)\n", " env.action_space.np_random.seed(seed)\n", " env.action_space.seed(seed)\n", " env_info = {}\n", " env_info['desc'] = env.desc # 2D array specifying what each grid item means\n", " env_info['num_states'] = env.nS # Number of observations/states or obs/state dim\n", " env_info['num_actions'] = env.nA # Number of actions or action dim\n", " # Define indices for (transition probability, nextstate, reward, done) tuple\n", " env_info['trans_prob_idx'] = 0 # Index of transition probability entry\n", " env_info['nextstate_idx'] = 1 # Index of next state entry\n", " env_info['reward_idx'] = 2 # Index of reward entry\n", " env_info['done_idx'] = 3 # Index of done entry\n", " env_info['mdp'] = {}\n", " env_info['env'] = env\n", "\n", " for (s, others) in env.P.items():\n", " # others(s) = {a0: [ (p(s'|s,a0), s', reward, done),...], a1:[...], ...}\n", "\n", " for (a, pxrds) in others.items():\n", " # pxrds is [(p1,next1,r1,d1),(p2,next2,r2,d2),..].\n", " # e.g. [(0.3, 0, 0, False), (0.3, 0, 0, False), (0.3, 4, 1, False)]\n", " env_info['mdp'][(s,a)] = pxrds\n", "\n", " return env_info" ] }, { "cell_type": "markdown", "id": "472118b3", "metadata": { "origin_pos": 12 }, "source": [ "Create enviroment\n" ] }, { "cell_type": "code", "execution_count": 5, "id": "6aa42d02", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:28:12.850455Z", "iopub.status.busy": "2023-08-18T19:28:12.849911Z", "iopub.status.idle": "2023-08-18T19:28:12.854684Z", "shell.execute_reply": "2023-08-18T19:28:12.853882Z" }, "origin_pos": 13, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "def make_env(name ='', seed=0): #@save\n", " # Input parameters:\n", " # name: specifies a gym environment.\n", " # For Value iteration, only FrozenLake-v1 is supported.\n", " if name == 'FrozenLake-v1':\n", " return frozen_lake(seed)\n", "\n", " else:\n", " raise ValueError(\"%s env is not supported in this Notebook\")" ] }, { "cell_type": "markdown", "id": "2f8b194e", "metadata": { "origin_pos": 14 }, "source": [ "Show value function\n" ] }, { "cell_type": "code", "execution_count": 6, "id": "164d4c50", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:28:12.858698Z", "iopub.status.busy": "2023-08-18T19:28:12.858267Z", "iopub.status.idle": "2023-08-18T19:28:12.873620Z", "shell.execute_reply": "2023-08-18T19:28:12.872821Z" }, "origin_pos": 15, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "def show_value_function_progress(env_desc, V, pi): #@save\n", " # This function visualizes how value and policy changes over time.\n", " # V: [num_iters, num_states]\n", " # pi: [num_iters, num_states]\n", " # How to visualize value function is adapted (but changed) from: https://sites.google.com/view/deep-rl-bootcamp/labs\n", "\n", " num_iters = V.shape[0]\n", " fig, ax = plt.subplots(figsize=(15, 15))\n", "\n", " for k in range(V.shape[0]):\n", " plt.subplot(4, 4, k + 1)\n", " plt.imshow(V[k].reshape(4,4), cmap=\"bone\")\n", " ax = plt.gca()\n", " ax.set_xticks(np.arange(0, 5)-.5, minor=True)\n", " ax.set_yticks(np.arange(0, 5)-.5, minor=True)\n", " ax.grid(which=\"minor\", color=\"w\", linestyle='-', linewidth=3)\n", " ax.tick_params(which=\"minor\", bottom=False, left=False)\n", " ax.set_xticks([])\n", " ax.set_yticks([])\n", "\n", " # LEFT action: 0, DOWN action: 1\n", " # RIGHT action: 2, UP action: 3\n", " action2dxdy = {0:(-.25, 0),1: (0, .25),\n", " 2:(0.25, 0),3: (-.25, 0)}\n", "\n", " for y in range(4):\n", " for x in range(4):\n", " action = pi[k].reshape(4,4)[y, x]\n", " dx, dy = action2dxdy[action]\n", "\n", " if env_desc[y,x].decode() == 'H':\n", " ax.text(x, y, str(env_desc[y,x].decode()),\n", " ha=\"center\", va=\"center\", color=\"y\",\n", " size=20, fontweight='bold')\n", "\n", " elif env_desc[y,x].decode() == 'G':\n", " ax.text(x, y, str(env_desc[y,x].decode()),\n", " ha=\"center\", va=\"center\", color=\"w\",\n", " size=20, fontweight='bold')\n", "\n", " else:\n", " ax.text(x, y, str(env_desc[y,x].decode()),\n", " ha=\"center\", va=\"center\", color=\"g\",\n", " size=15, fontweight='bold')\n", "\n", " # No arrow for cells with G and H labels\n", " if env_desc[y,x].decode() != 'G' and env_desc[y,x].decode() != 'H':\n", " ax.arrow(x, y, dx, dy, color='r', head_width=0.2, head_length=0.15)\n", "\n", " ax.set_title(\"Step = \" + str(k + 1), fontsize=20)\n", "\n", " fig.tight_layout()\n", " plt.show()" ] }, { "cell_type": "markdown", "id": "c30d7cc9", "metadata": { "origin_pos": 16 }, "source": [ "Show Q function\n" ] }, { "cell_type": "code", "execution_count": 7, "id": "2622a762", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:28:12.877031Z", "iopub.status.busy": "2023-08-18T19:28:12.876452Z", "iopub.status.idle": "2023-08-18T19:28:12.889267Z", "shell.execute_reply": "2023-08-18T19:28:12.888394Z" }, "origin_pos": 17, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "def show_Q_function_progress(env_desc, V_all, pi_all): #@save\n", " # This function visualizes how value and policy changes over time.\n", " # V: [num_iters, num_states]\n", " # pi: [num_iters, num_states]\n", "\n", " # We want to only shows few values\n", " num_iters_all = V_all.shape[0]\n", " num_iters = num_iters_all // 10\n", "\n", " vis_indx = np.arange(0, num_iters_all, num_iters).tolist()\n", " vis_indx.append(num_iters_all - 1)\n", " V = np.zeros((len(vis_indx), V_all.shape[1]))\n", " pi = np.zeros((len(vis_indx), V_all.shape[1]))\n", "\n", " for c, i in enumerate(vis_indx):\n", " V[c] = V_all[i]\n", " pi[c] = pi_all[i]\n", "\n", " num_iters = V.shape[0]\n", " fig, ax = plt.subplots(figsize=(15, 15))\n", "\n", " for k in range(V.shape[0]):\n", " plt.subplot(4, 4, k + 1)\n", " plt.imshow(V[k].reshape(4,4), cmap=\"bone\")\n", " ax = plt.gca()\n", " ax.set_xticks(np.arange(0, 5)-.5, minor=True)\n", " ax.set_yticks(np.arange(0, 5)-.5, minor=True)\n", " ax.grid(which=\"minor\", color=\"w\", linestyle='-', linewidth=3)\n", " ax.tick_params(which=\"minor\", bottom=False, left=False)\n", " ax.set_xticks([])\n", " ax.set_yticks([])\n", "\n", " # LEFT action: 0, DOWN action: 1\n", " # RIGHT action: 2, UP action: 3\n", " action2dxdy = {0:(-.25, 0),1:(0, .25),\n", " 2:(0.25, 0),3:(-.25, 0)}\n", "\n", " for y in range(4):\n", " for x in range(4):\n", " action = pi[k].reshape(4,4)[y, x]\n", " dx, dy = action2dxdy[action]\n", "\n", " if env_desc[y,x].decode() == 'H':\n", " ax.text(x, y, str(env_desc[y,x].decode()),\n", " ha=\"center\", va=\"center\", color=\"y\",\n", " size=20, fontweight='bold')\n", "\n", " elif env_desc[y,x].decode() == 'G':\n", " ax.text(x, y, str(env_desc[y,x].decode()),\n", " ha=\"center\", va=\"center\", color=\"w\",\n", " size=20, fontweight='bold')\n", "\n", " else:\n", " ax.text(x, y, str(env_desc[y,x].decode()),\n", " ha=\"center\", va=\"center\", color=\"g\",\n", " size=15, fontweight='bold')\n", "\n", " # No arrow for cells with G and H labels\n", " if env_desc[y,x].decode() != 'G' and env_desc[y,x].decode() != 'H':\n", " ax.arrow(x, y, dx, dy, color='r', head_width=0.2, head_length=0.15)\n", "\n", " ax.set_title(\"Step = \" + str(vis_indx[k] + 1), fontsize=20)\n", "\n", " fig.tight_layout()\n", " plt.show()" ] }, { "cell_type": "markdown", "id": "3af7be4b", "metadata": { "origin_pos": 18 }, "source": [ "Trainer\n", "\n", "A bunch of functions that will be deprecated:\n" ] }, { "cell_type": "code", "execution_count": 8, "id": "1c3317ff", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:28:12.892709Z", "iopub.status.busy": "2023-08-18T19:28:12.892441Z", "iopub.status.idle": "2023-08-18T19:28:12.909614Z", "shell.execute_reply": "2023-08-18T19:28:12.908815Z" }, "origin_pos": 20, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "def load_array(data_arrays, batch_size, is_train=True): #@save\n", " \"\"\"Construct a PyTorch data iterator.\"\"\"\n", " dataset = torch.utils.data.TensorDataset(*data_arrays)\n", " return torch.utils.data.DataLoader(dataset, batch_size, shuffle=is_train)\n", "\n", "def synthetic_data(w, b, num_examples): #@save\n", " \"\"\"Generate y = Xw + b + noise.\"\"\"\n", " X = torch.normal(0, 1, (num_examples, len(w)))\n", " y = torch.matmul(X, w) + b\n", " y += torch.normal(0, 0.01, y.shape)\n", " return X, y.reshape((-1, 1))\n", "\n", "def sgd(params, lr, batch_size): #@save\n", " \"\"\"Minibatch stochastic gradient descent.\"\"\"\n", " with torch.no_grad():\n", " for param in params:\n", " param -= lr * param.grad / batch_size\n", " param.grad.zero_()\n", "\n", "def get_dataloader_workers(): #@save\n", " \"\"\"Use 4 processes to read the data.\"\"\"\n", " return 4\n", "\n", "def load_data_fashion_mnist(batch_size, resize=None): #@save\n", " \"\"\"Download the Fashion-MNIST dataset and then load it into memory.\"\"\"\n", " trans = [transforms.ToTensor()]\n", " if resize:\n", " trans.insert(0, transforms.Resize(resize))\n", " trans = transforms.Compose(trans)\n", " mnist_train = torchvision.datasets.FashionMNIST(\n", " root=\"../data\", train=True, transform=trans, download=True)\n", " mnist_test = torchvision.datasets.FashionMNIST(\n", " root=\"../data\", train=False, transform=trans, download=True)\n", " return (torch.utils.data.DataLoader(mnist_train, batch_size, shuffle=True,\n", " num_workers=get_dataloader_workers()),\n", " torch.utils.data.DataLoader(mnist_test, batch_size, shuffle=False,\n", " num_workers=get_dataloader_workers()))\n", "\n", "def evaluate_accuracy_gpu(net, data_iter, device=None): #@save\n", " \"\"\"Compute the accuracy for a model on a dataset using a GPU.\"\"\"\n", " if isinstance(net, nn.Module):\n", " net.eval() # Set the model to evaluation mode\n", " if not device:\n", " device = next(iter(net.parameters())).device\n", " # No. of correct predictions, no. of predictions\n", " metric = d2l.Accumulator(2)\n", "\n", " with torch.no_grad():\n", " for X, y in data_iter:\n", " if isinstance(X, list):\n", " # Required for BERT Fine-tuning (to be covered later)\n", " X = [x.to(device) for x in X]\n", " else:\n", " X = X.to(device)\n", " y = y.to(device)\n", " metric.add(d2l.accuracy(net(X), y), y.numel())\n", " return metric[0] / metric[1]\n", "\n", "\n", "#@save\n", "def train_ch6(net, train_iter, test_iter, num_epochs, lr, device):\n", " \"\"\"Train a model with a GPU (defined in Chapter 6).\"\"\"\n", " def init_weights(m):\n", " if type(m) == nn.Linear or type(m) == nn.Conv2d:\n", " nn.init.xavier_uniform_(m.weight)\n", " net.apply(init_weights)\n", " print('training on', device)\n", " net.to(device)\n", " optimizer = torch.optim.SGD(net.parameters(), lr=lr)\n", " loss = nn.CrossEntropyLoss()\n", " animator = d2l.Animator(xlabel='epoch', xlim=[1, num_epochs],\n", " legend=['train loss', 'train acc', 'test acc'])\n", " timer, num_batches = d2l.Timer(), len(train_iter)\n", " for epoch in range(num_epochs):\n", " # Sum of training loss, sum of training accuracy, no. of examples\n", " metric = d2l.Accumulator(3)\n", " net.train()\n", " for i, (X, y) in enumerate(train_iter):\n", " timer.start()\n", " optimizer.zero_grad()\n", " X, y = X.to(device), y.to(device)\n", " y_hat = net(X)\n", " l = loss(y_hat, y)\n", " l.backward()\n", " optimizer.step()\n", " with torch.no_grad():\n", " metric.add(l * X.shape[0], d2l.accuracy(y_hat, y), X.shape[0])\n", " timer.stop()\n", " train_l = metric[0] / metric[2]\n", " train_acc = metric[1] / metric[2]\n", " if (i + 1) % (num_batches // 5) == 0 or i == num_batches - 1:\n", " animator.add(epoch + (i + 1) / num_batches,\n", " (train_l, train_acc, None))\n", " test_acc = evaluate_accuracy_gpu(net, test_iter)\n", " animator.add(epoch + 1, (None, None, test_acc))\n", " print(f'loss {train_l:.3f}, train acc {train_acc:.3f}, '\n", " f'test acc {test_acc:.3f}')\n", " print(f'{metric[2] * num_epochs / timer.sum():.1f} examples/sec '\n", " f'on {str(device)}')" ] }, { "cell_type": "code", "execution_count": 9, "id": "2a3c4dcf", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:28:12.913071Z", "iopub.status.busy": "2023-08-18T19:28:12.912410Z", "iopub.status.idle": "2023-08-18T19:28:12.919215Z", "shell.execute_reply": "2023-08-18T19:28:12.918183Z" }, "origin_pos": 23, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "def show_images(imgs, num_rows, num_cols, titles=None, scale=1.5): #@save\n", " \"\"\"Plot a list of images.\"\"\"\n", " figsize = (num_cols * scale, num_rows * scale)\n", " _, axes = d2l.plt.subplots(num_rows, num_cols, figsize=figsize)\n", " axes = axes.flatten()\n", " for i, (ax, img) in enumerate(zip(axes, imgs)):\n", " try:\n", " img = img.detach().numpy()\n", " except:\n", " pass\n", " ax.imshow(img)\n", " ax.axes.get_xaxis().set_visible(False)\n", " ax.axes.get_yaxis().set_visible(False)\n", " if titles:\n", " ax.set_title(titles[i])\n", " return axes" ] }, { "cell_type": "code", "execution_count": 10, "id": "251affa5", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:28:12.922865Z", "iopub.status.busy": "2023-08-18T19:28:12.922058Z", "iopub.status.idle": "2023-08-18T19:28:12.938713Z", "shell.execute_reply": "2023-08-18T19:28:12.937549Z" }, "origin_pos": 24, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "def linreg(X, w, b): #@save\n", " \"\"\"The linear regression model.\"\"\"\n", " return torch.matmul(X, w) + b\n", "\n", "def squared_loss(y_hat, y): #@save\n", " \"\"\"Squared loss.\"\"\"\n", " return (y_hat - y.reshape(y_hat.shape)) ** 2 / 2\n", "\n", "def get_fashion_mnist_labels(labels): #@save\n", " \"\"\"Return text labels for the Fashion-MNIST dataset.\"\"\"\n", " text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat',\n", " 'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']\n", " return [text_labels[int(i)] for i in labels]\n", "\n", "class Animator: #@save\n", " \"\"\"For plotting data in animation.\"\"\"\n", " def __init__(self, xlabel=None, ylabel=None, legend=None, xlim=None,\n", " ylim=None, xscale='linear', yscale='linear',\n", " fmts=('-', 'm--', 'g-.', 'r:'), nrows=1, ncols=1,\n", " figsize=(3.5, 2.5)):\n", " # Incrementally plot multiple lines\n", " if legend is None:\n", " legend = []\n", " d2l.use_svg_display()\n", " self.fig, self.axes = d2l.plt.subplots(nrows, ncols, figsize=figsize)\n", " if nrows * ncols == 1:\n", " self.axes = [self.axes, ]\n", " # Use a lambda function to capture arguments\n", " self.config_axes = lambda: d2l.set_axes(\n", " self.axes[0], xlabel, ylabel, xlim, ylim, xscale, yscale, legend)\n", " self.X, self.Y, self.fmts = None, None, fmts\n", "\n", " def add(self, x, y):\n", " # Add multiple data points into the figure\n", " if not hasattr(y, \"__len__\"):\n", " y = [y]\n", " n = len(y)\n", " if not hasattr(x, \"__len__\"):\n", " x = [x] * n\n", " if not self.X:\n", " self.X = [[] for _ in range(n)]\n", " if not self.Y:\n", " self.Y = [[] for _ in range(n)]\n", " for i, (a, b) in enumerate(zip(x, y)):\n", " if a is not None and b is not None:\n", " self.X[i].append(a)\n", " self.Y[i].append(b)\n", " self.axes[0].cla()\n", " for x, y, fmt in zip(self.X, self.Y, self.fmts):\n", " self.axes[0].plot(x, y, fmt)\n", " self.config_axes()\n", " display.display(self.fig)\n", " display.clear_output(wait=True)\n", "\n", "class Accumulator: #@save\n", " \"\"\"For accumulating sums over `n` variables.\"\"\"\n", " def __init__(self, n):\n", " self.data = [0.0] * n\n", "\n", " def add(self, *args):\n", " self.data = [a + float(b) for a, b in zip(self.data, args)]\n", "\n", " def reset(self):\n", " self.data = [0.0] * len(self.data)\n", "\n", " def __getitem__(self, idx):\n", " return self.data[idx]\n", "\n", "\n", "def accuracy(y_hat, y): #@save\n", " \"\"\"Compute the number of correct predictions.\"\"\"\n", " if len(y_hat.shape) > 1 and y_hat.shape[1] > 1:\n", " y_hat = y_hat.argmax(axis=1)\n", " cmp = y_hat.type(y.dtype) == y\n", " return float(cmp.type(y.dtype).sum())" ] }, { "cell_type": "code", "execution_count": 11, "id": "e8525488", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:28:12.942440Z", "iopub.status.busy": "2023-08-18T19:28:12.941567Z", "iopub.status.idle": "2023-08-18T19:28:12.951095Z", "shell.execute_reply": "2023-08-18T19:28:12.949973Z" }, "origin_pos": 25, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "import hashlib\n", "import os\n", "import tarfile\n", "import zipfile\n", "import requests\n", "\n", "\n", "def download(url, folder='../data', sha1_hash=None): #@save\n", " \"\"\"Download a file to folder and return the local filepath.\"\"\"\n", " if not url.startswith('http'):\n", " # For back compatability\n", " url, sha1_hash = DATA_HUB[url]\n", " os.makedirs(folder, exist_ok=True)\n", " fname = os.path.join(folder, url.split('/')[-1])\n", " # Check if hit cache\n", " if os.path.exists(fname) and sha1_hash:\n", " sha1 = hashlib.sha1()\n", " with open(fname, 'rb') as f:\n", " while True:\n", " data = f.read(1048576)\n", " if not data:\n", " break\n", " sha1.update(data)\n", " if sha1.hexdigest() == sha1_hash:\n", " return fname\n", " # Download\n", " print(f'Downloading {fname} from {url}...')\n", " r = requests.get(url, stream=True, verify=True)\n", " with open(fname, 'wb') as f:\n", " f.write(r.content)\n", " return fname\n", "\n", "def extract(filename, folder=None): #@save\n", " \"\"\"Extract a zip/tar file into folder.\"\"\"\n", " base_dir = os.path.dirname(filename)\n", " _, ext = os.path.splitext(filename)\n", " assert ext in ('.zip', '.tar', '.gz'), 'Only support zip/tar files.'\n", " if ext == '.zip':\n", " fp = zipfile.ZipFile(filename, 'r')\n", " else:\n", " fp = tarfile.open(filename, 'r')\n", " if folder is None:\n", " folder = base_dir\n", " fp.extractall(folder)" ] }, { "cell_type": "code", "execution_count": 12, "id": "9ced27a3", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:28:12.954444Z", "iopub.status.busy": "2023-08-18T19:28:12.953870Z", "iopub.status.idle": "2023-08-18T19:28:12.960111Z", "shell.execute_reply": "2023-08-18T19:28:12.959248Z" }, "origin_pos": 26, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "def download_extract(name, folder=None): #@save\n", " \"\"\"Download and extract a zip/tar file.\"\"\"\n", " fname = download(name)\n", " base_dir = os.path.dirname(fname)\n", " data_dir, ext = os.path.splitext(fname)\n", " if ext == '.zip':\n", " fp = zipfile.ZipFile(fname, 'r')\n", " elif ext in ('.tar', '.gz'):\n", " fp = tarfile.open(fname, 'r')\n", " else:\n", " assert False, 'Only zip/tar files can be extracted.'\n", " fp.extractall(base_dir)\n", " return os.path.join(base_dir, folder) if folder else data_dir\n", "\n", "\n", "def tokenize(lines, token='word'): #@save\n", " \"\"\"Split text lines into word or character tokens.\"\"\"\n", " assert token in ('word', 'char'), 'Unknown token type: ' + token\n", " return [line.split() if token == 'word' else list(line) for line in lines]" ] }, { "cell_type": "code", "execution_count": 13, "id": "677b4659", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:28:12.963237Z", "iopub.status.busy": "2023-08-18T19:28:12.962688Z", "iopub.status.idle": "2023-08-18T19:28:12.967574Z", "shell.execute_reply": "2023-08-18T19:28:12.966691Z" }, "origin_pos": 27, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "def evaluate_loss(net, data_iter, loss): #@save\n", " \"\"\"Evaluate the loss of a model on the given dataset.\"\"\"\n", " metric = d2l.Accumulator(2) # Sum of losses, no. of examples\n", " for X, y in data_iter:\n", " out = net(X)\n", " y = y.reshape(out.shape)\n", " l = loss(out, y)\n", " metric.add(l.sum(), l.numel())\n", " return metric[0] / metric[1]" ] }, { "cell_type": "code", "execution_count": 14, "id": "05e8f2e3", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:28:12.970822Z", "iopub.status.busy": "2023-08-18T19:28:12.970270Z", "iopub.status.idle": "2023-08-18T19:28:12.976073Z", "shell.execute_reply": "2023-08-18T19:28:12.975000Z" }, "origin_pos": 29, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "def grad_clipping(net, theta): #@save\n", " \"\"\"Clip the gradient.\"\"\"\n", " if isinstance(net, nn.Module):\n", " params = [p for p in net.parameters() if p.requires_grad]\n", " else:\n", " params = net.params\n", " norm = torch.sqrt(sum(torch.sum((p.grad ** 2)) for p in params))\n", " if norm > theta:\n", " for param in params:\n", " param.grad[:] *= theta / norm" ] }, { "cell_type": "markdown", "id": "63051dcf", "metadata": { "origin_pos": 31 }, "source": [ "More for the attention chapter.\n" ] }, { "cell_type": "code", "execution_count": 15, "id": "205a7cb5", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:28:12.979520Z", "iopub.status.busy": "2023-08-18T19:28:12.979206Z", "iopub.status.idle": "2023-08-18T19:28:12.991515Z", "shell.execute_reply": "2023-08-18T19:28:12.990689Z" }, "origin_pos": 32, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "#@save\n", "d2l.DATA_HUB['fra-eng'] = (d2l.DATA_URL + 'fra-eng.zip',\n", " '94646ad1522d915e7b0f9296181140edcf86a4f5')\n", "\n", "#@save\n", "def read_data_nmt():\n", " \"\"\"Load the English-French dataset.\"\"\"\n", " data_dir = d2l.download_extract('fra-eng')\n", " with open(os.path.join(data_dir, 'fra.txt'), 'r', encoding='utf-8') as f:\n", " return f.read()\n", "\n", "#@save\n", "def preprocess_nmt(text):\n", " \"\"\"Preprocess the English-French dataset.\"\"\"\n", " def no_space(char, prev_char):\n", " return char in set(',.!?') and prev_char != ' '\n", "\n", " # Replace non-breaking space with space, and convert uppercase letters to\n", " # lowercase ones\n", " text = text.replace('\\u202f', ' ').replace('\\xa0', ' ').lower()\n", " # Insert space between words and punctuation marks\n", " out = [' ' + char if i > 0 and no_space(char, text[i - 1]) else char\n", " for i, char in enumerate(text)]\n", " return ''.join(out)\n", "\n", "#@save\n", "def tokenize_nmt(text, num_examples=None):\n", " \"\"\"Tokenize the English-French dataset.\"\"\"\n", " source, target = [], []\n", " for i, line in enumerate(text.split('\\n')):\n", " if num_examples and i > num_examples:\n", " break\n", " parts = line.split('\\t')\n", " if len(parts) == 2:\n", " source.append(parts[0].split(' '))\n", " target.append(parts[1].split(' '))\n", " return source, target\n", "\n", "\n", "#@save\n", "def truncate_pad(line, num_steps, padding_token):\n", " \"\"\"Truncate or pad sequences.\"\"\"\n", " if len(line) > num_steps:\n", " return line[:num_steps] # Truncate\n", " return line + [padding_token] * (num_steps - len(line)) # Pad\n", "\n", "\n", "#@save\n", "def build_array_nmt(lines, vocab, num_steps):\n", " \"\"\"Transform text sequences of machine translation into minibatches.\"\"\"\n", " lines = [vocab[l] for l in lines]\n", " lines = [l + [vocab['']] for l in lines]\n", " array = torch.tensor([truncate_pad(\n", " l, num_steps, vocab['']) for l in lines])\n", " valid_len = (array != vocab['']).type(torch.int32).sum(1)\n", " return array, valid_len\n", "\n", "\n", "#@save\n", "def load_data_nmt(batch_size, num_steps, num_examples=600):\n", " \"\"\"Return the iterator and the vocabularies of the translation dataset.\"\"\"\n", " text = preprocess_nmt(read_data_nmt())\n", " source, target = tokenize_nmt(text, num_examples)\n", " src_vocab = d2l.Vocab(source, min_freq=2,\n", " reserved_tokens=['', '', ''])\n", " tgt_vocab = d2l.Vocab(target, min_freq=2,\n", " reserved_tokens=['', '', ''])\n", " src_array, src_valid_len = build_array_nmt(source, src_vocab, num_steps)\n", " tgt_array, tgt_valid_len = build_array_nmt(target, tgt_vocab, num_steps)\n", " data_arrays = (src_array, src_valid_len, tgt_array, tgt_valid_len)\n", " data_iter = d2l.load_array(data_arrays, batch_size)\n", " return data_iter, src_vocab, tgt_vocab" ] }, { "cell_type": "code", "execution_count": 16, "id": "c9ab814a", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:28:12.994838Z", "iopub.status.busy": "2023-08-18T19:28:12.994441Z", "iopub.status.idle": "2023-08-18T19:28:13.011201Z", "shell.execute_reply": "2023-08-18T19:28:13.010391Z" }, "origin_pos": 34, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "#@save\n", "def sequence_mask(X, valid_len, value=0):\n", " \"\"\"Mask irrelevant entries in sequences.\"\"\"\n", " maxlen = X.size(1)\n", " mask = torch.arange((maxlen), dtype=torch.float32,\n", " device=X.device)[None, :] < valid_len[:, None]\n", " X[~mask] = value\n", " return X\n", "\n", "\n", "#@save\n", "class MaskedSoftmaxCELoss(nn.CrossEntropyLoss):\n", " \"\"\"The softmax cross-entropy loss with masks.\"\"\"\n", " # `pred` shape: (`batch_size`, `num_steps`, `vocab_size`)\n", " # `label` shape: (`batch_size`, `num_steps`)\n", " # `valid_len` shape: (`batch_size`,)\n", " def forward(self, pred, label, valid_len):\n", " weights = torch.ones_like(label)\n", " weights = sequence_mask(weights, valid_len)\n", " self.reduction='none'\n", " unweighted_loss = super(MaskedSoftmaxCELoss, self).forward(\n", " pred.permute(0, 2, 1), label)\n", " weighted_loss = (unweighted_loss * weights).mean(dim=1)\n", " return weighted_loss\n", "\n", "#@save\n", "def train_seq2seq(net, data_iter, lr, num_epochs, tgt_vocab, device):\n", " \"\"\"Train a model for sequence to sequence.\"\"\"\n", " def xavier_init_weights(m):\n", " if type(m) == nn.Linear:\n", " nn.init.xavier_uniform_(m.weight)\n", " if type(m) == nn.GRU:\n", " for param in m._flat_weights_names:\n", " if \"weight\" in param:\n", " nn.init.xavier_uniform_(m._parameters[param])\n", " net.apply(xavier_init_weights)\n", " net.to(device)\n", " optimizer = torch.optim.Adam(net.parameters(), lr=lr)\n", " loss = MaskedSoftmaxCELoss()\n", " net.train()\n", " animator = d2l.Animator(xlabel='epoch', ylabel='loss',\n", " xlim=[10, num_epochs])\n", " for epoch in range(num_epochs):\n", " timer = d2l.Timer()\n", " metric = d2l.Accumulator(2) # Sum of training loss, no. of tokens\n", " for batch in data_iter:\n", " optimizer.zero_grad()\n", " X, X_valid_len, Y, Y_valid_len = [x.to(device) for x in batch]\n", " bos = torch.tensor([tgt_vocab['']] * Y.shape[0],\n", " device=device).reshape(-1, 1)\n", " dec_input = torch.cat([bos, Y[:, :-1]], 1) # Teacher forcing\n", " Y_hat, _ = net(X, dec_input, X_valid_len)\n", " l = loss(Y_hat, Y, Y_valid_len)\n", " l.sum().backward() # Make the loss scalar for `backward`\n", " d2l.grad_clipping(net, 1)\n", " num_tokens = Y_valid_len.sum()\n", " optimizer.step()\n", " with torch.no_grad():\n", " metric.add(l.sum(), num_tokens)\n", " if (epoch + 1) % 10 == 0:\n", " animator.add(epoch + 1, (metric[0] / metric[1],))\n", " print(f'loss {metric[0] / metric[1]:.3f}, {metric[1] / timer.stop():.1f} '\n", " f'tokens/sec on {str(device)}')\n", "\n", "\n", "#@save\n", "def predict_seq2seq(net, src_sentence, src_vocab, tgt_vocab, num_steps,\n", " device, save_attention_weights=False):\n", " \"\"\"Predict for sequence to sequence.\"\"\"\n", " # Set `net` to eval mode for inference\n", " net.eval()\n", " src_tokens = src_vocab[src_sentence.lower().split(' ')] + [\n", " src_vocab['']]\n", " enc_valid_len = torch.tensor([len(src_tokens)], device=device)\n", " src_tokens = d2l.truncate_pad(src_tokens, num_steps, src_vocab[''])\n", " # Add the batch axis\n", " enc_X = torch.unsqueeze(\n", " torch.tensor(src_tokens, dtype=torch.long, device=device), dim=0)\n", " enc_outputs = net.encoder(enc_X, enc_valid_len)\n", " dec_state = net.decoder.init_state(enc_outputs, enc_valid_len)\n", " # Add the batch axis\n", " dec_X = torch.unsqueeze(torch.tensor(\n", " [tgt_vocab['']], dtype=torch.long, device=device), dim=0)\n", " output_seq, attention_weight_seq = [], []\n", " for _ in range(num_steps):\n", " Y, dec_state = net.decoder(dec_X, dec_state)\n", " # We use the token with the highest prediction likelihood as input\n", " # of the decoder at the next time step\n", " dec_X = Y.argmax(dim=2)\n", " pred = dec_X.squeeze(dim=0).type(torch.int32).item()\n", " # Save attention weights (to be covered later)\n", " if save_attention_weights:\n", " attention_weight_seq.append(net.decoder.attention_weights)\n", " # Once the end-of-sequence token is predicted, the generation of the\n", " # output sequence is complete\n", " if pred == tgt_vocab['']:\n", " break\n", " output_seq.append(pred)\n", " return ' '.join(tgt_vocab.to_tokens(output_seq)), attention_weight_seq" ] } ], "metadata": { "language_info": { "name": "python" }, "required_libs": [] }, "nbformat": 4, "nbformat_minor": 5 }