{ "cells": [ { "cell_type": "markdown", "id": "cb2b0dcb", "metadata": { "origin_pos": 0 }, "source": [ "# Neural Style Transfer\n", "\n", "If you are a photography enthusiast,\n", "you may be familiar with the filter.\n", "It can change the color style of photos\n", "so that landscape photos become sharper\n", "or portrait photos have whitened skins.\n", "However,\n", "one filter usually only changes\n", "one aspect of the photo.\n", "To apply an ideal style\n", "to a photo,\n", "you probably need to\n", "try many different filter combinations.\n", "This process is\n", "as complex as tuning the hyperparameters of a model.\n", "\n", "\n", "\n", "In this section, we will\n", "leverage layerwise representations of a CNN\n", "to automatically apply the style of one image\n", "to another image, i.e., *style transfer* :cite:`Gatys.Ecker.Bethge.2016`.\n", "This task needs two input images:\n", "one is the *content image* and\n", "the other is the *style image*.\n", "We will use neural networks\n", "to modify the content image\n", "to make it close to the style image in style.\n", "For example,\n", "the content image in :numref:`fig_style_transfer` is a landscape photo taken by us\n", "in Mount Rainier National Park in the suburbs of Seattle, while the style image is an oil painting\n", "with the theme of autumn oak trees.\n", "In the output synthesized image,\n", "the oil brush strokes of the style image\n", "are applied, leading to more vivid colors,\n", "while preserving the main shape of the objects\n", "in the content image.\n", "\n", "![Given content and style images, style transfer outputs a synthesized image.](../img/style-transfer.svg)\n", ":label:`fig_style_transfer`\n", "\n", "## Method\n", "\n", ":numref:`fig_style_transfer_model` illustrates\n", "the CNN-based style transfer method with a simplified example.\n", "First, we initialize the synthesized image,\n", "for example, into the content image.\n", "This synthesized image is the only variable that needs to be updated during the style transfer process,\n", "i.e., the model parameters to be updated during training.\n", "Then we choose a pretrained CNN\n", "to extract image features and freeze its\n", "model parameters during training.\n", "This deep CNN uses multiple layers\n", "to extract\n", "hierarchical features for images.\n", "We can choose the output of some of these layers as content features or style features.\n", "Take :numref:`fig_style_transfer_model` as an example.\n", "The pretrained neural network here has 3 convolutional layers,\n", "where the second layer outputs the content features,\n", "and the first and third layers output the style features.\n", "\n", "![CNN-based style transfer process. Solid lines show the direction of forward propagation and dotted lines show backward propagation. ](../img/neural-style.svg)\n", ":label:`fig_style_transfer_model`\n", "\n", "Next, we calculate the loss function of style transfer through forward propagation (direction of solid arrows), and update the model parameters (the synthesized image for output) through backpropagation (direction of dashed arrows).\n", "The loss function commonly used in style transfer consists of three parts:\n", "(i) *content loss* makes the synthesized image and the content image close in content features;\n", "(ii) *style loss* makes the synthesized image and style image close in style features;\n", "and (iii) *total variation loss* helps to reduce the noise in the synthesized image.\n", "Finally, when the model training is over, we output the model parameters of the style transfer to generate\n", "the final synthesized image.\n", "\n", "\n", "\n", "In the following,\n", "we will explain the technical details of style transfer via a concrete experiment.\n", "\n", "\n", "## [**Reading the Content and Style Images**]\n", "\n", "First, we read the content and style images.\n", "From their printed coordinate axes,\n", "we can tell that these images have different sizes.\n" ] }, { "cell_type": "code", "execution_count": 1, "id": "e0868e28", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:52:55.964795Z", "iopub.status.busy": "2023-08-18T19:52:55.964503Z", "iopub.status.idle": "2023-08-18T19:52:59.646946Z", "shell.execute_reply": "2023-08-18T19:52:59.645534Z" }, "origin_pos": 2, "tab": [ "pytorch" ] }, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-08-18T19:52:59.431976\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.7.2, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "%matplotlib inline\n", "import torch\n", "import torchvision\n", "from torch import nn\n", "from d2l import torch as d2l\n", "\n", "d2l.set_figsize()\n", "content_img = d2l.Image.open('../img/rainier.jpg')\n", "d2l.plt.imshow(content_img);" ] }, { "cell_type": "code", "execution_count": 2, "id": "283f5e51", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:52:59.651637Z", "iopub.status.busy": "2023-08-18T19:52:59.650615Z", "iopub.status.idle": "2023-08-18T19:53:00.264518Z", "shell.execute_reply": "2023-08-18T19:53:00.263173Z" }, "origin_pos": 4, "tab": [ "pytorch" ] }, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-08-18T19:53:00.102067\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.7.2, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "style_img = d2l.Image.open('../img/autumn-oak.jpg')\n", "d2l.plt.imshow(style_img);" ] }, { "cell_type": "markdown", "id": "13710ade", "metadata": { "origin_pos": 5 }, "source": [ "## [**Preprocessing and Postprocessing**]\n", "\n", "Below, we define two functions for preprocessing and postprocessing images.\n", "The `preprocess` function standardizes\n", "each of the three RGB channels of the input image and transforms the results into the CNN input format.\n", "The `postprocess` function restores the pixel values in the output image to their original values before standardization.\n", "Since the image printing function requires that each pixel has a floating point value from 0 to 1,\n", "we replace any value smaller than 0 or greater than 1 with 0 or 1, respectively.\n" ] }, { "cell_type": "code", "execution_count": 3, "id": "9f1ef9cd", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:53:00.269093Z", "iopub.status.busy": "2023-08-18T19:53:00.268290Z", "iopub.status.idle": "2023-08-18T19:53:00.275592Z", "shell.execute_reply": "2023-08-18T19:53:00.274696Z" }, "origin_pos": 7, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "rgb_mean = torch.tensor([0.485, 0.456, 0.406])\n", "rgb_std = torch.tensor([0.229, 0.224, 0.225])\n", "\n", "def preprocess(img, image_shape):\n", " transforms = torchvision.transforms.Compose([\n", " torchvision.transforms.Resize(image_shape),\n", " torchvision.transforms.ToTensor(),\n", " torchvision.transforms.Normalize(mean=rgb_mean, std=rgb_std)])\n", " return transforms(img).unsqueeze(0)\n", "\n", "def postprocess(img):\n", " img = img[0].to(rgb_std.device)\n", " img = torch.clamp(img.permute(1, 2, 0) * rgb_std + rgb_mean, 0, 1)\n", " return torchvision.transforms.ToPILImage()(img.permute(2, 0, 1))" ] }, { "cell_type": "markdown", "id": "9f897786", "metadata": { "origin_pos": 8 }, "source": [ "## [**Extracting Features**]\n", "\n", "We use the VGG-19 model pretrained on the ImageNet dataset to extract image features :cite:`Gatys.Ecker.Bethge.2016`.\n" ] }, { "cell_type": "code", "execution_count": 4, "id": "1e1a4a43", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:53:00.278914Z", "iopub.status.busy": "2023-08-18T19:53:00.278636Z", "iopub.status.idle": "2023-08-18T19:53:04.940646Z", "shell.execute_reply": "2023-08-18T19:53:04.939402Z" }, "origin_pos": 10, "tab": [ "pytorch" ] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Downloading: \"https://download.pytorch.org/models/vgg19-dcbb9e9d.pth\" to /home/ci/.cache/torch/hub/checkpoints/vgg19-dcbb9e9d.pth\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", " 0%| | 0.00/548M [00:00\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-08-18T19:54:00.827796\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.7.2, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "device, image_shape = d2l.try_gpu(), (300, 450) # PIL Image (h, w)\n", "net = net.to(device)\n", "content_X, contents_Y = get_contents(image_shape, device)\n", "_, styles_Y = get_styles(image_shape, device)\n", "output = train(content_X, contents_Y, styles_Y, device, 0.3, 500, 50)" ] }, { "cell_type": "markdown", "id": "b940c57c", "metadata": { "origin_pos": 45 }, "source": [ "We can see that the synthesized image\n", "retains the scenery and objects of the content image,\n", "and transfers the color of the style image\n", "at the same time.\n", "For example,\n", "the synthesized image has blocks of color like\n", "those in the style image.\n", "Some of these blocks even have the subtle texture of brush strokes.\n", "\n", "\n", "\n", "\n", "## Summary\n", "\n", "* The loss function commonly used in style transfer consists of three parts: (i) content loss makes the synthesized image and the content image close in content features; (ii) style loss makes the synthesized image and style image close in style features; and (iii) total variation loss helps to reduce the noise in the synthesized image.\n", "* We can use a pretrained CNN to extract image features and minimize the loss function to continuously update the synthesized image as model parameters during training.\n", "* We use Gram matrices to represent the style outputs from the style layers.\n", "\n", "\n", "## Exercises\n", "\n", "1. How does the output change when you select different content and style layers?\n", "1. Adjust the weight hyperparameters in the loss function. Does the output retain more content or have less noise?\n", "1. Use different content and style images. Can you create more interesting synthesized images?\n", "1. Can we apply style transfer for text? Hint: you may refer to the survey paper by :citet:`10.1145/3544903.3544906`.\n" ] }, { "cell_type": "markdown", "id": "3ea49c52", "metadata": { "origin_pos": 47, "tab": [ "pytorch" ] }, "source": [ "[Discussions](https://discuss.d2l.ai/t/1476)\n" ] } ], "metadata": { "language_info": { "name": "python" }, "required_libs": [] }, "nbformat": 4, "nbformat_minor": 5 }