{ "cells": [ { "cell_type": "markdown", "id": "fa341661", "metadata": {}, "source": [ "# Classification with the Perceptron, Logistic Regression and Softmax Regression models" ] }, { "cell_type": "markdown", "id": "3c04bc10", "metadata": {}, "source": [ "Author: Omar Al-Ghattas" ] }, { "cell_type": "markdown", "id": "22f9d5d6", "metadata": {}, "source": [ "## Introduction" ] }, { "cell_type": "markdown", "id": "22b9686f", "metadata": {}, "source": [ "In this lab we will take a closer look at some common classification models used in machine learning. Our focus will be primarily on binary classification (two class problem), but we will also see an example of a k-class problem using softmax regression. \n", "\n", "We begin with a recap of the perceptron classifier and we will work through an example of applying it to simulated data. We will then turn our attention to the logistic regression and softmax regression models and apply them to the famous MNIST dataset. Throughout, we will pay close attention to the `sklearn` implementation of the `LogisticRegression` class." ] }, { "cell_type": "markdown", "id": "38368d2b", "metadata": {}, "source": [ "## Recap: Perceptrons" ] }, { "cell_type": "markdown", "id": "8ed02f47", "metadata": {}, "source": [ "We begin by taking a look at a very simple model of computation, called the perceptron. Assume that we have $n$ inputs, which we denote by $x_1,x_2,\\dots,x_n$, and assume that each $x_i$ is a vector in $p$ dimensions, so that $x_i = [x_{i1}, \\dots,x_{ip}]^T$, in other words, we have $p$ features - another way of writing this is: $x_i \\in \\mathbb{R}^p$ We will always assume that the first feature is a dummy feature, that is $x_{i1} = 1$, to account for the bias term.\n", "\n", "Next, we let $w \\in \\mathbb{R}^p$ be the weight vector for our perceptron model. Note that the dimension of weight vector and the input vector need to be the same. This allows us to compute the dot product:\n", "\n", "$$\n", "h_w(x_i) =\\langle w, x_i\\rangle = \\sum_{j=1}^p w_j x_{ij}.\n", "$$\n", "\n", "This is known as the activation. Note that the bias (intercept term) here is hidden in $w$. We could have equivalently defined our weight vector by $\\tilde{w} \\in \\mathbb{R}^{p-1}$ and our bias term, $b\\in \\mathbb{R}$, and written \n", "$$\n", "h_{\\tilde{w},b}(x_i) = b +\\sum_{j=2}^p w_j x_{ij},\n", "$$\n", "in other words, $w = (b,\\tilde{w})^T$. It doesn't matter which setting you use, they are equivalent, and we will stick with the first one for the remainder.\n", "\n", "As discussed in the lecture, the perceptron computes the dot product, and then outputs the sign of this value. The sign (denoted sgn(x) usually) function is defined by \n", "$$\n", "\\text{sgn}(x) = \\begin{cases}\n", "+1 \\quad &\\text{if} \\quad x>0\\\\\n", "0 \\quad &\\text{if} \\quad x=0\\\\\n", "-1 \\quad &\\text{if} \\quad x<0.\n", "\\end{cases}\n", "$$\n", "We can visualise this function:" ] }, { "cell_type": "code", "execution_count": 1, "id": "74329bef", "metadata": { "scrolled": true }, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "import matplotlib.pyplot as plt\n", "import numpy as np\n", "\n", "xx_neg = np.linspace(-5,-0.15,1000)\n", "xx_pos = np.linspace(0.15,5,1000)\n", "plt.plot(xx_neg, np.sign(xx_neg), color='black')\n", "plt.plot(xx_pos, np.sign(xx_pos), color='black')\n", "plt.scatter(0,0, s=80, color='black')\n", "plt.scatter(0,1, s=80, facecolors='none', edgecolors='black')\n", "plt.scatter(0,-1, s=80, facecolors='none', edgecolors='black')\n", "plt.title(\"y=sgn(x)\")\n", "plt.show()" ] }, { "cell_type": "markdown", "id": "61b470dc", "metadata": {}, "source": [ "In a binary classification problem, we only have two classes, and so we will usually treat $\\text{sgn}(0)=-1$, as in, anything with positive value is classied as $+1$, and anything with non-positive value is classified as $-1$. In other words, we will think of the sgn function as looking like:" ] }, { "cell_type": "code", "execution_count": 2, "id": "64fe6581", "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYIAAAEICAYAAABS0fM3AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAAAdX0lEQVR4nO3df7xVdZ3v8ddbiB8GBgqjR0DRpAKqwdvO22Q3i1CpMYFqShsVyySn7Nr0S7rgratUOFNX597RScZMrPyVDnmcagjJLJ00NkWaJEloAaEcxR8QICGf+WN9jy22e58f7H3O5pz1fj4e+3H2+q7vWuuz1z5nvdePvc9SRGBmZsV1QLMLMDOz5nIQmJkVnIPAzKzgHARmZgXnIDAzKzgHgZlZwTkIrEOS7pF0bBf6vVPSTQ1a5jZJR3cw/lFJ0xqxrE7q+Lykb/b0cnpbfv1KulbSgg76hqRjeq86awYHgdUk6Z3A1oj4RWd9I+J2YLKk19a73IgYFhHrUg0dbqj2R5JeIek2SW2StkhaKumVza6rXX797o/6awDvzxwE1pHzgG90o/8NwJweqqUvGQG0Aq8EDgV+BtzWzILMOhQRfjTpAXwauLWi7f8B/9TN+ZwNrAO2Ao8Af5vaBwBfAZ5I7ecDAQxM438EXALck6b9ATAqjRsE7ADG5pbzPeArueEbgWtyw8cDj9So8QPA7bnhh4Fv54bXA1PS8wCOIQuVPwG7gG3t0wOPAp8C7geeAW4ChtRY7gHAfOB3wGbgOuBladz4tKzZwO/TepqXm/bzwDfT8+8CH6uY9/3ArC68Pwen5RxSY/y1wFeBZel9uAs4Mjf+jcCK9FpXAG/swnt/TJrPM+l13ZSbJoBjurjsfN/BwJfTuno8TTd0H9b7W4ANFf0fBaYB09P7/af0nv8yjT+cLFy3AGuBcyvep28D30yv4QHgFcBn07LXAyfl+r8M+BqwCdgILAAGdLbe+vOj6QUU+QG0AH8ERqThgekX93Vp+Erg6RqP+1OflwLPAq/MzXNyen4esBoYC4wE7uDFQfDb9EczNA0vTOMmA3+sqPewVN9U4G/JNkDDc+PbN3gHVXmtR6e6D0h/1L9r3xikcU8BB6Thyg3Vgop5PUq2l314WuavgfNqrOMPpg3H0cAw4N+Ab6Rx49Oy/jW9/r8EngMmpvGf589B8F7gvtx8/xJ4EhjUhfd5JrCpg/HXkm3A3ky2sf0n4O7cOn0KODP9fpyehg/p5L2/AZiX1vcQ4E255VWu36rLrtL3MrKN8cHAcOB24Ev7sN7fQo0gqFzvufE/Jvt7GAJMAdqAqbn+O4GT0zq6jiwU5wEvAc4lt4MCLAGuSuvvL8h+lz7c2Xrrz4+mF1D0B/B90t4NcAqwupvTv5RsA/tuKvbOgB+2/4Kn4Wm8OAjm58Z/BPiP9Px44LEqy3s32R7WE5V/JOmPLoAjatS6HvhvwGnAovQH+Cqyo4XWXL+uBMEZueF/AL5aY5nLgY/khl9Jtrc5kD8HQf6o52fAaen5CxuktFF4CpiQhr8MXNmF92cs2V7n6R30uRa4MTc8DHgeGEcWAD+r6P9TsiOBjt7769I6HltleZXrt+qy830Bke20vDzX96+ofQTY0Xp/C90IgrQenmfvnY4vAdfm+i/LjXsn2dFE+17+8PQ6RpCdqnsuv77IwvXOztZbf374GkHzLQbOSM/PoHvn5ImIPwLvI9v73yTpu5JelUYfTrbxbbe+cnrgsdzz7WQbAsg2esOr9L+d7JTTmoi4u2Jce/+na5R7F9lG4M3p+Y+AE9LjrhrT1FKr7krtRx/tfke2MTq0O/OKiJ1kp6DOkHQA2cajw/dK0miy021XRsQNHfUl995ExDayUyCHV6m//TWM6eS9/wzZxvtnkh6U9MF9WHbeaOBAYKWkpyU9DfxHaq+mK+u9qw4HtkTE1or5jckNP557vgN4IiKezw1D9r4eSbbDsin3Oq4iOzKA7q23fsNB0HzfAV4r6dVkRwTfah8h6avpo37VHg+294uIpRFxItmpgYfITnVAdg50bG5Z47pR19qsBI2paP8C2amYFkmnV4ybCDwaEc/WmGd7EPyP9PwuOg+C6EbN1fyB7I+/3RHAbvbecHTVYrJTYm8DtkfET2t1lDSSLARaI+ILXZj3C++NpGFkp1/+wIvrh+w1bITa731EPBYR50bE4cCHgSs7+BhorWXnPUG2QZ0cESPS42URUSuAO1rvfyQLlfZlDmDvQKl8z/8AHCwpv2PywjropvVkRwSjcq/joIiYDN1eb/2Gg6DJ0p7mLcD1ZKcAfp8bd15kH/Wr9pgMIOlQSTMkvZTsF3wbsCfN4mbgAkljJI0ALuxGXbvIrimc0N4m6c1kp3HOIrvA+v8rguIEslNdtdwFvJXssHwD8BOyi4OHAL+oMc3jZOeZ99UNwN9LOipt5L5IdgFwd3dnlDb8e8guwNc8GpB0ELAUuCci5nZx9u+Q9CZJg8gu4N8bEevJLtC/QtL7JQ2U9D5gEvDvHb33kv5GUvtOwFNkG9c9lQvtZNn5176HLGQuk/QXaRljJJ1cY54drfffAEMk/bWkl5BdVB6cm/ZxYHw68iLV8p/AlyQNSR9RPofs4nC3RMQmsoD+iqSDJB0g6eWSTkivqTvrrd9wEOwfFgOvoZunhZIDgE+Q7TVtIdsY/10a969kv/T3k21ov0e2V/b8i2dT1VVk56jbN27XAedHxMaI+AnZJy++Lkmp/+lpmqoi4jdkG6ufpOFnyS4435M7jK/0NWBSOoz/ThfrzruGbL3+mOwC4k7gY/swn3bXkb1XHW2EZgGvBz5QcRR3RAfTXA98juw9fB3pdGFEPEl2pPhJsovTnwFOiYgn6Pi9fz1wn6RtZBd4L4ja3x2ouuwqLiQ7UrxX0rNkOwq1vh9Rc71HxDNk16OuJtur/yOwITftt9PPJyX9PD0/neyazh/ILvZ+LiLuqLHszpxF9qm41WQb+1vIjqige+ut31C6QGJNlDYQDwGHdXBapRHLeTvZRdXKUw0dTXMP2ca/1h57e793AmdGxHvrLHO/JuksYE5EvKmB87yW7OLp/EbN06w7Bja7gKJLh7+fIPvkRkNDQNJQslMxPyC7SPc5sr2pLouI47vY73ayC8n9lqQDyfZkr2x2LWaN5FNDTZTO7T4LnEi2kW74IoD/Q3b4+wuyi7z/uweW0++lc+FtZOevr29yOWYN5VNDZmYF5yMCM7OC65PXCEaNGhXjx49vdhlmZn3KypUrn4iIF30JsE8Gwfjx4ymXy80uw8ysT5FU+S11wKeGzMwKz0FgZlZwDgIzs4Lrk9cIzPYnO3fu5Oabb6a1tZXt27czefJk5syZw4QJE5pdmlmX+IjArA5r1qxh0qRJXH/99cycOZOPfOQjDBgwgDe+8Y188YtfbHZ5Zl3SkCMCSdeQ/WOszRHx6irjRXbno3eQ/b/3syPi52ncbLL/PgjZDUgWN6Ims562Y8cOpk+fzrx58/jQhz70Qvspp5zCBRdcwFvf+laOOOIIzjij1v9wM9s/NOqI4Fqyfydcy9uBCekxB/gXAEkHk/1rhf8OHAd8Lv0fd7P93k033cTEiRP3CoF2LS0tXHHFFVx66aX42/u2v2vIEUFE/FjS+A66zACui+wv4l5JIyS1kN2kZFlEbAGQtIwsUDq7m5NZ0y1ZsoSzzjqr5vipU6fyyCOP8IY3vIGhQ4f2YmXWX02ZMoXLL7+84fPtrWsEY9j7NokbUlut9heRNEdSWVK5ra2txwo166rt27czcmTtA1hJDB48mD17+v19TayP6zOfGoqIRWQ3laZUKvlY25pu0qRJ3HPPPZx8cvWbdG3evJk9e/awfPlyDjrooF6uzqzreuuIYCN73y93bGqr1W623/vwhz/MokWLqHWE+uUvf5l3vetdDgHb7/VWELQCZynzBuCZdO/QpcBJkkami8QnpTaz/d6kSZM499xzmTp1KnffffcLF4WfeOIJ5s6dy6233sqCBQuaXKVZ5xr18dEbyC78jpK0geyTQC8BiIivkt0r9x1k9zvdTnYDdCJii6RLgBVpVhe3Xzg26wsuvvhixo8fzznnnMOuXbsYMWIEjz76KDNnzuTuu++mpaWl85mYNVmfvDFNqVQK//dR259EBA899BA7duzg6KOPZsSIEc0uyexFJK2MiFJle5+5WGy2P5PExIkTm12G2T7xv5gwMys4B4GZWcE5CMzMCs5BYGZWcA4CM7OCcxCYmRWcg8DMrOAcBGZmBecgMDMrOAeBmVnBOQjMzArOQWBmVnAOAjOzgnMQmJkVnIPAzKzgGhIEkqZLWiNpraS5VcZfJmlVevxG0tO5cc/nxrU2oh4zM+u6um9MI2kAcAVwIrABWCGpNSJWt/eJiL/P9f8YcGxuFjsiYkq9dZiZ2b5pxBHBccDaiFgXEbuAG4EZHfQ/HbihAcs1M7MGaEQQjAHW54Y3pLYXkXQkcBTww1zzEEllSfdKmllrIZLmpH7ltra2BpRtZmbQ+xeLTwNuiYjnc21Hppspvx+4XNLLq00YEYsiohQRpdGjR/dGrWZmhdCIINgIjMsNj01t1ZxGxWmhiNiYfq4DfsTe1w/MzKyHNSIIVgATJB0laRDZxv5Fn/6R9CpgJPDTXNtISYPT81HA8cDqymnNzKzn1P2poYjYLel8YCkwALgmIh6UdDFQjoj2UDgNuDEiIjf5ROAqSXvIQmlh/tNGZmbW87T3drlvKJVKUS6Xm12GmVmfImlluia7F3+z2Mys4BwEZmYF5yAwMys4B4GZWcE5CMzMCs5BYGZWcA4CM7OCcxCYmRWcg8DMrOAcBGZmBecgMDMrOAeBmVnBOQjMzArOQWBmVnAOAjOzgnMQmJkVXEOCQNJ0SWskrZU0t8r4syW1SVqVHh/KjZst6eH0mN2IeszMrOvqvlWlpAHAFcCJwAZghaTWKrecvCkizq+Y9mDgc0AJCGBlmvapeusyM7OuacQRwXHA2ohYFxG7gBuBGV2c9mRgWURsSRv/ZcD0BtRkZmZd1IggGAOszw1vSG2V3i3pfkm3SBrXzWmRNEdSWVK5ra2tAWWbmRn03sXi24HxEfFasr3+xd2dQUQsiohSRJRGjx7d8ALNzIqqEUGwERiXGx6b2l4QEU9GxHNp8GrgdV2d1szMelYjgmAFMEHSUZIGAacBrfkOklpyg6cCv07PlwInSRopaSRwUmozM7NeUvenhiJit6TzyTbgA4BrIuJBSRcD5YhoBf6npFOB3cAW4Ow07RZJl5CFCcDFEbGl3prMzKzrFBHNrqHbSqVSlMvlZpdhZtanSFoZEaXKdn+z2Mys4BwEZmYF5yAwMys4B4GZWcE5CMzMCs5BYGZWcA4CM7OCcxCYmRWcg8DMrOAcBGZmBecgMDMrOAeBmVnBOQjMzArOQWBmVnAOAjOzgmtIEEiaLmmNpLWS5lYZ/wlJq9PN65dLOjI37nlJq9KjtXJaMzPrWXXfoUzSAOAK4ERgA7BCUmtErM51+wVQiojtkv4O+AfgfWncjoiYUm8dZma2bxpxRHAcsDYi1kXELuBGYEa+Q0TcGRHb0+C9ZDepNzOz/UAjgmAMsD43vCG11XIO8P3c8BBJZUn3SppZayJJc1K/cltbW10Fm5nZn9V9aqg7JJ0BlIATcs1HRsRGSUcDP5T0QET8tnLaiFgELILsnsW9UrCZWQE04ohgIzAuNzw2te1F0jRgHnBqRDzX3h4RG9PPdcCPgGMbUJOZmXVRI4JgBTBB0lGSBgGnAXt9+kfSscBVZCGwOdc+UtLg9HwUcDyQv8hsZmY9rO5TQxGxW9L5wFJgAHBNRDwo6WKgHBGtwD8Cw4BvSwL4fUScCkwErpK0hyyUFlZ82sjMzHqYIvre6fZSqRTlcrnZZZiZ9SmSVkZEqbLd3yw2Mys4B4GZWcE5CMzMCs5BYGZWcA4CM7OCcxCYmRWcg8DMrOAcBGZmBecgMDMrOAeBmVnBOQjMzArOQWBmVnAOAjOzgnMQmJkVnIPAzKzgHARmZgXXkCCQNF3SGklrJc2tMn6wpJvS+Pskjc+N+2xqXyPp5EbUY2ZmXVd3EEgaAFwBvB2YBJwuaVJFt3OApyLiGOAy4NI07SSyexxPBqYDV6b5mZlZL2nEEcFxwNqIWBcRu4AbgRkVfWYAi9PzW4C3Kbt58Qzgxoh4LiIeAdam+ZmZWS9pRBCMAdbnhjektqp9ImI38AxwSBenBUDSHEllSeW2trYGlG1mZtCHLhZHxKKIKEVEafTo0c0ux8ys32hEEGwExuWGx6a2qn0kDQReBjzZxWnNzKwHNSIIVgATJB0laRDZxd/Wij6twOz0/D3ADyMiUvtp6VNFRwETgJ81oCYzM+uigfXOICJ2SzofWAoMAK6JiAclXQyUI6IV+BrwDUlrgS1kYUHqdzOwGtgNfDQinq+3JjMz6zplO+Z9S6lUinK53OwyzMz6FEkrI6JU2d5nLhabmVnPcBCYmRWcg8DMrOAcBGZmBecgMDMrOAeBmVnBOQjMzArOQWBmVnAOAjOzgnMQmJkVnIPAzKzgHARmZgXnIDAzKzgHgZlZwTkIzMwKzkFgZlZwdQWBpIMlLZP0cPo5skqfKZJ+KulBSfdLel9u3LWSHpG0Kj2m1FOPmZl1X71HBHOB5RExAViehittB86KiMnAdOBySSNy4z8dEVPSY1Wd9ZiZWTfVGwQzgMXp+WJgZmWHiPhNRDycnv8B2AyMrnO5ZmbWIPUGwaERsSk9fww4tKPOko4DBgG/zTV/IZ0yukzS4A6mnSOpLKnc1tZWZ9lmZtau0yCQdIekX1V5zMj3i4gAooP5tADfAD4QEXtS82eBVwGvBw4GLqw1fUQsiohSRJRGj/YBhZlZowzsrENETKs1TtLjkloiYlPa0G+u0e8g4LvAvIi4Nzfv9qOJ5yR9HfhUt6o3M7O61XtqqBWYnZ7PBm6r7CBpELAEuC4ibqkY15J+iuz6wq/qrMfMzLqp3iBYCJwo6WFgWhpGUknS1anPe4E3A2dX+ZjotyQ9ADwAjAIW1FmPmZl1k7JT+31LqVSKcrnc7DLMzPoUSSsjolTZ7m8Wm5kVnIPAzKzgHARmZgXnIDAzKzgHgZlZwTkIzMwKzkFgZlZwDgIzs4JzEJiZFZyDwMys4BwEZmYF5yAwMys4B4GZWcE5CMzMCs5BYGZWcHUFgaSDJS2T9HD6ObJGv+dzN6VpzbUfJek+SWsl3ZTuZmZmZr2o3iOCucDyiJgALE/D1eyIiCnpcWqu/VLgsog4BngKOKfOeszMrJvqDYIZwOL0fDHZfYe7JN2neCrQfh/jbk1vZmaNUW8QHBoRm9Lzx4BDa/QbIqks6V5JM1PbIcDTEbE7DW8AxtRakKQ5aR7ltra2Oss2M7N2AzvrIOkO4LAqo+blByIiJNW6AfKREbFR0tHAD9MN65/pTqERsQhYBNk9i7szrZmZ1dZpEETEtFrjJD0uqSUiNklqATbXmMfG9HOdpB8BxwK3AiMkDUxHBWOBjfvwGszMrA71nhpqBWan57OB2yo7SBopaXB6Pgo4HlgdEQHcCbyno+nNzKxn1RsEC4ETJT0MTEvDSCpJujr1mQiUJf2SbMO/MCJWp3EXAp+QtJbsmsHX6qzHzMy6SdmOed9SKpWiXC43uwwzsz5F0sqIKFW2+5vFZmYF5yAwMys4B4GZWcE5CMzMCs5BYGZWcA4CM7OCcxCYmRWcg8DMrOAcBGZmBecgMDMrOAeBmVnBOQjMzArOQWBmVnAOAjOzgnMQmJkVnIPAzKzg6goCSQdLWibp4fRzZJU+b5W0KvfYKWlmGnetpEdy46bUU4+ZmXVfvUcEc4HlETEBWJ6G9xIRd0bElIiYAkwFtgM/yHX5dPv4iFhVZz1mZtZN9QbBDGBxer4YmNlJ//cA34+I7XUu18zMGqTeIDg0Ijal548Bh3bS/zTghoq2L0i6X9JlkgbXmlDSHEllSeW2trY6SjYzs7xOg0DSHZJ+VeUxI98vIgKIDubTArwGWJpr/izwKuD1wMHAhbWmj4hFEVGKiNLo0aM7K9vMzLpoYGcdImJarXGSHpfUEhGb0oZ+cwezei+wJCL+lJt3+9HEc5K+Dnyqi3WbmVmD1HtqqBWYnZ7PBm7roO/pVJwWSuGBJJFdX/hVnfWYmVk31RsEC4ETJT0MTEvDSCpJurq9k6TxwDjgrorpvyXpAeABYBSwoM56zMysmzo9NdSRiHgSeFuV9jLwodzwo8CYKv2m1rN8MzOrn79ZbGZWcA4CM7OCcxCYmRWcg8DMrOAcBGZmBecgMDMrOAeBmVnBOQjMzArOQWBmVnAOAjOzgnMQmJkVnIPAzKzgHARmZgXnIDAzK7i6/g21mWW2bt3KkiVLeOyxxzjssMOYNWsWw4cPb3ZZZl3iIDCrQ0SwcOFCLrnkEgYMGMDOnTsZMmQI5513HhdddBFz584luwGf2f6rrlNDkv5G0oOS9kgqddBvuqQ1ktZKmptrP0rSfan9JkmD6qnHrLctXLiQBQsWsGPHDrZt28bu3bvZtm0bO3bsYMGCBSxcuLDZJZp1qt5rBL8C3gX8uFYHSQOAK4C3A5OA0yVNSqMvBS6LiGOAp4Bz6qzHrNds3bqVSy65hO3bt1cdv337dhYsWMC2bdt6uTKz7qn3VpW/Bjo79D0OWBsR61LfG4EZkn4NTAXen/otBj4P/Es9NXXk4x//OKtWreqp2VvBPP744+zevbvDPgcccABLlizhzDPP7KWqzLqvNz41NAZYnxvekNoOAZ6OiN0V7VVJmiOpLKnc1tbWY8WaddWuXbs6DYKdO3eyadOmXqrIbN90ekQg6Q7gsCqj5kXEbY0vqbqIWAQsAiiVSrEv87j88ssbWZIV3HXXXcdHP/rRDk/9DBkyhJaWll6syqz7Og2CiJhW5zI2AuNyw2NT25PACEkD01FBe7tZnzBr1izOO++8Dvvs2bOHWbNm9VJFZvumN04NrQAmpE8IDQJOA1ojIoA7gfekfrOBXjvCMKvX8OHDueiiizjwwAOrjj/wwAOZP38+w4YN6+XKzLqn3o+PzpK0Afgr4LuSlqb2wyV9DyDt7Z8PLAV+DdwcEQ+mWVwIfELSWrJrBl+rpx6z3jZ37lzmz5/P0KFDGTZsGAMHDmTYsGEMHTqU+fPnM3fu3M5nYtZkynbM+5ZSqRTlcrnZZZi9YOvWrXznO99h06ZNtLS0MGvWLB8J2H5H0sqIeNF3vvzNYrMGGD58uD8ian2W/+mcmVnBOQjMzArOQWBmVnB98mKxpDbgd82uo5tGAU80u4he5tdcDH7NfceRETG6srFPBkFfJKlc7Wp9f+bXXAx+zX2fTw2ZmRWcg8DMrOAcBL1nUbMLaAK/5mLwa+7jfI3AzKzgfERgZlZwDgIzs4JzEDSBpE9KCkmjml1LT5P0j5IeknS/pCWSRjS7pp4iabqkNZLWSur3/3ZU0jhJd0paLelBSRc0u6beIGmApF9I+vdm19IoDoJeJmkccBLw+2bX0kuWAa+OiNcCvwE+2+R6eoSkAcAVwNuBScDpkiY1t6oetxv4ZERMAt4AfLQArxngArJ/qd9vOAh632XAZ4BCXKWPiB/k7kt9L9md6Pqj44C1EbEuInYBNwIzmlxTj4qITRHx8/R8K9nGseZ9x/sDSWOBvwaubnYtjeQg6EWSZgAbI+KXza6lST4IfL/ZRfSQMcD63PAG+vlGMU/SeOBY4L4ml9LTLifbkdvT5DoayvcjaDBJdwCHVRk1D/hfZKeF+pWOXnNE3Jb6zCM7lfCt3qzNep6kYcCtwMcj4tlm19NTJJ0CbI6IlZLe0uRyGspB0GARMa1au6TXAEcBv5QE2SmSn0s6LiIe68USG67Wa24n6WzgFOBt0X+/uLIRGJcbHpva+jVJLyELgW9FxL81u54edjxwqqR3AEOAgyR9MyLOaHJddfMXyppE0qNAKSL64n8w7DJJ04H/C5wQEW3NrqenSBpIdjH8bWQBsAJ4f+7+3P2Osj2axcCWiPh4k8vpVemI4FMRcUqTS2kIXyOwnvbPwHBgmaRVkr7a7IJ6Qrogfj6wlOyi6c39OQSS44EzganpvV2V9patj/ERgZlZwfmIwMys4BwEZmYF5yAwMys4B4GZWcE5CMzMCs5BYGZWcA4CM7OC+y8Tx3EXUsXkQwAAAABJRU5ErkJggg==", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "plt.plot(xx_neg, np.sign(xx_neg), color='black')\n", "plt.plot(xx_pos, np.sign(xx_pos), color='black')\n", "plt.scatter(0,-1, s=80, color='black')\n", "plt.scatter(0,1, s=80, facecolors='none', edgecolors='black')\n", "plt.title(\"y=sgn(x) with only 2 possible outomes\")\n", "plt.show()" ] }, { "cell_type": "markdown", "id": "e2b8ce0c", "metadata": {}, "source": [ "Now, denote the output of the perceptron on input $x_i$ to be $o(x_i) = \\text{sign}(h_w(x_i))$. The idea is to tweak the weight vector $w$ so that $h_w(x_i)$ is positive when $x_i$ has corresponding label $y_i=+1$, and $h_w(x_i)\\le 0$ whenever $y_i = -1$. In the next section, we show a simple example for the binary OR function.\n", "\n", "Recall that the binary OR function takes in two binary inputs, and returns TRUE if at least one of the two inputs is equal to $1$, and otherwise returns FALSE. For example, if $x = (1,0)$, then \n", "\n", "$$\n", "\\text{OR}(x) = \\text{TRUE} = +1 = y\n", "$$\n", "\n", "and if $x=(0,0)$, then \n", "\n", "$$\n", "\\text{OR}(x) = \\text{FALSE} = -1 = y.\n", "$$\n", "\n", "Note that here we are choosing to encode $\\{TRUE, FALSE\\} = \\{+1,-1\\}$. Sometimes we will choose to use the encoding $\\{TRUE, FALSE\\} = \\{+1,0\\}$, it doesn't matter, as long as we remember to be consistent." ] }, { "cell_type": "markdown", "id": "dbecc8e7", "metadata": {}, "source": [ "## Representing simple Boolean functions as a linear classifier" ] }, { "cell_type": "markdown", "id": "895b7ab5", "metadata": {}, "source": [ "We will first look at modelling a simple two-input Boolean function as linear classifier. We can think of this as a Perceptron WITHOUT any learning! To get started we will use the OR function, for which the truth table will be familiar to you all. Note that you will need to pick some weights for the function to output the correct values given the input. There are many possible values that could do the job. Also, remember to take care with the dimension of the weight vector." ] }, { "cell_type": "code", "execution_count": 3, "id": "78eb939d", "metadata": { "scrolled": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "OR Function: \n", "For Input [1 0 0] with Class -1 Predict -1\n", "For Input [1 0 1] with Class 1 Predict 1\n", "For Input [1 1 0] with Class 1 Predict 1\n", "For Input [1 1 1] with Class 1 Predict 1\n" ] } ], "source": [ "def sgn(x):\n", " return 1 if x>0 else -1\n", "\n", "# set up\n", "X = np.array([[1,0,0], # all possible inputs, note that first\n", " [1,0,1], # element is always equal to 1 (bias term)\n", " [1,1,0], \n", " [1,1,1]]) \n", "y = np.array([-1,1,1,1]) # labels for the OR function OR(x1,x2)\n", "n = X.shape[0] # number of data points\n", "m = X[0].shape[0] # input dimension\n", " \n", "w = np.array([-0.03, 0.04, 0.04]) # example weight vector\n", "\n", "# what predictions will our current model (w) make?\n", "print(\"OR Function: \")\n", "for i in range(n):\n", " h = np.dot(w, X[i]) # activation\n", " o = sgn(h) # output\n", " print('For Input', X[i], 'with Class', y[i], 'Predict ', o)" ] }, { "cell_type": "markdown", "id": "4c3954de", "metadata": {}, "source": [ "\n", " \n", "#### Exercise: \n", "Repeat the above analysis but for the AND function, choose an appropriate weight vector by trial and error" ] }, { "cell_type": "code", "execution_count": 4, "id": "c2c19a48", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "AND Function: \n", "For Input [1 0 0] with Class -1 Predict -1\n", "For Input [1 0 1] with Class -1 Predict -1\n", "For Input [1 1 0] with Class -1 Predict -1\n", "For Input [1 1 1] with Class 1 Predict 1\n" ] } ], "source": [ "### Solution\n", "y = np.array([-1,-1,-1,1]) # labels for the AND function AND(x1,x2)\n", "w = np.array([-0.03, 0.02, 0.01]) # example weight vector\n", "\n", "print(\"AND Function: \")\n", "for i in range(n):\n", " h = np.dot(w, X[i]) # activation\n", " o = sgn(h) # output\n", " print('For Input', X[i], 'with Class', y[i], 'Predict ', o)" ] }, { "cell_type": "markdown", "id": "cc55acad", "metadata": {}, "source": [ "## Perceptron Learning" ] }, { "cell_type": "markdown", "id": "1c0f4983", "metadata": {}, "source": [ "So far, we haven't done any actual learning, we just picked a weight vector, $w$, that works. We now look at an approach to learning $w$ from data. We first recall the (Batch) Perceptron Algorithm from lectures:\n", "\n", "\\begin{align*}\n", "&\\text{input: } (x_1,y_1),\\dots, (x_n, y_n)\\\\\n", "&\\text{initialise: } w^{(0)} = (0,0,\\dots, 0) \\in \\mathbb{R}^p\\\\\n", "&\\text{for } t=1,\\dots, \\texttt{max\\_iter}\\\\\n", "&\\qquad \\text{if there is an index $i$ such that } y_i \\langle w^{(t)}, x_i \\rangle \\le 0:\\\\\n", "&\\qquad \\qquad w^{(t+1)} = w^{(t)} + y_i x_i\\\\\n", "&\\qquad \\text{else:}\\\\\n", "&\\qquad \\qquad \\text{ output } w^{(t)}, t\\\\\n", "\\end{align*}" ] }, { "cell_type": "markdown", "id": "c44daab7", "metadata": {}, "source": [ "The perceptron is known as a mistake driven algorithm, since it updates $w$ only when a mistake is made (i.e. when $y_i \\langle w, x_i \\rangle \\le 0$). It is also important to note that the perceptron only converges when it makes no mistakes on the data, so in other words, this algorithm can only terminate if the data set is linearly separable. We can implement the above in code and try to visualise the learning process." ] }, { "cell_type": "markdown", "id": "986df7de", "metadata": {}, "source": [ "First, some helper functions to generate toy data and plot the perceptron at each iteration." ] }, { "cell_type": "code", "execution_count": 5, "id": "5e72de45", "metadata": {}, "outputs": [], "source": [ "def generate_data(n=20, means=[[3,3],[-1,1]], seed=1):\n", " '''\n", " generate data from two gaussians\n", " '''\n", " np.random.seed(seed)\n", " m1 = np.array(means[0])\n", " m2 = np.array(means[1])\n", " S1 = np.random.rand(2,2)\n", " S2 = np.random.rand(2,2)\n", " dist_01 = np.random.multivariate_normal(m1, S1.T @ S1, n)\n", " dist_02 = np.random.multivariate_normal(m2, S2.T @ S2, n)\n", " X = np.concatenate((np.ones(2*n).reshape(-1,1), \n", " np.concatenate((dist_01, dist_02))), axis=1)\n", " y = np.concatenate((np.ones(n), -1*np.ones(n))).reshape(-1,1)\n", " shuffle_idx = np.random.choice(2*n, size=2*n, replace=False)\n", " X = X[shuffle_idx]\n", " y = y[shuffle_idx]\n", " return X, y\n", "\n", "def plot_perceptron(ax, X, y, w):\n", " pos_points = X[np.where(y==1)[0]]\n", " neg_points = X[np.where(y==-1)[0]]\n", " ax.scatter(pos_points[:, 1], pos_points[:, 2], color='blue')\n", " ax.scatter(neg_points[:, 1], neg_points[:, 2], color='red')\n", " xx = np.linspace(-6,6)\n", " yy = -w[0]/w[2] - w[1]/w[2] * xx\n", " ax.plot(xx, yy, color='orange')\n", " \n", " ratio = (w[2]/w[1] + w[1]/w[2])\n", " xpt = (-1*w[0] / w[2]) * 1/ratio\n", " ypt = (-1*w[0] / w[1]) * 1/ratio\n", " \n", " ax.arrow(xpt, ypt, w[1], w[2], head_width=0.2, color='orange')\n", " ax.axis('equal')" ] }, { "cell_type": "markdown", "id": "6a1d2f87", "metadata": {}, "source": [ "Next, we implement the perceptron learning algorithm. Note here that we will randomly initiliase the weight vector rather than taking the weight vector to be $(0,0,\\dots,0)$, this is done to make it easier to visualise the algorithm. Note we also restrict attention to the two dimensional case, so $w \\in \\mathbb{R}^3$ to account for the two dimensions and the bias (intercept) term." ] }, { "cell_type": "code", "execution_count": 6, "id": "b27bad2b", "metadata": {}, "outputs": [], "source": [ "def train_perceptron_for_vis(X, y, max_iter=100):\n", " np.random.seed(20) # for consistency in weight init\n", " w = np.random.random(3) # init weights randomly\n", " ctr = 0 # keep track of number of iterations\n", " for _ in range(max_iter):\n", " \n", " yXw = (y * X) @ w.T # compute all predictions made by model\n", " mistake_idxs = np.where(yXw <= 0)[0] # find indexes where mistake is made\n", " if mistake_idxs.size > 0:\n", " ctr += 1\n", " i = np.random.choice(mistake_idxs) # pick idx randomly\n", " w = w + y[i] * X[i] # update w\n", " \n", " # visualisation\n", " fig, ax = plt.subplots()\n", " plot_perceptron(ax, X, y, w)\n", " plt.show()\n", " print(f\"Iteration {ctr}: w = {w}\")\n", " \n", " # plot final weight vector\n", " fig, ax = plt.subplots()\n", " plot_perceptron(ax, X, y, w)\n", " plt.show()\n", " print(f\"Iteration {ctr}: w = {w}\")\n", "\n", " return " ] }, { "cell_type": "code", "execution_count": 7, "id": "48e99cf4", "metadata": {}, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Iteration 1: w = [-0.4118692 0.45148601 -0.79584122]\n" ] }, { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Iteration 2: w = [ 0.5881308 -2.59682971 -2.12119882]\n" ] }, { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Iteration 3: w = [-0.4118692 -2.00249179 -2.76919429]\n" ] }, { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Iteration 4: w = [ 0.5881308 -1.74183941 -2.18951439]\n" ] }, { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Iteration 5: w = [-0.4118692 -1.42242212 -2.63810207]\n" ] }, { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Iteration 6: w = [ 0.5881308 -1.02625027 -2.41014956]\n" ] }, { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Iteration 7: w = [ 1.5881308 -0.76559789 -1.83046966]\n" ] }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXIAAAD4CAYAAADxeG0DAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAAAUXklEQVR4nO3df4xsZX3H8c93Z/dCFwFFLtJy7+5i1TYoWsxAbYi0Fm0RDbQ1bTCDudaYVaiEJiQE2LR/daOpjZVGsdkAhoRpSKMiRLGKUUnbFMpeRClcawy5e7nExgVqRTZ67+5++8eZYWfnzo8zM2fOc54571dyM5xnZs95uNn7Od95znOeY+4uAEC8pkJ3AAAwGoIcACJHkANA5AhyAIgcQQ4AkZsOcdAzzzzTFxYWQhwaAKJ18ODB59x9b3t7kCBfWFjQ6upqiEMDQLTMbK1TO0MrABA5ghwAIkeQA0DkCHIAiBxBDgCRI8gBjE+9Li0sSFNTyWu9HrpHEynI9EMAJVCvS4uL0sZGsr22lmxLUq0Wrl8TiIocwHgsLe2EeNPGRtKOTBHkAMbjyJHB2jE0ghzAeMzNDdaOoRHkAMZjeVmand3dNjubtCNTBDmA8ajVpJUVaX5eMkteV1a40DkGzFoBMD61GsGdAypyAIgcQQ4AkSPIASByBDkARI4gB4DIEeQAEDmCHAAiR5ADQOQIcgCIHEEOAJEjyAEgcpkFuZlVzOy7ZvaVrPYJAOgvy4r8ekmHMtwfACCFTILczPZJeo+k27PYHwAgvawq8k9LulHSdrcPmNmima2a2er6+npGhwUAjBzkZvZeST9x94O9PufuK+5edffq3r17Rz0sAKAhi4r8YklXmNlhSfdI+n0zuzuD/QIAUhg5yN39Znff5+4Lkq6S9C13v3rkngEAUmEeOQBELtNndrr7dyR9J8t9AgB6oyIHgMgR5AAQOYIcACJHkANA5AhyAIgcQQ4AkSPIgbKp16WFBWlqKnmt10P3CCPKdB45gIKr16XFRWljI9leW0u2JalWC9cvjISKHJhE3arupaWdEG/a2EjaES0qcmDS9Kq6jxzp/DPd2hEFKnJg0vSqus84o/PPzM2Nv18YG4IcmAStQylra50/s7Ymvfjiie0zM9Ly8li7h/FiaAWIXftQSjeVinTs2Intp53Ghc7IUZEDsbv++v4hPjsrbW11fu+FF7LvE3JFkAMxq9el55/v/r6ZND8vrawkr50wPh49hlaAmPWaNjg/Lx0+vLutfQhmdpbx8QlARQ7EqHlxs9uFTenEgK7Vdirz1kqd8fHombvnftBqteqrq6u5HxeYCGkubr761dJzz+XXJ+TCzA66e7W9nYociE2neeKtZmelW2/Nrz8IjiAHYtPrLkyGS0qJi51AbObmOo+Nd7q4iVKgIgdis7ycDJ+0YvZJqRHkQGyYfYI2DK0AMarVCG68jIocACJHkAMx4nFtaMHQChAbHteGNlTkQBH1qrh5XBvaUJEDRdOv4uZxbWhDRQ4UTb+Ku9uysyxHW1oEOVA0/SpubghCG4IcKJp+FTc3BKENQQ4UTZqKu1ZL1lXZ3k5eCfFSI8iBoqHixoCYtQIUEbfgYwBU5AAQOYIcACI3cpCb2X4z+7aZPWVmT5rZ9Vl0DJgorI2CMcpijHxT0g3u/piZnSrpoJk96O5PZbBvIH6sjYIxG7kid/cfu/tjjf9+UdIhSeeMul8gOt2qbtZGwZhlOmvFzBYkXSDpkQ7vLUpalKQ5biXGpOlVdbM2CsbM3D2bHZm9QtJDkpbd/Uu9PlutVn11dTWT4wKFsLDQ/YHIEg9LRibM7KC7V9vbM5m1YmYzkr4oqd4vxIGJ1KvqZm0UjFkWs1ZM0h2SDrn7p0bvEhCJ1jHxqS7/lObmuFMTY5fFGPnFkj4g6Qkze7zRdou7P5DBvoFiah8T39o68TOtVTd3amKMRg5yd/83SZZBX4A41OvSgQOdw7tSSRaymptLQpzwRg64sxPopNtUwmYl3inEpSTEWZEQOWPRLKBdr6mEneaEt2JqLQKgIgfa9bqBp9fcb2aiIBCCHGjXaypht4q7UmEmCoIhyIF2vR611m1O+F13EeIIhiAH2vW6gYc54SggghxoR1gjMsxaATrpdgMPS9KigKjIgUGwJC0KiCAHBsGStCggghwYRK8ZLUAgBDkwCJakRQER5MAgmNGCAmLWCjAolqRFwVCRA0DkCHIAiBxBjvK49lppejoZ256eTraBCcAYOcrh2mulz31uZ3tra2f7ttvC9AnICBU5ymFlZbB2ICIEOcqh26PZurUDESHIUQ6VymDtQEQIcpRDc4XCtO1ARAhyTJ56PXny/dRU8lqvJxc0r7lmpwKvVJJtLnRiAjBrBZOl13rht91GcGMiUZFjsmS8Xnin4h4oGipyTJYM1wvnYUCIBRU5JkOzdHbv/P4Q64XzMCDEgiBH/Jql89pa5/eb64UPOE7Cw4AQC4Ic8etUOjc11wuXdsLefWecpEeY8zAgxIIgR/y6lchm0uHDyYD2EOMkPAwIsShHkLt3HztF/NKUzkOMk/AwIMSiHLNWnv689MiHpamZ5I/N7Px3z+09KT+Xcru5v4F/rq3NynH+TW15eff0EunE0nlurvMYep9xEh4GhBiUI8hf9RbpjUuSH5e2W/8cS16b7b7Z9v4vpM0Xd7a97efa2z2nBZhsqhHsQ5xoKntSfG6690kt9YmnX/8aa4OPqpm0S0tJhT03l4R4awKnCfs+6vXehwBCMQ8w5FCtVn11dTX3446db0vbmyeeMPqdAIbdbj8Rjbzflv4pp9+L1pNG84RhM51PKsN+o2luP/6E9JWvSc/9r3T6GdIf/6l08e/u/pxNJye7tv3c/9UZ3XzLjH728xkd30r+zOyZ0af/YUZXvX9Gsko2JyWgBzM76O7VE9oJcnS0vRXoRHOsz+c2ux+r235zYfkPuQ06JNfpRMnQXVS6BXk5hlYwuKmKpIpUOTl0T0bjngx5pTnRvHRU+tcrpUu/dcIJofb+45quHNee6WOaqRzXTOW4piub2jN9XB//mz4nkq4nul9Kmz/v/fmtlv3mPnRXkBNN359r7d/07vdL8i2JIMdkM2sMz0xL+pXen50+JXl9zTtOeOvfn+18rXR+Xvr4+aN3MxX3Ib8ZdRk+S/1zKYYEN18a7Jubb+fzd9Y+dJfqZJPy2tP5fy3NnJbP/0cfmQS5mV0m6VZJFUm3u/snstgvMG6tFzDf/dsb+up1nT+XwbXS0Zkl4/fak+NBx8S3hxxaG/JEk/pEd1w6/n/pfu68m6SZ0H+RiZGD3Mwqkj4r6V2Sjkp61Mzud/enRt03ME7ti2L99IWNl9vbZ6OkmRiDAdiUVDkp+YORZXFV4yJJP3L3p939mKR7JF2ZwX6BsWq/2XN2z8bL7Z3UasmNotvbOzeMAkWQRZCfI+mZlu2jjbZdzGzRzFbNbHV9fT2DwwKjab+p85STXurYDhRdbvOM3H3F3avuXt27d29ehwW6ar+psxnkLIqF2GQR5M9K2t+yva/RBhRa+6JYsydtvNwOxCSLIH9U0uvN7Fwz2yPpKkn3Z7BfYKzaF8Xad/bGy+1ATEaeteLum2b2MUlfVzL98E53f3LkngE52LUo1pMb0veCdgcYSibzyN39AUkPZLEvIJjNLg+nAAqORRWApi2CHHEiyIGmlop8wMd7AkGx1grQtLVzZ2frHZ/Nx3tKXAhFMVGRA02bO3d2Dvh4TyAoghxoalTkQzzeEwiKIAeaGhV5mmc5A0VCkANNWzt3drbe8SkFWLIWGABBDjRt7tzZ2XrH5/x8ss2FThQVs1aAppZ55Lvu+AQKjoocE2moeeDc2YlIEeQIahw33jTnga+tJY+5bM4D77vvPnd2cpMQioogRzBDB24fQ88D33wp974CWSDIEUyawB2mCh56HrhvJU9dH7KvQCgEOVLLemihX+AOWwWPNA98erZjMzcJocgIcqQyjqGFfoE7bBU80jzwSucg5yYhFBlBjlTGMbRw+eXJPO1WrYE7bBU80jzwLhU5NwmhyAhypJL10EK9Lt11V1LdN5lJBw7sBO4oVXCtJh0+LG1vJ6+p54R3qci5SQhFRpAjlayHFjpV+O7SAy3PmQpSBXepyKURTg7AmBHkSCXrUE1T4QepgqdPGePOgfEgyJFK1qGatsLPvQruMrQCFBlBjtSyDNXCXjzsMbQCFBVBjiCGqfBzuUWeihwRIshLpkjrhQxS4ed2izwVOSJEkJdIjOuFNE88V1+d0y3yVOSIEEFeIrGtF9J64ukm81vkqcgRIYK8RGJbL6TTiadd5rfIM/0QESLISyS29UL6nWDSznIZ6LoAQyuIEEFeIqGn/A16obXXCSbtPPaBrwswtIIIEeQlEnK9kH6B2inku5147r47/Tz2ga8LUJEjQuatqxblpFqt+urqau7HRTgLC50vWs7PJ4G9uLg7cGdnk5OMlITukSNJhb68PNiJZ2pq98JcTWbJtMdd/smkS74s7bsy/QGAHJnZQXevtrd3fhwKkLFeF1p7Vc2j3kE6N9f5BNJ12IaKHBFiaAW56HWhdZyzaQa+LsAYOSJEkCMXvQJ1nLNpBr4uQEWOCDG0glw0g7PbeHenMfKsZtPUagMMz1CRI0IEOXLTLVD7hXyuqMgRIYIchTBQ1TxOVOSIEGPkgLQzR5GKHBEaKcjN7JNm9gMz+76Z3Wtmr8yoX0C+fDN5rZwUth/AEEatyB+U9CZ3f7OkH0q6efQuAQFs/SJ5Nb6kIj4j/da6+zfcm6WMHpa0b/QuAQFs9llmESiwLMuPD0n6Wrc3zWzRzFbNbHV9fT3Dw2ISBH9y0RZBjnj1nbViZt+UdHaHt5bc/b7GZ5YkbUrq+s/P3VckrUjJWitD9RYTqbmgVnMeeXNBLSmHmSwbR6X/+KB07IVk+6E/kk5+jXThZ6UpJnUhDiMvmmVmH5T0EUmXunuqsoZFs9Cq14Jahw+P+eAvPCZ9/W2SH99ps4r0vuelPaeP+eDAYLotmjXqrJXLJN0o6Yq0IQ60C/rkolddIM2ctrvt7HcS4ojKqGPkn5F0qqQHzexxM/vHDPqEkgn65CIzae59kizZnj5Vet1HcjgwkJ1RZ628zt33u/tvNf58NKuOoTxCP7lIc3+WBLgk+Zb0a5fndGAgG0yaRXAhn1wkSTrrksYNQSbt/xNuCkJ0uCyPQgi61srUjPSrfygdvVf69Q8H6gQwPCryEgo+Z7uIXntAOvX10llvD90TYGBU5CUTdM52ke27kmd1IlpU5CUz8FPlARQeQV4yQedsAxgLgrxkgs7ZBjAWBHnJBJ+zDSBzBPmE6TcjJficbQCZY9bKBEk7I6Uwz8cEkAkq8gnCjBSgnAjyCcKMFKCcCPIJwowUoJwI8gnCjBSgnAjyCcKMFKCcmLUyYZiRApQPFTkARI4gB4DIEeQAEDmCHAAiR5ADQOQIcgCIHEEOAJEjyAEgcgQ5AESOIAeAyBHkABA5ghwAIkeQA0DkCHIAiBxBDgCRI8gBIHIEeaTqdWlhQZqaSl7r9dA9AhAKTwiKUL0uLS5KGxvJ9tpasi3xdCCgjKjII7S0tBPiTRsbSTuA8iHII3TkyGDtACYbQR6hubnB2gFMtkyC3MxuMDM3szOz2B96W16WZmd3t83OJu0AymfkIDez/ZL+QBJf7HNSq0krK9L8vGSWvK6scKETKKssZq38vaQbJd2Xwb6QUq1GcANIjFSRm9mVkp519++l+Oyima2a2er6+voohwUAtOhbkZvZNyWd3eGtJUm3KBlW6cvdVyStSFK1WvUB+ggA6KFvkLv7Ozu1m9n5ks6V9D0zk6R9kh4zs4vc/X8y7SUAoKuhx8jd/QlJZzW3zeywpKq7P5dBvwAAKTGPHAAiZ+75D1eb2bqktZwPe6akGL8t0O980e980e/BzLv73vbGIEEegpmtuns1dD8GRb/zRb/zRb+zwdAKAESOIAeAyJUpyFdCd2BI9Dtf9Dtf9DsDpRkjB4BJVaaKHAAmEkEOAJErXZCb2XVm9gMze9LM/jZ0fwYR27rvZvbJxt/1983sXjN7Zeg+9WJml5nZf5vZj8zsptD9ScPM9pvZt83sqcbv9PWh+zQIM6uY2XfN7Cuh+5KWmb3SzL7Q+N0+ZGa/E7pPpQpyM3uHpCslvcXd3yjp7wJ3KbVI131/UNKb3P3Nkn4o6ebA/enKzCqSPivp3ZLOk/R+MzsvbK9S2ZR0g7ufJ+ltkv4ikn43XS/pUOhODOhWSf/i7r8p6S0qQP9LFeSSrpH0CXf/pSS5+08C92cQzXXfo7k67e7fcPfNxubDShZWK6qLJP3I3Z9292OS7lFy0i80d/+xuz/W+O8XlYTKOWF7lY6Z7ZP0Hkm3h+5LWmZ2uqRLJN0hSe5+zN1/GrRTKl+Qv0HS283sETN7yMwuDN2hNAZZ973APiTpa6E70cM5kp5p2T6qSAKxycwWJF0g6ZHAXUnr00qKk+3A/RjEuZLWJX2+MSR0u5mdErpTWTwhqFD6rJ8+LekMJV9BL5T0z2b2Wi/AHMys1n3PW69+u/t9jc8sKRkCqOfZtzIxs1dI+qKkv3T3n4XuTz9m9l5JP3H3g2b2e4G7M4hpSW+VdJ27P2Jmt0q6SdJfhe7UROm2frokmdk1kr7UCO7/NLNtJYvfBH9kUazrvvf6+5YkM/ugpPdKurQIJ8wenpW0v2V7X6Ot8MxsRkmI1939S6H7k9LFkq4ws8slnSzpNDO7292vDtyvfo5KOuruzW89X1AS5EGVbWjly5LeIUlm9gZJe1Twldfc/Ql3P8vdF9x9Qckv0luLEOL9mNllSr46X+HuG6H708ejkl5vZuea2R5JV0m6P3Cf+rLk7H6HpEPu/qnQ/UnL3W92932N3+mrJH0rghBX49/dM2b2G42mSyU9FbBLkiawIu/jTkl3mtl/STom6UDBq8TYfUbSSZIebHybeNjdPxq2S525+6aZfUzS1yVVJN3p7k8G7lYaF0v6gKQnzOzxRtst7v5AuC5NvOsk1Rsn/Kcl/Xng/nCLPgDErmxDKwAwcQhyAIgcQQ4AkSPIASByBDkARI4gB4DIEeQAELn/BzcnK3v9ORYuAAAAAElFTkSuQmCC", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Iteration 8: w = [ 0.5881308 -0.04577813 -2.32033579]\n" ] }, { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Iteration 9: w = [ 1.5881308 0.21487424 -1.74065589]\n" ] }, { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Iteration 10: w = [ 0.5881308 0.53429153 -2.18924357]\n" ] }, { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Iteration 11: w = [ 1.5881308 0.79494391 -1.60956368]\n" ] }, { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Iteration 12: w = [ 0.5881308 1.38928183 -2.25755914]\n" ] }, { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Iteration 13: w = [ 1.5881308 1.6499342 -1.67787925]\n" ] }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXIAAAD4CAYAAADxeG0DAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAAAcUElEQVR4nO3de5jcVZ3n8fe3O9cmiTH3S9PdUUGI3GmjiKAYLjFBGHF3BmwdHHcmyigTZ10dQ884OrM9MIOPAyuLaxZwWSmGYQQCD0mAAHKXQAeMkUSupptcSDqJhkATknSf+eNUpao7VdVVXb+qX/3q93k9Tz2VOl2p+sKT/tap7znn+zPnHCIiEl11YQcgIiKlUSIXEYk4JXIRkYhTIhcRiTglchGRiBsRxptOmTLFtbS0hPHWIiKRtXbt2p3OuamDx0NJ5C0tLXR2dobx1iIikWVmXdnGVVoREYk4JXIRkYhTIhcRiTglchGRiFMiFxGJOCVyESmfRAJaWqCuzt8nEmFHVJNC2X4oIjGQSMDixdDb6x93dfnHAG1t4cVVgzQjF5HyaG9PJ/GU3l4/LoFSIheR8ujuLm5chk2JXETKo6mpuHEZNiVyESmPjg5oaBg41tDgxyVQSuQiUh5tbbBsGTQ3g5m/X7ZMC51loF0rIlI+bW1K3BWgGbmISMQpkYuIRJwSuYhIxCmRi4hEnBK5iEjEKZGLiEScErmISMQpkYuIRJwSuYhIxCmRi4hEnBK5iEjEBZbIzazezJ43s3uDek0RERlakDPyJcDGAF9PREQKEEgiN7NGYBFwQxCvJyIihQtqRn4N8G2gP9cTzGyxmXWaWWdPT09AbysiIiUncjM7H9jhnFub73nOuWXOuVbnXOvUqVNLfVsREUkKYkZ+OnCBmW0CbgM+ZWa3BPC6IiJSgJITuXNuqXOu0TnXAlwMPOyc+0LJkYmISEG0j1xEJOICvWanc+4R4JEgX1NERPLTjFxEJOKUyEVEIk6JXEQk4pTIRUQiTolcRCTilMhFRCJOiVwkbhIJaGmBujp/n0iEHZGUKNB95CJS5RIJWLwYenv9464u/xigrS28uKQkmpGL1KJcs+729nQST+nt9eMSWZqRi9SafLPu7u7sfyfXuESCZuQitSbfrHvSpOx/p6mp/HFJ2SiRi9SCzFJKV1f253R1wd69h4+PHAkdHWUNT8pLpRWRqBtcSsmlvh727z98fMIELXRGnGbkIlG3ZMnQSbyhAfr6sv9s9+7gY5KKUiIXibJEAnbtyv1zM2huhmXL/H02qo9HnkorIlGWb9tgczNs2jRwbHAJpqFB9fEaoBm5SBSlFjdzLWzC4Qm6rS09M8+cqas+HnnmnKv4m7a2trrOzs6Kv69ITShkcXPyZNi5s3IxSUWY2VrnXOvgcc3IRaIm2z7xTA0NcO21lYtHQqdELhI1+U5hqlwSS1rsFImapqbstfFsi5sSC5qRi0RNR4cvn2TS7pNYi0ciP/Am9G4OOwqRYGj3iQwSj9JK9x2w5ssw8QSYtdDfppwGdfH4z5ca1NamxC2HxGNGPv2TcPLVMHoybPwBPHgm3DEVnvgTeO1meGd72BGKiAxbPBL5uDlw7P+A+Q/Df9kFZ9wBTZ+Dnsfh6S/BXTPgvnnw6+/BzmfA9YccsMgQdLk2yRDvA0HOwe9/BVtX+tuup30SHz0VZi7wJZiZ58LoHD2cRcKQ7UBQQ4Pq5DGQ60BQvBP5YO/ugm0PwNYVsO0+/9jqfD09VVufeKJfYBIpp0TCH/zp7vbbDTs60kk619F8bT+seUrkxervg93Ppmfru9f68bGzYNanYdYimHE2jBwfbpxSe4aacdfV+W+Tg5lBv8qCtUyJvFTvvAFbV8G2VbDtfr+lsW4kTD0jPVufcIxm61K6oWbcmpHHlhJ5kPoPQM9Tydn6Ctjzgh8/oiWd1KefBSMa8r6MSFZDzbhVI48tJfJyervbz9a3roQ3HoS+Xqgb7ZP5rIUwexGMe1/YUUpUFDLjzldDl5qlRF4pfftgx+N+pr51Jex92Y9P+CDMXAizF/pyTP3ocOOU6qUZt+SgRB6WN1/2dfWtK2H7I9D/LowY5xdKZy30C6cNjWFHKdVGM27JQom8Ghx8G7b/wif1LSugN9mOVK0DRKQASuTVxjnYsyG9vbHnCXAHYeREfwhp1iKYtQDGTAs7UhGpErkSuaZ+YTGDiR/yt7nfgv17/EJpKrF33w4YTGpNz9Ynt/oDSiIiGUrOCmZ2pJn9wsw2mNkLZrYkiMBiZ9R7fP+Xj94In90CC56DE/7B71V/4R/hgY/AnTPgqT+FTbfBu7vDjliKod4oUkYll1bMbCYw0zn3nJmNB9YCf+Sc25Dr76i0UqR3d/lDSFtXqnVAFGkXigSkbBdfds5tc849l/zzXmAjMLvU15UMoydDy+fhY7fAZ7fDub+ED7X7rY7r2mHVybC8Edb8Obx+FxzYG3bE8ZRr1p3tYsm9vX5cJACBLnaaWQvwGHCcc+7NQT9bDCwGaGpqOrUr24EHKd47b/hZ+pYV8MYDah0Qlnyz7i9+Ub1RJBBl37ViZuOAR4EO59yd+Z6r0kqZDGgdsBL2/MaPHzEno3XAJ9U6oBzyncYE9UaRQJQ1kZvZSOBe4H7n3A+Her4SeYW83Z1O6m885FsH1I+BaanWAQvVOiAo+fqj/OxnqpFLIMpWIzczA24ENhaSxKWCjmiCo74Kn7jHXxnprPvhA1/xbQPWXg73vB/uPRae+2Yy0e8PO+JoyayJ1+X4VWpq0sWSpeyC2LXyceBxYD2QKvhd4ZxbmevvaEZeBVKtA7asgB2PQP9+tQ4oRraa+GCadUvAynYgyDn3BKCVtKiZcJS/ffCvfOuANx5ON/ravNw/R60Dsksk4NJLoa/v8J/V1/sFTPVHkQrSEX0ZKG/rgPOSiT0GrQNyNa0aaiaunShSRjqiL4XJ2jpgdbrfeve/M6B1wOxFMOnU2modMDhZd3X5x5B9T3impqbyxycyiGbkUjjXD7//VTKpr4CdTwMORk9NXsd0oW/4Neq9YUdamnxbCbu7s+9OAdXEpezU/VCCt2+nP4R0WOuAj2W0DjgheoeR8m0lbGrKnuTr6+Hmm5XEpazKtv1QYmzMlIGtA855CuYuhYO9sO4KWHVSsnXAX0SrdUCu8kiqVt4w6EBVQ4OSuIRKiVyCUVcPU0+DE/8nfHotfHYrfOQmPzvvvh0evwjumAwPnQ0bfwh7fpu7RBG2XMk6teCpPeFSZbTYKeUxdia8/8/8bUDrgBXw/Df9rVpbB6SSsi61JhGhGrlU3ttdfsF0ywrY/nC0WgeoJa2ESIudUp369sGOx9LXMX3rFT8+4Zj0bH3qGVA/Ktw4U/LtaFEDLCkzJXKJhjdfTh9GOqx1wKJk64AQ293n29Gig0BSZjoQJNEw4SiYsASOWZLROmDloNYBJ2a0DvhoZVsH5Np+qINAEiIlcqleI46Axs/42+DWARuvhg1XVr51QEdH9hp5R0d531ckDyVyiYacrQNWwNb7Dm8dMGshTG4NvnWAdrRIFVKNXKLvUOuA5Gy9VlsHSOypRi61y+pg0in+dtzf+tYB2+5P7oS5F373/8HqfSveKLcOEMlBM3Kpbf19sOuZ9GGk3z/vx8fOTif1GfNh5Phw4xQpgGbkEk+p1gFTT4Of7ILb18Fx/XDSFjjl/8Gr/xfqRsLUM9OJfcIHNVuXSFEil3j4y7+EH//Y//nR5K3+AHzrAvjjo/1J01TrgHHvSyf1aZ+EEWNDDFxkaCqtSDyMGJH70mwHD/o/D2gd8BD0veNbB0z/FMw4FxovgHFzKhu3SAad7JR4y1cqyfY70LcPtj+avCrSf8C+bX68WlsHSCyoH7nEW319keNjYNZ50Hot1I/2Y6dcAw1HwkvXwcNn+7a8j10Er9wAvVvKErZIIVQjl3hYvDhdIx88PpS3N8GcP/VtAw5rHbACNt/lnxdm6wCJNc3IpfYkEr5LYV2dv08k4Prr4bLL0jPw+nr/+Prr87/WO9v9/YlXpsdSrQPm/Rgu7IKFv4GT/hlGTYSN/wIPngF3ToMnL4Hf/Qz27SjDf6RImmrkUluC7he+5s/h1Rvh8wX+nhxqHZA8ZbpvO2Aw+cPp2fqkU4NvHSCxoMVOiYeg+4XfamzbM5vZX9tcfFuVzNYBW1bArjWodYCUQgeCJB66u4sbzyORcLQZXHbjdTjnPx9SJfWCknnO1gErsrQOWJRsHXC8DiNJ0fT9TmpDqi6e6xvmMPqFr/n32wC4e+2Fh8Z6e33jw2EZMwXmtMHpt8JFO+Ccp2Dud/zi6bqlsOpEWH4krFkMry+HA3uH+UYSNyqtSPRlq4tnStXIoaj2s7t+MpnJ43djbQN/R8pyMaDerbBtlT+QtO0BOLhXrQPkMCqtSO1qb8+dxJub0xd9yEz2BdRJJo/fzY/u//ph42W5GFDDLHj/f/O3vv2w86n09ka1DpAhaEYu0VfIdTSLXQR9axPcM4fGJbvYsnPSoeFSNsAM29tdyQXTlRmtA8bC9LPSiV2tA2JBu1akdhWSpIu9aPITfwzd/0HCueq6GNCh1gErfHJ/61U/PuHYjNYBH1frgBqlRC61q5C948XOyG81nxzP31COiIPz5svpEsyOR6F/P4wYBzPOSSf2hllhRykBUa8VqV1tbT5pNzf7GXZz8+H1j44On9wz5bpocn+yS2LrjwYMZzswGroJR/m2AZ96AD63C868G1raYHcnPPMXsHw2rDoZ1rVDz5PQfzDsiKUMNCOX+EgkCtu18vJP4NmvDjjNGfSB0bJzDva8kJ6t9zwJrs8fPpp5XvIw0gIYMzXsSKUIKq2IFOrfRoI7OCCRB31gtOL2/yGjdcAqtQ6IKG0/FCmUOwhzlw4YCvDAaDhGTYSm/+pvrt9fu3TLCp/U138f1n9PrQMiTIlcJNOejf7+QwMTeVNT9hl5WfaUl5vV+dn3pFPh+O9mtA5YOah1wMfSs3W1DqhqgXyPMrMFZvaimb1iZt8J4jVFKmHwAubWe//a/2Dk+AHPK2atNHIOtQ5IDGod8Fa6dcDdTb51wOa74cBbYUcsg5RcIzezeuAl4BxgM/AscIlzLue+LdXIpRpkW8B0CWOnm8eUtjVZn19Ve8oroXcrbLvPL5huW51sHTAKpmW0Dhh/tGbrFVK2xU4zOw34nnPuvOTjpQDOuStz/R0lcqkGgxcwR414l3dvHsMfXf80y5/4SGhxVa2+/bDzyXSv9T3JuZpaB1RMORc7ZwOvZzzeDBz2W2Bmi4HFAE2RLCxKrRm8UHn5uX7f+D1PKYlnVT/KtwWYfhacfLVvY7BtlW8d8OqN/lqm9WNh+qcyWge0hB11LFRssdM5twxYBn5GXqn3Fcll8ALmD9q+dWhcCjCuBY66zN/69sH2R9L71reu8M9R64CKCGKxcwtwZMbjxuSYSFXLtoD53TuurI0FzEqrHwOzFkDr/4LPvALnvwin/Cs0NMJLP4KH58MdU+Cxi+CVG3ztXQITxIz8WeAoM5uDT+AXA58P4HVFyiq1UNneDlPr/ZrNsRd9g0tqfQGz3MxgwtH+dsw3/C6X7Q+nG31tvss/770npWfrkz8CddoNPVyBnOw0s4XANUA9cJNzLu+cRoudUnUe+Bjs/GXhF1mW4RnQOmAl9Dyh1gFFKOvJTufcSmBlEK8lEoqdv4Tp88OOovaZwcTj/G3utw9vHdB1G751wLyM1gGnqHXAEPRdRuTg2/7+1GtCDSOWsrYOSC6Yrv8erP97GDMNZma2DpgYctDVRx9zIhv+2d9PPO7QUFW2rK11qdYBx/8dnPc0XLQdTvuZ/6a05R548k/8gunqM+GFq+AP63NfbDtm1P1Q5NbkqcRkfTxyLWvjoL8Pdq1JLpiu8jN38LtiUiWY6fNh5Lhw4ywztbEVyeVWg9br4OivATXQsjYOerf6w0hbV8aqdYASuUg22x+Bh86Ciw8c2v5W7OU9JWRDtg5YBNM+UROtA9SPXCSbzq/7+4w9zDXVsjYOcrYOWBGb1gFK5BJve16Axs8OGOroyF4j14nPiMhsHXDwHX9R6sGtA94zN53Up5we+dYBKq1IfO3/Pfx8Elzwqv8aniGWLWtrnXOw9+X0CdMdj0L/ARgxHmaek9ze+GlomBV2pDmpRi4y2NpvwIvX6jRnXB14C7Y/lK6t92724+89ydfVD7UOqA81zEy5Ern2kUtNKmgf+IvXwsgJFY5MqsbIcdB4Icz7CVzYDQt/DSdd5f9NbLgKVp8Od06DJz8Pv7sF9vWEHXFOSuQSqnIcvEntA+/q8t+mu7r84wGvnfom2npdqLFKlTDz1yWd+zdw9qPwuZ3w8dth9md8w69ffhHunA73fxTW/wPs6vQnUauESisSmnIdvCloH/jme+CxC+GSvoL6eOiQUIwdah2QrK3vegZwcMFrMG5ORUNRjVyqTiEJdziLjgXtA1/eDL3dBdfHdUhIDtnXAz2Pw5EXVfytVSOXkgVdWhh8qbXB4wWVSLLItd97wHhvN8y5NLBYJUbGTA0lieejRC4FGW5SzWeohNvePrCUAf5xe3v+18125Z8B+8DfecPfn/hPgcUqEiYlcinIcJNqPgsXHt4KIzPhDncW3Nbma9fNzf71m5sH1bLXXZF8s8L3Cw/54SASIiVyKUjQpYVEAm6+eWAt2wwuvTSdcEuZBbe1+dp1f7+/H1BXf+2nMHZ2UfEO+eEgEiIlcilI0KWFbDN852BlxnWmyjILTm0ZK2LbYUreDweRECmRS0GCTqqFzPDLMgve9G/+vvHCEl5EpLookUtBgk6qhc7wA58Fp7od1lifaok3JXIpWJBJNbTFwwN/gKO/XuY3EaksJXIJxXBm+CXvY3/rd/7++O8PM2qR6qREHjPV1C+kmBl+IPvYn/+Wvx89qYSoRaqPEnmMlONQT7mlPni+8IUA9rG/fgdMODbI8ESqghJ5jJTjUE85ZX7w5FLwPvb+g/6+9UclxyVSbZTIYyRq/UKyffAMVvA+9ldv8Pcz5pcUk0g1UiKPkaj1CxnqA6bQXS6JBPSt+RoQ/rqASDkokcdI2P1Cil1ozfcBU+g+9lR5pr6un3+6e2kk1gVEiuacq/jt1FNPdRKOW25xrrnZOTN/f8stlXvfhgbn/DKrvzU0pN8/W1xD/Z1CNDc7N3f2b5xL4I4YvffQ6zQ3B//fKFJuQKfLklN1YQmpiHwXZujoyH31HSjtavZ1dXDf35zLucevxtrS/9YHXGRCJCJyXVhiRBjBSPzkW2jNt5um1BOkTU1w7vGrWfPKvMPGRWqFauRSEfkWWsu5m+bKjncBuPzm9LZD9RGXWqNELhWRb6G1LLtptj8Kr/6US471F5HY0TdPfcSlZqm0IhWRSpy56t3ZauQlzZrXtcPuTuj3M/JNy5fCUZfBEaqpSO1RIpeKaWvLPhMeKskPy4gjDiVxADb+AOrHwvHfLeFFRaqTErlUhVxJfthGjBv4ePQUOOa/B/gGItVDNXKpTSMnpP9cPxbO+DmMHJf7+SIRVlIiN7Orzey3ZvZrM7vLzCYGFJdIaUaO9/f1Y+EDX4Gpp4cbj0gZlTojXw0c55w7AXgJWFp6SCIBSM3Ix86Ck64KNxaRMispkTvnHnDOJfuD8jTQWHpIIgEYMQ6sDs64E+pHhx2NSFkFWSP/MrAq1w/NbLGZdZpZZ09PT4BvK7Ug8CsXNV8MZ94D7z0hgOhEqtuQvVbM7EFgRpYftTvn7k4+px1oBS5yBTRvUa8VyZTqUJit14oO7oik5eq1UnLTLDP7EvAVYL5zbojLAHhK5JIpX0OtTZsqHY1I9SpL0ywzWwB8G/hEoUlcZLCoXblIpNqUWiO/DhgPrDazX5nZ/wkgJomZqF25SKTalDQjd859IKhAJL5y9SNXh0KRwuhkp4Surc0vbDY3ow6FIsOgXitSFQLvtSISI5qRx1Dge7ZFJFSakcfM4D3bqavKg2bEIlGlGXnM5Ls+pohEkxJ5zGjPtkjtUSKPGe3ZFqk9SuQxk+8iyCISTUrkNWaoHSnasy1Se7RrpYYUuiNFe7ZFaotm5DVEO1JE4kmJvIZoR4pIPCmR1xDtSBGJJyXyGqIdKSLxpEReQ7QjRSSetGulxmhHikj8aEYuIhJxSuQiIhGnRC4iEnFK5CIiEadELiIScUrkIiIRp0QuIhJxSuQiIhGnRC4iEnFK5CIiEadELiIScUrkIiIRp0QuIhJxSuQiIhGnRC4iEnFK5CIiEadEHlGJBLS0QF2dv08kwo5IRMKiKwRFUCIBixdDb69/3NXlH4OuDiQSR5qRR1B7ezqJp/T2+nERiR8l8gjq7i5uXERqmxJ5BDU1FTcuIrUtkERuZt80M2dmU4J4PcmvowMaGgaONTT4cRGJn5ITuZkdCZwL6It9hbS1wbJl0NwMZv5+2TItdIrEVRC7Vv4V+DZwdwCvJQVqa1PiFhGvpBm5mV0IbHHOrSvguYvNrNPMOnt6ekp5WxERyTDkjNzMHgRmZPlRO3AFvqwyJOfcMmAZQGtrqysiRhERyWPIRO6cOzvbuJkdD8wB1pkZQCPwnJnNc869EWiUIiKS07Br5M659cC01GMz2wS0Oud2BhCXiIgUSPvIRUQizpyrfLnazHqArgq/7RQgit8WFHdlKe7KUtzFaXbOTR08GEoiD4OZdTrnWsOOo1iKu7IUd2Up7mCotCIiEnFK5CIiERenRL4s7ACGSXFXluKuLMUdgNjUyEVEalWcZuQiIjVJiVxEJOJil8jN7HIz+62ZvWBm/xJ2PMWIWt93M7s6+f/612Z2l5lNDDumfMxsgZm9aGavmNl3wo6nEGZ2pJn9wsw2JP9NLwk7pmKYWb2ZPW9m94YdS6HMbKKZ/Tz5b3ujmZ0WdkyxSuRmdhZwIXCic+5DwA9CDqlgEe37vho4zjl3AvASsDTkeHIys3rgfwOfBuYCl5jZ3HCjKshB4JvOubnAR4GvRSTulCXAxrCDKNK1wH3OuWOAE6mC+GOVyIHLgKucc+8COOd2hBxPMVJ93yOzOu2ce8A5dzD58Gl8Y7VqNQ94xTn3mnNuP3Ab/kO/qjnntjnnnkv+eS8+qcwON6rCmFkjsAi4IexYCmVm7wHOBG4EcM7td879IdSgiF8iPxo4w8zWmNmjZvbhsAMqRDF936vYl4FVYQeRx2zg9YzHm4lIQkwxsxbgZGBNyKEU6hr85KQ/5DiKMQfoAX6aLAndYGZHhB1UEFcIqipD9E8fAUzCfwX9MHC7mb3PVcEezKD6vldavridc3cnn9OOLwEkKhlbnJjZOOAO4BvOuTfDjmcoZnY+sMM5t9bMPhlyOMUYAZwCXO6cW2Nm1wLfAf4u7KBqSq7+6QBmdhlwZzJxP2Nm/fjmN6Ffsiiqfd/z/f8GMLMvAecD86vhAzOPLcCRGY8bk2NVz8xG4pN4wjl3Z9jxFOh04AIzWwiMASaY2S3OuS+EHNdQNgObnXOpbz0/xyfyUMWttLIcOAvAzI4GRlHlndecc+udc9Occy3OuRb8P6RTqiGJD8XMFuC/Ol/gnOsNO54hPAscZWZzzGwUcDFwT8gxDcn8p/uNwEbn3A/DjqdQzrmlzrnG5L/pi4GHI5DESf7evW5mH0wOzQc2hBgSUIMz8iHcBNxkZr8B9gOXVvksMequA0YDq5PfJp52zn013JCyc84dNLOvA/cD9cBNzrkXQg6rEKcDXwTWm9mvkmNXOOdWhhdSzbscSCQ/8F8D/izkeHREX0Qk6uJWWhERqTlK5CIiEadELiIScUrkIiIRp0QuIhJxSuQiIhGnRC4iEnH/CfrmpNxRAgv0AAAAAElFTkSuQmCC", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Iteration 14: w = [ 2.5881308 -0.75825596 -2.61905618]\n" ] }, { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Iteration 15: w = [ 1.5881308 -0.43883867 -3.06764386]\n" ] }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXIAAAD4CAYAAADxeG0DAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAAAW1klEQVR4nO3dfYxc1XnH8d+zL6xZ22vH3k1QMLtDEgMBkohoSdrSVk1JWwIIqv5TkiEiTaVVoUmpGgmFWP2vK6VNlYYqodUWqFCZKmrz0kQpeQElRIqqkKwJCcF2UkptA6FlDQ02Mdjs7tM/7ox3dnbmzp2dO3Pvmfv9SKvZubu+c2ytf/Psc88519xdAIBwDWU9AABAdwhyAAgcQQ4AgSPIASBwBDkABG4kixednJz0UqmUxUsDQLD2799/zN2nGo9nEuSlUkmLi4tZvDQABMvMjjQ7TmsFAAJHkANA4AhyAAgcQQ4AgSPIASBwBDmA3qlUpFJJGhqKHiuVrEc0kDKZfgigACoVaW5OOnkyen7kSPRcksrl7MY1gKjIAfTGvn1rIV5z8mR0HKkiyAH0xtGjnR3HphHkAHpjerqz49g0ghxAb8zPS+Pj64+Nj0fHkSqCHEBvlMvSwoI0MyOZRY8LC1zo7AFmrQDonXKZ4O4DKnIACBxBDgCBI8gBIHAEOQAEjiAHgMAR5AAQOIIcAAJHkANA4AhyAAgcQQ4AgSPIASBwqQW5mQ2b2Q/M7CtpnRMA0F6aFfmtkg6meD4AQAKpBLmZ7ZF0jaS70jgfACC5tCryT0m6TdJqq28wszkzWzSzxaWlpZReFgDQdZCb2bWSnnP3/XHf5+4L7j7r7rNTU1PdviwAoCqNivwKSdeZ2WFJn5X0m2Z2XwrnBQAk0HWQu/vt7r7H3UuSbpD0TXe/seuRAQASYR45AAQu1Xt2uvtDkh5K85wAgHhU5AAQOIIcAAJHkANA4AhyAAgcQQ4AgSPIASBwBDlQNJWKVCpJQ0PRY6WS9YjQpVTnkQPIuUpFmpuTTp6Mnh85Ej2XpHI5u3GhK1TkwCBqVXXv27cW4jUnT0bHESwqcmDQxFXdR482/zOtjiMIVOTAoImrunftav5npqd7Py70DEEODIL6VsqRI82/58gR6cSJjcdHR6X5+Z4OD71FawUIXWMrpZXhYen06Y3HJya40Bk4KnIgdLfe2j7Ex8ellZXmX3vhhfTHhL4iyIGQVSrS88+3/rqZNDMjLSxEj83QHw8erRUgZHHTBmdmpMOH1x9rbMGMj9MfHwBU5ECIahc3W13YlDYGdLm8VpnXV+r0x4Nn7t73F52dnfXFxcW+vy4wEJJc3Ny9Wzp2rH9jQl+Y2X53n208TkUOhKbZPPF64+PSHXf0bzzIHEEOhCZuFSbtkkLiYicQmunp5r3xZhc3UQhU5EBo5uej9kk9Zp8UGkEOhIbZJ2hAawUIUblMcOMMKnIACBxBDoSI27WhDq0VIDTcrg0NqMiBPIqruLldGxpQkQN5067i5nZtaEBFDuRNu4q71bazbEdbWAQ5kDftKm4WBKEBQQ7kTbuKmwVBaECQA3mTpOIul6N9VVZXo0dCvNAIciBvqLjRIWatAHnEEnx0gIocAAJHkANA4LoOcjM7z8y+ZWYHzOxxM7s1jYEBA4W9UdBDafTIlyV9xN0fMbPtkvab2QPufiCFcwPhY28U9FjXFbm7P+vuj1Q/PyHpoKRzuz0vEJxWVTd7o6DHUp21YmYlSZdJerjJ1+YkzUnSNEuJMWjiqm72RkGPmbuncyKzbZK+LWne3b8Q972zs7O+uLiYyusCuVAqtb4hssTNkpEKM9vv7rONx1OZtWJmo5I+L6nSLsSBgRRXdbM3CnosjVkrJuluSQfd/ZPdDwkIRH1PfKjFf6XpaVZqoufS6JFfIen9kh4zs0erxz7m7vencG4gnxp74isrG7+nvupmpSZ6qOsgd/fvSLIUxgKEoVKRbrqpeXgPD0cbWU1PRyFOeKMPWNkJNNNqKmGtEm8W4lIU4uxIiD5j0yygUdxUwmZzwusxtRYZoCIHGsUt4Imb+81MFGSEIAcaxU0lbFVxDw8zEwWZIciBRnG3Wms1J/zeewlxZIYgBxrFLeBhTjhyiCAHGhHWCAyzVoBmWi3gYUta5BAVOdAJtqRFDhHkQCfYkhY5RJADnYib0QJkhCAHOsGWtMghghzoBDNakEPMWgE6xZa0yBkqcgAIHEEOAIEjyFEct9wijYxEve2Rkeg5MADokaMYbrlF+ru/W3u+srL2/M47sxkTkBIqchTDwkJnx4GAEOQohla3Zmt1HAgIQY5iGB7u7DgQEIIcxVDboTDpcSAgBDkGT6US3fl+aCh6rFSiC5o337xWgQ8PR8+50IkBwKwVDJa4/cLvvJPgxkCiIsdgSXm/8GbFPQpsdUV65Zj04iHJV7MezRlU5BgsKe4Xzs2ABpy7tHxCOnUsCudT9R9LDc+rH6dfWAvw31uStkxm+3eoIsgxGCqVqOp2b/71TewXHlfcE+Q5tHJqfei+0hjGTcJ59XTzcw2NSmOT0lm7pS1T0s63Rs/rP0bGm//ZDBDkCF9j6dyotl94LeyPHo2CfX4+NpG5GVCGVlek0/8XH8SNQb18ovX5ztq1FsBbS9LuyxuCeUoa2119nJRGJ6KtHAJBkCN8zUrnmpmZtZs+dNgnmZ6Ovq3ZcXTAXVp+qbNqub6F0Whk6/oQnrhwfQiPTUZVdK2aPmuXNDTYUTfYfzsUQ6sS2Uw6fDj6vFTquE8yP7+x0OdmQNrYwmgazA3hHNfCqAXu2KS08y3rWxpnquVacO/OVUsjLwhyhC9J6byJPkkt3zvoxoTHV6VTL/SghbE7amHsml0L5DPhXBfMgbUw8oogR/iSlM6b7JMEdTMgd2n5F+2DuP7rcS2M4fH1ITxxQUN1XNfGGJssRAsjr/hXR/iSlM4p9Ek6vFbavWYtjGYf9UG9eqr5uWx4fZtix6XrQ3jDRb98zcpAPPNW07V6aHZ21hcXF/v+uii4LpK42cSY8fEO7rvsq9EsjHYX+eq/HtvCeM3G6rj+At+6cN4tje6khTEAzGy/u89uOE6QA+2VSvWdGdfWsV9oamJJb9l7TF/+1wRtjHYtjMY2xbqLfNUZGbQwCq9VkPPTAKyclk4/H1st/8N7j2lqYkmT249pctsxbTmrroXxUN25NrQwLlkfwhsu+jELA90jyDFYai2MpCv7Th2TXj3e+nxnvUYam9SuiUkdPTatHxy+TEsnprR0fErHTkxq6OxJ3f1PddX06A5aGOi7VILczK6SdIekYUl3ufvH0zgvCu7MLIwEU+IStTDOXl8db9+rQ/89qa88OKknnprSn13zt7rgdQekq3+8YSHJoePS3F8075Frqvf/FECcroPczIYlfUbSb0l6WtL3zezL7n6g23NjwNRaGEmq5drX287CqLYxarMwml70a97CaLyA+bbpR3XB6w6o8u+XbLiAWYg55QhWGhX5OyQ94e5PSpKZfVbS9ZII8kFW38JIusovroUxunMthM/eI73msviLfim0MBpX9h/82ZvPHG8W0EHNKUehpBHk50p6qu7505Le2fhNZjYnaU6SptmsIl/WtTAStjGSzsIYm5S2vXHjzIt1Ib0rWqrdZ42LOg/97KKmx4G869vFTndfkLQgRdMP+/W6hRTbwmgR1LEtjLoNiXZc0nxKXE6394zTuNizFuTUGQhNGkH+jKTz6p7vqR5DGnxVOv3zzqrlJC2MsUlp/LyGFkbDBkUDPgujcbHn0y/skST95fxxSRPZDQzoUBpB/n1Je83sfEUBfoOk96Vw3sGzoYURMyXuzMfzkq80P9/wlvXBu+2NdSHcZI/lsd2ZtDDyauMFzOjOh79/1U8lbVhzAeRW10Hu7stm9iFJX1c0/fAed3+865GFoL6FkXTe8sorzc+1bhbGpDRx8cY9lmsBXZuRMbK1v3/fAbThAuY/Szp+UNpNkCMcqfTI3f1+Sfenca7MbGhhJGhjvPpi6/ON7lgL4XWzMGrhPLU+nEd3SMa9sHPhxYNZjwDoSDFWdr70pLT0H22q5SQtjGoIb3vD+nZF4x7LtDDCdvxQ1iMAOlKMIP/fh6SH/zD63IYaWhgXxe+xPDZJC6Nojh/q/5a1QBeKEeR7flea+tXqhkU7aWEglr94qNPbewKZKkaije2q3t1kFyGOeNv3ysxb3t4TyCNSDag3cVHLL7HiE3lFkAP1YoKcFZ/IK4IcqDcRbZw13rDLQIe39wT6iiAH6u2IgnxhQZqZiXYnmJnp4N6cQAaKMWsFSGriQklS+X2ucnkw95jB4KEix0CqVKIbJg8NRY+VSsI/OLozenz52d4MDOgBghyZ2nTgtjnn3Fw0/9t9bR54onPXdnpssrqzF2MF0kCQIzNdBW6Mxjv/SJuYB94Q5L0aK5AGghyZSRK4m6mCW8337mge+PH1G2el8uYA9AhBjsTSbi20C9zNVsGt5nt3NA+8YQfEVN4cgB4hyJFIL1oL7QJ3s1Xw/HwK88AbWiupvDkAPUKQI5FetBauvnrjXeTqA3ezVXC53OU88OFx6eX1dytM5c0B6BGCHImk3VqoVKR7742q+xoz6aab1gK3myq4XJYOH5ZWV6PHjhbzVBcFNZ6PRULIK4IciaTdWmhW4btL99fdZyqzKrjFfitdvTkAPUSQI5G0QzVJhZ9ZFRyzcRaQRwQ5Ekk7VJNW+JlUwQQ5AkOQI7E0QzXXFw8nNvbIgTwjyJGJzVT4fVsiv/1N0ePqqz16ASBdBHnB5Gm/kE4q/L4ukR8eix5P/FcPTg6kjyAvkBD3C6m98dx4YwZL5JtsnAXkEUFeIKHtF1L/xtNKT5fIE+QIBEFeIKHtF9LsjadRT5fIE+QIBEFeIKHtF9LuDSbpLJdNXxdo2AERyCuCvECynvLXaaDGvcEkncfe1XWBFwlyhIEgL5As9wtpF6jNQr7VG8999yWfx77p6wJbXictn0j2lwMyZl6/a1GfzM7O+uLiYt9fF9kplZpftJyZiQJ7bm594I6PR28yUhS6R49GFfr8fGdvPEND6zfmqjGLpj229OC7pOcekt7X//8fQCtmtt/dZxuPj2QxGBRP3IXWuKq52xWk09PN30DaXhfY8eYoyIEA0FpBX8RdaO3lbJpNXxdgvxUEhCBHX8QFai9n02z6ugBBjoDQWkFf1IKzVb+7WY88rdk05fIm2jNsnIWAEOTom1aB2i7kMzF+bvT46nFpdCLDgQDtEeTIhU1Vzb1k1a7j8Z9Iuy/PdixAG/TIgTgsCkIAugpyM/uEmR0ysx+Z2RfNbGdK4wLygf1WEIBuK/IHJF3q7m+V9FNJt3c/JCBHCHIEoKsgd/dvuPty9el3Je3pfkhAjhDkCECaPfIPSvpqqy+a2ZyZLZrZ4tLSUoovi0GQpzsXrcMOiAhA2yA3swfN7MdNPq6v+559kpYltfzv5+4L7j7r7rNTU1PpjB4DIbd3Ltq+N+MBAMm0nX7o7u+O+7qZfUDStZKu9Cx24ELw4vZayXRK4sRF0on/zHAAQDLdzlq5StJtkq5z9zb3cgGay+2di1jdiUB02yP/tKTtkh4ws0fN7O9TGBMKJrd3LtpBkCMMXa3sdPc3pTUQFFer/cj7deeiltg4C4FgZScyl+Wdi2JNXBg9cukHOcdeK8iF3O21IkmjO6PHl5+Vxl+f6VCAOFTkBZTbOdt5YxY9sigIOUeQF0xu52znGYuCkHMEecFs+q7yRcYOiMg5grxgcjtnO89orSDnCPKCye2c7TwjyJFzBHnBbPqu8kU1slV6+ZmsRwHEIsgHTLsZKbmds51XLApCAJhHPkBqM1JqFzNrM1Kk9UGdyznbeTVxkfTC/qxHAcSiIh8gzEjpATbOQgAI8gHCjJQeYOMsBIAgHyDMSOkBeuQIAEE+QJiR0gPb3hg9rpzOdhxADIJ8gDAjpQeGx6LHl57MdhxADGatDBhmpPTI8UPSDtosyCcqciAJVncixwhyIAl2QESOEeRAEuyAiBwjyIEkaK0gxwhyoJ0t50jLJ7IeBdASQQ60w+pO5BxBDrTD6k7kHEEOtEOQI+cIcqAddkBEzhHkQJzVFWnLa6PPn/uO9LOvSe7ZjglowBJ9oJWXn5W+NKMz9c63r5WWX5J+53vSrrdnOjSgHhU50MrYpDS8VVo9FT1/9UXJhqUdl2Q7LqABQQ60MjQqXfgn0tDY2rEdl67tiAjkBEEOxNl7sySLPrcR6fXXZDocoBmCHIhz9jnS698jyaSRcemcK7MeEbABQQ60c/Ft0vAWaeUVafKdWY8G2IAgD1SlIpVK0tBQ9FipZD2iAbb7ndL4tDRxcRToQM4w/TBAlYo0NyedPBk9P3Ikei5xd6CeMJN+5b612StAzphnsLhhdnbWFxcX+/66g6JUisK70cyMdPhwv0cDoF/MbL+7zzYep7USoKNHOzsOYLAR5AGanu7sOIDBlkqQm9lHzMzNbDKN8yHe/Lw0Pr7+2Ph4dBxA8XQd5GZ2nqTflsQv9n1SLksLC1FP3Cx6XFjgQidQVGnMWvkbSbdJ+lIK50JC5TLBDSDSVUVuZtdLesbdf5jge+fMbNHMFpeWlrp5WQBAnbYVuZk9KOmcJl/aJ+ljitoqbbn7gqQFKZp+2MEYAQAx2ga5u7+72XEze4uk8yX90MwkaY+kR8zsHe7+P6mOEgDQ0qZ75O7+mKTX1p6b2WFJs+5+LIVxAQASYh45AAQukyX6ZrYkqcki856alBTibwuMu78Yd38x7s7MuPtU48FMgjwLZrbYbI+CvGPc/cW4+4txp4PWCgAEjiAHgMAVKcgXsh7AJjHu/mLc/cW4U1CYHjkADKoiVeQAMJAIcgAIXOGC3Mw+bGaHzOxxM/urrMfTidD2fTezT1T/rX9kZl80s51ZjymOmV1lZj8xsyfM7KNZjycJMzvPzL5lZgeqP9O3Zj2mTpjZsJn9wMy+kvVYkjKznWb2uerP9kEz++Wsx1SoIDezd0m6XtLb3P0SSX+d8ZASC3Tf9wckXerub5X0U0m3ZzyelsxsWNJnJL1H0sWS3mtmF2c7qkSWJX3E3S+W9EuS/jiQcdfcKulg1oPo0B2SvubuF0l6m3Iw/kIFuaSbJX3c3U9Jkrs/l/F4OlHb9z2Yq9Pu/g13X64+/a6ijdXy6h2SnnD3J939tKTPKnrTzzV3f9bdH6l+fkJRqJyb7aiSMbM9kq6RdFfWY0nKzHZI+nVJd0uSu592959nOigVL8gvkPRrZvawmX3bzC7PekBJdLLve459UNJXsx5EjHMlPVX3/GkFEog1ZlaSdJmkhzMeSlKfUlScrGY8jk6cL2lJ0j9WW0J3mdnWrAeVxh2CcqXN/ukjknYp+hX0ckn/YmZv8BzMwUxr3/d+ixu3u3+p+j37FLUAKv0cW5GY2TZJn5f0p+5+POvxtGNm10p6zt33m9lvZDycToxIerukD7v7w2Z2h6SPSvrzrAc1UFrtny5JZnazpC9Ug/t7ZraqaPObzG9ZFOq+73H/3pJkZh+QdK2kK/PwhhnjGUnn1T3fUz2We2Y2qijEK+7+hazHk9AVkq4zs6slbZE0YWb3ufuNGY+rnaclPe3utd96PqcoyDNVtNbKv0l6lySZ2QWSzlLOd15z98fc/bXuXnL3kqIfpLfnIcTbMbOrFP3qfJ27n8x6PG18X9JeMzvfzM6SdIOkL2c8prYsene/W9JBd/9k1uNJyt1vd/c91Z/pGyR9M4AQV/X/3VNmdmH10JWSDmQ4JEkDWJG3cY+ke8zsx5JOS7op51Vi6D4taUzSA9XfJr7r7n+U7ZCac/dlM/uQpK9LGpZ0j7s/nvGwkrhC0vslPWZmj1aPfczd789uSAPvw5Iq1Tf8JyX9QcbjYYk+AISuaK0VABg4BDkABI4gB4DAEeQAEDiCHAACR5ADQOAIcgAI3P8D3Xy3/9Il9EgAAAAASUVORK5CYII=", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Iteration 16: w = [ 0.5881308 0.28098109 -3.55750999]\n" ] }, { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Iteration 17: w = [ 1.5881308 0.54163347 -2.97783009]\n" ] }, { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Iteration 18: w = [ 0.5881308 0.86105076 -3.42641778]\n" ] }, { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Iteration 19: w = [ 1.5881308 1.12170313 -2.84673788]\n" ] }, { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Iteration 19: w = [ 1.5881308 1.12170313 -2.84673788]\n" ] } ], "source": [ "X, y = generate_data(n=20, means=[[-1,-1], [1,2]], seed=204)\n", "train_perceptron_for_vis(X, y)" ] }, { "cell_type": "markdown", "id": "b047f58f", "metadata": {}, "source": [ "Note that in the implementation we relied on matrix algebra to make life easier. Namely, instead of computing $y_i \\langle w, x_i \\rangle$ for each data point, we can compute this in one go by noting that the design matrix is\n", "$$\n", "X = \\begin{bmatrix}\n", " x_1^T\\\\ x_2^T\\\\\\vdots\\\\x_n^T\n", "\\end{bmatrix} \\in \\mathbb{R}^{n \\times 3},\n", "$$\n", "$y \\in \\mathbb{R}^n$ and $w \\in \\mathbb{R}^3$, so \n", "$$\n", "y \\odot Xw = \n", "y \\odot \\begin{bmatrix}\n", " \\langle x_1, w \\rangle\\\\ \\langle x_2, w \\rangle\\\\\\vdots\\\\ \\langle x_n, w \\rangle\n", "\\end{bmatrix}=\n", "\\begin{bmatrix}\n", " y_1 \\langle x_1, w \\rangle\\\\ y_2 \\langle x_2, w \\rangle\\\\\\vdots\\\\y_n \\langle x_n, w \\rangle\n", "\\end{bmatrix}\n", "$$\n", "\n", "where $\\odot$ is called the Hadamard, or pointwise product, and denotes the elementwise multiplication of two vectors. From here, we can simply check which indices we are making mistakes on by checking which elements of this vector are negative." ] }, { "cell_type": "markdown", "id": "94028b80", "metadata": {}, "source": [ "## Binary Classification with Logistic Regression\n", "The perceptron classifier is perhaps the most simple type of binary classifier. One of the most popular approaches to doing classification in practice however is the Logistic Regression model. In the tutorial this week, we will take a deep look at the theory behind logistic regression. Here, we will be more interested in how logistic regression can be applied to real data problems. We will specifically look at:\n", "\n", "1. The `sklearn` logistic regression implementation\n", "2. Applying logistic regression to MNIST data\n", "3. Common performance metrics used when performing classification\n", "\n", "Further, we will look at:\n", "\n", "4. Softmax Regression: An extension of logistic regression (used only for binary classification) to multi-label classification\n", "\n", "The discussion of logistic regression here is important for two reasons, first we are going to be continuing our exploration of the `sklearn` library, and second we will be working with logistic and softmax regressions, which can be thought of as the simplest type of neural network, and so having a deep understanding of the content here will be very helpful to you in later weeks." ] }, { "cell_type": "markdown", "id": "3eaae9f6", "metadata": {}, "source": [ "#### MNIST Dataset\n", "First we introduce one of the most famous datasets used in machine learning: MNIST. MNIST is often referred to as the `Hello, World!` of Machine Learning. The MNIST dataset is comprised of images of handwritten numbers (0-9), so it has a total of $10$ classes. When doing logistic regression we will only use two labels (any subset of 2 of the 10 available labels in MNIST). The MNIST digits are represented by a a $28 \\times 28$ array (images), and so we have $28 \\times 28 = 784$ features in total.\n", "\n", "We can load MNIST in a number of ways, but see we already have some experience with the `PyTorch` module from the previous lab, we will use it to load in the data. `PyTorch` gives us an easy way to apply transformations to the data and create a `DataLoader` object that makes training models more straight forward. We will discuss the `DataLoader` object in more detail a little later on in the course when we start working with neural nets." ] }, { "cell_type": "markdown", "id": "99718861", "metadata": {}, "source": [ "The following code loads in the MNIST data, here we have both a train data set and a test data set." ] }, { "cell_type": "code", "execution_count": 9, "id": "a1fb116e", "metadata": { "scrolled": true }, "outputs": [], "source": [ "import torchvision #! pip3 install torchvision\n", "import torchvision.transforms as transforms\n", "from torchvision.datasets import MNIST\n", "\n", "transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (1.0,))])\n", "trainset = MNIST(root = './', train=True, download=True, transform=transform)\n", "testset = MNIST(root = './', train=False, download=True, transform=transform)" ] }, { "cell_type": "markdown", "id": "bfd40f6e", "metadata": {}, "source": [ "Let's take a minute to understand these objects:" ] }, { "cell_type": "code", "execution_count": 10, "id": "c28a6b78", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Number of observations in train set: 60000\n", " Number of observations in test set: 10000\n" ] }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAEICAYAAACZA4KlAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAAAReUlEQVR4nO3de9BU9X3H8fdHRKM8tl6oFK+otc4wGI0yTNpai7UmXqqYsbWgbdGmQNPo1Jle4oi3KtGMknTyj5cnxQYv9YoopmpU0kRTO4yAioAXqIJCuIQiAhMvD/LtH3seu8CzZx92z+7Zx9/nNfPMs7vfc/my8OGcPWfP+SkiMLPPvz3KbsDM2sNhN0uEw26WCIfdLBEOu1kiHHazRDjsVhhJIyXNl6R+TPtFSS+2oy+rcNgHIEkrJP1R2X304UZgemRf3pB0r6Q1kjZLekvSX/dOGBGLgE2Szi2r2dQ47NY0SXtKGg6cBjxWVboZGBERvwacB0yTdHJV/T5gStsaTZzDPsBIugc4AnhC0lZJ/yTpy5JelLRJ0quSxlZN/1NJN0r6L0lbJD0jaWhW+0K29f3fbN6XJA3LaodImiNpo6TlkiZVLfN6SY9k824GLgHOABZGxEe900XEkoj4uPdp9nNM1R/np8DpkvYu/p2ynTnsA0xE/AXwLnBuRHRR2Tr+BzANOBD4B2CWpN+omu0i4FLgYGCvbBqAicCvA4cDBwF/A3yY1R4AVgGHAH8C3CTpD6uWOQ54BNg/6+F44M2d+5V0m6RfAW8Aa4Anq/4sq4Ee4LjdfydsdznsA9+fA09GxJMRsT0ingXmA2dXTfNvEfFWRHwIPAScmL3eQyXkvxURn0bEgojYLOlw4PeAb0XERxHxCvCvwF9WLfO/I+KxbJ0fUgn9lp2bi4i/BfYDfh94FPh4p0m2ZPNaiznsA9+RwJ9mu+GbJG0CTgGGV02zturxr4Cu7PE9wI+BByT9QtItkgZT2ZpvjIjq8K4EDq16/t5OfbxPJdS7yP4j+TlwGPCNncr7AZvy/4hWBId9YKq+VPE94J6I2L/qZ0hEfKfuQiJ6IuKfI2Ik8LvAH1PZev8COFBSdXiPAFbX6AFgEfDbdVa5J1Wf2SUdSuVjxS67/1Y8h31gWgccnT2+FzhX0lclDcoOuo2VdFi9hUg6TdLxkgYBm6ns1m+PiPeAF4Gbs+V9Efh6tq5angVOkvSFbNkHSxovqSvr66vABGBu1Tx/APyk6iCetZDDPjDdDFyd7bL/GZWDZVcBv6Sypf9H+vd3+5tUDrJtBl4HfkZl1x4qwRxBZSs/G7guIp6rtaCIWAf8JOsFKlv+b1A5yPc+MB24IiLmVM12MXBHP/q0Asg3r7CiSBoJzATGRJ1/WNnewp0R8Tttac4cdrNUeDfeLBEOu1kiHHazROzZzpVJ8gECsxaLiD4vMW5qyy7pTElvZhdKXNnMssystRo+Gp99EeMtKlc7rQJeAiZExNKcebxlN2uxVmzZxwDLI+LtiPiEylVS4+rMY2YlaSbsh7LjxRCr2PFCCQAkTc5uVTS/iXWZWZNafoAuIrqBbvBuvFmZmtmyr6Zy04Neh7HjVVFm1kGaCftLwLGSjpK0FzAemFNnHjMrScO78RGxTdJlVG5+MAi4KyKWFNaZmRWqrRfC+DO7Weu15Es1ZjZwOOxmiXDYzRLhsJslwmE3S4TDbpYIh90sEQ67WSIcdrNEOOxmiXDYzRLhsJslwmE3S4TDbpYIh90sEQ67WSIcdrNEOOxmiXDYzRLhsJslwmE3S4TDbpYIh90sEQ67WSIcdrNEOOxmiXDYzRLhsJslwmE3S0TDQzZb/40aNSq3ftRRR+XWzzrrrNz62rVra9Yef/zx3HnHjh2bWz/uuONy6/WMGzeuZm348OFNLVvqc7DSz+SNULxt27bceadOnZpbv/XWW3PrnaipsEtaAWwBPgW2RcToIpoys+IVsWU/LSI2FLAcM2shf2Y3S0SzYQ/gGUkLJE3uawJJkyXNlzS/yXWZWROa3Y0/JSJWSzoYeFbSGxHxfPUEEdENdANIqn3ExMxaqqkte0Sszn6vB2YDY4poysyK13DYJQ2RtF/vY+ArwOKiGjOzYinvXGTujNLRVLbmUPk48O8R8e068yS5G79+/frc+kEHHdSmTqy/tm7dmlufNm1abr3M8/AR0ecXEBr+zB4RbwMnNNyRmbWVT72ZJcJhN0uEw26WCIfdLBEOu1kifIlrAbq6unLrgwYNaun6P/roo5q1VatWtXTdzZg3b15u/eWXX86t77ln/j/fKVOm1KzVu6y43t/pAQcckFvvRN6ymyXCYTdLhMNulgiH3SwRDrtZIhx2s0Q47GaJ8Hn2Auy777659QsvvLCl69+0aVPN2oIFC1q67jJNmDAht17vXHqejRs35tbvvPPOhpddFm/ZzRLhsJslwmE3S4TDbpYIh90sEQ67WSIcdrNE+Dx7AerdKnru3Llt6mRgqXc9+qWXXppbnz59esPrrncL9e7u7tz6ypUrG153WbxlN0uEw26WCIfdLBEOu1kiHHazRDjsZolw2M0S4fPs1lKnnnpqzdrVV1+dO+/pp59edDufybunPMCMGTNatu6y1N2yS7pL0npJi6teO1DSs5KWZb8H3h3zzRLTn934HwJn7vTalcDciDgWmJs9N7MOVjfsEfE8sPM9esYBM7PHM4Hzi23LzIrW6Gf2YRGxJnu8FhhWa0JJk4HJDa7HzArS9AG6iAhJNa8qiIhuoBsgbzoza61GT72tkzQcIPudf9mXmZWu0bDPASZmjycCjxfTjpm1iupd1yvpfmAsMBRYB1wHPAY8BBwBrAQujIj8G23j3fhOVO+a8nrjlE+enH845uKLL65ZGzVqVO689dT7tztp0qSatXvvvTd33p6enoZ66gQRob5er/uZPSJq3Ym/dd94MLPC+euyZolw2M0S4bCbJcJhN0uEw26WCF/i+jl35pk7X8O0oyuuuCK3fsYZZxTYzY4+/vjj3PqDDz6YW58zZ05uffbs2bvd0+eZt+xmiXDYzRLhsJslwmE3S4TDbpYIh90sEQ67WSJ8nn0AGDRoUG49b2jjG264IXfeYcNq3lGsEEuXLq1Zq9fbww8/XHQ7SfOW3SwRDrtZIhx2s0Q47GaJcNjNEuGwmyXCYTdLhM+zd4B6QxPffPPNufWTTz65yHZ28MYbb+TWr7322tz6rFmzimzHmuAtu1kiHHazRDjsZolw2M0S4bCbJcJhN0uEw26WiLpDNhe6skSHbB4yZEhu/YUXXsitn3DCCUW2U6iFCxfm1pctW1azdskll+TO+8knnzTSUvJqDdlcd8su6S5J6yUtrnrtekmrJb2S/ZxdZLNmVrz+7Mb/EOhrWJF/iYgTs58ni23LzIpWN+wR8TywsQ29mFkLNXOA7jJJi7Ld/ANqTSRpsqT5kuY3sS4za1KjYb8dOAY4EVgDfLfWhBHRHRGjI2J0g+syswI0FPaIWBcRn0bEduAHwJhi2zKzojUUdknDq55+DVhca1oz6wx1z7NLuh8YCwwF1gHXZc9PBAJYAUyJiDV1V5boefZRo0bl1l999dU2ddJZFixYkFsfM8Y7jI2odZ697s0rImJCHy/PaLojM2srf13WLBEOu1kiHHazRDjsZolw2M0S4VtJt8HWrVtz6xMnTsytH3nkkbn1/fffv2btiSeeyJ23Wc8991xuPW+46Z6enqLbsRzespslwmE3S4TDbpYIh90sEQ67WSIcdrNEOOxmifB59jZYsWJFU/VW2nvvvXPrp512Wm59jz0a3150dXU1PK/tPm/ZzRLhsJslwmE3S4TDbpYIh90sEQ67WSIcdrNEJHOe/fLLL8+tL1++PLf+1FNPFdlO29Q7D37ZZZfl1m+55ZYi29nBpk2bWrZs25W37GaJcNjNEuGwmyXCYTdLhMNulgiH3SwRDrtZIuqeZ5d0OHA3MIzKEM3dEfF9SQcCDwIjqAzbfGFEvN+6Vptz0kkn5dZHjhyZWy/zPHu9c+XHH398zdrUqVNz573gggsa6qm/Pvjgg5q1Mq/jT1F/tuzbgL+PiJHAl4FvShoJXAnMjYhjgbnZczPrUHXDHhFrImJh9ngL8DpwKDAOmJlNNhM4v0U9mlkBduszu6QRwJeAecCwiFiTldZS2c03sw7V7+/GS+oCZgFXRMRmSZ/VIiIkRY35JgOTm23UzJrTry27pMFUgn5fRDyavbxO0vCsPhxY39e8EdEdEaMjYnQRDZtZY+qGXZVN+Azg9Yj4XlVpDtA7/OhE4PHi2zOzoiiiz73v/59AOgV4AXgN2J69fBWVz+0PAUcAK6mcettYZ1n5K2uhp59+Orc+ePDg3PojjzxSs3b77bc31FOvoUOH5tYnTZqUW582bVpT62/GvHnzcuvnnXdezdqGDRuKbseAiFBfr9f9zB4RPwf6nBk4vZmmzKx9/A06s0Q47GaJcNjNEuGwmyXCYTdLhMNuloi659kLXVmJ59lvu+223PqUKVNy6z09PTVr7777bkM99dpnn31y64ccckhTy8+zffv23Po111yTW7/jjjty675ddPvVOs/uLbtZIhx2s0Q47GaJcNjNEuGwmyXCYTdLhMNulohkhmx+5513cutbt27NrXd1ddWsHXPMMQ31VJRly5bVrNX7fkHevDBwh6q2XXnLbpYIh90sEQ67WSIcdrNEOOxmiXDYzRLhsJslIpnr2es555xzcuvjx4+vWbvooouaWne9e9rfdNNNufUlS5bUrPl68vT4enazxDnsZolw2M0S4bCbJcJhN0uEw26WCIfdLBH9GZ/9cOBuYBgQQHdEfF/S9cAk4JfZpFdFxJN1ltWx59nNPi9qnWfvT9iHA8MjYqGk/YAFwPnAhcDWiJje3yYcdrPWqxX2uneqiYg1wJrs8RZJrwOHFtuembXabn1mlzQC+BIwL3vpMkmLJN0l6YAa80yWNF/S/OZaNbNm9Pu78ZK6gJ8B346IRyUNAzZQ+Rx/I5Vd/b+qswzvxpu1WMOf2QEkDQZ+BPw4Ir7XR30E8KOIGFVnOQ67WYs1fCGMJAEzgNerg54duOv1NWBxs02aWev052j8KcALwGtA7/i+VwETgBOp7MavAKZkB/PyluUtu1mLNbUbXxSH3az1fD27WeIcdrNEOOxmiXDYzRLhsJslwmE3S4TDbpYIh90sEQ67WSIcdrNEOOxmiXDYzRLhsJslwmE3S0TdG04WbAOwsur50Oy1TtSpvXVqX+DeGlVkb0fWKrT1evZdVi7Nj4jRpTWQo1N769S+wL01ql29eTfeLBEOu1kiyg57d8nrz9OpvXVqX+DeGtWW3kr9zG5m7VP2lt3M2sRhN0tEKWGXdKakNyUtl3RlGT3UImmFpNckvVL2+HTZGHrrJS2ueu1ASc9KWpb97nOMvZJ6u17S6uy9e0XS2SX1drik/5S0VNISSX+XvV7qe5fTV1vet7Z/Zpc0CHgLOANYBbwETIiIpW1tpAZJK4DREVH6FzAknQpsBe7uHVpL0i3Axoj4TvYf5QER8a0O6e16dnMY7xb1VmuY8Uso8b0rcvjzRpSxZR8DLI+ItyPiE+ABYFwJfXS8iHge2LjTy+OAmdnjmVT+sbRdjd46QkSsiYiF2eMtQO8w46W+dzl9tUUZYT8UeK/q+So6a7z3AJ6RtEDS5LKb6cOwqmG21gLDymymD3WH8W6nnYYZ75j3rpHhz5vlA3S7OiUiTgLOAr6Z7a52pKh8Buukc6e3A8dQGQNwDfDdMpvJhhmfBVwREZura2W+d3301Zb3rYywrwYOr3p+WPZaR4iI1dnv9cBsKh87Osm63hF0s9/rS+7nMxGxLiI+jYjtwA8o8b3LhhmfBdwXEY9mL5f+3vXVV7vetzLC/hJwrKSjJO0FjAfmlNDHLiQNyQ6cIGkI8BU6byjqOcDE7PFE4PESe9lBpwzjXWuYcUp+70of/jwi2v4DnE3liPz/AFPL6KFGX0cDr2Y/S8ruDbifym5dD5VjG18HDgLmAsuA54ADO6i3e6gM7b2ISrCGl9TbKVR20RcBr2Q/Z5f93uX01Zb3zV+XNUuED9CZJcJhN0uEw26WCIfdLBEOu1kiHHazRDjsZon4P6amj+TmJh8JAAAAAElFTkSuQmCC", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "# size of datasets\n", "print(f\"Number of observations in train set: {trainset.data.shape[0]}\")\n", "print(f\" Number of observations in test set: {testset.data.shape[0]}\")\n", "\n", "# plotting an image in the dataset using imshow()\n", "idx = 2489\n", "plt.imshow(trainset.data[idx], cmap='gray')\n", "plt.title(trainset.targets[idx])\n", "plt.show()" ] }, { "cell_type": "markdown", "id": "ce2ab6e8", "metadata": {}, "source": [ "### Logistic Regression with Sklearn\n", "\n", "In the tutorial this week, we define the logistic regression problem as the problem of minimizing the cross entropy loss:\n", "\n", "\\begin{align*}\n", "\\hat{\\beta}, \\hat{\\beta}_0 &= \\arg \\min_{\\beta, \\beta_0} \\mathcal{L}(\\beta, \\beta_0)\\\\\n", "&= \\arg \\min_{\\beta, \\beta_0 \\in \\mathbb{R}^p} \n", "-\\left [ \n", " \\sum_{i=1}^n y_i \\ln \\left (\\frac{1}{1+e^{-\\beta^T x_i - \\beta_0}} \\right) + (1-y_i) \\ln \\left (1-\\frac{1}{1+e^{-\\beta^T x_i - \\beta_0}} \\right)\n", "\\right]\n", "\\end{align*}\n", "\n", "where $y_i \\in \\{0,1\\}$ is binary valued and represents the label associated with the input $x_i$. In pratice however, we usually work with the following regularized version:\n", "\n", "\\begin{align*}\n", "\\hat{\\beta}, \\hat{\\beta}_0 &= \\arg \\min_{\\beta, \\beta_0} C\\mathcal{L}(\\beta, \\beta_0) + \\text{Penalty}(\\beta),\n", "\\end{align*}\n", "\n", "where Penalty $(\\beta)$ is a penalty/regularization term applied to the weight vector $w$. For example, $\\text{Penalty}(\\beta) = \\|\\beta\\|_1$ when we want to apply $\\ell_1$ regularization to $w$ (this would be the logistic version of the LASSO). Note that we not usually penalize the intercept/bias term $w_0$. Note also the introduction of the hyper-parameter $C$, which does a similar job to $\\lambda$ in the lasso/ridge regression model formulations.." ] }, { "cell_type": "markdown", "id": "bde6e2a6", "metadata": {}, "source": [ "\n", " \n", "#### Exercise\n", "\n", "Consider the `sklearn` logistic regression implementation (Section 1.1.11) which claims to minimize the following objective:\n", " \n", "\\begin{align*}\n", " \\hat{w}, \\hat{c} = \\arg \\min_{w, c} \\left[\\|w\\|_1 + C \\sum_{i=1}^n \\log (1+ \\exp (-\\tilde{y}_i(w^T x_i + c))) \\right ].\n", "\\end{align*}\n", "\n", "It turns out that this objective is identical to our objective above, but only after re-coding the variables to be in $\\{-1,1\\}$ instead of binary values $\\{ 0,1\\}$. That is, $\\tilde{y}_i \\in \\{-1,1\\}$, whereas $y_i \\in \\{0,1 \\}$. Argue rigorously that the two objectives are identical, in that they give us the same solutions ($\\hat{\\beta}_0 = \\hat{c}$ and $\\hat{\\beta} = \\hat{w}$). Further, describe the role of $C$ in the objectives, how does it compare to the standard LASSO parameter $\\lambda$?" ] }, { "cell_type": "markdown", "id": "b396693d", "metadata": {}, "source": [ "\n", " \n", "We can focus on a single $i$ for simplicity, note that the $i$-th summand in the first objective is equivalent to \n", "\\begin{align*}\n", " &y_i \\ln (1+\\exp(-\\beta^T x_i - \\beta_0))+ (1-y_i) \n", " \\ln \\left ( 1+\\exp(\\beta^T x_i + \\beta_0) \\right)\\\\\n", " &= \\begin{cases}\n", " \\ln \\left ( 1+\\exp(-1 \\times (\\beta^T x_i + \\beta_0)) \\right) \\qquad &\\text{if} \\quad y=1\\\\\n", " \\ln \\left ( 1+\\exp(1 \\times (\\beta^T x_i + \\beta_0)) \\right) \\qquad &\\text{if} \\quad y=0\n", " \\end{cases}\\\\\n", " &= \\begin{cases}\n", " \\ln \\left ( 1+\\exp(-\\tilde{y}_i \\times (\\beta^T x_i + \\beta_0)) \\right) \\qquad &\\text{if} \\quad \\tilde{y}_i=1\\\\\n", " \\ln \\left ( 1+\\exp(-\\tilde{y} \\times (\\beta^T x_i + \\beta_0)) \\right) \\qquad &\\text{if} \\quad \\tilde{y}=-1\n", " \\end{cases}\\\\\n", " &= \\ln \\left ( 1+\\exp(-\\tilde{y}(\\beta^T x_i + \\beta_0)) \\right).\n", "\\end{align*}\n", " $C$ attaches higher importance to the `fit'term, so as $C$ increases, we care more about fitting than the penalty, and so $C$ plays an inverse role to that of $\\lambda$ in standard LASSO, i.e. $C \\propto 1/\\lambda$. This is a standard trick used to rewrite loss functions, see for example the SVM objective from lectures." ] }, { "cell_type": "markdown", "id": "c9920eb3", "metadata": {}, "source": [ "We next create train, validation and test sets to run logistic regression on - these will be comprised of only the class1 and class2 images in MNIST (to be determined in the code). Logistic regression expects a vector input and so treats the $28 \\times 28$ image as a flattened 784 dimensional vector. Obviously in practice we would want to account for the spatial properties of the input images, but for here we will ignore those. \n", "\n", "We will take our train set to be 60\\% of the available images, whereas the validation and test sets will be 20\\% each. We will also convert the tensor datasets to numpy arrays.\n", "\n", "To make the problem more challenging for the model, we will work with a smaller sample of size `nSample` and we will also randomly choose a subset of $s$ features and omit the rest." ] }, { "cell_type": "code", "execution_count": 11, "id": "f08c1ff1", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Xtrain shape = (1200, 40)\n", "Xvalid shape = (400, 40)\n", "Xtest shape = (400, 40)\n" ] } ], "source": [ "from sklearn.model_selection import train_test_split\n", "\n", "# sample size and features to work with\n", "s = 40 # reduced features to work with \n", "sFeatures = np.random.choice(np.arange(784), size=s, replace=False) # choose s features randomly from the 784\n", "sFeatures.sort()\n", "nSample = 1000\n", "\n", "\n", "# choose two class labels\n", "class1Label = 0\n", "class2Label = 1\n", "\n", "class1Images = trainset.data[trainset.targets==class1Label].reshape(-1,784).numpy() # images with class1\n", "class2Images = trainset.data[trainset.targets==class2Label].reshape(-1,784).numpy() # images with class2\n", "\n", "# work with a smaller sample size\n", "class1Images = class1Images[:nSample, sFeatures]\n", "class2Images = class2Images[:nSample, sFeatures]\n", "X = np.concatenate((class1Images, class2Images), axis=0)\n", "y = np.concatenate((np.zeros(class1Images.shape[0]), np.ones(class2Images.shape[0])))\n", "\n", "# create Xtrain, Xvalid, Xtest\n", "Xtrain, X_, ytrain, y_ = train_test_split(X, y, test_size=0.4, shuffle=True)\n", "Xvalid, Xtest, yvalid, ytest = train_test_split(X_, y_, test_size=0.5, shuffle=True)\n", "\n", "print(f'Xtrain shape = {Xtrain.shape}')\n", "print(f'Xvalid shape = {Xvalid.shape}')\n", "print(f'Xtest shape = {Xtest.shape}')" ] }, { "cell_type": "markdown", "id": "dbd8fb05", "metadata": {}, "source": [ "\n", " \n", "#### Exercise\n", "\n", "In order to choose the correct value of the hyper-parameter $C$, we will use K-fold cross validation. As always, we use the validation set to do any form of hyper-parameter tuning.\n", " \n", "Create a grid fo 100 $C$ values ranging from $C=0.0001$ to $C=1$ in equally sized increments. Use the `sklearn.model_selection.GridSearchCV` to do the grid search cross validation using your parameter grid. For the `estimator` argument, use the `sklearn.linear_model.LogisticRegression` model with `l1` penalty and `liblinear` solver and find the optimal value of $C$ when performing 10-fold cross validation using `neg_log_loss` scoring. Fit your logistic model with the chosen $C$ value. Report train and test accuracy.\n", " \n", "Further, explain why it is a bad idea to tune hyper-parameters using the test set." ] }, { "cell_type": "markdown", "id": "4da28c82", "metadata": {}, "source": [ "\n", "\n", "The test set is an independent clean dataset that we test our final model on. If we were to use it to tune the hyperparameters of our model then we would be using the test set to both tune our model and test our model, and is likely to result in overfitting." ] }, { "cell_type": "code", "execution_count": 12, "id": "74a728d9", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Best C according to gridsearch: 0.002119191919191919\n", "Train Accuracy: 0.505\n", "Test Accuracy: 0.475\n" ] } ], "source": [ "#### Solution\n", "from sklearn.linear_model import LogisticRegression\n", "from sklearn.model_selection import GridSearchCV\n", "from sklearn.metrics import accuracy_score\n", "\n", "C_grid = np.linspace(0.0001, 0.2, 100)\n", "grid_lr = GridSearchCV(estimator=LogisticRegression(penalty='l1', solver='liblinear'),\n", " cv=10, \n", " param_grid={'C': C_grid}, \n", " scoring='neg_log_loss')\n", "grid_lr.fit(Xvalid, yvalid)\n", "Cbest = grid_lr.best_params_['C']\n", "print(f'Best C according to gridsearch: {Cbest}')\n", "\n", "logistic_mod = LogisticRegression(penalty='l1', solver='liblinear', C=Cbest).fit(Xtrain, ytrain)\n", "print(f'Train Accuracy: {accuracy_score(logistic_mod.predict(Xtrain), ytrain)}')\n", "print(f'Test Accuracy: {accuracy_score(logistic_mod.predict(Xtest), ytest)}')" ] }, { "cell_type": "markdown", "id": "44df30fc", "metadata": {}, "source": [ "#### An important distinction: predict_proba vs. predict\n", "Many classification models in `sklearn` (not just logistic regression) have two methods that can be somewhat confusing to students, so let's make sure we understand what each of these methods does and when to use each one. In what follows, let's assume we called our logistic model `logistic_mod`. Then, recall that a logistic regression is a type of regression, meaning it outputs a probability that an input belongs to a particular class (the class coded as $y=1$). In order to turn this into a classifier, we use a threshold value $t$ (known as the discrimination threshold) and classify any probabilty that is $t$ or greater as belonging to class $y=1$, and $y=0$ otherwise. The choice of $t$ is up to the modeller, and the default value is $t=0.5$. \n", "\n", "With this in mind, the `predict_proba()` method returns the predicted probability that the model returns for a particular input, whereas `predict()` returns the predicted class. If we are solely interested in model accuracy, then two different models that have the same predicted labels will be of equal value to us. However, the accuracy can be a crude way of comparing models. \n", "\n", "For example, consider two models: m1 and m2. Let's say we have inputs `[x1,x2,x3,x4]` with labels $[1, 0, 0, 0]$ and the following holds:\n", "\n", "```\n", "m1.predict([x1,x2,x3,x4]) = [1,0,0,0]\n", "m2.predict([x1,x2,x3,x4]) = [1,0,0,0]\n", "```\n", "\n", "So both models have 100\\% accuracy. However, let's say we also get:\n", "\n", "```\n", "m1.predict_proba([x1,x2,x3,x4]) = [[0.1, 0.9],\n", " [0.99, 0.01]\n", " [0.9, 0.1],\n", " [0.92, 0.08]]\n", "m2.predict_proba([x1,x2,x3,x4]) = [[0.49, 0.51] ,\n", " [0.51, 0.49],\n", " [0.52, 0.48],\n", " [0.51, 0.49]]\n", "```\n", "The `predict_proba` method returns the predicted probability for each class, which is why there are two numbers for each `xi`.\n", "\n", "Now, clearly m1 is able to discriminate the positive and negative classes much more effectively than m2, but the two models would be equivalent in terms of accuracy score. If we are more interested in the ability of a model to discriminate between two classes, we would consider different metrics, such as the cross entropy loss (see tutorial this week), which is also called the log-loss, and is implemented in `sklearn.metrics.log_loss`. From tutorials, we know that the cross entropy measures the distance between a predicted probability and the truth, and is defined by\n", "\n", "$$\n", "\\mathcal{L}(w) = -\\sum_{i=1}^n (y_i \\ln \\hat{p}_i + (1-y_i) \\ln(1-\\hat{p}_i)).\n", "$$\n", "\n", "In our toy example,\n", "\n", "```\n", "ll_m1 = log_loss([1,0,0,0], m1.predict_proba([x1,x2,x3,x4])) = 0.0760\n", "ll_m2 = log_loss([1,0,0,0], m2.predict_proba([x1,x2,x3,x4])) = 0.6685\n", "```\n", "\n", "So the log-loss of m1 is much lower than that of m2, and we would prefer m1 if our criteria is log-loss and not accuracy." ] }, { "cell_type": "markdown", "id": "3227045b", "metadata": {}, "source": [ "#### The Trade-Off between TPR (Sensitivity) and TNR (Specificity)\n", "A binary classifier can have one of four possible outcomes:\n", "1. True Positive (TP): classifier predicts $\\hat{y}=1$ when true label is $y=1$\n", "2. False Positive (FN): classifier predicts $\\hat{y}=1$ when true label is $y=0$\n", "3. True Negative (TN): classifier predicts $\\hat{y}=0$ when true label is $y=0$\n", "4. False Negative (FN): classifier predicts $\\hat{y}=0$ when true label is $y=1$\n", "\n", "True positives and True negatives are clearly good, whereas false positives and false negatives are obviously not so good. Two good examples to keep in mind are:\n", "\n", "Example 1 (CANCER DETECTION): You build a classifier that looks at the output of some costly medical test and attempts to predict whether patient has a rare form of cancer and needs further testing ($y=1$) or no cancer ($y=0$). The goal here is to save as many lives as possible.\n", "\n", "Example 2 (MARKETING CAMPAIGN): You work for a bank and wish to identify new customers to send ads to about a new credit card. The classifier identifies a potential customer and sends them a targeted ad via email ($y=1$) or labels them as not a potential customer ($y=0$). The goal here is to increase the revenue stream of the bank." ] }, { "cell_type": "markdown", "id": "25157d0a", "metadata": {}, "source": [ "\n", " \n", "#### Exercise\n", "Discuss Examples 1 and 2, are false positives and false negatives equally as bad in both cases?" ] }, { "cell_type": "markdown", "id": "53368dae", "metadata": {}, "source": [ "\n", " \n", "In cancer detection, a false positive means that we perform further tests on someone who does not have cancer. A false negative means that we send a patient suffering from cancer home. Clearly a false negative is much more costly (in terms of saving lives) than is a false positive.\n", "\n", "In marketing, a false positive means we send someone who is not interested an email ad. Most likely they delete the email without much thought. A false negative means we miss the chance to gain a new customer that will grow our business. A false negative is arguably much worse in this scenario. " ] }, { "cell_type": "markdown", "id": "a543337b", "metadata": {}, "source": [ "We define the following terms, each of which (somewhat annoyingly) has many names in the literature:\n", "\n", "- True Positive Rate (TPR), Sensitivity, Recall: $\\text{TPR} = \\frac{\\text{TP}}{\\text{P}} = \\frac{\\text{TP}}{\\text{TP} + \\text{FN}}$ where $P$ is the actual number of positive values in the data.\n", "- True Negative Rate (TNR), Specificity, Selectivity: $\\text{TNR} = \\frac{\\text{TN}}{\\text{N}} = \\frac{\\text{TN}}{\\text{TN} + \\text{FP}}$ where $N$ is the actual number of negative values in the data.\n", "\n", "Any good classifier should ideally have a large TPR and a large TNR, but unfortunately, there is a trade-off between the two. To see this, consider the following exercise." ] }, { "cell_type": "markdown", "id": "f406db0c", "metadata": {}, "source": [ "\n", " \n", "#### Exercise\n", "Consider fitting logistic regression models to the cancer detection problem discussed earlier. Once you fit your model, you need to choose the discrimination parameter $t$. Explain what happens at the extremes $t=0$ and $t=1$. Explain the trade-off between TPR and TNR and $t$ increases." ] }, { "cell_type": "markdown", "id": "6bbc7289", "metadata": {}, "source": [ "\n", "\n", "At $t=0$ we classify everything as $\\hat{y}=1$, i.e. we classify everyone as having the cancer. While this is going to have a TPR of 100 % it will also lead to fa TNR of 0%. As you increase $t$, you will start to classify less and less instances as positive, and eventually will start to make mistakes, leading to a reduction in TPR. Your TNR must increase though, since classifying more and more negatives is bound to be the correct decision in some cases." ] }, { "cell_type": "markdown", "id": "2338cb1d", "metadata": {}, "source": [ "The Receiver Operator Characteristic (ROC) curve is a graphical depiction of what happens to the TPR and FPR = 1-TNR as we vary the threshold $t$. It is a good way of comparing different classification models. Let's consider a couple of logistic regression models fit to the MNIST problem but with different $C$ values, and compare them by looking at their ROC curves. The Area Under the Curve (AUC) gives us a nice summary of how good a particular model is, it is the area under each of the ROC curves." ] }, { "cell_type": "code", "execution_count": 13, "id": "46995bfe", "metadata": {}, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "from sklearn.metrics import roc_curve, roc_auc_score\n", "\n", "# choosing C via cross validation then computing FPR and TPR at all thresholds\n", "logistic_mod_1 = LogisticRegression(penalty='l1', solver='liblinear', C=Cbest).fit(Xtrain, ytrain)\n", "mod_1_preds = logistic_mod_1.predict_proba(Xtest)[:,1] # note, predictions need to be just for one of the classes\n", "fpr_1, tpr_1, _ = roc_curve(ytest, mod_1_preds)\n", "auc_1 = roc_auc_score(ytest, mod_1_preds)\n", "\n", "# choosing C small\n", "logistic_mod_2 = LogisticRegression(penalty='l1', solver='liblinear', C=0.02).fit(Xtrain, ytrain)\n", "mod_2_preds = logistic_mod_2.predict_proba(Xtest)[:,1]\n", "fpr_2, tpr_2, _ = roc_curve(ytest, mod_2_preds)\n", "auc_2 = roc_auc_score(ytest, mod_2_preds)\n", "\n", "# choosing C to be very large (no regularization)\n", "logistic_mod_3 = LogisticRegression(solver='liblinear', C=1000).fit(Xtrain, ytrain)\n", "mod_3_preds = logistic_mod_3.predict_proba(Xtest)[:,1]\n", "fpr_3, tpr_3, _ = roc_curve(ytest, mod_3_preds)\n", "auc_3 = roc_auc_score(ytest, mod_3_preds)\n", "\n", "\n", "fig = plt.figure(figsize=(10,10))\n", "plt.plot(fpr_1, tpr_1, label=\"model 1, AUC=\"+str(auc_1))\n", "plt.plot(fpr_2, tpr_2, label=\"model 2, AUC=\"+str(auc_2))\n", "plt.plot(fpr_3, tpr_3, label=\"model 3, AUC=\"+str(auc_3))\n", "plt.ylabel(\"True Positive Rate (TPR)\")\n", "plt.xlabel(\"False Positive Rate (FPR)\")\n", "plt.title(\"ROC Curves for Three Logistic Models\")\n", "plt.legend()\n", "plt.show()\n" ] }, { "cell_type": "markdown", "id": "f6628ac5", "metadata": {}, "source": [ "Some things to note:\n", "\n", "1. The ultimate model would have a TPR and a TNR set to 1, or equivalently the TPR set to 1 and the FPR set to 0, so the closer that the models are to the top-left corner of the plot, the better.\n", "2. All three models seem to be quite poor at this task. \n", "3. At a given threshold, there is no guarantee that one model is better than the other.\n", "4. The AUC provides a nice summary score that we can use to compare the models. You can think of the AUC as the average FPR/TPR trade-off over all possible discrimination thresholds.\n", "\n", "Extra resource: Check out the ROC visually here: https://kennis-research.shinyapps.io/ROC-Curves/. " ] }, { "cell_type": "markdown", "id": "b4d6ee5e", "metadata": {}, "source": [ "## Multi-class classification - Softmax Regression" ] }, { "cell_type": "markdown", "id": "c2f07d6b", "metadata": {}, "source": [ "In Softmax Regression, we allow the label to take on one of $K$ possible values, so given an input $x$, we would like to estimate $P(Y=k|x)$ for $k=1,\\dots, K$. The model must therefore output a $K$ dimensional vector of probabilities whose elements sum to $1$, i.e. \n", "\n", "\\begin{align*}\n", "h(x) = \n", "\\begin{bmatrix}\n", "P(y=1|x) \\\\ P(y=2|x)\\\\ \\vdots \\\\ P(y=K|x)\n", "\\end{bmatrix}\n", "\\end{align*}\n", "\n", "We will not go into the mathematical details of how softmax regression is derived (for those that are interested, it is similar to the logistic regression case, but using the multinomial distribution instead of the bernoulli). We will however give a short motivation for how the probabilities are estimated." ] }, { "cell_type": "markdown", "id": "6c0e9941", "metadata": {}, "source": [ "#### Short motivation for softmax regression \n", "Recall (from lectures and tutorial this week) that in the logistic regression problem (binary $y$), we wish to estimate a probability: $P(y=1|x)$. One approach is to use a linear model:\n", "\n", "$$\n", "P(y=1|x) = w^T x,\n", "$$\n", "\n", "where $w$ is a parameter vector that we estimate. The problem with this however is that a probability must be in the range $[0,1]$, whereas $w^T x \\in (-\\infty, \\infty)$. A solution is to then wrap the linear model in a logistic sigmoid:\n", "\n", "$$\n", "P(y=1|x) = \\sigma(w^T x) = \\frac{1}{1+e^{-w^T x}},\n", "$$\n", "\n", "which guarantees that our estimated probability is actually a probability, regardless of the choice of $w$. Let's extend this idea to the multi-class case: assume that for each class we have a $p$-dimensional vector $\\theta_1 \\in \\mathbb{R}^p, \\theta_2 \\in \\mathbb{R}^p, \\dots, \\theta_K \\in \\mathbb{R}^p$, so we have a total of $Kp$ unkown parameters to be estimated, and we will stack these unkown vectors into a single matrix for brevity:\n", "\n", "$$\n", "\\Theta = [\\theta_1,\\dots, \\theta_K] \\in \\mathbb{R}^{p \\times K}.\n", "$$\n", "\n", "Now, we want to estimate $K$ probabilities, so we could just do the linear model approach:\n", "\n", "\\begin{align*}\n", "P(y=1|x) &= \\theta_1^T x\\\\\n", "P(y=2|x) &= \\theta_2^T x\\\\\n", "& \\vdots\\\\\n", "P(y=K|x) &= \\theta_K^T x\n", "\\end{align*}\n", "\n", "in which case we run into the same issue that each term is not guaranteed to be a probability. We can use the same trick of using the logistic sigmoid:\n", "\n", "\\begin{align*}\n", "P(y=1|x) &= \\sigma(\\theta_1^T x)\\\\\n", "P(y=2|x) &= \\sigma(\\theta_2^T x)\\\\\n", "& \\vdots\\\\\n", "P(y=K|x) &= \\sigma(\\theta_K^T x).\n", "\\end{align*}\n", "\n", "Which seemingly fixes the problem - but we are faced with another problem now, there is no guarantee that the sum of the probabilities add up to one. An easy fix is to normalize the probabilities by their sum:\n", "\n", "\\begin{align*}\n", "P(y=1|x) &= \\frac{\\sigma(\\theta_1^T x)}{\\sum_{j=1}^k \\sigma(\\theta_j^T x)}\\\\\n", "P(y=2|x) &= \\frac{\\sigma(\\theta_2^T x)}{\\sum_{j=1}^k \\sigma(\\theta_j^T x)}\\\\\n", "& \\vdots\\\\\n", "P(y=K|x) &= \\frac{\\sigma(\\theta_K^T x)}{\\sum_{j=1}^k \\sigma(\\theta_j^T x)}.\n", "\\end{align*}\n", "\n", "This ensures two things:\n", "1. each element is a valid probability (it has to be between $0$ and $1$.\n", "2. the sum of the estimated probabilities is 1.\n", "\n", "Our softmax regression model therefore takes the form:\n", "\n", "\\begin{align*}\n", "h_\\Theta (x) = \n", "\\begin{bmatrix}\n", "P(y=1|x) \\\\ P(y=2|x)\\\\ \\vdots \\\\ P(y=K|x)\n", "\\end{bmatrix}\n", "=\n", "\\frac{1}{ \\sum_{j=1}^k \\sigma(\\theta_j^T x)}\n", "\\begin{bmatrix}\n", "\\sigma(\\theta_1^T x)\\\\\n", "\\sigma(\\theta_2^T x)\\\\\n", "\\vdots \\\\\n", "\\sigma(\\theta_K^T x)\n", "\\end{bmatrix}\n", "\\end{align*}\n", "\n", "and we now have a parameter matrix $\\Theta$ to estimate. What remains is to define the loss function we will use for the problem. Recall in the logistic case the loss function is:\n", "\n", "\\begin{align*}\n", "\\mathcal{L} (w)\n", "&= -\\sum_{i=1}^n (y_i \\ln \\hat{p}_i + (1-y_i) \\ln(1-\\hat{p}_i))\\\\\n", "&= -\\sum_{i=1}^n (y_i \\ln P(y_i = 1|x_i) + (1-y_i) \\ln P(y_i = 0|x_i),\n", "\\end{align*}\n", "\n", "and noting that $y_i = 0$ or $y_i = 1$, we can rewrite this as \n", "\\begin{align*}\n", "\\mathcal{L} (w)&= -\\sum_{i=1}^n \\sum_{j=0}^1 \\mathbf{1}\\{y_i = j \\} \\ln P(y_i = j|x_i),\n", "\\end{align*}\n", "\n", "where $ \\mathbf{1}\\{y_i = j \\}$ is an indicator variable, which simply takes the value $1$ if the term inside the brackets is true, and zero otherwise. For example $\\mathbf{1}\\{ 1+2=3\\} = 1$. Now, the reason we wrote the loss this way is that it allows us to generalize to the softmax regression problem in the following way:\n", "\n", "\\begin{align*}\n", "\\mathcal{L}(\\Theta) \n", "&= -\\sum_{i=1}^n \\sum_{j=0}^K \\mathbf{1}\\{y_i = j \\} \\ln P(y_i = j|x_i)\\\\\n", "&= -\\sum_{i=1}^n \\sum_{j=0}^K \\mathbf{1}\\{y_i = j \\} \\ln \\frac{\\sigma(\\theta^T_j x_i)}{\\sum_{k=1}^K \\sigma(\\theta^T_k x_i)}.\n", "\\end{align*}" ] }, { "cell_type": "markdown", "id": "b76f7b57", "metadata": {}, "source": [ "#### Softmax Regression in sklearn\n", "Luckily for us, the `sklearn` logistic regression implementation handles softmax regression automatically. In this section we will demonstrate its performance on the full MNIST dataset. We first create our dataset in the usual way:" ] }, { "cell_type": "code", "execution_count": 14, "id": "981a8c93", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Xtrain shape = (2000, 784)\n", "Xtest shape = (400, 784)\n" ] } ], "source": [ "# subsample the data to make it more manageable\n", "nSample = 2000\n", "idxsTrain = np.random.choice(np.arange(trainset.data.shape[0]), size=nSample, replace=False)\n", "idxsTest = np.random.choice(np.arange(testset.data.shape[0]), size=nSample//5, replace=False)\n", "\n", "Xtrain = trainset.data.reshape(-1, 784).numpy()[idxsTrain, :]\n", "Xtest = testset.data.reshape(-1, 784).numpy()[idxsTest, :] \n", "ytrain = trainset.targets.numpy()[idxsTrain]\n", "ytest = testset.targets.numpy()[idxsTest]\n", "\n", "print(f'Xtrain shape = {Xtrain.shape}')\n", "print(f'Xtest shape = {Xtest.shape}')" ] }, { "cell_type": "markdown", "id": "0923b443", "metadata": {}, "source": [ "\n", " \n", "#### Exercise\n", "As mentioned earlier, `sklearn` handles softmax regression easily using the same `LogisticRegression` object as used for the binary regression problem, with an added argument `multi_class` which controls whether we fit a binary logistic or a softmax regression (also referred to as multinomial logistic regression). If we set `multi_class=ovr`, it fits a single logistic to every single label (i.e. $K$ standard logistic fits), whereas if we set `multi_class=multinomial`, we will get a softmax regression fit. Note that the `multi_class=multinomial` is not compatible with the `solver=liblinear` setting.\n", " \n", "Fit a multinomial regression to the entire MNIST dataset using the `sag` solver and `l2` regularization with `C=50` (you may need to increase the `max_iter` to 1000 here), additionally you can increase the `tol` argument to `0.001` for faster convergence. Use the `sklearn.metrics.confusion_matrix`, produce a confusion matrix (on the test set) to display your results." ] }, { "cell_type": "code", "execution_count": 15, "id": "36914bf7", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Confusion_matrix: \n", "[[37 0 0 0 0 0 0 0 0 0]\n", " [ 0 34 1 1 0 1 0 0 2 0]\n", " [ 0 1 27 0 1 0 0 1 3 0]\n", " [ 0 0 0 34 0 1 1 1 0 0]\n", " [ 0 0 1 0 36 0 0 0 0 2]\n", " [ 0 0 0 2 2 35 0 0 5 1]\n", " [ 0 0 2 0 2 1 42 1 0 0]\n", " [ 0 0 1 1 0 1 0 40 0 2]\n", " [ 0 2 1 1 0 2 0 0 28 2]\n", " [ 1 0 1 0 3 0 0 1 1 34]]\n" ] } ], "source": [ "#### Solution\n", "from sklearn.metrics import confusion_matrix\n", "\n", "sm_mod = LogisticRegression(multi_class='multinomial',\n", " penalty='l2',\n", " C=50,\n", " solver='sag',\n", " tol=.001,\n", " max_iter=1000\n", " ).fit(Xtrain, ytrain)\n", "\n", "ypred = sm_mod.predict(Xtest)\n", "print(\"Confusion_matrix: \\n\"+str(confusion_matrix(ytest, ypred)))" ] }, { "cell_type": "markdown", "id": "bfb15935", "metadata": {}, "source": [ "\n", " \n", "#### Exercise\n", "Recall that the softmax regression requires estimation of $Kp$ parameters, one $p$-dimensional vector for each of the $K$ classes. For the MNIST problem, that means we have $10 \\times 784$ parameters in total. Create a $2 \\times 5$ grid of plots (one for each of the 10 MNIST classes), and use `plt.imshow()` to plot the estimated coefficient vector $\\theta_k$. Note that you will have to reshape each coefficient vector into the $28 \\times 28$ original format of MNIST to see anything interesting. What do you observe?" ] }, { "cell_type": "code", "execution_count": 16, "id": "f317f838", "metadata": {}, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "#### Solution\n", "fig, axes = plt.subplots(2, 5, figsize=(14,8))\n", "for i, ax in enumerate(axes.flat):\n", " ax.imshow(sm_mod.coef_[i].reshape(28,28), cmap=plt.cm.RdBu_r, interpolation='nearest')\n", "plt.tight_layout()\n", "plt.show()" ] }, { "cell_type": "markdown", "id": "3db32996", "metadata": {}, "source": [ "\n", " \n", "#### Exercise\n", "In some cases, we would like a single metric to summarize the performance of a model as opposed to having a confusion matrix. The most commonly used (and readily available in `sklearn.metrics`) metrics are the precision, recall and F1 scores. Read the following to gain an understanding of these metrics:\n", "\n", "https://towardsdatascience.com/accuracy-precision-recall-or-f1-331fb37c5cb9\n", " \n", "Further, read the following post to understand the difference between micro and macro versions of these metrics:\n", " \n", "https://datascience.stackexchange.com/questions/15989/micro-average-vs-macro-average-performance-in-a-multiclass-classification-settin" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.9" } }, "nbformat": 4, "nbformat_minor": 5 }