{ "cells": [ { "cell_type": "markdown", "id": "1fd3a016", "metadata": { "origin_pos": 1 }, "source": [ "# What Is Hyperparameter Optimization?\n", ":label:`sec_what_is_hpo`\n", "\n", "As we have seen in the previous chapters, deep neural networks come with a\n", "large number of parameters or weights that are learned during training. On\n", "top of these, every neural network has additional *hyperparameters* that need\n", "to be configured by the user. For example, to ensure that stochastic gradient\n", "descent converges to a local optimum of the training loss\n", "(see :numref:`chap_optimization`), we have to adjust the learning rate and batch\n", "size. To avoid overfitting on training datasets,\n", "we might have to set regularization parameters, such as weight decay\n", "(see :numref:`sec_weight_decay`) or dropout (see :numref:`sec_dropout`). We can\n", "define the capacity and inductive bias of the model by setting the number of\n", "layers and number of units or filters per layer (i.e., the effective number\n", "of weights).\n", "\n", "Unfortunately, we cannot simply adjust these hyperparameters by minimizing the\n", "training loss, because this would lead to overfitting on the training data. For\n", "example, setting regularization parameters, such as dropout or weight decay\n", "to zero leads to a small training loss, but might hurt the generalization\n", "performance.\n", "\n", "![Typical workflow in machine learning that consists of training the model multiple times with different hyperparameters.](../img/ml_workflow.svg)\n", ":label:`ml_workflow`\n", "\n", "Without a different form of automation, hyperparameters have to be set manually\n", "in a trial-and-error fashion, in what amounts to a time-consuming and difficult\n", "part of machine learning workflows. For example, consider training\n", "a ResNet (see :numref:`sec_resnet`) on CIFAR-10, which requires more than 2 hours\n", "on an Amazon Elastic Cloud Compute (EC2) `g4dn.xlarge` instance. Even just\n", "trying ten hyperparameter configurations in sequence, this would already take us\n", "roughly one day. To make matters worse, hyperparameters are usually not directly\n", "transferable across architectures and datasets\n", ":cite:`feurer-arxiv22,wistuba-ml18,bardenet-icml13a`, and need to be re-optimized\n", "for every new task. Also, for most hyperparameters, there are no rule-of-thumbs,\n", "and expert knowledge is required to find sensible values.\n", "\n", "*Hyperparameter optimization (HPO)* algorithms are designed to tackle this\n", "problem in a principled and automated fashion :cite:`feurer-automlbook18a`, by\n", "framing it as a global optimization problem. The default objective is the error\n", "on a hold-out validation dataset, but could in principle be any other business\n", "metric. It can be combined with or constrained by secondary objectives, such as\n", "training time, inference time, or model complexity. \n", "\n", "Recently, hyperparameter optimization has been extended to *neural architecture\n", "search (NAS)* :cite:`elsken-arxiv18a,wistuba-arxiv19`, where the goal is to find\n", "entirely new neural network architectures. Compared to classical HPO, NAS is even\n", "more expensive in terms of computation and requires additional efforts to remain\n", "feasible in practice. Both, HPO and NAS can be considered as sub-fields of \n", "AutoML :cite:`hutter-book19a`, which aims to automate the entire ML pipeline.\n", "\n", "In this section we will introduce HPO and show how we can automatically find\n", "the best hyperparameters of the logistic regression example introduced in\n", ":numref:`sec_softmax_concise`.\n", "\n", "## The Optimization Problem\n", ":label:`sec_definition_hpo`\n", "\n", "We will start with a simple toy problem: searching for the learning rate of the\n", "multi-class logistic regression model `SoftmaxRegression` from\n", ":numref:`sec_softmax_concise` to minimize the validation error on the Fashion\n", "MNIST dataset. While other hyperparameters like batch size or number of epochs\n", "are also worth tuning, we focus on learning rate alone for simplicity.\n" ] }, { "cell_type": "code", "execution_count": 1, "id": "37611073", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:29:44.506669Z", "iopub.status.busy": "2023-08-18T19:29:44.506127Z", "iopub.status.idle": "2023-08-18T19:29:47.782553Z", "shell.execute_reply": "2023-08-18T19:29:47.781668Z" }, "origin_pos": 2, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "import numpy as np\n", "import torch\n", "from scipy import stats\n", "from torch import nn\n", "from d2l import torch as d2l" ] }, { "cell_type": "markdown", "id": "b51422df", "metadata": { "origin_pos": 3 }, "source": [ "Before we can run HPO, we first need to define two ingredients: the objective\n", "function and the configuration space.\n", "\n", "### The Objective Function\n", "\n", "The performance of a learning algorithm can be seen as a function\n", "$f: \\mathcal{X} \\rightarrow \\mathbb{R}$ that maps from the hyperparameter space\n", "$\\mathbf{x} \\in \\mathcal{X}$ to the validation loss. For every evaluation of\n", "$f(\\mathbf{x})$, we have to train and validate our machine learning model, which\n", "can be time and compute intensive in the case of deep neural networks trained on\n", "large datasets. Given our criterion $f(\\mathbf{x})$ our goal is to find\n", "$\\mathbf{x}_{\\star} \\in \\mathrm{argmin}_{\\mathbf{x} \\in \\mathcal{X}} f(\\mathbf{x})$. \n", "\n", "There is no simple way to compute gradients of $f$ with respect to $\\mathbf{x}$,\n", "because it would require to propagate the gradient through the entire training\n", "process. While there is recent work :cite:`maclaurin-icml15,franceschi-icml17a`\n", "to drive HPO by approximate \"hypergradients\", none of the existing approaches\n", "are competitive with the state-of-the-art yet, and we will not discuss them\n", "here. Furthermore, the computational burden of evaluating $f$ requires HPO\n", "algorithms to approach the global optimum with as few samples as possible.\n", "\n", "The training of neural networks is stochastic (e.g., weights are randomly\n", "initialized, mini-batches are randomly sampled), so that our observations will\n", "be noisy: $y \\sim f(\\mathbf{x}) + \\epsilon$, where we usually assume that the\n", "$\\epsilon \\sim N(0, \\sigma)$ observation noise is Gaussian distributed.\n", "\n", "Faced with all these challenges, we usually try to identify a small set of well\n", "performing hyperparameter configurations quickly, instead of hitting the global\n", "optima exactly. However, due to large computational demands of most neural\n", "networks models, even this can take days or weeks of compute. We will explore\n", "in :numref:`sec_mf_hpo` how we can speed-up the optimization process by either\n", "distributing the search or using cheaper-to-evaluate approximations of the\n", "objective function.\n", "\n", "We begin with a method for computing the validation error of a model.\n" ] }, { "cell_type": "code", "execution_count": 2, "id": "8f5a6798", "metadata": { "attributes": { "classes": [], "id": "", "n": "8" }, "execution": { "iopub.execute_input": "2023-08-18T19:29:47.786829Z", "iopub.status.busy": "2023-08-18T19:29:47.786035Z", "iopub.status.idle": "2023-08-18T19:29:47.791984Z", "shell.execute_reply": "2023-08-18T19:29:47.791090Z" }, "origin_pos": 4, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "class HPOTrainer(d2l.Trainer): #@save\n", " def validation_error(self):\n", " self.model.eval()\n", " accuracy = 0\n", " val_batch_idx = 0\n", " for batch in self.val_dataloader:\n", " with torch.no_grad():\n", " x, y = self.prepare_batch(batch)\n", " y_hat = self.model(x)\n", " accuracy += self.model.accuracy(y_hat, y)\n", " val_batch_idx += 1\n", " return 1 - accuracy / val_batch_idx" ] }, { "cell_type": "markdown", "id": "06625e65", "metadata": { "origin_pos": 5 }, "source": [ "We optimize validation error with respect to the hyperparameter configuration\n", "`config`, consisting of the `learning_rate`. For each evaluation, we train our\n", "model for `max_epochs` epochs, then compute and return its validation error:\n" ] }, { "cell_type": "code", "execution_count": 3, "id": "7bb87209", "metadata": { "attributes": { "classes": [], "id": "", "n": "5" }, "execution": { "iopub.execute_input": "2023-08-18T19:29:47.795434Z", "iopub.status.busy": "2023-08-18T19:29:47.794840Z", "iopub.status.idle": "2023-08-18T19:29:47.799807Z", "shell.execute_reply": "2023-08-18T19:29:47.798996Z" }, "origin_pos": 6, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "def hpo_objective_softmax_classification(config, max_epochs=8):\n", " learning_rate = config[\"learning_rate\"]\n", " trainer = d2l.HPOTrainer(max_epochs=max_epochs)\n", " data = d2l.FashionMNIST(batch_size=16)\n", " model = d2l.SoftmaxRegression(num_outputs=10, lr=learning_rate)\n", " trainer.fit(model=model, data=data)\n", " return trainer.validation_error().detach().numpy()" ] }, { "cell_type": "markdown", "id": "45aec920", "metadata": { "origin_pos": 7 }, "source": [ "### The Configuration Space\n", ":label:`sec_intro_config_spaces`\n", "\n", "Along with the objective function $f(\\mathbf{x})$, we also need to define the\n", "feasible set $\\mathbf{x} \\in \\mathcal{X}$ to optimize over, known as\n", "*configuration space* or *search space*. For our logistic regression example,\n", "we will use:\n" ] }, { "cell_type": "code", "execution_count": 4, "id": "c439a206", "metadata": { "attributes": { "classes": [], "id": "", "n": "6" }, "execution": { "iopub.execute_input": "2023-08-18T19:29:47.803280Z", "iopub.status.busy": "2023-08-18T19:29:47.802540Z", "iopub.status.idle": "2023-08-18T19:29:47.807469Z", "shell.execute_reply": "2023-08-18T19:29:47.806663Z" }, "origin_pos": 8, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "config_space = {\"learning_rate\": stats.loguniform(1e-4, 1)}" ] }, { "cell_type": "markdown", "id": "f1ba685a", "metadata": { "origin_pos": 9 }, "source": [ "Here we use the use the `loguniform` object from SciPy, which represents a\n", "uniform distribution between -4 and -1 in the logarithmic space. This object\n", "allows us to sample random variables from this distribution.\n", "\n", "Each hyperparameter has a data type, such as `float` for `learning_rate`, as\n", "well as a closed bounded range (i.e., lower and upper bounds). We usually assign\n", "a prior distribution (e.g, uniform or log-uniform) to each hyperparameter to\n", "sample from. Some positive parameters, such as `learning_rate`, are best\n", "represented on a logarithmic scale as optimal values can differ by several\n", "orders of magnitude, while others, such as momentum, come with linear scale.\n", "\n", "Below we show a simple example of a configuration space consisting of typical\n", "hyperparameters of a multi-layer perceptron including their type and standard\n", "ranges.\n", "\n", ": Example configuration space of multi-layer perceptron\n", ":label:`tab_example_configspace`\n", "\n", "| Name | Type |Hyperparameter Ranges | log-scale |\n", "| :----: | :----: |:------------------------------:|:---------:|\n", "| learning rate | float | $[10^{-6},10^{-1}]$ | yes |\n", "| batch size | integer | $[8,256]$ | yes |\n", "| momentum | float | $[0,0.99]$ | no |\n", "| activation function | categorical | $\\{\\textrm{tanh}, \\textrm{relu}\\}$ | - |\n", "| number of units | integer | $[32, 1024]$ | yes |\n", "| number of layers | integer | $[1, 6]$ | no |\n", "\n", "\n", "\n", "In general, the structure of the configuration space $\\mathcal{X}$ can be complex\n", "and it can be quite different from $\\mathbb{R}^d$. In practice, some\n", "hyperparameters may depend on the value of others. For example, assume we try\n", "to tune the number of layers for a multi-layer perceptron, and for each layer\n", "the number of units. The number of units of the $l\\textrm{-th}$ layer is\n", "relevant only if the network has at least $l+1$ layers. These advanced HPO\n", "problems are beyond the scope of this chapter. We refer the interested reader\n", "to :cite:`hutter-lion11a,jenatton-icml17a,baptista-icml18a`.\n", "\n", "The configuration space plays an important role for hyperparameter optimization,\n", "since no algorithms can find something that is not included in the configuration\n", "space. On the other hand, if the ranges are too large, the computation budget\n", "to find well performing configurations might become infeasible.\n", "\n", "## Random Search\n", ":label:`sec_rs`\n", "\n", "*Random search* is the first hyperparameter optimization algorithm we will\n", "consider. The main idea of random search is to independently sample from the\n", "configuration space until a predefined budget (e.g maximum\n", "number of iterations) is exhausted, and to return the best observed\n", "configuration. All evaluations can be executed independently in parallel (see\n", ":numref:`sec_rs_async`), but here we use a sequential loop for simplicity.\n" ] }, { "cell_type": "code", "execution_count": 5, "id": "07f883db", "metadata": { "attributes": { "classes": [], "id": "", "n": "7" }, "execution": { "iopub.execute_input": "2023-08-18T19:29:47.810993Z", "iopub.status.busy": "2023-08-18T19:29:47.810248Z", "iopub.status.idle": "2023-08-18T19:40:58.104742Z", "shell.execute_reply": "2023-08-18T19:40:58.103591Z" }, "origin_pos": 10, "tab": [ "pytorch" ] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " validation_error = 0.17070001363754272\n" ] }, { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-08-18T19:40:56.501558\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.7.2, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-08-18T19:40:56.857170\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.7.2, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-08-18T19:40:57.260562\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.7.2, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-08-18T19:40:57.656031\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.7.2, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-08-18T19:40:58.009492\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.7.2, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "errors, values = [], []\n", "num_iterations = 5\n", "\n", "for i in range(num_iterations):\n", " learning_rate = config_space[\"learning_rate\"].rvs()\n", " print(f\"Trial {i}: learning_rate = {learning_rate}\")\n", " y = hpo_objective_softmax_classification({\"learning_rate\": learning_rate})\n", " print(f\" validation_error = {y}\")\n", " values.append(learning_rate)\n", " errors.append(y)" ] }, { "cell_type": "markdown", "id": "126a63c5", "metadata": { "origin_pos": 11 }, "source": [ "The best learning rate is then simply the one with the lowest validation error.\n" ] }, { "cell_type": "code", "execution_count": 6, "id": "43170dca", "metadata": { "attributes": { "classes": [], "id": "", "n": "7" }, "execution": { "iopub.execute_input": "2023-08-18T19:40:58.110279Z", "iopub.status.busy": "2023-08-18T19:40:58.109413Z", "iopub.status.idle": "2023-08-18T19:40:58.115714Z", "shell.execute_reply": "2023-08-18T19:40:58.114693Z" }, "origin_pos": 12, "tab": [ "pytorch" ] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "optimal learning rate = 0.09844872561810249\n" ] } ], "source": [ "best_idx = np.argmin(errors)\n", "print(f\"optimal learning rate = {values[best_idx]}\")" ] }, { "cell_type": "markdown", "id": "267bcd29", "metadata": { "origin_pos": 13 }, "source": [ "Due to its simplicity and generality, random search is one of the most frequently\n", "used HPO algorithms. It does not require any sophisticated implementation and\n", "can be applied to any configuration space as long as we can define some\n", "probability distribution for each hyperparameter.\n", "\n", "Unfortunately random search also comes with a few shortcomings. First, it does\n", "not adapt the sampling distribution based on the previous observations it\n", "collected so far. Hence, it is equally likely to sample a poorly performing\n", "configuration than a better performing configuration. Second, the same amount\n", "of resources are spent for all configurations, even though some may show poor\n", "initial performance and are less likely to outperform previously seen\n", "configurations.\n", "\n", "In the next sections we will look at more sample efficient hyperparameter\n", "optimization algorithms that overcome the shortcomings of random search by\n", "using a model to guide the search. We will also look at algorithms that\n", "automatically stop the evaluation process of poorly performing configurations\n", "to speed up the optimization process.\n", "\n", "## Summary\n", "\n", "In this section we introduced hyperparameter optimization (HPO) and how we can\n", "phrase it as a global optimization by defining a configuration space and an\n", "objective function. We also implemented our first HPO algorithm, random search,\n", "and applied it on a simple softmax classification problem.\n", "\n", "While random search is very simple, it is the better alternative to grid\n", "search, which simply evaluates a fixed set of hyperparameters. Random search\n", "somewhat mitigates the curse of dimensionality :cite:`bellman-science66`, and\n", "can be far more efficient than grid search if the criterion most strongly\n", "depends on a small subset of the hyperparameters.\n", "\n", "## Exercises\n", "\n", "1. In this chapter, we optimize the validation error of a model after training on a disjoint training set. For simplicity, our code uses `Trainer.val_dataloader`, which maps to a loader around `FashionMNIST.val`.\n", " 1. Convince yourself (by looking at the code) that this means we use the original FashionMNIST training set (60000 examples) for training, and the original *test set* (10000 examples) for validation.\n", " 2. Why could this practice be problematic? Hint: Re-read :numref:`sec_generalization_basics`, especially about *model selection*.\n", " 3. What should we have done instead?\n", "2. We stated above that hyperparameter optimization by gradient descent is very hard to do. Consider a small problem, such as training a two-layer perceptron on the FashionMNIST dataset (:numref:`sec_mlp-implementation`) with a batch size of 256. We would like to tune the learning rate of SGD in order to minimize a validation metric after one epoch of training.\n", " 1. Why cannot we use validation *error* for this purpose? What metric on the validation set would you use?\n", " 2. Sketch (roughly) the computational graph of the validation metric after training for one epoch. You may assume that initial weights and hyperparameters (such as learning rate) are input nodes to this graph. Hint: Re-read about computational graphs in :numref:`sec_backprop`.\n", " 3. Give a rough estimate of the number of floating point values you need to store during a forward pass on this graph. Hint: FashionMNIST has 60000 cases. Assume the required memory is dominated by the activations after each layer, and look up the layer widths in :numref:`sec_mlp-implementation`.\n", " 5. Apart from the sheer amount of compute and storage required, what other issues would gradient-based hyperparameter optimization run into? Hint: Re-read about vanishing and exploding gradients in :numref:`sec_numerical_stability`.\n", " 6. *Advanced*: Read :cite:`maclaurin-icml15` for an elegant (yet still somewhat unpractical) approach to gradient-based HPO.\n", "3. Grid search is another HPO baseline, where we define an equi-spaced grid for each hyperparameter, then iterate over the (combinatorial) Cartesian product in order to suggest configurations.\n", " 1. We stated above that random search can be much more efficient than grid search for HPO on a sizable number of hyperparameters, if the criterion most strongly depends on a small subset of the hyperparameters. Why is this? Hint: Read :cite:`bergstra2011algorithms`.\n" ] }, { "cell_type": "markdown", "id": "b1f1c594", "metadata": { "origin_pos": 14, "tab": [ "pytorch" ] }, "source": [ "[Discussions](https://discuss.d2l.ai/t/12090)\n" ] } ], "metadata": { "language_info": { "name": "python" }, "required_libs": [] }, "nbformat": 4, "nbformat_minor": 5 }