{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "ce0cb736-c157-4eda-999a-7fd3883152c5", "metadata": {}, "outputs": [], "source": [ "import torch\n", "import torchvision\n", "import torchvision.transforms as transforms\n" ] }, { "cell_type": "code", "execution_count": 2, "id": "608349fb-20a5-42bb-a296-91080f0c9528", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100.0%\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Extracting ./data/cifar-10-python.tar.gz to ./data\n", "Files already downloaded and verified\n" ] } ], "source": [ "transform = transforms.Compose(\n", " [transforms.ToTensor(),\n", " transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])\n", "\n", "batch_size = 4\n", "\n", "trainset = torchvision.datasets.CIFAR10(root='./data', train=True,\n", " download=True, transform=transform)\n", "trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,\n", " shuffle=True, num_workers=2)\n", "\n", "testset = torchvision.datasets.CIFAR10(root='./data', train=False,\n", " download=True, transform=transform)\n", "testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,\n", " shuffle=False, num_workers=2)\n", "\n", "classes = ('plane', 'car', 'bird', 'cat',\n", " 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')" ] }, { "cell_type": "code", "execution_count": 3, "id": "10f1736b-d217-4340-a505-153f6ba7287b", "metadata": {}, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "plane cat dog dog \n" ] } ], "source": [ "import matplotlib.pyplot as plt\n", "import numpy as np\n", "\n", "# functions to show an image\n", "\n", "\n", "def imshow(img):\n", " img = img / 2 + 0.5 # unnormalize\n", " npimg = img.numpy()\n", " plt.imshow(np.transpose(npimg, (1, 2, 0)))\n", " plt.show()\n", "\n", "\n", "# get some random training images\n", "dataiter = iter(trainloader)\n", "images, labels = next(dataiter)\n", "\n", "# show images\n", "imshow(torchvision.utils.make_grid(images))\n", "# print labels\n", "print(' '.join(f'{classes[labels[j]]:5s}' for j in range(batch_size)))" ] }, { "cell_type": "code", "execution_count": 32, "id": "86066880-556b-4a23-94cd-e8dd8906057c", "metadata": {}, "outputs": [], "source": [ "import torch.nn as nn\n", "import torch.nn.functional as F\n", "\n", "\n", "class Net(nn.Module):\n", " def __init__(self):\n", " super().__init__()\n", " self.conv1 = nn.Conv2d(3, 6, 5)\n", " self.pool = nn.MaxPool2d(2, 2)\n", " self.conv2 = nn.Conv2d(6, 16, 5)\n", " self.fc1 = nn.Linear(16 * 5 * 5, 120)\n", " self.fc2 = nn.Linear(120, 84)\n", " self.fc3 = nn.Linear(84, 10)\n", "\n", " def forward(self, x):\n", " x = self.pool(F.relu(self.conv1(x)))\n", " x = self.pool(F.relu(self.conv2(x)))\n", " x = torch.flatten(x, 1) # flatten all dimensions except batch\n", " x = F.relu(self.fc1(x))\n", " x = F.relu(self.fc2(x))\n", " x = self.fc3(x)\n", " return x\n", "\n", "use_mps = torch.backends.mps.is_available()\n", "device = torch.device('mps' if use_mps else 'cpu')\n", "net = Net().to(device)" ] }, { "cell_type": "code", "execution_count": 33, "id": "36cfba59-e3b6-44e8-ada4-a3677bbf4e91", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Net(\n", " (conv1): Conv2d(3, 6, kernel_size=(5, 5), stride=(1, 1))\n", " (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n", " (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))\n", " (fc1): Linear(in_features=400, out_features=120, bias=True)\n", " (fc2): Linear(in_features=120, out_features=84, bias=True)\n", " (fc3): Linear(in_features=84, out_features=10, bias=True)\n", ")\n" ] } ], "source": [ "print(net)" ] }, { "cell_type": "code", "execution_count": 34, "id": "37b22850-2085-4510-b79d-7e2a0a858dc5", "metadata": {}, "outputs": [], "source": [ "import torch.optim as optim\n", "\n", "criterion = nn.CrossEntropyLoss()\n", "optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)" ] }, { "cell_type": "code", "execution_count": 35, "id": "e35ffb3e-33b7-4355-b625-81015301c3a3", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[1, 2500] loss: 2.155\n", "[1, 5000] loss: 1.766\n", "[1, 7500] loss: 1.610\n", "[1, 10000] loss: 1.514\n", "[1, 12500] loss: 1.459\n", "[2, 2500] loss: 1.395\n", "[2, 5000] loss: 1.354\n", "[2, 7500] loss: 1.341\n", "[2, 10000] loss: 1.294\n", "[2, 12500] loss: 1.271\n", "Finished Training\n" ] } ], "source": [ "for epoch in range(2): # loop over the dataset multiple times\n", "\n", " running_loss = 0.0\n", " net.train()\n", " for i, data in enumerate(trainloader, 0):\n", " # get the inputs; data is a list of [inputs, labels]\n", " inputs, labels = data\n", " inputs, labels = inputs.to(device), labels.to(device)\n", "\n", " # zero the parameter gradients\n", " optimizer.zero_grad()\n", "\n", " # forward + backward + optimize\n", " outputs = net(inputs)\n", " loss = criterion(outputs, labels)\n", " loss.backward()\n", " optimizer.step()\n", "\n", " # print statistics\n", " running_loss += loss.item()\n", " if i % 2500 == 2499: # print every 2000 mini-batches\n", " print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 2500:.3f}')\n", " running_loss = 0.0\n", "\n", "print('Finished Training')" ] }, { "cell_type": "code", "execution_count": 36, "id": "cab280c3-264f-4cc5-960e-eb205dc100c5", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/var/folders/yn/zz9pphzd12lgpt153xxnkk100000gn/T/ipykernel_3504/136913534.py:7: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n", " net.load_state_dict(torch.load(PATH))\n" ] }, { "data": { "text/plain": [ "" ] }, "execution_count": 36, "metadata": {}, "output_type": "execute_result" } ], "source": [ "#save model\n", "PATH = './cifar_net.pth'\n", "torch.save(net.state_dict(), PATH)\n", "\n", "#load model\n", "net = Net()\n", "net.load_state_dict(torch.load(PATH))" ] }, { "cell_type": "code", "execution_count": 38, "id": "8c98ea48-95b7-419a-9230-f1637aafe5d9", "metadata": {}, "outputs": [], "source": [ "outputs = net(images.cpu())" ] }, { "cell_type": "code", "execution_count": 39, "id": "76b46a5d-4860-4d55-93d9-10d90cfdfd8a", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Predicted: cat cat deer horse\n" ] } ], "source": [ "_, predicted = torch.max(outputs, 1)\n", "\n", "print('Predicted: ', ' '.join(f'{classes[predicted[j]]:5s}'\n", " for j in range(4)))" ] }, { "cell_type": "markdown", "id": "5b3347a1-846a-42be-9f7e-c6387b4563db", "metadata": {}, "source": [ "# Testing" ] }, { "cell_type": "code", "execution_count": 43, "id": "c27af430-c7b1-4af8-8133-620c4cb5032a", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Accuracy of the network on the 10000 test images: 55 %\n" ] } ], "source": [ "correct = 0\n", "total = 0\n", "# since we're not training, we don't need to calculate the gradients for our outputs\n", "net.eval()\n", "with torch.no_grad():\n", " for data in testloader:\n", " images, labels = data\n", " images, labels = images.to(device), labels.to(device)\n", " # calculate outputs by running images through the network\n", " outputs = net(images.cpu())\n", " # the class with the highest energy is what we choose as prediction\n", " _, predicted = torch.max(outputs.data, 1)\n", " total += labels.cpu().size(0)\n", " correct += (predicted == labels.cpu()).sum().item()\n", "\n", "print(f'Accuracy of the network on the 10000 test images: {100 * correct // total} %')" ] }, { "cell_type": "code", "execution_count": 16, "id": "570b5e04-f4e5-4f93-9ead-51e574f26988", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Accuracy for class: plane is 75.6 %\n", "Accuracy for class: car is 73.3 %\n", "Accuracy for class: bird is 45.6 %\n", "Accuracy for class: cat is 51.2 %\n", "Accuracy for class: deer is 50.0 %\n", "Accuracy for class: dog is 54.6 %\n", "Accuracy for class: frog is 63.2 %\n", "Accuracy for class: horse is 63.5 %\n", "Accuracy for class: ship is 64.8 %\n", "Accuracy for class: truck is 53.6 %\n" ] } ], "source": [ "# prepare to count predictions for each class\n", "correct_pred = {classname: 0 for classname in classes}\n", "total_pred = {classname: 0 for classname in classes}\n", "\n", "# again no gradients needed\n", "with torch.no_grad():\n", " for data in testloader:\n", " images, labels = data\n", " outputs = net(images)\n", " _, predictions = torch.max(outputs, 1)\n", " # collect the correct predictions for each class\n", " for label, prediction in zip(labels, predictions):\n", " if label == prediction:\n", " correct_pred[classes[label]] += 1\n", " total_pred[classes[label]] += 1\n", "\n", "\n", "# print accuracy for each class\n", "for classname, correct_count in correct_pred.items():\n", " accuracy = 100 * float(correct_count) / total_pred[classname]\n", " print(f'Accuracy for class: {classname:5s} is {accuracy:.1f} %')" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.9" } }, "nbformat": 4, "nbformat_minor": 5 }