{ "cells": [ { "cell_type": "markdown", "id": "37236084", "metadata": { "origin_pos": 0 }, "source": [ "# Image Classification (CIFAR-10) on Kaggle\n", ":label:`sec_kaggle_cifar10`\n", "\n", "So far, we have been using high-level APIs of deep learning frameworks to directly obtain image datasets in tensor format.\n", "However, custom image datasets\n", "often come in the form of image files.\n", "In this section, we will start from\n", "raw image files,\n", "and organize, read, then transform them\n", "into tensor format step by step.\n", "\n", "We experimented with the CIFAR-10 dataset in :numref:`sec_image_augmentation`,\n", "which is an important dataset in computer vision.\n", "In this section,\n", "we will apply the knowledge we learned\n", "in previous sections\n", "to practice the Kaggle competition of\n", "CIFAR-10 image classification.\n", "(**The web address of the competition is https://www.kaggle.com/c/cifar-10**)\n", "\n", ":numref:`fig_kaggle_cifar10` shows the information on the competition's webpage.\n", "In order to submit the results,\n", "you need to register a Kaggle account.\n", "\n", "![CIFAR-10 image classification competition webpage information. The competition dataset can be obtained by clicking the \"Data\" tab.](../img/kaggle-cifar10.png)\n", ":width:`600px`\n", ":label:`fig_kaggle_cifar10`\n" ] }, { "cell_type": "code", "execution_count": 1, "id": "e5a3fb64", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:29:26.260293Z", "iopub.status.busy": "2023-08-18T19:29:26.259545Z", "iopub.status.idle": "2023-08-18T19:29:29.089576Z", "shell.execute_reply": "2023-08-18T19:29:29.088675Z" }, "origin_pos": 2, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "import collections\n", "import math\n", "import os\n", "import shutil\n", "import pandas as pd\n", "import torch\n", "import torchvision\n", "from torch import nn\n", "from d2l import torch as d2l" ] }, { "cell_type": "markdown", "id": "403ecf47", "metadata": { "origin_pos": 3 }, "source": [ "## Obtaining and Organizing the Dataset\n", "\n", "The competition dataset is divided into\n", "a training set and a test set,\n", "which contain 50000 and 300000 images, respectively.\n", "In the test set,\n", "10000 images will be used for evaluation,\n", "while the remaining 290000 images will not\n", "be evaluated:\n", "they are included just\n", "to make it hard\n", "to cheat with\n", "*manually* labeled results of the test set.\n", "The images in this dataset\n", "are all png color (RGB channels) image files,\n", "whose height and width are both 32 pixels.\n", "The images cover a total of 10 categories, namely airplanes, cars, birds, cats, deer, dogs, frogs, horses, boats, and trucks.\n", "The upper-left corner of :numref:`fig_kaggle_cifar10` shows some images of airplanes, cars, and birds in the dataset.\n", "\n", "\n", "### Downloading the Dataset\n", "\n", "After logging in to Kaggle, we can click the \"Data\" tab on the CIFAR-10 image classification competition webpage shown in :numref:`fig_kaggle_cifar10` and download the dataset by clicking the \"Download All\" button.\n", "After unzipping the downloaded file in `../data`, and unzipping `train.7z` and `test.7z` inside it, you will find the entire dataset in the following paths:\n", "\n", "* `../data/cifar-10/train/[1-50000].png`\n", "* `../data/cifar-10/test/[1-300000].png`\n", "* `../data/cifar-10/trainLabels.csv`\n", "* `../data/cifar-10/sampleSubmission.csv`\n", "\n", "where the `train` and `test` directories contain the training and testing images, respectively, `trainLabels.csv` provides labels for the training images, and `sample_submission.csv` is a sample submission file.\n", "\n", "To make it easier to get started, [**we provide a small-scale sample of the dataset that\n", "contains the first 1000 training images and 5 random testing images.**]\n", "To use the full dataset of the Kaggle competition, you need to set the following `demo` variable to `False`.\n" ] }, { "cell_type": "code", "execution_count": 2, "id": "0d41dcd1", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:29:29.095074Z", "iopub.status.busy": "2023-08-18T19:29:29.094404Z", "iopub.status.idle": "2023-08-18T19:29:29.393994Z", "shell.execute_reply": "2023-08-18T19:29:29.393137Z" }, "origin_pos": 4, "tab": [ "pytorch" ] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Downloading ../data/kaggle_cifar10_tiny.zip from http://d2l-data.s3-accelerate.amazonaws.com/kaggle_cifar10_tiny.zip...\n" ] } ], "source": [ "#@save\n", "d2l.DATA_HUB['cifar10_tiny'] = (d2l.DATA_URL + 'kaggle_cifar10_tiny.zip',\n", " '2068874e4b9a9f0fb07ebe0ad2b29754449ccacd')\n", "\n", "# If you use the full dataset downloaded for the Kaggle competition, set\n", "# `demo` to False\n", "demo = True\n", "\n", "if demo:\n", " data_dir = d2l.download_extract('cifar10_tiny')\n", "else:\n", " data_dir = '../data/cifar-10/'" ] }, { "cell_type": "markdown", "id": "0f716217", "metadata": { "origin_pos": 5 }, "source": [ "### [**Organizing the Dataset**]\n", "\n", "We need to organize datasets to facilitate model training and testing.\n", "Let's first read the labels from the csv file.\n", "The following function returns a dictionary that maps\n", "the non-extension part of the filename to its label.\n" ] }, { "cell_type": "code", "execution_count": 3, "id": "04bf8387", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:29:29.399003Z", "iopub.status.busy": "2023-08-18T19:29:29.398718Z", "iopub.status.idle": "2023-08-18T19:29:29.406335Z", "shell.execute_reply": "2023-08-18T19:29:29.405552Z" }, "origin_pos": 6, "tab": [ "pytorch" ] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "# training examples: 1000\n", "# classes: 10\n" ] } ], "source": [ "#@save\n", "def read_csv_labels(fname):\n", " \"\"\"Read `fname` to return a filename to label dictionary.\"\"\"\n", " with open(fname, 'r') as f:\n", " # Skip the file header line (column name)\n", " lines = f.readlines()[1:]\n", " tokens = [l.rstrip().split(',') for l in lines]\n", " return dict(((name, label) for name, label in tokens))\n", "\n", "labels = read_csv_labels(os.path.join(data_dir, 'trainLabels.csv'))\n", "print('# training examples:', len(labels))\n", "print('# classes:', len(set(labels.values())))" ] }, { "cell_type": "markdown", "id": "abdf4d16", "metadata": { "origin_pos": 7 }, "source": [ "Next, we define the `reorg_train_valid` function to [**split the validation set out of the original training set.**]\n", "The argument `valid_ratio` in this function is the ratio of the number of examples in the validation set to the number of examples in the original training set.\n", "More concretely,\n", "let $n$ be the number of images of the class with the least examples, and $r$ be the ratio.\n", "The validation set will split out\n", "$\\max(\\lfloor nr\\rfloor,1)$ images for each class.\n", "Let's use `valid_ratio=0.1` as an example. Since the original training set has 50000 images,\n", "there will be 45000 images used for training in the path `train_valid_test/train`,\n", "while the other 5000 images will be split out\n", "as validation set in the path `train_valid_test/valid`. After organizing the dataset, images of the same class will be placed under the same folder.\n" ] }, { "cell_type": "code", "execution_count": 4, "id": "0ae3357e", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:29:29.411145Z", "iopub.status.busy": "2023-08-18T19:29:29.410869Z", "iopub.status.idle": "2023-08-18T19:29:29.418258Z", "shell.execute_reply": "2023-08-18T19:29:29.417439Z" }, "origin_pos": 8, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "#@save\n", "def copyfile(filename, target_dir):\n", " \"\"\"Copy a file into a target directory.\"\"\"\n", " os.makedirs(target_dir, exist_ok=True)\n", " shutil.copy(filename, target_dir)\n", "\n", "#@save\n", "def reorg_train_valid(data_dir, labels, valid_ratio):\n", " \"\"\"Split the validation set out of the original training set.\"\"\"\n", " # The number of examples of the class that has the fewest examples in the\n", " # training dataset\n", " n = collections.Counter(labels.values()).most_common()[-1][1]\n", " # The number of examples per class for the validation set\n", " n_valid_per_label = max(1, math.floor(n * valid_ratio))\n", " label_count = {}\n", " for train_file in os.listdir(os.path.join(data_dir, 'train')):\n", " label = labels[train_file.split('.')[0]]\n", " fname = os.path.join(data_dir, 'train', train_file)\n", " copyfile(fname, os.path.join(data_dir, 'train_valid_test',\n", " 'train_valid', label))\n", " if label not in label_count or label_count[label] < n_valid_per_label:\n", " copyfile(fname, os.path.join(data_dir, 'train_valid_test',\n", " 'valid', label))\n", " label_count[label] = label_count.get(label, 0) + 1\n", " else:\n", " copyfile(fname, os.path.join(data_dir, 'train_valid_test',\n", " 'train', label))\n", " return n_valid_per_label" ] }, { "cell_type": "markdown", "id": "6d2cfa19", "metadata": { "origin_pos": 9 }, "source": [ "The `reorg_test` function below [**organizes the testing set for data loading during prediction.**]\n" ] }, { "cell_type": "code", "execution_count": 5, "id": "890972a8", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:29:29.422565Z", "iopub.status.busy": "2023-08-18T19:29:29.422289Z", "iopub.status.idle": "2023-08-18T19:29:29.426856Z", "shell.execute_reply": "2023-08-18T19:29:29.426083Z" }, "origin_pos": 10, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "#@save\n", "def reorg_test(data_dir):\n", " \"\"\"Organize the testing set for data loading during prediction.\"\"\"\n", " for test_file in os.listdir(os.path.join(data_dir, 'test')):\n", " copyfile(os.path.join(data_dir, 'test', test_file),\n", " os.path.join(data_dir, 'train_valid_test', 'test',\n", " 'unknown'))" ] }, { "cell_type": "markdown", "id": "0e790936", "metadata": { "origin_pos": 11 }, "source": [ "Finally, we use a function to [**invoke**]\n", "the `read_csv_labels`, `reorg_train_valid`, and `reorg_test` (**functions defined above.**)\n" ] }, { "cell_type": "code", "execution_count": 6, "id": "00f50b41", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:29:29.431514Z", "iopub.status.busy": "2023-08-18T19:29:29.430810Z", "iopub.status.idle": "2023-08-18T19:29:29.434961Z", "shell.execute_reply": "2023-08-18T19:29:29.434181Z" }, "origin_pos": 12, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "def reorg_cifar10_data(data_dir, valid_ratio):\n", " labels = read_csv_labels(os.path.join(data_dir, 'trainLabels.csv'))\n", " reorg_train_valid(data_dir, labels, valid_ratio)\n", " reorg_test(data_dir)" ] }, { "cell_type": "markdown", "id": "d0454395", "metadata": { "origin_pos": 13 }, "source": [ "Here we only set the batch size to 32 for the small-scale sample of the dataset.\n", "When training and testing\n", "the complete dataset of the Kaggle competition,\n", "`batch_size` should be set to a larger integer, such as 128.\n", "We split out 10% of the training examples as the validation set for tuning hyperparameters.\n" ] }, { "cell_type": "code", "execution_count": 7, "id": "1daf58c4", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:29:29.439643Z", "iopub.status.busy": "2023-08-18T19:29:29.438882Z", "iopub.status.idle": "2023-08-18T19:29:29.700309Z", "shell.execute_reply": "2023-08-18T19:29:29.699321Z" }, "origin_pos": 14, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "batch_size = 32 if demo else 128\n", "valid_ratio = 0.1\n", "reorg_cifar10_data(data_dir, valid_ratio)" ] }, { "cell_type": "markdown", "id": "240c5c50", "metadata": { "origin_pos": 15 }, "source": [ "## [**Image Augmentation**]\n", "\n", "We use image augmentation to address overfitting.\n", "For example, images can be flipped horizontally at random during training.\n", "We can also perform standardization for the three RGB channels of color images. Below lists some of these operations that you can tweak.\n" ] }, { "cell_type": "code", "execution_count": 8, "id": "70e97f85", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:29:29.705470Z", "iopub.status.busy": "2023-08-18T19:29:29.704953Z", "iopub.status.idle": "2023-08-18T19:29:29.710326Z", "shell.execute_reply": "2023-08-18T19:29:29.709544Z" }, "origin_pos": 17, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "transform_train = torchvision.transforms.Compose([\n", " # Scale the image up to a square of 40 pixels in both height and width\n", " torchvision.transforms.Resize(40),\n", " # Randomly crop a square image of 40 pixels in both height and width to\n", " # produce a small square of 0.64 to 1 times the area of the original\n", " # image, and then scale it to a square of 32 pixels in both height and\n", " # width\n", " torchvision.transforms.RandomResizedCrop(32, scale=(0.64, 1.0),\n", " ratio=(1.0, 1.0)),\n", " torchvision.transforms.RandomHorizontalFlip(),\n", " torchvision.transforms.ToTensor(),\n", " # Standardize each channel of the image\n", " torchvision.transforms.Normalize([0.4914, 0.4822, 0.4465],\n", " [0.2023, 0.1994, 0.2010])])" ] }, { "cell_type": "markdown", "id": "d7593105", "metadata": { "origin_pos": 18 }, "source": [ "During testing,\n", "we only perform standardization on images\n", "so as to\n", "remove randomness in the evaluation results.\n" ] }, { "cell_type": "code", "execution_count": 9, "id": "be0d5428", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:29:29.714890Z", "iopub.status.busy": "2023-08-18T19:29:29.714292Z", "iopub.status.idle": "2023-08-18T19:29:29.718602Z", "shell.execute_reply": "2023-08-18T19:29:29.717807Z" }, "origin_pos": 20, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "transform_test = torchvision.transforms.Compose([\n", " torchvision.transforms.ToTensor(),\n", " torchvision.transforms.Normalize([0.4914, 0.4822, 0.4465],\n", " [0.2023, 0.1994, 0.2010])])" ] }, { "cell_type": "markdown", "id": "27f918e0", "metadata": { "origin_pos": 21 }, "source": [ "## Reading the Dataset\n", "\n", "Next, we [**read the organized dataset consisting of raw image files**]. Each example includes an image and a label.\n" ] }, { "cell_type": "code", "execution_count": 10, "id": "056ac33a", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:29:29.722917Z", "iopub.status.busy": "2023-08-18T19:29:29.722506Z", "iopub.status.idle": "2023-08-18T19:29:29.733889Z", "shell.execute_reply": "2023-08-18T19:29:29.733119Z" }, "origin_pos": 23, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "train_ds, train_valid_ds = [torchvision.datasets.ImageFolder(\n", " os.path.join(data_dir, 'train_valid_test', folder),\n", " transform=transform_train) for folder in ['train', 'train_valid']]\n", "\n", "valid_ds, test_ds = [torchvision.datasets.ImageFolder(\n", " os.path.join(data_dir, 'train_valid_test', folder),\n", " transform=transform_test) for folder in ['valid', 'test']]" ] }, { "cell_type": "markdown", "id": "16747ae7", "metadata": { "origin_pos": 24 }, "source": [ "During training,\n", "we need to [**specify all the image augmentation operations defined above**].\n", "When the validation set\n", "is used for model evaluation during hyperparameter tuning,\n", "no randomness from image augmentation should be introduced.\n", "Before final prediction,\n", "we train the model on the combined training set and validation set to make full use of all the labeled data.\n" ] }, { "cell_type": "code", "execution_count": 11, "id": "06fa7207", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:29:29.738557Z", "iopub.status.busy": "2023-08-18T19:29:29.737952Z", "iopub.status.idle": "2023-08-18T19:29:29.743073Z", "shell.execute_reply": "2023-08-18T19:29:29.742323Z" }, "origin_pos": 26, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "train_iter, train_valid_iter = [torch.utils.data.DataLoader(\n", " dataset, batch_size, shuffle=True, drop_last=True)\n", " for dataset in (train_ds, train_valid_ds)]\n", "\n", "valid_iter = torch.utils.data.DataLoader(valid_ds, batch_size, shuffle=False,\n", " drop_last=True)\n", "\n", "test_iter = torch.utils.data.DataLoader(test_ds, batch_size, shuffle=False,\n", " drop_last=False)" ] }, { "cell_type": "markdown", "id": "9e84ffa3", "metadata": { "origin_pos": 27 }, "source": [ "## Defining the [**Model**]\n" ] }, { "cell_type": "markdown", "id": "aceea7ad", "metadata": { "origin_pos": 33, "tab": [ "pytorch" ] }, "source": [ "We define the ResNet-18 model described in\n", ":numref:`sec_resnet`.\n" ] }, { "cell_type": "code", "execution_count": 12, "id": "d527425d", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:29:29.747678Z", "iopub.status.busy": "2023-08-18T19:29:29.747059Z", "iopub.status.idle": "2023-08-18T19:29:29.751129Z", "shell.execute_reply": "2023-08-18T19:29:29.750380Z" }, "origin_pos": 35, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "def get_net():\n", " num_classes = 10\n", " net = d2l.resnet18(num_classes, 3)\n", " return net\n", "\n", "loss = nn.CrossEntropyLoss(reduction=\"none\")" ] }, { "cell_type": "markdown", "id": "b0d9de60", "metadata": { "origin_pos": 36 }, "source": [ "## Defining the [**Training Function**]\n", "\n", "We will select models and tune hyperparameters according to the model's performance on the validation set.\n", "In the following, we define the model training function `train`.\n" ] }, { "cell_type": "code", "execution_count": 13, "id": "bde40789", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:29:29.755665Z", "iopub.status.busy": "2023-08-18T19:29:29.755131Z", "iopub.status.idle": "2023-08-18T19:29:29.764392Z", "shell.execute_reply": "2023-08-18T19:29:29.763621Z" }, "origin_pos": 38, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "def train(net, train_iter, valid_iter, num_epochs, lr, wd, devices, lr_period,\n", " lr_decay):\n", " trainer = torch.optim.SGD(net.parameters(), lr=lr, momentum=0.9,\n", " weight_decay=wd)\n", " scheduler = torch.optim.lr_scheduler.StepLR(trainer, lr_period, lr_decay)\n", " num_batches, timer = len(train_iter), d2l.Timer()\n", " legend = ['train loss', 'train acc']\n", " if valid_iter is not None:\n", " legend.append('valid acc')\n", " animator = d2l.Animator(xlabel='epoch', xlim=[1, num_epochs],\n", " legend=legend)\n", " net = nn.DataParallel(net, device_ids=devices).to(devices[0])\n", " for epoch in range(num_epochs):\n", " net.train()\n", " metric = d2l.Accumulator(3)\n", " for i, (features, labels) in enumerate(train_iter):\n", " timer.start()\n", " l, acc = d2l.train_batch_ch13(net, features, labels,\n", " loss, trainer, devices)\n", " metric.add(l, acc, labels.shape[0])\n", " timer.stop()\n", " if (i + 1) % (num_batches // 5) == 0 or i == num_batches - 1:\n", " animator.add(epoch + (i + 1) / num_batches,\n", " (metric[0] / metric[2], metric[1] / metric[2],\n", " None))\n", " if valid_iter is not None:\n", " valid_acc = d2l.evaluate_accuracy_gpu(net, valid_iter)\n", " animator.add(epoch + 1, (None, None, valid_acc))\n", " scheduler.step()\n", " measures = (f'train loss {metric[0] / metric[2]:.3f}, '\n", " f'train acc {metric[1] / metric[2]:.3f}')\n", " if valid_iter is not None:\n", " measures += f', valid acc {valid_acc:.3f}'\n", " print(measures + f'\\n{metric[2] * num_epochs / timer.sum():.1f}'\n", " f' examples/sec on {str(devices)}')" ] }, { "cell_type": "markdown", "id": "f285eced", "metadata": { "origin_pos": 39 }, "source": [ "## [**Training and Validating the Model**]\n", "\n", "Now, we can train and validate the model.\n", "All the following hyperparameters can be tuned.\n", "For example, we can increase the number of epochs.\n", "When `lr_period` and `lr_decay` are set to 4 and 0.9, respectively, the learning rate of the optimization algorithm will be multiplied by 0.9 after every 4 epochs. Just for ease of demonstration,\n", "we only train 20 epochs here.\n" ] }, { "cell_type": "code", "execution_count": 14, "id": "cd4a55c7", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:29:29.768734Z", "iopub.status.busy": "2023-08-18T19:29:29.768227Z", "iopub.status.idle": "2023-08-18T19:30:37.496878Z", "shell.execute_reply": "2023-08-18T19:30:37.495860Z" }, "origin_pos": 41, "tab": [ "pytorch" ] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "train loss 0.654, train acc 0.789, valid acc 0.438\n", "958.1 examples/sec on [device(type='cuda', index=0), device(type='cuda', index=1)]\n" ] }, { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-08-18T19:30:37.438438\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.7.2, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "devices, num_epochs, lr, wd = d2l.try_all_gpus(), 20, 2e-4, 5e-4\n", "lr_period, lr_decay, net = 4, 0.9, get_net()\n", "net(next(iter(train_iter))[0])\n", "train(net, train_iter, valid_iter, num_epochs, lr, wd, devices, lr_period,\n", " lr_decay)" ] }, { "cell_type": "markdown", "id": "bd51eac4", "metadata": { "origin_pos": 42 }, "source": [ "## [**Classifying the Testing Set**] and Submitting Results on Kaggle\n", "\n", "After obtaining a promising model with hyperparameters,\n", "we use all the labeled data (including the validation set) to retrain the model and classify the testing set.\n" ] }, { "cell_type": "code", "execution_count": 15, "id": "a66ef205", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:30:37.501313Z", "iopub.status.busy": "2023-08-18T19:30:37.500748Z", "iopub.status.idle": "2023-08-18T19:31:40.934103Z", "shell.execute_reply": "2023-08-18T19:31:40.932837Z" }, "origin_pos": 44, "tab": [ "pytorch" ] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "train loss 0.608, train acc 0.786\n", "1040.8 examples/sec on [device(type='cuda', index=0), device(type='cuda', index=1)]\n" ] }, { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-08-18T19:31:40.877905\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.7.2, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "net, preds = get_net(), []\n", "net(next(iter(train_valid_iter))[0])\n", "train(net, train_valid_iter, None, num_epochs, lr, wd, devices, lr_period,\n", " lr_decay)\n", "\n", "for X, _ in test_iter:\n", " y_hat = net(X.to(devices[0]))\n", " preds.extend(y_hat.argmax(dim=1).type(torch.int32).cpu().numpy())\n", "sorted_ids = list(range(1, len(test_ds) + 1))\n", "sorted_ids.sort(key=lambda x: str(x))\n", "df = pd.DataFrame({'id': sorted_ids, 'label': preds})\n", "df['label'] = df['label'].apply(lambda x: train_valid_ds.classes[x])\n", "df.to_csv('submission.csv', index=False)" ] }, { "cell_type": "markdown", "id": "0a33f6c7", "metadata": { "origin_pos": 45 }, "source": [ "The above code\n", "will generate a `submission.csv` file,\n", "whose format\n", "meets the requirement of the Kaggle competition.\n", "The method\n", "for submitting results to Kaggle\n", "is similar to that in :numref:`sec_kaggle_house`.\n", "\n", "## Summary\n", "\n", "* We can read datasets containing raw image files after organizing them into the required format.\n" ] }, { "cell_type": "markdown", "id": "fe9ddf70", "metadata": { "origin_pos": 47, "tab": [ "pytorch" ] }, "source": [ "* We can use convolutional neural networks and image augmentation in an image classification competition.\n" ] }, { "cell_type": "markdown", "id": "ed9f69e2", "metadata": { "origin_pos": 48 }, "source": [ "## Exercises\n", "\n", "1. Use the complete CIFAR-10 dataset for this Kaggle competition. Set hyperparameters as `batch_size = 128`, `num_epochs = 100`, `lr = 0.1`, `lr_period = 50`, and `lr_decay = 0.1`. See what accuracy and ranking you can achieve in this competition. Can you further improve them?\n", "1. What accuracy can you get when not using image augmentation?\n" ] }, { "cell_type": "markdown", "id": "74fe0ee3", "metadata": { "origin_pos": 50, "tab": [ "pytorch" ] }, "source": [ "[Discussions](https://discuss.d2l.ai/t/1479)\n" ] } ], "metadata": { "language_info": { "name": "python" }, "required_libs": [] }, "nbformat": 4, "nbformat_minor": 5 }