{ "cells": [ { "cell_type": "code", "execution_count": 32, "id": "15430ec1-d4de-4098-b399-0b74d1928f62", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "NetBNormConv(\n", " (conv1): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (bn1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (conv2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n", " (dropout): Dropout(p=0.25, inplace=False)\n", " (conv3): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (bn3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (pool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n", " (relu): ReLU()\n", " (fc1): Linear(in_features=1152, out_features=128, bias=True)\n", " (fc2): Linear(in_features=128, out_features=10, bias=True)\n", ")\n", "Train Epoch: 1 [0/60000 (0%)]\tLoss: 2.269554\n", "Train Epoch: 1 [6400/60000 (11%)]\tLoss: 0.577421\n", "Train Epoch: 1 [12800/60000 (21%)]\tLoss: 0.379949\n", "Train Epoch: 1 [19200/60000 (32%)]\tLoss: 0.544332\n", "Train Epoch: 1 [25600/60000 (43%)]\tLoss: 0.486600\n", "Train Epoch: 1 [32000/60000 (53%)]\tLoss: 0.408367\n", "Train Epoch: 1 [38400/60000 (64%)]\tLoss: 0.385972\n", "Train Epoch: 1 [44800/60000 (75%)]\tLoss: 0.602968\n", "Train Epoch: 1 [51200/60000 (85%)]\tLoss: 0.379398\n", "Train Epoch: 1 [57600/60000 (96%)]\tLoss: 0.430390\n", "[[856. 0. 18. 44. 12. 2. 63. 0. 5. 0.]\n", " [ 1. 962. 1. 26. 7. 0. 1. 0. 2. 0.]\n", " [ 18. 0. 713. 15. 196. 0. 57. 0. 1. 0.]\n", " [ 18. 3. 6. 911. 48. 0. 14. 0. 0. 0.]\n", " [ 1. 1. 21. 31. 910. 1. 35. 0. 0. 0.]\n", " [ 0. 0. 0. 0. 0. 943. 0. 51. 0. 6.]\n", " [191. 0. 68. 32. 153. 0. 547. 0. 9. 0.]\n", " [ 0. 0. 0. 0. 0. 5. 0. 986. 0. 9.]\n", " [ 1. 1. 5. 8. 12. 9. 2. 6. 955. 1.]\n", " [ 0. 0. 0. 0. 0. 5. 1. 59. 0. 935.]]\n", "\n", "Test set: Average loss: 0.3528, Accuracy: 8718/10000 (87%)\n", "\n", "Train Epoch: 2 [0/60000 (0%)]\tLoss: 0.248701\n", "Train Epoch: 2 [6400/60000 (11%)]\tLoss: 0.382996\n", "Train Epoch: 2 [12800/60000 (21%)]\tLoss: 0.265890\n", "Train Epoch: 2 [19200/60000 (32%)]\tLoss: 0.390715\n", "Train Epoch: 2 [25600/60000 (43%)]\tLoss: 0.316079\n", "Train Epoch: 2 [32000/60000 (53%)]\tLoss: 0.438563\n", "Train Epoch: 2 [38400/60000 (64%)]\tLoss: 0.320268\n", "Train Epoch: 2 [44800/60000 (75%)]\tLoss: 0.469024\n", "Train Epoch: 2 [51200/60000 (85%)]\tLoss: 0.364491\n", "Train Epoch: 2 [57600/60000 (96%)]\tLoss: 0.357442\n", "[[833. 3. 7. 41. 7. 3. 102. 0. 4. 0.]\n", " [ 3. 974. 0. 18. 3. 0. 0. 0. 2. 0.]\n", " [ 18. 1. 629. 14. 210. 0. 125. 0. 3. 0.]\n", " [ 14. 5. 3. 915. 39. 0. 23. 0. 1. 0.]\n", " [ 1. 1. 12. 35. 870. 0. 80. 0. 1. 0.]\n", " [ 0. 0. 0. 0. 0. 901. 0. 83. 0. 16.]\n", " [160. 2. 41. 37. 83. 0. 667. 0. 10. 0.]\n", " [ 0. 0. 0. 0. 0. 1. 0. 986. 0. 13.]\n", " [ 1. 1. 0. 4. 6. 2. 6. 7. 973. 0.]\n", " [ 0. 0. 1. 0. 0. 3. 0. 56. 0. 940.]]\n", "\n", "Test set: Average loss: 0.3520, Accuracy: 8688/10000 (87%)\n", "\n", "Train Epoch: 3 [0/60000 (0%)]\tLoss: 0.352757\n", "Train Epoch: 3 [6400/60000 (11%)]\tLoss: 0.290876\n", "Train Epoch: 3 [12800/60000 (21%)]\tLoss: 0.218306\n", "Train Epoch: 3 [19200/60000 (32%)]\tLoss: 0.233687\n", "Train Epoch: 3 [25600/60000 (43%)]\tLoss: 0.257908\n", "Train Epoch: 3 [32000/60000 (53%)]\tLoss: 0.386713\n", "Train Epoch: 3 [38400/60000 (64%)]\tLoss: 0.186729\n", "Train Epoch: 3 [44800/60000 (75%)]\tLoss: 0.380331\n", "Train Epoch: 3 [51200/60000 (85%)]\tLoss: 0.325970\n", "Train Epoch: 3 [57600/60000 (96%)]\tLoss: 0.189678\n", "[[852. 1. 3. 31. 4. 4. 100. 1. 4. 0.]\n", " [ 2. 971. 0. 22. 3. 0. 0. 0. 2. 0.]\n", " [ 26. 1. 647. 20. 183. 0. 121. 0. 2. 0.]\n", " [ 11. 4. 0. 941. 22. 0. 21. 0. 0. 1.]\n", " [ 2. 1. 7. 41. 871. 0. 78. 0. 0. 0.]\n", " [ 0. 0. 0. 0. 0. 952. 0. 36. 0. 12.]\n", " [124. 2. 26. 41. 80. 0. 718. 0. 9. 0.]\n", " [ 0. 0. 0. 0. 0. 2. 0. 970. 0. 28.]\n", " [ 1. 1. 0. 5. 3. 3. 2. 3. 981. 1.]\n", " [ 0. 0. 0. 0. 0. 3. 0. 33. 0. 964.]]\n", "\n", "Test set: Average loss: 0.3135, Accuracy: 8867/10000 (89%)\n", "\n", "Train Epoch: 4 [0/60000 (0%)]\tLoss: 0.227521\n", "Train Epoch: 4 [6400/60000 (11%)]\tLoss: 0.265327\n", "Train Epoch: 4 [12800/60000 (21%)]\tLoss: 0.168143\n", "Train Epoch: 4 [19200/60000 (32%)]\tLoss: 0.323850\n", "Train Epoch: 4 [25600/60000 (43%)]\tLoss: 0.337085\n", "Train Epoch: 4 [32000/60000 (53%)]\tLoss: 0.312168\n", "Train Epoch: 4 [38400/60000 (64%)]\tLoss: 0.217497\n", "Train Epoch: 4 [44800/60000 (75%)]\tLoss: 0.357237\n", "Train Epoch: 4 [51200/60000 (85%)]\tLoss: 0.300967\n", "Train Epoch: 4 [57600/60000 (96%)]\tLoss: 0.208505\n", "[[899. 0. 6. 20. 4. 2. 65. 0. 4. 0.]\n", " [ 1. 975. 0. 20. 2. 0. 0. 0. 2. 0.]\n", " [ 27. 1. 787. 15. 96. 1. 72. 0. 1. 0.]\n", " [ 16. 2. 3. 942. 17. 0. 19. 0. 0. 1.]\n", " [ 2. 1. 22. 50. 855. 1. 69. 0. 0. 0.]\n", " [ 0. 0. 0. 0. 0. 964. 0. 23. 0. 13.]\n", " [158. 3. 50. 34. 72. 0. 673. 0. 10. 0.]\n", " [ 0. 0. 0. 0. 0. 2. 0. 982. 0. 16.]\n", " [ 1. 1. 1. 6. 2. 2. 2. 3. 982. 0.]\n", " [ 0. 0. 0. 0. 0. 3. 0. 36. 0. 961.]]\n", "\n", "Test set: Average loss: 0.2750, Accuracy: 9020/10000 (90%)\n", "\n", "Train Epoch: 5 [0/60000 (0%)]\tLoss: 0.147354\n", "Train Epoch: 5 [6400/60000 (11%)]\tLoss: 0.251549\n", "Train Epoch: 5 [12800/60000 (21%)]\tLoss: 0.091474\n", "Train Epoch: 5 [19200/60000 (32%)]\tLoss: 0.234651\n", "Train Epoch: 5 [25600/60000 (43%)]\tLoss: 0.307294\n", "Train Epoch: 5 [32000/60000 (53%)]\tLoss: 0.314786\n", "Train Epoch: 5 [38400/60000 (64%)]\tLoss: 0.208747\n", "Train Epoch: 5 [44800/60000 (75%)]\tLoss: 0.434271\n", "Train Epoch: 5 [51200/60000 (85%)]\tLoss: 0.326049\n", "Train Epoch: 5 [57600/60000 (96%)]\tLoss: 0.215203\n", "[[907. 0. 3. 15. 4. 4. 62. 0. 5. 0.]\n", " [ 3. 971. 0. 21. 3. 0. 0. 0. 2. 0.]\n", " [ 32. 1. 726. 7. 129. 1. 99. 0. 5. 0.]\n", " [ 21. 0. 3. 910. 36. 0. 28. 0. 0. 2.]\n", " [ 3. 1. 10. 24. 875. 1. 86. 0. 0. 0.]\n", " [ 0. 0. 0. 0. 0. 981. 0. 15. 0. 4.]\n", " [155. 1. 33. 26. 61. 0. 706. 0. 18. 0.]\n", " [ 0. 0. 0. 0. 0. 9. 0. 963. 0. 28.]\n", " [ 2. 0. 1. 4. 1. 3. 0. 5. 984. 0.]\n", " [ 0. 0. 0. 0. 0. 6. 0. 28. 0. 966.]]\n", "\n", "Test set: Average loss: 0.2795, Accuracy: 8989/10000 (90%)\n", "\n", "Train Epoch: 6 [0/60000 (0%)]\tLoss: 0.184679\n", "Train Epoch: 6 [6400/60000 (11%)]\tLoss: 0.244871\n", "Train Epoch: 6 [12800/60000 (21%)]\tLoss: 0.102463\n", "Train Epoch: 6 [19200/60000 (32%)]\tLoss: 0.234760\n", "Train Epoch: 6 [25600/60000 (43%)]\tLoss: 0.401345\n", "Train Epoch: 6 [32000/60000 (53%)]\tLoss: 0.287975\n", "Train Epoch: 6 [38400/60000 (64%)]\tLoss: 0.224324\n", "Train Epoch: 6 [44800/60000 (75%)]\tLoss: 0.329089\n", "Train Epoch: 6 [51200/60000 (85%)]\tLoss: 0.288648\n", "Train Epoch: 6 [57600/60000 (96%)]\tLoss: 0.125468\n", "[[931. 0. 5. 13. 3. 2. 38. 0. 8. 0.]\n", " [ 3. 977. 0. 16. 2. 0. 0. 0. 2. 0.]\n", " [ 22. 1. 800. 10. 109. 1. 53. 0. 4. 0.]\n", " [ 19. 1. 6. 929. 24. 0. 19. 0. 1. 1.]\n", " [ 2. 1. 22. 37. 886. 1. 51. 0. 0. 0.]\n", " [ 0. 0. 0. 0. 0. 977. 0. 14. 0. 9.]\n", " [178. 1. 46. 24. 70. 0. 667. 0. 14. 0.]\n", " [ 0. 0. 0. 0. 0. 8. 0. 967. 0. 25.]\n", " [ 3. 1. 1. 2. 1. 2. 0. 3. 987. 0.]\n", " [ 0. 0. 0. 0. 0. 5. 0. 23. 0. 972.]]\n", "\n", "Test set: Average loss: 0.2526, Accuracy: 9093/10000 (91%)\n", "\n", "Train Epoch: 7 [0/60000 (0%)]\tLoss: 0.150053\n", "Train Epoch: 7 [6400/60000 (11%)]\tLoss: 0.164031\n", "Train Epoch: 7 [12800/60000 (21%)]\tLoss: 0.106214\n", "Train Epoch: 7 [19200/60000 (32%)]\tLoss: 0.257319\n", "Train Epoch: 7 [25600/60000 (43%)]\tLoss: 0.331515\n", "Train Epoch: 7 [32000/60000 (53%)]\tLoss: 0.274458\n", "Train Epoch: 7 [38400/60000 (64%)]\tLoss: 0.197483\n", "Train Epoch: 7 [44800/60000 (75%)]\tLoss: 0.263878\n", "Train Epoch: 7 [51200/60000 (85%)]\tLoss: 0.206326\n", "Train Epoch: 7 [57600/60000 (96%)]\tLoss: 0.109148\n", "[[925. 0. 5. 15. 4. 1. 45. 0. 5. 0.]\n", " [ 1. 983. 0. 13. 1. 0. 0. 0. 2. 0.]\n", " [ 23. 1. 799. 11. 89. 0. 74. 0. 3. 0.]\n", " [ 26. 3. 6. 920. 22. 0. 22. 0. 0. 1.]\n", " [ 2. 1. 19. 34. 872. 0. 72. 0. 0. 0.]\n", " [ 0. 0. 0. 0. 0. 967. 0. 24. 0. 9.]\n", " [164. 2. 38. 27. 60. 0. 699. 0. 10. 0.]\n", " [ 0. 0. 0. 0. 0. 1. 0. 976. 0. 23.]\n", " [ 2. 1. 1. 3. 1. 2. 1. 2. 987. 0.]\n", " [ 0. 0. 0. 0. 0. 5. 0. 21. 0. 974.]]\n", "\n", "Test set: Average loss: 0.2468, Accuracy: 9102/10000 (91%)\n", "\n", "Train Epoch: 8 [0/60000 (0%)]\tLoss: 0.121899\n", "Train Epoch: 8 [6400/60000 (11%)]\tLoss: 0.193384\n", "Train Epoch: 8 [12800/60000 (21%)]\tLoss: 0.082954\n", "Train Epoch: 8 [19200/60000 (32%)]\tLoss: 0.254026\n", "Train Epoch: 8 [25600/60000 (43%)]\tLoss: 0.221451\n", "Train Epoch: 8 [32000/60000 (53%)]\tLoss: 0.404695\n", "Train Epoch: 8 [38400/60000 (64%)]\tLoss: 0.190577\n", "Train Epoch: 8 [44800/60000 (75%)]\tLoss: 0.283892\n", "Train Epoch: 8 [51200/60000 (85%)]\tLoss: 0.252844\n", "Train Epoch: 8 [57600/60000 (96%)]\tLoss: 0.110405\n", "[[923. 0. 4. 13. 3. 2. 53. 0. 2. 0.]\n", " [ 4. 977. 0. 14. 2. 0. 1. 0. 2. 0.]\n", " [ 20. 1. 752. 8. 141. 0. 77. 0. 1. 0.]\n", " [ 22. 0. 7. 920. 26. 0. 24. 0. 0. 1.]\n", " [ 2. 1. 8. 26. 935. 0. 28. 0. 0. 0.]\n", " [ 0. 0. 0. 0. 0. 973. 0. 19. 0. 8.]\n", " [154. 1. 31. 21. 91. 0. 695. 0. 7. 0.]\n", " [ 0. 0. 0. 0. 0. 1. 0. 981. 0. 18.]\n", " [ 2. 1. 0. 3. 6. 2. 1. 3. 982. 0.]\n", " [ 0. 0. 0. 0. 0. 5. 0. 29. 0. 966.]]\n", "\n", "Test set: Average loss: 0.2563, Accuracy: 9104/10000 (91%)\n", "\n", "Train Epoch: 9 [0/60000 (0%)]\tLoss: 0.136938\n", "Train Epoch: 9 [6400/60000 (11%)]\tLoss: 0.247591\n", "Train Epoch: 9 [12800/60000 (21%)]\tLoss: 0.090397\n", "Train Epoch: 9 [19200/60000 (32%)]\tLoss: 0.241102\n", "Train Epoch: 9 [25600/60000 (43%)]\tLoss: 0.269329\n", "Train Epoch: 9 [32000/60000 (53%)]\tLoss: 0.310535\n", "Train Epoch: 9 [38400/60000 (64%)]\tLoss: 0.168209\n", "Train Epoch: 9 [44800/60000 (75%)]\tLoss: 0.234917\n", "Train Epoch: 9 [51200/60000 (85%)]\tLoss: 0.158344\n", "Train Epoch: 9 [57600/60000 (96%)]\tLoss: 0.131066\n", "[[926. 0. 5. 12. 2. 1. 48. 0. 6. 0.]\n", " [ 4. 979. 0. 14. 1. 0. 0. 0. 2. 0.]\n", " [ 28. 1. 795. 10. 97. 0. 68. 0. 1. 0.]\n", " [ 21. 0. 6. 952. 8. 0. 12. 0. 0. 1.]\n", " [ 2. 1. 17. 38. 890. 0. 51. 0. 1. 0.]\n", " [ 0. 0. 0. 0. 0. 976. 0. 19. 0. 5.]\n", " [163. 1. 34. 30. 63. 0. 703. 0. 6. 0.]\n", " [ 0. 0. 0. 0. 0. 5. 0. 979. 0. 16.]\n", " [ 3. 1. 1. 2. 1. 2. 0. 3. 987. 0.]\n", " [ 1. 0. 0. 0. 0. 4. 0. 37. 0. 958.]]\n", "\n", "Test set: Average loss: 0.2417, Accuracy: 9145/10000 (91%)\n", "\n", "Train Epoch: 10 [0/60000 (0%)]\tLoss: 0.104825\n", "Train Epoch: 10 [6400/60000 (11%)]\tLoss: 0.194985\n", "Train Epoch: 10 [12800/60000 (21%)]\tLoss: 0.117805\n", "Train Epoch: 10 [19200/60000 (32%)]\tLoss: 0.224127\n", "Train Epoch: 10 [25600/60000 (43%)]\tLoss: 0.160759\n", "Train Epoch: 10 [32000/60000 (53%)]\tLoss: 0.293535\n", "Train Epoch: 10 [38400/60000 (64%)]\tLoss: 0.247059\n", "Train Epoch: 10 [44800/60000 (75%)]\tLoss: 0.283217\n", "Train Epoch: 10 [51200/60000 (85%)]\tLoss: 0.169100\n", "Train Epoch: 10 [57600/60000 (96%)]\tLoss: 0.112807\n", "[[918. 0. 8. 12. 2. 3. 51. 0. 6. 0.]\n", " [ 2. 977. 0. 18. 1. 0. 0. 0. 2. 0.]\n", " [ 24. 1. 819. 10. 84. 1. 57. 0. 4. 0.]\n", " [ 20. 0. 4. 949. 9. 0. 16. 0. 0. 2.]\n", " [ 2. 1. 14. 38. 916. 0. 27. 0. 1. 1.]\n", " [ 0. 0. 0. 0. 0. 981. 0. 14. 0. 5.]\n", " [168. 0. 38. 28. 93. 1. 662. 0. 10. 0.]\n", " [ 0. 0. 0. 0. 0. 8. 0. 976. 0. 16.]\n", " [ 3. 1. 0. 2. 2. 2. 0. 3. 987. 0.]\n", " [ 1. 0. 0. 0. 0. 4. 0. 29. 0. 966.]]\n", "\n", "Test set: Average loss: 0.2386, Accuracy: 9151/10000 (92%)\n", "\n", "Train Epoch: 11 [0/60000 (0%)]\tLoss: 0.151567\n", "Train Epoch: 11 [6400/60000 (11%)]\tLoss: 0.146887\n", "Train Epoch: 11 [12800/60000 (21%)]\tLoss: 0.062417\n", "Train Epoch: 11 [19200/60000 (32%)]\tLoss: 0.335674\n", "Train Epoch: 11 [25600/60000 (43%)]\tLoss: 0.181816\n", "Train Epoch: 11 [32000/60000 (53%)]\tLoss: 0.277126\n", "Train Epoch: 11 [38400/60000 (64%)]\tLoss: 0.147858\n", "Train Epoch: 11 [44800/60000 (75%)]\tLoss: 0.188083\n", "Train Epoch: 11 [51200/60000 (85%)]\tLoss: 0.163272\n", "Train Epoch: 11 [57600/60000 (96%)]\tLoss: 0.123491\n", "[[948. 0. 6. 11. 3. 2. 27. 0. 3. 0.]\n", " [ 0. 987. 0. 11. 1. 0. 0. 0. 1. 0.]\n", " [ 22. 1. 840. 8. 77. 0. 51. 0. 1. 0.]\n", " [ 22. 0. 6. 942. 15. 0. 13. 0. 0. 2.]\n", " [ 1. 1. 29. 30. 913. 0. 26. 0. 0. 0.]\n", " [ 0. 0. 0. 0. 0. 973. 0. 19. 0. 8.]\n", " [201. 2. 39. 24. 86. 1. 640. 0. 7. 0.]\n", " [ 0. 0. 0. 0. 0. 2. 0. 972. 0. 26.]\n", " [ 3. 1. 0. 3. 2. 2. 0. 3. 985. 1.]\n", " [ 0. 0. 0. 0. 0. 5. 0. 21. 0. 974.]]\n", "\n", "Test set: Average loss: 0.2390, Accuracy: 9174/10000 (92%)\n", "\n", "Train Epoch: 12 [0/60000 (0%)]\tLoss: 0.159876\n", "Train Epoch: 12 [6400/60000 (11%)]\tLoss: 0.125542\n", "Train Epoch: 12 [12800/60000 (21%)]\tLoss: 0.112106\n", "Train Epoch: 12 [19200/60000 (32%)]\tLoss: 0.224706\n", "Train Epoch: 12 [25600/60000 (43%)]\tLoss: 0.265347\n", "Train Epoch: 12 [32000/60000 (53%)]\tLoss: 0.293587\n", "Train Epoch: 12 [38400/60000 (64%)]\tLoss: 0.249993\n", "Train Epoch: 12 [44800/60000 (75%)]\tLoss: 0.196611\n", "Train Epoch: 12 [51200/60000 (85%)]\tLoss: 0.170177\n", "Train Epoch: 12 [57600/60000 (96%)]\tLoss: 0.163824\n", "[[935. 0. 4. 13. 3. 3. 37. 0. 5. 0.]\n", " [ 0. 986. 0. 11. 1. 0. 0. 0. 2. 0.]\n", " [ 23. 1. 844. 7. 76. 0. 47. 0. 2. 0.]\n", " [ 20. 1. 5. 939. 20. 0. 14. 0. 0. 1.]\n", " [ 2. 2. 22. 27. 926. 0. 21. 0. 0. 0.]\n", " [ 0. 0. 0. 0. 0. 982. 0. 13. 0. 5.]\n", " [173. 1. 45. 34. 91. 0. 646. 0. 10. 0.]\n", " [ 0. 0. 0. 0. 0. 2. 0. 980. 0. 18.]\n", " [ 3. 1. 1. 3. 1. 2. 0. 5. 984. 0.]\n", " [ 0. 0. 0. 0. 0. 5. 0. 26. 0. 969.]]\n", "\n", "Test set: Average loss: 0.2351, Accuracy: 9191/10000 (92%)\n", "\n", "Train Epoch: 13 [0/60000 (0%)]\tLoss: 0.158473\n", "Train Epoch: 13 [6400/60000 (11%)]\tLoss: 0.158532\n", "Train Epoch: 13 [12800/60000 (21%)]\tLoss: 0.038748\n", "Train Epoch: 13 [19200/60000 (32%)]\tLoss: 0.163933\n", "Train Epoch: 13 [25600/60000 (43%)]\tLoss: 0.261832\n", "Train Epoch: 13 [32000/60000 (53%)]\tLoss: 0.280879\n", "Train Epoch: 13 [38400/60000 (64%)]\tLoss: 0.167006\n", "Train Epoch: 13 [44800/60000 (75%)]\tLoss: 0.268756\n", "Train Epoch: 13 [51200/60000 (85%)]\tLoss: 0.225810\n", "Train Epoch: 13 [57600/60000 (96%)]\tLoss: 0.110739\n", "[[922. 0. 5. 9. 3. 2. 50. 0. 9. 0.]\n", " [ 1. 980. 0. 13. 2. 0. 1. 0. 3. 0.]\n", " [ 22. 1. 822. 7. 93. 0. 54. 0. 1. 0.]\n", " [ 18. 1. 6. 923. 28. 0. 23. 0. 0. 1.]\n", " [ 1. 1. 12. 16. 942. 0. 28. 0. 0. 0.]\n", " [ 0. 0. 0. 0. 0. 976. 0. 18. 0. 6.]\n", " [151. 0. 36. 22. 87. 0. 698. 0. 6. 0.]\n", " [ 0. 0. 0. 0. 0. 1. 0. 971. 0. 28.]\n", " [ 4. 1. 0. 3. 3. 2. 0. 2. 985. 0.]\n", " [ 0. 0. 0. 0. 1. 4. 0. 18. 0. 977.]]\n", "\n", "Test set: Average loss: 0.2319, Accuracy: 9196/10000 (92%)\n", "\n", "Train Epoch: 14 [0/60000 (0%)]\tLoss: 0.115108\n", "Train Epoch: 14 [6400/60000 (11%)]\tLoss: 0.188998\n", "Train Epoch: 14 [12800/60000 (21%)]\tLoss: 0.067578\n", "Train Epoch: 14 [19200/60000 (32%)]\tLoss: 0.158400\n", "Train Epoch: 14 [25600/60000 (43%)]\tLoss: 0.254126\n", "Train Epoch: 14 [32000/60000 (53%)]\tLoss: 0.337398\n", "Train Epoch: 14 [38400/60000 (64%)]\tLoss: 0.123709\n", "Train Epoch: 14 [44800/60000 (75%)]\tLoss: 0.205702\n", "Train Epoch: 14 [51200/60000 (85%)]\tLoss: 0.138448\n", "Train Epoch: 14 [57600/60000 (96%)]\tLoss: 0.153552\n", "[[942. 0. 7. 9. 3. 2. 31. 0. 6. 0.]\n", " [ 1. 976. 0. 17. 1. 0. 2. 0. 3. 0.]\n", " [ 20. 1. 897. 5. 51. 0. 25. 0. 1. 0.]\n", " [ 16. 1. 6. 941. 20. 0. 15. 0. 1. 0.]\n", " [ 2. 1. 44. 22. 907. 0. 24. 0. 0. 0.]\n", " [ 0. 0. 0. 0. 0. 971. 0. 22. 1. 6.]\n", " [192. 0. 67. 28. 77. 0. 626. 0. 10. 0.]\n", " [ 0. 0. 0. 0. 0. 2. 0. 982. 0. 16.]\n", " [ 4. 1. 0. 3. 1. 1. 0. 2. 988. 0.]\n", " [ 1. 0. 0. 0. 0. 4. 0. 31. 0. 964.]]\n", "\n", "Test set: Average loss: 0.2374, Accuracy: 9194/10000 (92%)\n", "\n", "Train Epoch: 15 [0/60000 (0%)]\tLoss: 0.155847\n", "Train Epoch: 15 [6400/60000 (11%)]\tLoss: 0.116800\n", "Train Epoch: 15 [12800/60000 (21%)]\tLoss: 0.039833\n", "Train Epoch: 15 [19200/60000 (32%)]\tLoss: 0.171820\n", "Train Epoch: 15 [25600/60000 (43%)]\tLoss: 0.162665\n", "Train Epoch: 15 [32000/60000 (53%)]\tLoss: 0.248377\n", "Train Epoch: 15 [38400/60000 (64%)]\tLoss: 0.156792\n", "Train Epoch: 15 [44800/60000 (75%)]\tLoss: 0.187477\n", "Train Epoch: 15 [51200/60000 (85%)]\tLoss: 0.152571\n", "Train Epoch: 15 [57600/60000 (96%)]\tLoss: 0.049584\n", "[[922. 0. 6. 15. 6. 1. 44. 0. 6. 0.]\n", " [ 1. 985. 0. 11. 1. 0. 0. 0. 2. 0.]\n", " [ 24. 1. 832. 6. 102. 0. 34. 0. 1. 0.]\n", " [ 10. 3. 5. 935. 28. 0. 17. 0. 1. 1.]\n", " [ 2. 1. 15. 21. 948. 0. 13. 0. 0. 0.]\n", " [ 0. 0. 0. 0. 0. 973. 0. 18. 0. 9.]\n", " [144. 1. 48. 29. 114. 0. 655. 0. 9. 0.]\n", " [ 0. 0. 0. 0. 0. 3. 0. 974. 0. 23.]\n", " [ 2. 1. 1. 3. 1. 2. 1. 3. 986. 0.]\n", " [ 1. 0. 0. 0. 0. 3. 0. 25. 0. 971.]]\n", "\n", "Test set: Average loss: 0.2455, Accuracy: 9181/10000 (92%)\n", "\n", "Train Epoch: 16 [0/60000 (0%)]\tLoss: 0.117903\n", "Train Epoch: 16 [6400/60000 (11%)]\tLoss: 0.107160\n", "Train Epoch: 16 [12800/60000 (21%)]\tLoss: 0.051522\n", "Train Epoch: 16 [19200/60000 (32%)]\tLoss: 0.121555\n", "Train Epoch: 16 [25600/60000 (43%)]\tLoss: 0.120275\n", "Train Epoch: 16 [32000/60000 (53%)]\tLoss: 0.232415\n", "Train Epoch: 16 [38400/60000 (64%)]\tLoss: 0.120694\n", "Train Epoch: 16 [44800/60000 (75%)]\tLoss: 0.246462\n", "Train Epoch: 16 [51200/60000 (85%)]\tLoss: 0.144428\n", "Train Epoch: 16 [57600/60000 (96%)]\tLoss: 0.069679\n", "[[940. 0. 11. 7. 2. 1. 34. 0. 5. 0.]\n", " [ 2. 980. 1. 11. 3. 0. 0. 0. 3. 0.]\n", " [ 22. 1. 890. 7. 51. 0. 26. 0. 3. 0.]\n", " [ 20. 2. 4. 935. 21. 0. 15. 0. 2. 1.]\n", " [ 2. 0. 54. 23. 902. 0. 16. 0. 1. 2.]\n", " [ 0. 0. 0. 0. 0. 969. 0. 21. 0. 10.]\n", " [189. 1. 67. 26. 101. 0. 605. 0. 11. 0.]\n", " [ 0. 0. 0. 0. 0. 3. 0. 978. 0. 19.]\n", " [ 2. 1. 0. 2. 0. 3. 0. 1. 991. 0.]\n", " [ 0. 0. 0. 0. 1. 2. 0. 27. 0. 970.]]\n", "\n", "Test set: Average loss: 0.2465, Accuracy: 9160/10000 (92%)\n", "\n", "Train Epoch: 17 [0/60000 (0%)]\tLoss: 0.123562\n", "Train Epoch: 17 [6400/60000 (11%)]\tLoss: 0.105745\n", "Train Epoch: 17 [12800/60000 (21%)]\tLoss: 0.045860\n", "Train Epoch: 17 [19200/60000 (32%)]\tLoss: 0.108356\n", "Train Epoch: 17 [25600/60000 (43%)]\tLoss: 0.101529\n", "Train Epoch: 17 [32000/60000 (53%)]\tLoss: 0.245731\n", "Train Epoch: 17 [38400/60000 (64%)]\tLoss: 0.096277\n", "Train Epoch: 17 [44800/60000 (75%)]\tLoss: 0.128771\n", "Train Epoch: 17 [51200/60000 (85%)]\tLoss: 0.132103\n", "Train Epoch: 17 [57600/60000 (96%)]\tLoss: 0.143218\n", "[[944. 0. 6. 7. 3. 1. 34. 0. 5. 0.]\n", " [ 2. 981. 1. 11. 1. 0. 2. 0. 2. 0.]\n", " [ 23. 1. 875. 7. 61. 0. 32. 0. 1. 0.]\n", " [ 21. 6. 5. 936. 14. 0. 17. 0. 1. 0.]\n", " [ 2. 0. 36. 22. 914. 0. 26. 0. 0. 0.]\n", " [ 0. 0. 0. 0. 0. 973. 0. 16. 0. 11.]\n", " [175. 0. 48. 21. 73. 0. 674. 0. 9. 0.]\n", " [ 0. 0. 0. 0. 0. 1. 0. 975. 0. 24.]\n", " [ 4. 1. 0. 3. 0. 3. 0. 2. 987. 0.]\n", " [ 1. 0. 0. 0. 0. 2. 0. 20. 0. 977.]]\n", "\n", "Test set: Average loss: 0.2334, Accuracy: 9236/10000 (92%)\n", "\n", "Train Epoch: 18 [0/60000 (0%)]\tLoss: 0.143297\n", "Train Epoch: 18 [6400/60000 (11%)]\tLoss: 0.209471\n", "Train Epoch: 18 [12800/60000 (21%)]\tLoss: 0.083080\n", "Train Epoch: 18 [19200/60000 (32%)]\tLoss: 0.173864\n", "Train Epoch: 18 [25600/60000 (43%)]\tLoss: 0.143225\n", "Train Epoch: 18 [32000/60000 (53%)]\tLoss: 0.234640\n", "Train Epoch: 18 [38400/60000 (64%)]\tLoss: 0.128212\n", "Train Epoch: 18 [44800/60000 (75%)]\tLoss: 0.141312\n", "Train Epoch: 18 [51200/60000 (85%)]\tLoss: 0.164612\n", "Train Epoch: 18 [57600/60000 (96%)]\tLoss: 0.079647\n", "[[949. 0. 8. 6. 1. 1. 31. 0. 4. 0.]\n", " [ 1. 986. 0. 10. 1. 0. 0. 0. 2. 0.]\n", " [ 21. 1. 917. 6. 26. 0. 28. 0. 1. 0.]\n", " [ 20. 5. 7. 933. 14. 0. 19. 0. 1. 1.]\n", " [ 4. 1. 72. 26. 862. 0. 34. 0. 1. 0.]\n", " [ 0. 0. 0. 0. 0. 981. 0. 16. 0. 3.]\n", " [193. 3. 83. 23. 53. 0. 636. 0. 9. 0.]\n", " [ 0. 0. 0. 0. 0. 1. 0. 975. 0. 24.]\n", " [ 5. 1. 1. 3. 0. 1. 0. 2. 987. 0.]\n", " [ 1. 0. 0. 0. 0. 5. 0. 17. 0. 977.]]\n", "\n", "Test set: Average loss: 0.2397, Accuracy: 9203/10000 (92%)\n", "\n", "Train Epoch: 19 [0/60000 (0%)]\tLoss: 0.140702\n", "Train Epoch: 19 [6400/60000 (11%)]\tLoss: 0.074412\n", "Train Epoch: 19 [12800/60000 (21%)]\tLoss: 0.039120\n", "Train Epoch: 19 [19200/60000 (32%)]\tLoss: 0.173565\n", "Train Epoch: 19 [25600/60000 (43%)]\tLoss: 0.213276\n", "Train Epoch: 19 [32000/60000 (53%)]\tLoss: 0.242456\n", "Train Epoch: 19 [38400/60000 (64%)]\tLoss: 0.119806\n", "Train Epoch: 19 [44800/60000 (75%)]\tLoss: 0.139165\n", "Train Epoch: 19 [51200/60000 (85%)]\tLoss: 0.104396\n", "Train Epoch: 19 [57600/60000 (96%)]\tLoss: 0.093405\n", "[[931. 0. 14. 8. 2. 1. 38. 0. 6. 0.]\n", " [ 0. 990. 1. 6. 1. 0. 0. 0. 2. 0.]\n", " [ 20. 1. 913. 6. 33. 0. 26. 0. 1. 0.]\n", " [ 17. 4. 8. 933. 18. 0. 17. 0. 2. 1.]\n", " [ 1. 0. 48. 22. 898. 0. 30. 0. 1. 0.]\n", " [ 0. 0. 0. 0. 0. 972. 0. 20. 0. 8.]\n", " [173. 1. 65. 20. 70. 0. 658. 0. 13. 0.]\n", " [ 0. 0. 0. 0. 0. 2. 0. 973. 0. 25.]\n", " [ 2. 1. 0. 2. 1. 2. 0. 2. 990. 0.]\n", " [ 0. 0. 0. 0. 0. 2. 0. 21. 1. 976.]]\n", "\n", "Test set: Average loss: 0.2422, Accuracy: 9234/10000 (92%)\n", "\n", "Train Epoch: 20 [0/60000 (0%)]\tLoss: 0.093597\n", "Train Epoch: 20 [6400/60000 (11%)]\tLoss: 0.094620\n", "Train Epoch: 20 [12800/60000 (21%)]\tLoss: 0.030099\n", "Train Epoch: 20 [19200/60000 (32%)]\tLoss: 0.139641\n", "Train Epoch: 20 [25600/60000 (43%)]\tLoss: 0.218915\n", "Train Epoch: 20 [32000/60000 (53%)]\tLoss: 0.222010\n", "Train Epoch: 20 [38400/60000 (64%)]\tLoss: 0.083391\n", "Train Epoch: 20 [44800/60000 (75%)]\tLoss: 0.193114\n", "Train Epoch: 20 [51200/60000 (85%)]\tLoss: 0.141759\n", "Train Epoch: 20 [57600/60000 (96%)]\tLoss: 0.078522\n", "[[934. 0. 8. 6. 3. 1. 45. 0. 3. 0.]\n", " [ 2. 980. 0. 13. 1. 0. 1. 0. 3. 0.]\n", " [ 19. 1. 892. 7. 43. 0. 38. 0. 0. 0.]\n", " [ 17. 2. 7. 936. 16. 0. 21. 0. 1. 0.]\n", " [ 2. 0. 41. 18. 902. 0. 37. 0. 0. 0.]\n", " [ 0. 0. 0. 0. 0. 976. 0. 20. 0. 4.]\n", " [144. 1. 58. 20. 59. 0. 708. 0. 10. 0.]\n", " [ 0. 0. 0. 0. 0. 3. 0. 977. 0. 20.]\n", " [ 4. 1. 0. 3. 1. 1. 0. 2. 988. 0.]\n", " [ 1. 0. 0. 0. 0. 3. 0. 23. 0. 973.]]\n", "\n", "Test set: Average loss: 0.2341, Accuracy: 9266/10000 (93%)\n", "\n" ] } ], "source": [ "import torch\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", "import torch.optim as optim\n", "import sklearn.metrics as metrics\n", "import numpy as np\n", "from torchvision import datasets, transforms\n", "\n", "class NetBNormConv(nn.Module): #92% accuracy\n", " # much faster convergence.\n", " def __init__(self):\n", " super(NetBNormConv, self).__init__()\n", " self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)\n", " self.bn1 = nn.BatchNorm2d(32)\n", " self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)\n", " self.bn2 = nn.BatchNorm2d(64)\n", " self.pool = nn.MaxPool2d(2, 2)\n", " self.dropout = nn.Dropout(0.25)\n", "\n", " self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)\n", " self.bn3 = nn.BatchNorm2d(128)\n", " self.pool2 = nn.MaxPool2d(2, 2)\n", "\n", " self.relu = nn.ReLU()\n", " self.fc1 = nn.Linear(128 * 3 * 3, 128)\n", "\n", " self.fc2 = nn.Linear(128, 10)\n", "\n", " def forward(self, x):\n", " x = self.pool(self.relu(self.bn1(self.conv1(x))))\n", " x = self.pool(self.relu(self.bn2(self.conv2(x))))\n", " x = self.dropout(x)\n", " x = self.relu(self.bn3(self.conv3(x)))\n", " x = self.pool2(x)\n", "\n", " x = x.view(x.size(0), -1)\n", " x = self.relu(self.fc1(x))\n", " x = self.dropout(x)\n", " x = self.fc2(x)\n", " return F.log_softmax(x, dim=1)\n", "\n", "\n", "class NetConv(nn.Module): # 87% accuracy\n", " # two convolutional layers and one fully connected layer,\n", " # all using relu, followed by log_softmax\n", " def __init__(self):\n", " super(NetConv, self).__init__()\n", " self.conv1 = nn.Conv2d(1, 128, kernel_size=3, stride=1, padding=1)\n", " self.conv2 = nn.Conv2d(128, 256, kernel_size=5, stride=2, padding=1)\n", " self.pool = nn.MaxPool2d(kernel_size=2, stride=2)\n", " self.fc = nn.Linear(256 * 6 * 6, 10)\n", " self.relu = nn.ReLU()\n", " self.act = nn.LogSoftmax(dim=1)\n", "\n", " def forward(self, x):\n", " x = self.conv1(x)\n", " x = self.relu(x)\n", " x = self.conv2(x)\n", " x = self.relu(x)\n", " x = self.pool(x)\n", "\n", " #x = x.flatten(1)\n", " x = x.view(-1, 256 * 6 * 6)\n", " x = self.fc(x)\n", " x = self.act(x)\n", "\n", " return x\n", "\n", "def train(model, device, train_loader, optimizer, epoch):\n", " model.train()\n", " for batch_idx, (data, target) in enumerate(train_loader):\n", " data, target = data.to(device), target.to(device)\n", " optimizer.zero_grad()\n", " output = model(data)\n", " loss = F.nll_loss(output, target)\n", " loss.backward()\n", " optimizer.step()\n", " if batch_idx % 100 == 0:\n", " print('Train Epoch: {} [{}/{} ({:.0f}%)]\\tLoss: {:.6f}'.format(\n", " epoch, batch_idx * len(data), len(train_loader.dataset),\n", " 100. * batch_idx / len(train_loader), loss.item()))\n", "\n", "def test(model, device, test_loader):\n", " model.eval()\n", " test_loss = 0\n", " correct = 0\n", " conf_matrix = np.zeros((10,10)) # initialize confusion matrix\n", " with torch.no_grad():\n", " for data, target in test_loader:\n", " data, target = data.to(device), target.to(device)\n", " output = model(data)\n", " # sum up batch loss\n", " test_loss += F.nll_loss(output, target, reduction='sum').item()\n", " # determine index with maximal log-probability\n", " pred = output.argmax(dim=1, keepdim=True)\n", " correct += pred.eq(target.view_as(pred)).sum().item()\n", " # update confusion matrix\n", " conf_matrix = conf_matrix + metrics.confusion_matrix(\n", " target.cpu(),pred.cpu(),labels=[0,1,2,3,4,5,6,7,8,9])\n", " # print confusion matrix\n", " np.set_printoptions(precision=4, suppress=True)\n", " print(conf_matrix)\n", "\n", " test_loss /= len(test_loader.dataset)\n", "\n", " print('\\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\\n'.format(\n", " test_loss, correct, len(test_loader.dataset),\n", " 100. * correct / len(test_loader.dataset)))\n", "\n", "def main(model, lr=0.01, mom=0.5, epochs=20):\n", " use_mps = torch.backends.mps.is_available()\n", " device = torch.device('mps' if use_mps else 'cpu')\n", "\n", " # normalise.\n", " transform = transforms.Compose([transforms.ToTensor(),\n", "\t\t\t\t transforms.Normalize((0.5,), (0.5,))])\n", "\n", " # fetch and load training data\n", " trainset = datasets.FashionMNIST(root='./data', train=True, download=True, transform=transform)\n", " train_loader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=False)\n", "\n", " # fetch and load test data\n", " testset = datasets.FashionMNIST(root='./data', train=False, download=True, transform=transform)\n", " test_loader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=False)\n", "\n", " # choose network architecture\n", " if model == 'lin':\n", " net = NetLin().to(device)\n", " elif model == 'full':\n", " net = NetFull().to(device)\n", " elif model == 'bn-conv':\n", " net = NetBNormConv().to(device)\n", " print(net)\n", " else:\n", " net = NetConv().to(device)\n", "\n", " if list(net.parameters()):\n", "\t # use SGD optimizer\n", "\t #optimizer = optim.SGD(net.parameters(), lr=lr, momentum=mom)\n", "\n", "\t # use Adam optimizer\n", "\t #optimizer = optim.Adam(net.parameters(),lr=lr,\n", "\t # weight_decay=0.00001)\n", " optimizer = optim.SGD(net.parameters(),lr=lr,momentum=0.9, weight_decay=0.00001)\n", "\n", "\t # training and testing loop\n", " for epoch in range(1, epochs + 1):\n", " train(net, device, train_loader, optimizer, epoch)\n", " test(net, device, test_loader)\n", "\n", "main('bn-conv')" ] }, { "cell_type": "code", "execution_count": null, "id": "d8f2659e-5bd7-4b89-9fc3-5edc5f167cc5", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.9" } }, "nbformat": 4, "nbformat_minor": 5 }