{ "cells": [ { "cell_type": "markdown", "id": "3c358c3f", "metadata": { "origin_pos": 0 }, "source": [ "# Fully Convolutional Networks\n", ":label:`sec_fcn`\n", "\n", "As discussed in :numref:`sec_semantic_segmentation`,\n", "semantic segmentation\n", "classifies images in pixel level.\n", "A fully convolutional network (FCN)\n", "uses a convolutional neural network to\n", "transform image pixels to pixel classes :cite:`Long.Shelhamer.Darrell.2015`.\n", "Unlike the CNNs that we encountered earlier\n", "for image classification \n", "or object detection,\n", "a fully convolutional network\n", "transforms \n", "the height and width of intermediate feature maps\n", "back to those of the input image:\n", "this is achieved by\n", "the transposed convolutional layer\n", "introduced in :numref:`sec_transposed_conv`.\n", "As a result,\n", "the classification output\n", "and the input image \n", "have a one-to-one correspondence \n", "in pixel level:\n", "the channel dimension at any output pixel \n", "holds the classification results\n", "for the input pixel at the same spatial position.\n" ] }, { "cell_type": "code", "execution_count": 1, "id": "805a6df5", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:34:22.296389Z", "iopub.status.busy": "2023-08-18T19:34:22.296055Z", "iopub.status.idle": "2023-08-18T19:34:25.836644Z", "shell.execute_reply": "2023-08-18T19:34:25.835421Z" }, "origin_pos": 2, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "%matplotlib inline\n", "import torch\n", "import torchvision\n", "from torch import nn\n", "from torch.nn import functional as F\n", "from d2l import torch as d2l" ] }, { "cell_type": "markdown", "id": "05a33edc", "metadata": { "origin_pos": 3 }, "source": [ "## The Model\n", "\n", "Here we describe the basic design of the fully convolutional network model. \n", "As shown in :numref:`fig_fcn`,\n", "this model first uses a CNN to extract image features,\n", "then transforms the number of channels into\n", "the number of classes\n", "via a $1\\times 1$ convolutional layer,\n", "and finally transforms the height and width of\n", "the feature maps\n", "to those\n", "of the input image via\n", "the transposed convolution introduced in :numref:`sec_transposed_conv`. \n", "As a result,\n", "the model output has the same height and width as the input image,\n", "where the output channel contains the predicted classes\n", "for the input pixel at the same spatial position.\n", "\n", "\n", "![Fully convolutional network.](../img/fcn.svg)\n", ":label:`fig_fcn`\n", "\n", "Below, we [**use a ResNet-18 model pretrained on the ImageNet dataset to extract image features**]\n", "and denote the model instance as `pretrained_net`.\n", "The last few layers of this model\n", "include a global average pooling layer\n", "and a fully connected layer:\n", "they are not needed\n", "in the fully convolutional network.\n" ] }, { "cell_type": "code", "execution_count": 2, "id": "a2a9a1c6", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:34:25.842583Z", "iopub.status.busy": "2023-08-18T19:34:25.841424Z", "iopub.status.idle": "2023-08-18T19:34:27.191787Z", "shell.execute_reply": "2023-08-18T19:34:27.190380Z" }, "origin_pos": 5, "tab": [ "pytorch" ] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Downloading: \"https://download.pytorch.org/models/resnet18-f37072fd.pth\" to /home/ci/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", " 0%| | 0.00/44.7M [00:00\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-08-18T19:34:28.085245\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.7.2, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "d2l.set_figsize()\n", "print('input image shape:', img.permute(1, 2, 0).shape)\n", "d2l.plt.imshow(img.permute(1, 2, 0));\n", "print('output image shape:', out_img.shape)\n", "d2l.plt.imshow(out_img);" ] }, { "cell_type": "markdown", "id": "e85f4c76", "metadata": { "origin_pos": 27 }, "source": [ "In a fully convolutional network, we [**initialize the transposed convolutional layer with upsampling of bilinear interpolation. For the $1\\times 1$ convolutional layer, we use Xavier initialization.**]\n" ] }, { "cell_type": "code", "execution_count": 10, "id": "1ae40200", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:34:28.372559Z", "iopub.status.busy": "2023-08-18T19:34:28.371880Z", "iopub.status.idle": "2023-08-18T19:34:28.381788Z", "shell.execute_reply": "2023-08-18T19:34:28.380642Z" }, "origin_pos": 29, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "W = bilinear_kernel(num_classes, num_classes, 64)\n", "net.transpose_conv.weight.data.copy_(W);" ] }, { "cell_type": "markdown", "id": "fbc76cde", "metadata": { "origin_pos": 30 }, "source": [ "## [**Reading the Dataset**]\n", "\n", "We read\n", "the semantic segmentation dataset\n", "as introduced in :numref:`sec_semantic_segmentation`. \n", "The output image shape of random cropping is\n", "specified as $320\\times 480$: both the height and width are divisible by $32$.\n" ] }, { "cell_type": "code", "execution_count": 11, "id": "0bdc2a20", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:34:28.386035Z", "iopub.status.busy": "2023-08-18T19:34:28.385440Z", "iopub.status.idle": "2023-08-18T19:35:21.373422Z", "shell.execute_reply": "2023-08-18T19:35:21.369676Z" }, "origin_pos": 31, "tab": [ "pytorch" ] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "read 1114 examples\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "read 1078 examples\n" ] } ], "source": [ "batch_size, crop_size = 32, (320, 480)\n", "train_iter, test_iter = d2l.load_data_voc(batch_size, crop_size)" ] }, { "cell_type": "markdown", "id": "6654107c", "metadata": { "origin_pos": 32 }, "source": [ "## [**Training**]\n", "\n", "\n", "Now we can train our constructed\n", "fully convolutional network. \n", "The loss function and accuracy calculation here\n", "are not essentially different from those in image classification of earlier chapters. \n", "Because we use the output channel of the\n", "transposed convolutional layer to\n", "predict the class for each pixel,\n", "the channel dimension is specified in the loss calculation.\n", "In addition, the accuracy is calculated\n", "based on correctness\n", "of the predicted class for all the pixels.\n" ] }, { "cell_type": "code", "execution_count": 12, "id": "b65f6226", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:35:21.378517Z", "iopub.status.busy": "2023-08-18T19:35:21.377599Z", "iopub.status.idle": "2023-08-18T19:36:21.659017Z", "shell.execute_reply": "2023-08-18T19:36:21.657836Z" }, "origin_pos": 34, "tab": [ "pytorch" ] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "loss 0.449, train acc 0.861, test acc 0.852\n", "226.7 examples/sec on [device(type='cuda', index=0), device(type='cuda', index=1)]\n" ] }, { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-08-18T19:36:21.596439\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.7.2, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "def loss(inputs, targets):\n", " return F.cross_entropy(inputs, targets, reduction='none').mean(1).mean(1)\n", "\n", "num_epochs, lr, wd, devices = 5, 0.001, 1e-3, d2l.try_all_gpus()\n", "trainer = torch.optim.SGD(net.parameters(), lr=lr, weight_decay=wd)\n", "d2l.train_ch13(net, train_iter, test_iter, loss, trainer, num_epochs, devices)" ] }, { "cell_type": "markdown", "id": "9c0c7f12", "metadata": { "origin_pos": 35 }, "source": [ "## [**Prediction**]\n", "\n", "\n", "When predicting, we need to standardize the input image\n", "in each channel and transform the image into the four-dimensional input format required by the CNN.\n" ] }, { "cell_type": "code", "execution_count": 13, "id": "e7f1ceba", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:36:21.663792Z", "iopub.status.busy": "2023-08-18T19:36:21.662705Z", "iopub.status.idle": "2023-08-18T19:36:21.669589Z", "shell.execute_reply": "2023-08-18T19:36:21.668481Z" }, "origin_pos": 37, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "def predict(img):\n", " X = test_iter.dataset.normalize_image(img).unsqueeze(0)\n", " pred = net(X.to(devices[0])).argmax(dim=1)\n", " return pred.reshape(pred.shape[1], pred.shape[2])" ] }, { "cell_type": "markdown", "id": "27b9364c", "metadata": { "origin_pos": 38 }, "source": [ "To [**visualize the predicted class**] of each pixel, we map the predicted class back to its label color in the dataset.\n" ] }, { "cell_type": "code", "execution_count": 14, "id": "88b09f25", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:36:21.673578Z", "iopub.status.busy": "2023-08-18T19:36:21.672677Z", "iopub.status.idle": "2023-08-18T19:36:21.678207Z", "shell.execute_reply": "2023-08-18T19:36:21.677171Z" }, "origin_pos": 40, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "def label2image(pred):\n", " colormap = torch.tensor(d2l.VOC_COLORMAP, device=devices[0])\n", " X = pred.long()\n", " return colormap[X, :]" ] }, { "cell_type": "markdown", "id": "db007d32", "metadata": { "origin_pos": 41 }, "source": [ "Images in the test dataset vary in size and shape.\n", "Since the model uses a transposed convolutional layer with stride of 32,\n", "when the height or width of an input image is indivisible by 32,\n", "the output height or width of the\n", "transposed convolutional layer will deviate from the shape of the input image.\n", "In order to address this issue,\n", "we can crop multiple rectangular areas with height and width that are integer multiples of 32 in the image,\n", "and perform forward propagation\n", "on the pixels in these areas separately.\n", "Note that\n", "the union of these rectangular areas needs to completely cover the input image.\n", "When a pixel is covered by multiple rectangular areas,\n", "the average of the transposed convolution outputs\n", "in separate areas for this same pixel\n", "can be input to\n", "the softmax operation\n", "to predict the class.\n", "\n", "\n", "For simplicity, we only read a few larger test images,\n", "and crop a $320\\times480$ area for prediction starting from the upper-left corner of an image.\n", "For these test images, we\n", "print their cropped areas,\n", "prediction results,\n", "and ground-truth row by row.\n" ] }, { "cell_type": "code", "execution_count": 15, "id": "8f780e21", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:36:21.682056Z", "iopub.status.busy": "2023-08-18T19:36:21.681351Z", "iopub.status.idle": "2023-08-18T19:36:53.274281Z", "shell.execute_reply": "2023-08-18T19:36:53.273369Z" }, "origin_pos": 43, "tab": [ "pytorch" ] }, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-08-18T19:36:53.099928\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.7.2, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "voc_dir = d2l.download_extract('voc2012', 'VOCdevkit/VOC2012')\n", "test_images, test_labels = d2l.read_voc_images(voc_dir, False)\n", "n, imgs = 4, []\n", "for i in range(n):\n", " crop_rect = (0, 0, 320, 480)\n", " X = torchvision.transforms.functional.crop(test_images[i], *crop_rect)\n", " pred = label2image(predict(X))\n", " imgs += [X.permute(1,2,0), pred.cpu(),\n", " torchvision.transforms.functional.crop(\n", " test_labels[i], *crop_rect).permute(1,2,0)]\n", "d2l.show_images(imgs[::3] + imgs[1::3] + imgs[2::3], 3, n, scale=2);" ] }, { "cell_type": "markdown", "id": "f4076170", "metadata": { "origin_pos": 44 }, "source": [ "## Summary\n", "\n", "* The fully convolutional network first uses a CNN to extract image features, then transforms the number of channels into the number of classes via a $1\\times 1$ convolutional layer, and finally transforms the height and width of the feature maps to those of the input image via the transposed convolution.\n", "* In a fully convolutional network, we can use upsampling of bilinear interpolation to initialize the transposed convolutional layer.\n", "\n", "\n", "## Exercises\n", "\n", "1. If we use Xavier initialization for the transposed convolutional layer in the experiment, how does the result change?\n", "1. Can you further improve the accuracy of the model by tuning the hyperparameters?\n", "1. Predict the classes of all pixels in test images.\n", "1. The original fully convolutional network paper also uses outputs of some intermediate CNN layers :cite:`Long.Shelhamer.Darrell.2015`. Try to implement this idea.\n" ] }, { "cell_type": "markdown", "id": "a4ba1b83", "metadata": { "origin_pos": 46, "tab": [ "pytorch" ] }, "source": [ "[Discussions](https://discuss.d2l.ai/t/1582)\n" ] } ], "metadata": { "language_info": { "name": "python" }, "required_libs": [] }, "nbformat": 4, "nbformat_minor": 5 }