{ "cells": [ { "cell_type": "markdown", "id": "e57838cd", "metadata": { "origin_pos": 1 }, "source": [ "# The Image Classification Dataset\n", ":label:`sec_fashion_mnist`\n", "\n", "(~~The MNIST dataset is one of the widely used dataset for image classification, while it is too simple as a benchmark dataset. We will use the similar, but more complex Fashion-MNIST dataset ~~)\n", "\n", "One widely used dataset for image classification is the [MNIST dataset](https://en.wikipedia.org/wiki/MNIST_database) :cite:`LeCun.Bottou.Bengio.ea.1998` of handwritten digits. At the time of its release in the 1990s it posed a formidable challenge to most machine learning algorithms, consisting of 60,000 images of $28 \\times 28$ pixels resolution (plus a test dataset of 10,000 images). To put things into perspective, back in 1995, a Sun SPARCStation 5 with a whopping 64MB of RAM and a blistering 5 MFLOPs was considered state of the art equipment for machine learning at AT&T Bell Laboratories. Achieving high accuracy on digit recognition was a key component in automating letter sorting for the USPS in the 1990s. Deep networks such as LeNet-5 :cite:`LeCun.Jackel.Bottou.ea.1995`, support vector machines with invariances :cite:`Scholkopf.Burges.Vapnik.1996`, and tangent distance classifiers :cite:`Simard.LeCun.Denker.ea.1998` all could reach error rates below 1%. \n", "\n", "For over a decade, MNIST served as *the* point of reference for comparing machine learning algorithms. \n", "While it had a good run as a benchmark dataset,\n", "even simple models by today's standards achieve classification accuracy over 95%,\n", "making it unsuitable for distinguishing between strong models and weaker ones. Even more, the dataset allows for *very* high levels of accuracy, not typically seen in many classification problems. This skewed algorithmic development towards specific families of algorithms that can take advantage of clean datasets, such as active set methods and boundary-seeking active set algorithms.\n", "Today, MNIST serves as more of a sanity check than as a benchmark. ImageNet :cite:`Deng.Dong.Socher.ea.2009` poses a much \n", "more relevant challenge. Unfortunately, ImageNet is too large for many of the examples and illustrations in this book, as it would take too long to train to make the examples interactive. As a substitute we will focus our discussion in the coming sections on the qualitatively similar, but much smaller Fashion-MNIST\n", "dataset :cite:`Xiao.Rasul.Vollgraf.2017` which was released in 2017. It contains images of 10 categories of clothing at $28 \\times 28$ pixels resolution.\n" ] }, { "cell_type": "code", "execution_count": 1, "id": "270279ba", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:35:53.123984Z", "iopub.status.busy": "2023-08-18T19:35:53.123639Z", "iopub.status.idle": "2023-08-18T19:35:56.952902Z", "shell.execute_reply": "2023-08-18T19:35:56.951810Z" }, "origin_pos": 3, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "%matplotlib inline\n", "import time\n", "import torch\n", "import torchvision\n", "from torchvision import transforms\n", "from d2l import torch as d2l\n", "\n", "d2l.use_svg_display()" ] }, { "cell_type": "markdown", "id": "1070bad5", "metadata": { "origin_pos": 6 }, "source": [ "## Loading the Dataset\n", "\n", "Since the Fashion-MNIST dataset is so useful, all major frameworks provide preprocessed versions of it. We can [**download and read it into memory using built-in framework utilities.**]\n" ] }, { "cell_type": "code", "execution_count": 2, "id": "b2e83e15", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:35:56.959608Z", "iopub.status.busy": "2023-08-18T19:35:56.958972Z", "iopub.status.idle": "2023-08-18T19:35:56.974091Z", "shell.execute_reply": "2023-08-18T19:35:56.969189Z" }, "origin_pos": 8, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "class FashionMNIST(d2l.DataModule): #@save\n", " \"\"\"The Fashion-MNIST dataset.\"\"\"\n", " def __init__(self, batch_size=64, resize=(28, 28)):\n", " super().__init__()\n", " self.save_hyperparameters()\n", " trans = transforms.Compose([transforms.Resize(resize),\n", " transforms.ToTensor()])\n", " self.train = torchvision.datasets.FashionMNIST(\n", " root=self.root, train=True, transform=trans, download=True)\n", " self.val = torchvision.datasets.FashionMNIST(\n", " root=self.root, train=False, transform=trans, download=True)" ] }, { "cell_type": "markdown", "id": "2ac7e84e", "metadata": { "origin_pos": 10 }, "source": [ "Fashion-MNIST consists of images from 10 categories, each represented\n", "by 6000 images in the training dataset and by 1000 in the test dataset.\n", "A *test dataset* is used for evaluating model performance (it must not be used for training).\n", "Consequently the training set and the test set\n", "contain 60,000 and 10,000 images, respectively.\n" ] }, { "cell_type": "code", "execution_count": 3, "id": "9702ba11", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:35:56.979230Z", "iopub.status.busy": "2023-08-18T19:35:56.978838Z", "iopub.status.idle": "2023-08-18T19:35:57.112651Z", "shell.execute_reply": "2023-08-18T19:35:57.111496Z" }, "origin_pos": 11, "tab": [ "pytorch" ] }, "outputs": [ { "data": { "text/plain": [ "(60000, 10000)" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "data = FashionMNIST(resize=(32, 32))\n", "len(data.train), len(data.val)" ] }, { "cell_type": "markdown", "id": "a66cd0bb", "metadata": { "origin_pos": 13 }, "source": [ "The images are grayscale and upscaled to $32 \\times 32$ pixels in resolution above. This is similar to the original MNIST dataset which consisted of (binary) black and white images. Note, though, that most modern image data has three channels (red, green, blue) and that hyperspectral images can have in excess of 100 channels (the HyMap sensor has 126 channels).\n", "By convention we store an image as a $c \\times h \\times w$ tensor, where $c$ is the number of color channels, $h$ is the height and $w$ is the width.\n" ] }, { "cell_type": "code", "execution_count": 4, "id": "b31548fa", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:35:57.116730Z", "iopub.status.busy": "2023-08-18T19:35:57.116328Z", "iopub.status.idle": "2023-08-18T19:35:57.128533Z", "shell.execute_reply": "2023-08-18T19:35:57.127453Z" }, "origin_pos": 14, "tab": [ "pytorch" ] }, "outputs": [ { "data": { "text/plain": [ "torch.Size([1, 32, 32])" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "data.train[0][0].shape" ] }, { "cell_type": "markdown", "id": "b2e625e5", "metadata": { "origin_pos": 15 }, "source": [ "[~~Two utility functions to visualize the dataset~~]\n", "\n", "The categories of Fashion-MNIST have human-understandable names. \n", "The following convenience method converts between numeric labels and their names.\n" ] }, { "cell_type": "code", "execution_count": 5, "id": "ca95ebc5", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:35:57.137128Z", "iopub.status.busy": "2023-08-18T19:35:57.136465Z", "iopub.status.idle": "2023-08-18T19:35:57.142322Z", "shell.execute_reply": "2023-08-18T19:35:57.141204Z" }, "origin_pos": 16, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "@d2l.add_to_class(FashionMNIST) #@save\n", "def text_labels(self, indices):\n", " \"\"\"Return text labels.\"\"\"\n", " labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat',\n", " 'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']\n", " return [labels[int(i)] for i in indices]" ] }, { "cell_type": "markdown", "id": "4a87f298", "metadata": { "origin_pos": 17 }, "source": [ "## Reading a Minibatch\n", "\n", "To make our life easier when reading from the training and test sets,\n", "we use the built-in data iterator rather than creating one from scratch.\n", "Recall that at each iteration, a data iterator\n", "[**reads a minibatch of data with size `batch_size`.**]\n", "We also randomly shuffle the examples for the training data iterator.\n" ] }, { "cell_type": "code", "execution_count": 6, "id": "8982acc7", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:35:57.146600Z", "iopub.status.busy": "2023-08-18T19:35:57.145989Z", "iopub.status.idle": "2023-08-18T19:35:57.153720Z", "shell.execute_reply": "2023-08-18T19:35:57.150894Z" }, "origin_pos": 19, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "@d2l.add_to_class(FashionMNIST) #@save\n", "def get_dataloader(self, train):\n", " data = self.train if train else self.val\n", " return torch.utils.data.DataLoader(data, self.batch_size, shuffle=train,\n", " num_workers=self.num_workers)" ] }, { "cell_type": "markdown", "id": "f6058b32", "metadata": { "origin_pos": 21 }, "source": [ "To see how this works, let's load a minibatch of images by invoking the `train_dataloader` method. It contains 64 images.\n" ] }, { "cell_type": "code", "execution_count": 7, "id": "81f8afca", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:35:57.160131Z", "iopub.status.busy": "2023-08-18T19:35:57.159304Z", "iopub.status.idle": "2023-08-18T19:35:57.397652Z", "shell.execute_reply": "2023-08-18T19:35:57.396242Z" }, "origin_pos": 22, "tab": [ "pytorch" ] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "torch.Size([64, 1, 32, 32]) torch.float32 torch.Size([64]) torch.int64\n" ] } ], "source": [ "X, y = next(iter(data.train_dataloader()))\n", "print(X.shape, X.dtype, y.shape, y.dtype)" ] }, { "cell_type": "markdown", "id": "51090d6f", "metadata": { "origin_pos": 23 }, "source": [ "Let's look at the time it takes to read the images. Even though it is a built-in loader, it is not blazingly fast. Nonetheless, this is sufficient since processing images with a deep network takes quite a bit longer. Hence it is good enough that training a network will not be I/O constrained.\n" ] }, { "cell_type": "code", "execution_count": 8, "id": "47e90ba5", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:35:57.402928Z", "iopub.status.busy": "2023-08-18T19:35:57.402306Z", "iopub.status.idle": "2023-08-18T19:36:02.097749Z", "shell.execute_reply": "2023-08-18T19:36:02.096784Z" }, "origin_pos": 24, "tab": [ "pytorch" ] }, "outputs": [ { "data": { "text/plain": [ "'4.69 sec'" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "tic = time.time()\n", "for X, y in data.train_dataloader():\n", " continue\n", "f'{time.time() - tic:.2f} sec'" ] }, { "cell_type": "markdown", "id": "2887996a", "metadata": { "origin_pos": 25 }, "source": [ "## Visualization\n", "\n", "We will often be using the Fashion-MNIST dataset. A convenience function `show_images` can be used to visualize the images and the associated labels. \n", "Skipping implementation details, we just show the interface below: we only need to know how to invoke `d2l.show_images` rather than how it works\n", "for such utility functions.\n" ] }, { "cell_type": "code", "execution_count": 9, "id": "06fb4e72", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:36:02.101877Z", "iopub.status.busy": "2023-08-18T19:36:02.101254Z", "iopub.status.idle": "2023-08-18T19:36:02.105863Z", "shell.execute_reply": "2023-08-18T19:36:02.105019Z" }, "origin_pos": 26, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "def show_images(imgs, num_rows, num_cols, titles=None, scale=1.5): #@save\n", " \"\"\"Plot a list of images.\"\"\"\n", " raise NotImplementedError" ] }, { "cell_type": "markdown", "id": "2e35cb97", "metadata": { "origin_pos": 27 }, "source": [ "Let's put it to good use. In general, it is a good idea to visualize and inspect data that you are training on. \n", "Humans are very good at spotting oddities and because of that, visualization serves as an additional safeguard against mistakes and errors in the design of experiments. Here are [**the images and their corresponding labels**] (in text)\n", "for the first few examples in the training dataset.\n" ] }, { "cell_type": "code", "execution_count": 10, "id": "b16ddca3", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:36:02.109988Z", "iopub.status.busy": "2023-08-18T19:36:02.109439Z", "iopub.status.idle": "2023-08-18T19:36:02.819862Z", "shell.execute_reply": "2023-08-18T19:36:02.811142Z" }, "origin_pos": 28, "tab": [ "pytorch" ] }, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-08-18T19:36:02.720056\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.7.2, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "@d2l.add_to_class(FashionMNIST) #@save\n", "def visualize(self, batch, nrows=1, ncols=8, labels=[]):\n", " X, y = batch\n", " if not labels:\n", " labels = self.text_labels(y)\n", " d2l.show_images(X.squeeze(1), nrows, ncols, titles=labels)\n", "batch = next(iter(data.val_dataloader()))\n", "data.visualize(batch)" ] }, { "cell_type": "markdown", "id": "437083bc", "metadata": { "origin_pos": 29 }, "source": [ "We are now ready to work with the Fashion-MNIST dataset in the sections that follow.\n", "\n", "## Summary\n", "\n", "We now have a slightly more realistic dataset to use for classification. Fashion-MNIST is an apparel classification dataset consisting of images representing 10 categories. We will use this dataset in subsequent sections and chapters to evaluate various network designs, from a simple linear model to advanced residual networks. As we commonly do with images, we read them as a tensor of shape (batch size, number of channels, height, width). For now, we only have one channel as the images are grayscale (the visualization above uses a false color palette for improved visibility). \n", "\n", "Lastly, data iterators are a key component for efficient performance. For instance, we might use GPUs for efficient image decompression, video transcoding, or other preprocessing. Whenever possible, you should rely on well-implemented data iterators that exploit high-performance computing to avoid slowing down your training loop.\n", "\n", "\n", "## Exercises\n", "\n", "1. Does reducing the `batch_size` (for instance, to 1) affect the reading performance?\n", "1. The data iterator performance is important. Do you think the current implementation is fast enough? Explore various options to improve it. Use a system profiler to find out where the bottlenecks are.\n", "1. Check out the framework's online API documentation. Which other datasets are available?\n" ] }, { "cell_type": "markdown", "id": "7bff1562", "metadata": { "origin_pos": 31, "tab": [ "pytorch" ] }, "source": [ "[Discussions](https://discuss.d2l.ai/t/49)\n" ] } ], "metadata": { "language_info": { "name": "python" }, "required_libs": [] }, "nbformat": 4, "nbformat_minor": 5 }