{ "cells": [ { "cell_type": "markdown", "id": "db492c7b", "metadata": { "origin_pos": 0 }, "source": [ "# Image Augmentation\n", ":label:`sec_image_augmentation`\n", "\n", "In :numref:`sec_alexnet`, \n", "we mentioned that large datasets \n", "are a prerequisite\n", "for the success of\n", "deep neural networks\n", "in various applications.\n", "*Image augmentation* \n", "generates similar but distinct training examples\n", "after a series of random changes to the training images, thereby expanding the size of the training set.\n", "Alternatively,\n", "image augmentation can be motivated\n", "by the fact that \n", "random tweaks of training examples \n", "allow models to rely less on\n", "certain attributes, thereby improving their generalization ability.\n", "For example, we can crop an image in different ways to make the object of interest appear in different positions, thereby reducing the dependence of a model on the position of the object. \n", "We can also adjust factors such as brightness and color to reduce a model's sensitivity to color.\n", "It is probably true\n", "that image augmentation was indispensable\n", "for the success of AlexNet at that time.\n", "In this section we will discuss this widely used technique in computer vision.\n" ] }, { "cell_type": "code", "execution_count": 1, "id": "9156edd4", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:29:21.137867Z", "iopub.status.busy": "2023-08-18T19:29:21.137598Z", "iopub.status.idle": "2023-08-18T19:29:26.996310Z", "shell.execute_reply": "2023-08-18T19:29:26.995373Z" }, "origin_pos": 2, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "%matplotlib inline\n", "import torch\n", "import torchvision\n", "from torch import nn\n", "from d2l import torch as d2l" ] }, { "cell_type": "markdown", "id": "b98250c7", "metadata": { "origin_pos": 3 }, "source": [ "## Common Image Augmentation Methods\n", "\n", "In our investigation of common image augmentation methods, we will use the following $400\\times 500$ image an example.\n" ] }, { "cell_type": "code", "execution_count": 2, "id": "09b9d785", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:29:27.001470Z", "iopub.status.busy": "2023-08-18T19:29:27.000767Z", "iopub.status.idle": "2023-08-18T19:29:27.240747Z", "shell.execute_reply": "2023-08-18T19:29:27.239864Z" }, "origin_pos": 5, "tab": [ "pytorch" ] }, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-08-18T19:29:27.163217\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.7.2, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "d2l.set_figsize()\n", "img = d2l.Image.open('../img/cat1.jpg')\n", "d2l.plt.imshow(img);" ] }, { "cell_type": "markdown", "id": "91261ab8", "metadata": { "origin_pos": 6 }, "source": [ "Most image augmentation methods have a certain degree of randomness. To make it easier for us to observe the effect of image augmentation, next we define an auxiliary function `apply`. This function runs the image augmentation method `aug` multiple times on the input image `img` and shows all the results.\n" ] }, { "cell_type": "code", "execution_count": 3, "id": "a0f9c352", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:29:27.248513Z", "iopub.status.busy": "2023-08-18T19:29:27.247942Z", "iopub.status.idle": "2023-08-18T19:29:27.252681Z", "shell.execute_reply": "2023-08-18T19:29:27.251835Z" }, "origin_pos": 7, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "def apply(img, aug, num_rows=2, num_cols=4, scale=1.5):\n", " Y = [aug(img) for _ in range(num_rows * num_cols)]\n", " d2l.show_images(Y, num_rows, num_cols, scale=scale)" ] }, { "cell_type": "markdown", "id": "a216c50c", "metadata": { "origin_pos": 8 }, "source": [ "### Flipping and Cropping\n" ] }, { "cell_type": "markdown", "id": "a1ca3607", "metadata": { "origin_pos": 10, "tab": [ "pytorch" ] }, "source": [ "[**Flipping the image left and right**] usually does not change the category of the object. \n", "This is one of the earliest and most widely used methods of image augmentation.\n", "Next, we use the `transforms` module to create the `RandomHorizontalFlip` instance, which flips\n", "an image left and right with a 50% chance.\n" ] }, { "cell_type": "code", "execution_count": 4, "id": "4bab7249", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:29:27.257058Z", "iopub.status.busy": "2023-08-18T19:29:27.256507Z", "iopub.status.idle": "2023-08-18T19:29:27.917989Z", "shell.execute_reply": "2023-08-18T19:29:27.917148Z" }, "origin_pos": 12, "tab": [ "pytorch" ] }, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-08-18T19:29:27.773074\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.7.2, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "apply(img, torchvision.transforms.RandomHorizontalFlip())" ] }, { "cell_type": "markdown", "id": "6c8b8d93", "metadata": { "origin_pos": 14, "tab": [ "pytorch" ] }, "source": [ "[**Flipping up and down**] is not as common as flipping left and right. But at least for this example image, flipping up and down does not hinder recognition.\n", "Next, we create a `RandomVerticalFlip` instance to flip\n", "an image up and down with a 50% chance.\n" ] }, { "cell_type": "code", "execution_count": 5, "id": "6383df69", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:29:27.922797Z", "iopub.status.busy": "2023-08-18T19:29:27.922207Z", "iopub.status.idle": "2023-08-18T19:29:28.463825Z", "shell.execute_reply": "2023-08-18T19:29:28.462928Z" }, "origin_pos": 16, "tab": [ "pytorch" ] }, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-08-18T19:29:28.316410\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.7.2, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "apply(img, torchvision.transforms.RandomVerticalFlip())" ] }, { "cell_type": "markdown", "id": "2651d645", "metadata": { "origin_pos": 17 }, "source": [ "In the example image we used, the cat is in the middle of the image, but this may not be the case in general. \n", "In :numref:`sec_pooling`, we explained that the pooling layer can reduce the sensitivity of a convolutional layer to the target position.\n", "In addition, we can also randomly crop the image to make objects appear in different positions in the image at different scales, which can also reduce the sensitivity of a model to the target position.\n", "\n", "In the code below, we [**randomly crop**] an area with an area of $10\\% \\sim 100\\%$ of the original area each time, and the ratio of width to height of this area is randomly selected from $0.5 \\sim 2$. Then, the width and height of the region are both scaled to 200 pixels. \n", "Unless otherwise specified, the random number between $a$ and $b$ in this section refers to a continuous value obtained by random and uniform sampling from the interval $[a, b]$.\n" ] }, { "cell_type": "code", "execution_count": 6, "id": "0189b481", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:29:28.471478Z", "iopub.status.busy": "2023-08-18T19:29:28.470821Z", "iopub.status.idle": "2023-08-18T19:29:28.849478Z", "shell.execute_reply": "2023-08-18T19:29:28.848615Z" }, "origin_pos": 19, "tab": [ "pytorch" ] }, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-08-18T19:29:28.751971\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.7.2, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "shape_aug = torchvision.transforms.RandomResizedCrop(\n", " (200, 200), scale=(0.1, 1), ratio=(0.5, 2))\n", "apply(img, shape_aug)" ] }, { "cell_type": "markdown", "id": "9b17c5a5", "metadata": { "origin_pos": 20 }, "source": [ "### Changing Colors\n", "\n", "Another augmentation method is changing colors. We can change four aspects of the image color: brightness, contrast, saturation, and hue. In the example below, we [**randomly change the brightness**] of the image to a value between 50% ($1-0.5$) and 150% ($1+0.5$) of the original image.\n" ] }, { "cell_type": "code", "execution_count": 7, "id": "c677ef85", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:29:28.854373Z", "iopub.status.busy": "2023-08-18T19:29:28.853774Z", "iopub.status.idle": "2023-08-18T19:29:29.455908Z", "shell.execute_reply": "2023-08-18T19:29:29.454983Z" }, "origin_pos": 22, "tab": [ "pytorch" ] }, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-08-18T19:29:29.311386\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.7.2, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "apply(img, torchvision.transforms.ColorJitter(\n", " brightness=0.5, contrast=0, saturation=0, hue=0))" ] }, { "cell_type": "markdown", "id": "4c234904", "metadata": { "origin_pos": 23 }, "source": [ "Similarly, we can [**randomly change the hue**] of the image.\n" ] }, { "cell_type": "code", "execution_count": 8, "id": "fa26bb09", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:29:29.460051Z", "iopub.status.busy": "2023-08-18T19:29:29.459449Z", "iopub.status.idle": "2023-08-18T19:29:30.086603Z", "shell.execute_reply": "2023-08-18T19:29:30.085725Z" }, "origin_pos": 25, "tab": [ "pytorch" ] }, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-08-18T19:29:29.936350\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.7.2, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "apply(img, torchvision.transforms.ColorJitter(\n", " brightness=0, contrast=0, saturation=0, hue=0.5))" ] }, { "cell_type": "markdown", "id": "e01e2f9c", "metadata": { "origin_pos": 26 }, "source": [ "We can also create a `RandomColorJitter` instance and set how to [**randomly change the `brightness`, `contrast`, `saturation`, and `hue` of the image at the same time**].\n" ] }, { "cell_type": "code", "execution_count": 9, "id": "5428065d", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:29:30.091653Z", "iopub.status.busy": "2023-08-18T19:29:30.090710Z", "iopub.status.idle": "2023-08-18T19:29:30.914768Z", "shell.execute_reply": "2023-08-18T19:29:30.913630Z" }, "origin_pos": 28, "tab": [ "pytorch" ] }, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-08-18T19:29:30.771423\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.7.2, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "color_aug = torchvision.transforms.ColorJitter(\n", " brightness=0.5, contrast=0.5, saturation=0.5, hue=0.5)\n", "apply(img, color_aug)" ] }, { "cell_type": "markdown", "id": "4c4b8e74", "metadata": { "origin_pos": 29 }, "source": [ "### Combining Multiple Image Augmentation Methods\n", "\n", "In practice, we will [**combine multiple image augmentation methods**]. \n", "For example,\n", "we can combine the different image augmentation methods defined above and apply them to each image via a `Compose` instance.\n" ] }, { "cell_type": "code", "execution_count": 10, "id": "2a896155", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:29:30.918830Z", "iopub.status.busy": "2023-08-18T19:29:30.917934Z", "iopub.status.idle": "2023-08-18T19:29:31.430226Z", "shell.execute_reply": "2023-08-18T19:29:31.429042Z" }, "origin_pos": 31, "tab": [ "pytorch" ] }, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-08-18T19:29:31.332988\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.7.2, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "augs = torchvision.transforms.Compose([\n", " torchvision.transforms.RandomHorizontalFlip(), color_aug, shape_aug])\n", "apply(img, augs)" ] }, { "cell_type": "markdown", "id": "a8d21882", "metadata": { "origin_pos": 32 }, "source": [ "## [**Training with Image Augmentation**]\n", "\n", "Let's train a model with image augmentation.\n", "Here we use the CIFAR-10 dataset instead of the Fashion-MNIST dataset that we used before. \n", "This is because the position and size of the objects in the Fashion-MNIST dataset have been normalized, while the color and size of the objects in the CIFAR-10 dataset have more significant differences. \n", "The first 32 training images in the CIFAR-10 dataset are shown below.\n" ] }, { "cell_type": "code", "execution_count": 11, "id": "24b6f13d", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:29:31.434766Z", "iopub.status.busy": "2023-08-18T19:29:31.433902Z", "iopub.status.idle": "2023-08-18T19:29:40.274559Z", "shell.execute_reply": "2023-08-18T19:29:40.273350Z" }, "origin_pos": 34, "tab": [ "pytorch" ] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ../data/cifar-10-python.tar.gz\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", " 0%| | 0/170498071 [00:00\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-08-18T19:29:40.099619\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.7.2, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "all_images = torchvision.datasets.CIFAR10(train=True, root=\"../data\",\n", " download=True)\n", "d2l.show_images([all_images[i][0] for i in range(32)], 4, 8, scale=0.8);" ] }, { "cell_type": "markdown", "id": "fdf61781", "metadata": { "origin_pos": 35 }, "source": [ "In order to obtain definitive results during prediction, we usually only apply image augmentation to training examples, and do not use image augmentation with random operations during prediction. \n", "[**Here we only use the simplest random left-right flipping method**]. In addition, we use a `ToTensor` instance to convert a minibatch of images into the format required by the deep learning framework, i.e., \n", "32-bit floating point numbers between 0 and 1 with the shape of (batch size, number of channels, height, width).\n" ] }, { "cell_type": "code", "execution_count": 12, "id": "bba8dca6", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:29:40.279906Z", "iopub.status.busy": "2023-08-18T19:29:40.279133Z", "iopub.status.idle": "2023-08-18T19:29:40.284581Z", "shell.execute_reply": "2023-08-18T19:29:40.283471Z" }, "origin_pos": 37, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "train_augs = torchvision.transforms.Compose([\n", " torchvision.transforms.RandomHorizontalFlip(),\n", " torchvision.transforms.ToTensor()])\n", "\n", "test_augs = torchvision.transforms.Compose([\n", " torchvision.transforms.ToTensor()])" ] }, { "cell_type": "markdown", "id": "3a6ca32d", "metadata": { "origin_pos": 39, "tab": [ "pytorch" ] }, "source": [ "Next, we [**define an auxiliary function to facilitate reading the image and\n", "applying image augmentation**]. \n", "The `transform` argument provided by PyTorch's\n", "dataset applies augmentation to transform the images.\n", "For\n", "a detailed introduction to `DataLoader`, please refer to :numref:`sec_fashion_mnist`.\n" ] }, { "cell_type": "code", "execution_count": 13, "id": "432deae2", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:29:40.288346Z", "iopub.status.busy": "2023-08-18T19:29:40.287475Z", "iopub.status.idle": "2023-08-18T19:29:40.293168Z", "shell.execute_reply": "2023-08-18T19:29:40.291997Z" }, "origin_pos": 41, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "def load_cifar10(is_train, augs, batch_size):\n", " dataset = torchvision.datasets.CIFAR10(root=\"../data\", train=is_train,\n", " transform=augs, download=True)\n", " dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,\n", " shuffle=is_train, num_workers=d2l.get_dataloader_workers())\n", " return dataloader" ] }, { "cell_type": "markdown", "id": "abe99e62", "metadata": { "origin_pos": 42 }, "source": [ "### Multi-GPU Training\n", "\n", "We train the ResNet-18 model from\n", ":numref:`sec_resnet` on the\n", "CIFAR-10 dataset.\n", "Recall the introduction to\n", "multi-GPU training in :numref:`sec_multi_gpu_concise`.\n", "In the following,\n", "[**we define a function to train and evaluate the model using multiple GPUs**].\n" ] }, { "cell_type": "code", "execution_count": 14, "id": "6fb8c6fc", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:29:40.296952Z", "iopub.status.busy": "2023-08-18T19:29:40.296069Z", "iopub.status.idle": "2023-08-18T19:29:40.302999Z", "shell.execute_reply": "2023-08-18T19:29:40.301894Z" }, "origin_pos": 44, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "#@save\n", "def train_batch_ch13(net, X, y, loss, trainer, devices):\n", " \"\"\"Train for a minibatch with multiple GPUs (defined in Chapter 13).\"\"\"\n", " if isinstance(X, list):\n", " # Required for BERT fine-tuning (to be covered later)\n", " X = [x.to(devices[0]) for x in X]\n", " else:\n", " X = X.to(devices[0])\n", " y = y.to(devices[0])\n", " net.train()\n", " trainer.zero_grad()\n", " pred = net(X)\n", " l = loss(pred, y)\n", " l.sum().backward()\n", " trainer.step()\n", " train_loss_sum = l.sum()\n", " train_acc_sum = d2l.accuracy(pred, y)\n", " return train_loss_sum, train_acc_sum" ] }, { "cell_type": "code", "execution_count": 15, "id": "d40afce2", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:29:40.306525Z", "iopub.status.busy": "2023-08-18T19:29:40.305771Z", "iopub.status.idle": "2023-08-18T19:29:40.315273Z", "shell.execute_reply": "2023-08-18T19:29:40.314141Z" }, "origin_pos": 46, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "#@save\n", "def train_ch13(net, train_iter, test_iter, loss, trainer, num_epochs,\n", " devices=d2l.try_all_gpus()):\n", " \"\"\"Train a model with multiple GPUs (defined in Chapter 13).\"\"\"\n", " timer, num_batches = d2l.Timer(), len(train_iter)\n", " animator = d2l.Animator(xlabel='epoch', xlim=[1, num_epochs], ylim=[0, 1],\n", " legend=['train loss', 'train acc', 'test acc'])\n", " net = nn.DataParallel(net, device_ids=devices).to(devices[0])\n", " for epoch in range(num_epochs):\n", " # Sum of training loss, sum of training accuracy, no. of examples,\n", " # no. of predictions\n", " metric = d2l.Accumulator(4)\n", " for i, (features, labels) in enumerate(train_iter):\n", " timer.start()\n", " l, acc = train_batch_ch13(\n", " net, features, labels, loss, trainer, devices)\n", " metric.add(l, acc, labels.shape[0], labels.numel())\n", " timer.stop()\n", " if (i + 1) % (num_batches // 5) == 0 or i == num_batches - 1:\n", " animator.add(epoch + (i + 1) / num_batches,\n", " (metric[0] / metric[2], metric[1] / metric[3],\n", " None))\n", " test_acc = d2l.evaluate_accuracy_gpu(net, test_iter)\n", " animator.add(epoch + 1, (None, None, test_acc))\n", " print(f'loss {metric[0] / metric[2]:.3f}, train acc '\n", " f'{metric[1] / metric[3]:.3f}, test acc {test_acc:.3f}')\n", " print(f'{metric[2] * num_epochs / timer.sum():.1f} examples/sec on '\n", " f'{str(devices)}')" ] }, { "cell_type": "markdown", "id": "c6d31b36", "metadata": { "origin_pos": 47 }, "source": [ "Now we can [**define the `train_with_data_aug` function to train the model with image augmentation**].\n", "This function gets all available GPUs, \n", "uses Adam as the optimization algorithm,\n", "applies image augmentation to the training dataset,\n", "and finally calls the `train_ch13` function just defined to train and evaluate the model.\n" ] }, { "cell_type": "code", "execution_count": 16, "id": "d9dd7d2e", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:29:40.319151Z", "iopub.status.busy": "2023-08-18T19:29:40.318112Z", "iopub.status.idle": "2023-08-18T19:29:40.332876Z", "shell.execute_reply": "2023-08-18T19:29:40.331732Z" }, "origin_pos": 49, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "batch_size, devices, net = 256, d2l.try_all_gpus(), d2l.resnet18(10, 3)\n", "net.apply(d2l.init_cnn)\n", "\n", "def train_with_data_aug(train_augs, test_augs, net, lr=0.001):\n", " train_iter = load_cifar10(True, train_augs, batch_size)\n", " test_iter = load_cifar10(False, test_augs, batch_size)\n", " loss = nn.CrossEntropyLoss(reduction=\"none\")\n", " trainer = torch.optim.Adam(net.parameters(), lr=lr)\n", " net(next(iter(train_iter))[0])\n", " train_ch13(net, train_iter, test_iter, loss, trainer, 10, devices)" ] }, { "cell_type": "markdown", "id": "317fba8a", "metadata": { "origin_pos": 50 }, "source": [ "Let's [**train the model**] using image augmentation based on random left-right flipping.\n" ] }, { "cell_type": "code", "execution_count": 17, "id": "52e519e2", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:29:40.336596Z", "iopub.status.busy": "2023-08-18T19:29:40.335748Z", "iopub.status.idle": "2023-08-18T19:32:06.346574Z", "shell.execute_reply": "2023-08-18T19:32:06.345519Z" }, "origin_pos": 51, "tab": [ "pytorch" ] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "loss 0.215, train acc 0.925, test acc 0.810\n", "4728.8 examples/sec on [device(type='cuda', index=0), device(type='cuda', index=1)]\n" ] }, { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-08-18T19:32:06.293200\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.7.2, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "train_with_data_aug(train_augs, test_augs, net)" ] }, { "cell_type": "markdown", "id": "0f64ff62", "metadata": { "origin_pos": 52 }, "source": [ "## Summary\n", "\n", "* Image augmentation generates random images based on existing training data to improve the generalization ability of models.\n", "* In order to obtain definitive results during prediction, we usually only apply image augmentation to training examples, and do not use image augmentation with random operations during prediction.\n", "* Deep learning frameworks provide many different image augmentation methods, which can be applied simultaneously.\n", "\n", "\n", "## Exercises\n", "\n", "1. Train the model without using image augmentation: `train_with_data_aug(test_augs, test_augs)`. Compare training and testing accuracy when using and not using image augmentation. Can this comparative experiment support the argument that image augmentation can mitigate overfitting? Why?\n", "1. Combine multiple different image augmentation methods in model training on the CIFAR-10 dataset. Does it improve test accuracy? \n", "1. Refer to the online documentation of the deep learning framework. What other image augmentation methods does it also provide?\n" ] }, { "cell_type": "markdown", "id": "47342568", "metadata": { "origin_pos": 54, "tab": [ "pytorch" ] }, "source": [ "[Discussions](https://discuss.d2l.ai/t/1404)\n" ] } ], "metadata": { "language_info": { "name": "python" }, "required_libs": [] }, "nbformat": 4, "nbformat_minor": 5 }