{ "cells": [ { "cell_type": "markdown", "id": "ca26ead8-dc86-46c2-b3b3-a650fd415df3", "metadata": {}, "source": [ "# decision trees" ] }, { "cell_type": "code", "execution_count": 75, "id": "02e7c747-e8e6-4dc8-87df-59b36844b3b5", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "RangeIndex: 178 entries, 0 to 177\n", "Data columns (total 14 columns):\n", " # Column Non-Null Count Dtype \n", "--- ------ -------------- ----- \n", " 0 Wine 178 non-null int64 \n", " 1 Alcohol 178 non-null float64\n", " 2 Malic.acid 178 non-null float64\n", " 3 Ash 178 non-null float64\n", " 4 Acl 178 non-null float64\n", " 5 Mg 178 non-null int64 \n", " 6 Phenols 178 non-null float64\n", " 7 Flavanoids 178 non-null float64\n", " 8 Nonflavanoid.phenols 178 non-null float64\n", " 9 Proanth 178 non-null float64\n", " 10 Color.int 178 non-null float64\n", " 11 Hue 178 non-null float64\n", " 12 OD 178 non-null float64\n", " 13 Proline 178 non-null int64 \n", "dtypes: float64(11), int64(3)\n", "memory usage: 19.6 KB\n" ] }, { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
WineAlcoholMalic.acidAshAclMgPhenolsFlavanoidsNonflavanoid.phenolsProanthColor.intHueODProline
0114.231.712.4315.61272.803.060.282.295.641.043.921065
1113.201.782.1411.21002.652.760.261.284.381.053.401050
2113.162.362.6718.61012.803.240.302.815.681.033.171185
3114.371.952.5016.81133.853.490.242.187.800.863.451480
4113.242.592.8721.01182.802.690.391.824.321.042.93735
.............................................
106212.251.732.1219.0801.652.030.371.633.401.003.17510
107212.721.752.2822.5841.381.760.481.633.300.882.42488
108212.221.291.9419.0922.362.040.392.082.700.863.02312
109211.611.352.7020.0942.742.920.292.492.650.963.26680
110211.463.741.8219.51073.182.580.243.582.900.752.81562
\n", "

111 rows × 14 columns

\n", "
" ], "text/plain": [ " Wine Alcohol Malic.acid Ash Acl Mg Phenols Flavanoids \\\n", "0 1 14.23 1.71 2.43 15.6 127 2.80 3.06 \n", "1 1 13.20 1.78 2.14 11.2 100 2.65 2.76 \n", "2 1 13.16 2.36 2.67 18.6 101 2.80 3.24 \n", "3 1 14.37 1.95 2.50 16.8 113 3.85 3.49 \n", "4 1 13.24 2.59 2.87 21.0 118 2.80 2.69 \n", ".. ... ... ... ... ... ... ... ... \n", "106 2 12.25 1.73 2.12 19.0 80 1.65 2.03 \n", "107 2 12.72 1.75 2.28 22.5 84 1.38 1.76 \n", "108 2 12.22 1.29 1.94 19.0 92 2.36 2.04 \n", "109 2 11.61 1.35 2.70 20.0 94 2.74 2.92 \n", "110 2 11.46 3.74 1.82 19.5 107 3.18 2.58 \n", "\n", " Nonflavanoid.phenols Proanth Color.int Hue OD Proline \n", "0 0.28 2.29 5.64 1.04 3.92 1065 \n", "1 0.26 1.28 4.38 1.05 3.40 1050 \n", "2 0.30 2.81 5.68 1.03 3.17 1185 \n", "3 0.24 2.18 7.80 0.86 3.45 1480 \n", "4 0.39 1.82 4.32 1.04 2.93 735 \n", ".. ... ... ... ... ... ... \n", "106 0.37 1.63 3.40 1.00 3.17 510 \n", "107 0.48 1.63 3.30 0.88 2.42 488 \n", "108 0.39 2.08 2.70 0.86 3.02 312 \n", "109 0.29 2.49 2.65 0.96 3.26 680 \n", "110 0.24 3.58 2.90 0.75 2.81 562 \n", "\n", "[111 rows x 14 columns]" ] }, "execution_count": 75, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import pandas as pd\n", "df = pd.read_csv('wine.csv')\n", "df.info()\n", "df.head(111)" ] }, { "cell_type": "code", "execution_count": 44, "id": "282230e7-b6bb-4742-b4a7-7e88660d3a44", "metadata": {}, "outputs": [], "source": [ "df.head()\n", "X = df.iloc[:, 1:]\n", "y = df.iloc[:, 0]\n", "\n", "import numpy as np\n", "from sklearn.model_selection import train_test_split\n", "#rs = np.random.RandomState(123) # doesn't work with DecisionTreeClassifier.\n", "rs = 123\n", "train_prop = 0.7 \n", "X_train, X_test, y_train, y_test = train_test_split(X, y, train_size=train_prop, random_state=rs)" ] }, { "cell_type": "code", "execution_count": 20, "id": "62f355bf-e32e-4fe8-9f8d-cc37f3684de9", "metadata": {}, "outputs": [], "source": [ "from sklearn.preprocessing import StandardScaler\n", "\n", "scaler = StandardScaler()\n", "X_train = scaler.fit_transform(X_train)\n", "X_test = scaler.transform(X_test)" ] }, { "cell_type": "code", "execution_count": 73, "id": "e37aaee4-a46a-4954-b285-682ac0f68895", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0.9838709677419355\n", "0.9444444444444444\n" ] } ], "source": [ "from sklearn import tree\n", "from sklearn.tree import DecisionTreeClassifier\n", "from sklearn.metrics import accuracy_score\n", "\n", "model = DecisionTreeClassifier(max_depth=3,criterion='gini',random_state=rs)\n", "model.fit(X_train, y_train)\n", "print(accuracy_score(model.predict(X_train), y_train))\n", "print(accuracy_score(model.predict(X_test), y_test))" ] }, { "cell_type": "markdown", "id": "36bc2d39-e632-477c-8b2b-b4145e3c4916", "metadata": {}, "source": [ "# ensembles" ] }, { "cell_type": "markdown", "id": "e1c275aa-9cf5-4571-94a6-c294e89b4f86", "metadata": {}, "source": [ "## random forests (bagging)" ] }, { "cell_type": "code", "execution_count": 81, "id": "411880db-5d8e-42bf-a878-ff482064169f", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0.9814814814814815" ] }, "execution_count": 81, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from sklearn.ensemble import RandomForestClassifier\n", "\n", "model_rf = RandomForestClassifier(n_estimators=100, random_state=rs).fit(X_train,y_train)\n", "\n", "accuracy_score(model_rf.predict(X_test), y_test)" ] }, { "cell_type": "markdown", "id": "e800e0ec-4fed-4f25-adef-c97572ccfd67", "metadata": {}, "source": [ "## boosting (ada)" ] }, { "cell_type": "code", "execution_count": 90, "id": "57a2a6bd-a170-4d4b-8d15-caa365186cc6", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0.9629629629629629" ] }, "execution_count": 90, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from sklearn.ensemble import AdaBoostClassifier\n", "model_b = AdaBoostClassifier(estimator=DecisionTreeClassifier(max_depth=1, criterion='entropy'),\n", " n_estimators=1000, random_state=rs)\n", "model_b.fit(X_train, y_train)\n", "y_pred = model_b.predict(X_test)\n", "accuracy_score(y_pred, y_test)" ] }, { "cell_type": "markdown", "id": "f112b0a2-ff69-4df0-86f4-15fd49afa445", "metadata": {}, "source": [ "## metrics" ] }, { "cell_type": "code", "execution_count": 91, "id": "e2548a5c-e89c-433b-a597-fa23effaefef", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " precision recall f1-score support\n", "\n", " 1 0.86 1.00 0.92 12\n", " 2 1.00 0.90 0.95 20\n", " 3 1.00 1.00 1.00 22\n", "\n", " accuracy 0.96 54\n", " macro avg 0.95 0.97 0.96 54\n", "weighted avg 0.97 0.96 0.96 54\n", "\n" ] } ], "source": [ "from sklearn import metrics\n", "print(metrics.classification_report(y_pred, y_test))" ] }, { "cell_type": "code", "execution_count": 95, "id": "6b0b87e7-7416-4b57-b182-9d876e097cfb", "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAbAAAAGwCAYAAADITjAqAAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjEsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvc2/+5QAAAAlwSFlzAAAPYQAAD2EBqD+naQAAJJhJREFUeJzt3XlYlXX+//HXURCUFDM2I3fzp4ap4FoOllrWlA5TaTqVS1bmQorjL7FSR02xLAnGpSlTs9U2bbG+o6Nlirsmhmu5hDkgEgkCKtv5/uE3ZgiRc/Acbj7yfFwX1yWfc3P7rjt7ep9z3+fY7Ha7XQAAGKaG1QMAAFARBAwAYCQCBgAwEgEDABiJgAEAjETAAABGImAAACMRMACAkTysHsAdDre5y+oRUEnaHvne6hEAuEFB3slyt+EMDABgJAIGADASAQMAGImAAQCMRMAAAEYiYAAAIxEwAICRCBgAwEgEDABgJAIGADASAQMAGImAAQCMRMAAAEYiYAAAIxEwAICRCBgAwEgEDABgJAIGADASAQMAGImAAQCMRMAAAEYiYAAAIxEwAICRCBgAwEgEDABgJAIGADASAQMAGImAAQCMRMAAAEYiYAAAIxEwAICRCBgAwEgEDABgJAIGADASAQMAGImAAQCMRMAAAEYiYAAAIxEwAICRCBgAwEgEDABgJAIGADASAQMAGImAAQCMRMAAAEYiYAAAIxEwAICRCBgAwEgEDABgJAIGADASAQMAGImAAQCMRMAAAEYiYAAAIxEwAICRCBgAwEgEzDC1O4Xo+oV/U/MN76jVgf+RT+/u/3nQo6b8/vqomny6SC13rVLzDe8oaM5E1fRvYN3AcLlRTw7Vj4e3KjvriDZv+lydO3WweiS4Ccf68giYYWy1vXXh0DGlzVxQ6rEa3l7yattSvyx6Vz/dP1b/fmqmPJveoOCFf6v8QeEWAwb010tzp2nm8/PUuetdSty7X1+ufkf+/tdZPRpcjGNdPpvdbrdbPYSrHW5zl9UjVIpWB/5HJ8dOV866LWVu4xXSSk0+jNfRXo+oIOV0JU5XOdoe+d7qESrV5k2fa8fORI0b/5wkyWaz6fjRHVqwcKlenFv6LzUwV3U/1gV5J8vdxqMS5ihTenq6lixZoi1btig1NVWSFBQUpFtuuUXDhg2Tv7+/leNdFWrW9ZG9qEhFWTlWj4Ir5OnpqdDQmzXnxfnFa3a7XevWb1K3bmEWTgZX41g7xrKnEHfs2KFWrVopPj5evr6+Cg8PV3h4uHx9fRUfH6/WrVtr586d5e7nwoULysrKKvGVV1RUCf8EVZ+tlqf8/vqozq7+RkU5uVaPgyvk59dAHh4eSjuVXmI9Le20ggL5y97VhGPtGMvOwCIjIzVgwAC9+uqrstlsJR6z2+168sknFRkZqS1byn56TJJiYmI0ffr0Emtjr2uhSP+WLp/ZKB411TD2WclmU9r0+eVvDwCGsewMLDExUVFRUaXiJV18rjcqKkp79uwpdz+TJ09WZmZmia+R1zV3w8QG8aip62Ofkef1Afp5xGTOvq4S6ekZKigoUECgX4n1gAB/pZ66+l7frM441o6xLGBBQUHavn17mY9v375dgYGB5e7Hy8tL9erVK/FVq0Y1vrjyt3g1CdbPj05W0ZmzVk8EF8nPz9fu3XvV6/YexWs2m029bu+hrVt3WTgZXI1j7RjLnkKcOHGinnjiCe3atUu9e/cujtWpU6e0bt06vf7663rppZesGq/KstXxVq3G1xd/73lDkLxaN1dh5lkVnM7Q9a88J6+2LXVy1FSpZg3V9LtWklSYeVbKL7BqbLhIbNzrWvpGrHbt3qsdO77TU5GPy8entpa9ucLq0eBiHOvyWRawMWPGyM/PT7GxsVq4cKEKCwslSTVr1lRYWJiWLVumgQMHWjVeleV9Uys1Wv5i8fcB0SMlSZkr1+qX+W/rmv+7sbnpqkUlfu7EkKd1bsfeyhsUbvHhh5/J36+B/jZ1ooKC/JWYuE/33Puw0tLSy/9hGIVjXb4qcR9Yfn6+0tMvHhQ/Pz95enpe0f6qy31gqH73gQHVRZW/D+w3np6eatiwodVjAAAMUo2vdgAAmIyAAQCMRMAAAEYiYAAAIxEwAICRCBgAwEgEDABgJAIGADASAQMAGImAAQCMRMAAAEYiYAAAIxEwAICRCBgAwEgEDABgJAIGADASAQMAGImAAQCMRMAAAEYiYAAAIxEwAICRCBgAwEgEDABgJAIGADASAQMAGImAAQCMRMAAAEYiYAAAIxEwAICRCBgAwEgEDABgJAIGADASAQMAGImAAQCMRMAAAEYiYAAAIxEwAICRCBgAwEgEDABgJAIGADASAQMAGImAAQCMRMAAAEYiYAAAIxEwAICRCBgAwEgEDABgJAIGADASAQMAGImAAQCMRMAAAEYiYAAAIxEwAICRCBgAwEgeVg/gDnefTrN6BFSSs0sftXoEVKK6w5dYPQKqEM7AAABGImAAACMRMACAkQgYAMBIBAwAYCQCBgAwEgEDABiJgAEAjETAAABGImAAACMRMACAkQgYAMBIDr2Z77XXXiubzebQDjMyMq5oIAAAHOFQwF555RU3jwEAgHMcCtjQoUPdPQcAAE6p0GtgR44c0XPPPafBgwcrLe3iZ2999dVX2rdvn0uHAwCgLE4HbMOGDWrXrp22bdumTz75RNnZ2ZKkxMRETZs2zeUDAgBwKU4HLDo6Ws8//7zWrl2rWrVqFa/36tVLW7dudelwAACUxemAff/99/rzn/9caj0gIEDp6ekuGQoAgPI4HbD69esrJSWl1Pp3332n4OBglwwFAEB5nA7YoEGDNGnSJKWmpspms6moqEgJCQmaOHGihgwZ4o4ZAQAoxemAzZ49W61bt1ajRo2UnZ2ttm3bKjw8XLfccouee+45d8wIAEApNrvdbq/IDyYnJyspKUnZ2dnq2LGjbrzxRlfPVmEt/EKtHgGVJCnuHqtHQCWqO3yJ1SOgkhTknSx3G4duZL6Uxo0bq1GjRpLk8NtMAQDgKhW6kfmNN95QSEiIvL295e3trZCQEC1evNjVswEAUCanz8CmTp2qefPmKTIyUt27d5ckbdmyRVFRUUpOTtaMGTNcPiQAAL/n9Gtg/v7+io+P1+DBg0usv/fee4qMjKwS94LxGlj1wWtg1QuvgVUfjrwG5vRTiPn5+erUqVOp9bCwMBUUFDi7OwAAKsTpgD3yyCNatGhRqfXXXntNDz30kEuGAgCgPA69BjZhwoTiX9tsNi1evFhr1qxRt27dJEnbtm1TcnIyNzIDACqNQwH77rvvSnwfFhYm6eLHqkiSn5+f/Pz8+DgVAEClcShgX3/9tbvnAADAKRW6DwwAAKtV6J04du7cqQ8++EDJycnKy8sr8dgnn3ziksEAALgcp8/A3n//fd1yyy06cOCAVq5cqfz8fO3bt0/r16+Xr6+vO2YEAKCUCr0bfWxsrD7//HPVqlVLcXFxOnjwoAYOHKjGjRu7Y0YAAEpxOmBHjhzRPfdcfPeDWrVqKScnRzabTVFRUXrttddcPiAAAJfidMCuvfZanT17VpIUHByspKQkSdKZM2eUm5vr2ukAACiD0xdxhIeHa+3atWrXrp0GDBigcePGaf369Vq7dq169+7tjhkBACjF6YDNnz9f58+flyQ9++yz8vT01ObNm3X//ffzicwAgErjdMAaNGhQ/OsaNWooOjrapQMBAOAIhwKWlZXl8A7r1atX4WEAAHCUQwGrX7++bDbbZbex2+2y2WwqLCx0yWAAAFwO74UIADCSQwHr2bOnu+cAAMApvJkvAMBIBAwAYCQCBgAwEgEDABiJgAEAjOTQVYgdO3Ys9z6w3+zevfuKBgIAwBEOBSwiIqL41+fPn9fChQvVtm1bde/eXZK0detW7du3T6NHj3bLkAAA/J5DAZs2bVrxrx977DE99dRTmjlzZqltTpw44drpAAAog9OvgX344YcaMmRIqfWHH35YH3/8sUuGAgCgPE4HrHbt2kpISCi1npCQIG9vb5cMBQBAeZz+OJXx48dr1KhR2r17t7p06SJJ2rZtm5YsWaIpU6a4fEAAAC7F6YBFR0erefPmiouL09tvvy1JatOmjZYuXaqBAwe6fEBc3pPjhqvvvb3U/MamunDugnbvSNQLM+J17MefrB4NV2jXT6f15uaDOpDyq05nn9e8gbeqV+vg4sdz8/IVt+57fX3wpDLP5Sm4vo8Gd2mpAZ1aWjg1XGnUk0P11wmjFBTkr71792vc+CnasXOP1WNVGU4HTJIGDhxIrKqIrreE6e03PtDe7/appkdNTXxurN78cKH63nq/zuWet3o8XIFzeQVqFVhfER2bacIHm0s9/tKaRO04lqZZf+6q6+v7aMuRVMV8uVv+dWvrtv8XfIk9wiQDBvTXS3OnafSYaG3f8Z2einxMX65+R21DwnX69C9Wj1clVOhG5jNnzmjx4sV65plnlJGRIeni/V8nT5506XAo3/AHx+rj9z/XD4eO6uC+H/T02GkKbtRQIe3bWj0arlCPGxtqbK926tX6hks+nngiXf3aN1HnpgEKru+jB8JaqFVQfSWdzKjkSeEOUeMe1+I33tWbyz/QgQM/aPSYaOXmntPwYYOsHq3KcDpge/fuVatWrfTCCy9o7ty5OnPmjCTpk08+0eTJk109H5xUt15dSVLmr5kWTwJ3a9/IT98c/rdOZeXKbrdrx7E0/fTLWXVvEWT1aLhCnp6eCg29WevWbyxes9vtWrd+k7p1C7NwsqrF6YBNmDBBw4YN0w8//FDiqsM//vGP+vbbb1063IkTJ/Too49edpsLFy4oKyurxJfdXuTSOUxhs9n03KyJ2rn1Ox0+eMTqceBm0Xd1VHO/eur7yhfqPOsjjX73W02+O1RhTfytHg1XyM+vgTw8PJR2Kr3EelraaQUFcnx/43TAduzYoZEjR5ZaDw4OVmpqqkuG+k1GRobefPPNy24TExMjX1/fEl+/njvl0jlMMf3FaLVq3ULjHudMuDp4b/sP+v5khuIe7KF3H79Df72jvWK+2q2tR6vnf/+ofpy+iMPLy0tZWVml1g8fPix/f+f+ZvDZZ59d9vGjR4+Wu4/JkydrwoQJJdY6NAt3ao6rwbQ5k9Trzj9oUL/HlJqSZvU4cLPz+QX6+/okzRt4i8JbXS9JahVYX4dOndHyLYfUrXmgxRPiSqSnZ6igoEABgX4l1gMC/JV66rRFU1U9Tgesf//+mjFjhj744ANJF5+2Sk5O1qRJk3T//fc7ta+IiAjZbDbZ7fYytynvTYS9vLzk5eX1u5+pXm+yP23OJN15z+166E+P6+fkf1s9DipBQZFdBUVFqvG7Px81bDYVXebPE8yQn5+v3bv3qtftPfTZZ/+UdPH/hb1u76GFi5ZaPF3V4fT/6V9++WVlZ2crICBA586dU8+ePdWyZUvVrVtXs2bNcmpfDRs21CeffKKioqJLfvHO9uWb/mK0Igb8UVEjn1F2dq78Aq6TX8B18vL2Kv+HUaXl5uXrYOqvOpj6qyTp5JlsHUz9VSmZObrGy1NhTfwV+69E7TieppO/ZuvTPcf0xd6fStwrBnPFxr2ux0b8RY88MkCtW7fUgvlz5ONTW8veXGH1aFWGzX6505/LSEhIUGJiorKzsxUaGqo+ffo4vY/+/furQ4cOmjFjxiUfT0xMVMeOHVVU5NxFGS38Qp2exVRH0i8d+afHTtPH739eydNUvqS4e6wewW12HE/T48u/KbXer31TzfxTF6Vnn1P8uu+15egpZZ3LU0PfOro/tLke7tbK4Y8/Mk3d4UusHqFSjR41rPhG5sTEfRofNVXbd3xn9ViVoiCv/NuynA7Y8uXL9eCDD5Z62i4vL0/vv//+Jd/otywbN25UTk6O7rrrrks+npOTo507d6pnz57OjFitAlbdXc0BQ2nVLWDVmVsCVrNmTaWkpCggIKDE+i+//KKAgAAVFhY6N6UbELDqg4BVLwSs+nAkYE6/Bma32y/59MTPP/8sX19fZ3cHAECFOHwVYseOHWWz2WSz2dS7d295ePznRwsLC3Xs2LEynwoEAMDVHA5YRESEJGnPnj3q27evrrnmmuLHatWqpaZNmzp9GT0AABXlcMCmTZsmSWratKkGDRpU6iIOAAAqk9OvgbVt21Z79uwptb5t2zbt3LnTFTMBAFAupwM2ZswYnThxotT6yZMnNWbMGJcMBQBAeZwO2P79+xUaWvoy9Y4dO2r//v0uGQoAgPI4HTAvLy+dOlX63a5TUlJKXJkIAIA7OR2wO++8U5MnT1Zm5n8+MPHMmTN65plndMcdd7h0OAAAyuL0KdNLL72k8PBwNWnSRB07dpR08dL6wMBAvfXWWy4fEACAS3E6YMHBwdq7d6/eeecdJSYmqnbt2ho+fLgGDx4sT09Pd8wIAEApFXrRysfHR0888YSrZwEAwGEOBeyzzz7T3XffLU9Pz3I/Rbl///4uGQwAgMtxKGARERFKTU1VQEBA8VtKXYrNZqsS70YPALj6ORSw//5ASWc/XBIAAHdw+jJ6AACqAofOwOLj4x3e4VNPPVXhYQAAcJRDAYuNjS3x/enTp5Wbm6v69etLungjc506dRQQEEDAAACVwqGnEI8dO1b8NWvWLHXo0EEHDhxQRkaGMjIydODAAYWGhmrmzJnunhcAAEmSzW632535gRYtWuijjz4qfheO3+zatUsPPPCAjh075tIBK6KFX+k3G8bVKSnuHqtHQCWqO3yJ1SOgkhTknSx3G6cv4khJSVFBQUGp9cLCwku+yS8AAO7gdMB69+6tkSNHavfu3cVru3bt0qhRo9SnTx+XDgcAQFmcDtiSJUsUFBSkTp06ycvLS15eXurSpYsCAwO1ePFid8wIAEApTr8Xor+/v7788ksdPnxYBw8elCS1bt1arVq1cvlwAACUpcKfQNm0aVPZ7Xa1aNGCD7IEAFQ6p59CzM3N1YgRI1SnTh3ddNNNSk5OliRFRkZqzpw5Lh8QAIBLcTpgkydPVmJior755ht5e3sXr/fp00crVqxw6XAAAJTF6ef+Vq1apRUrVqhbt26y2WzF6zfddJOOHDni0uEAACiL02dgp0+fVkBAQKn1nJycEkEDAMCdnA5Yp06dtHr16uLvf4vW4sWL1b17d9dNBgDAZTj9FOLs2bN19913a//+/SooKFBcXJz279+vzZs3a8OGDe6YEQCAUpw+A+vRo4cSExNVUFCgdu3aac2aNQoICNCWLVsUFhbmjhkBACjFqTOw/Px8jRw5UlOmTNHrr7/urpkAACiXU2dgnp6e+vjjj901CwAADnP6KcSIiAitWrXKDaMAAOA4py/iuPHGGzVjxgwlJCQoLCxMPj4+JR7nE5kBAJXB6Q+0bNasWdk7s9l09OjRKx7qSvGBltUHH2hZvfCBltWHIx9o6fQZWFX4xGUAAJx+Dey/2e12OXkCBwCAS1QoYG+88YZCQkLk7e0tb29vhYSE8GGWAIBK5fRTiFOnTtW8efMUGRlZ/NZRW7ZsUVRUlJKTkzVjxgyXDwkAwO85fRGHv7+/4uPjNXjw4BLr7733niIjI5Wenu7SASuCiziqDy7iqF64iKP6cOQiDqefQszPz1enTp1KrYeFhamgoMDZ3QEAUCFOB+yRRx7RokWLSq2/9tpreuihh1wyFAAA5XH6NTDp4kUca9asUbdu3SRJ27ZtU3JysoYMGaIJEyYUbzdv3jzXTAkAwO84HbCkpCSFhl58jem3T2D28/OTn5+fkpKSirfjwy0BAO7kdMC+/vprd8wBAIBTruhGZgAArELAAABGImAAACMRMACAkQgYAMBIBAwAYCQCBgAwEgEDABjJ6XejN4FHrWCrRwDgBuf+vdHqEVBJPP2al7sNZ2AAACMRMACAkQgYAMBIBAwAYCQCBgAwEgEDABiJgAEAjETAAABGImAAACMRMACAkQgYAMBIBAwAYCQCBgAwEgEDABiJgAEAjETAAABGImAAACMRMACAkQgYAMBIBAwAYCQCBgAwEgEDABiJgAEAjETAAABGImAAACMRMACAkQgYAMBIBAwAYCQCBgAwEgEDABiJgAEAjETAAABGImAAACMRMACAkQgYAMBIBAwAYCQCBgAwEgEDABiJgAEAjETAAABGImAAACMRMACAkQgYAMBIBAwAYCQCBgAwEgEDABiJgAEAjETAAABGImAAACMRMACAkQgYAMBIBAwAYCQCBgAwEgEDABiJgAEAjETArgKjnhyqHw9vVXbWEW3e9Lk6d+pg9UhwI4731ef15Sv04Iin1KXPfQq/Z5Ceip6hYz/9XPx4ZtZZzZ63UPcOekxht/9Jfe4botmxi3Q2O8fCqa1HwAw3YEB/vTR3mmY+P0+du96lxL379eXqd+Tvf53Vo8ENON5Xp517vtfg+/rp3ddi9dors5VfUKAnop5V7rnzkqS09F+Ulp6hiWMf08q3FmnWsxOUsG2XpsbEWjy5tWx2u91u9RCu5lEr2OoRKs3mTZ9rx85EjRv/nCTJZrPp+NEdWrBwqV6cu8Di6eBq1f14n/v3RqtHqBQZv55R+L2DtWzBi+rUod0lt/nn+o2KnvGidvxrlTw8albyhO7n6de83G04AzOYp6enQkNv1rr1//lDbbfbtW79JnXrFmbhZHAHjnf1kZ2TK0nyrVe3zG3OZufoGp86V2W8HGV5wM6dO6dNmzZp//79pR47f/68li9fftmfv3DhgrKyskp8XYUnlZfk59dAHh4eSjuVXmI9Le20ggL9LZoK7sLxrh6Kioo0J+4f6nhzW93YvOklt/n1TKb+sew9PdD/7sodroqxNGCHDx9WmzZtFB4ernbt2qlnz55KSUkpfjwzM1PDhw+/7D5iYmLk6+tb4stedNbdowOAWzz/8gL9ePS45k6PvuTj2Tk5Gv3/p6lFs8YaPeLhSp6uarE0YJMmTVJISIjS0tJ06NAh1a1bV7feequSk5Md3sfkyZOVmZlZ4stWo+zT7qtJenqGCgoKFBDoV2I9IMBfqadOWzQV3IXjffWb9fJCbdi8XUv+/oKCAkqfVefk5GrkhCnyqVNbcbOnyNPDw4Ipqw5LA7Z582bFxMTIz89PLVu21Oeff66+ffvqD3/4g44ePerQPry8vFSvXr0SXzabzc2TVw35+fnavXuvet3eo3jNZrOp1+09tHXrLgsngztwvK9edrtds15eqHXfbtaS+Dm64fqgUttk5+Toiahn5enpob+/ME1eXrUsmLRqsTRg586dk8d//Q3CZrNp0aJF6tevn3r27KnDhw9bOJ0ZYuNe12Mj/qJHHhmg1q1basH8OfLxqa1lb66wejS4Acf76vT8ywv0xZr1euFvT8unTm2l/5Kh9F8ydP7CBUn/F6/xzyr3/HnNiB6vnJzc4m0KCwstnt46lp5/tm7dWjt37lSbNm1KrM+fP1+S1L9/fyvGMsqHH34mf78G+tvUiQoK8ldi4j7dc+/DSktLL/+HYRyO99VpxcrVkqThYyeVWH/+mQmKuOcO7T90RHv3H5Ik/fHBESW2+edHyxTcMLByBq1iLL0PLCYmRhs3btSXX355ycdHjx6tV199VUVFRU7ttzrdBwZUJ9XlPjA4dh8YNzIDMAYBqz64kRkAcNUiYAAAIxEwAICRCBgAwEgEDABgJAIGADASAQMAGImAAQCMRMAAAEYiYAAAIxEwAICRCBgAwEgEDABgJAIGADASAQMAGImAAQCMRMAAAEYiYAAAIxEwAICRCBgAwEgEDABgJAIGADASAQMAGImAAQCMRMAAAEYiYAAAIxEwAICRCBgAwEgEDABgJAIGADASAQMAGImAAQCMRMAAAEYiYAAAIxEwAICRCBgAwEgEDABgJAIGADASAQMAGImAAQCMRMAAAEYiYAAAIxEwAICRCBgAwEgEDABgJAIGADASAQMAGImAAQCMRMAAAEYiYAAAIxEwAICRCBgAwEgEDABgJAIGADASAQMAGMlmt9vtVg+BK3fhwgXFxMRo8uTJ8vLysnocuBHHuvrgWF8eAbtKZGVlydfXV5mZmapXr57V48CNONbVB8f68ngKEQBgJAIGADASAQMAGImAXSW8vLw0bdo0XuitBjjW1QfH+vK4iAMAYCTOwAAARiJgAAAjETAAgJEIGADASATsKrBgwQI1bdpU3t7e6tq1q7Zv3271SHCDb7/9Vv369dP1118vm82mVatWWT0S3CQmJkadO3dW3bp1FRAQoIiICB06dMjqsaocAma4FStWaMKECZo2bZp2796t9u3bq2/fvkpLS7N6NLhYTk6O2rdvrwULFlg9Ctxsw4YNGjNmjLZu3aq1a9cqPz9fd955p3JycqwerUrhMnrDde3aVZ07d9b8+fMlSUVFRWrUqJEiIyMVHR1t8XRwF5vNppUrVyoiIsLqUVAJTp8+rYCAAG3YsEHh4eFWj1NlcAZmsLy8PO3atUt9+vQpXqtRo4b69OmjLVu2WDgZAFfKzMyUJDVo0MDiSaoWAmaw9PR0FRYWKjAwsMR6YGCgUlNTLZoKgCsVFRVp/PjxuvXWWxUSEmL1OFWKh9UDAADKNmbMGCUlJWnTpk1Wj1LlEDCD+fn5qWbNmjp16lSJ9VOnTikoKMiiqQC4ytixY/XFF1/o22+/1Q033GD1OFUOTyEarFatWgoLC9O6deuK14qKirRu3Tp1797dwskAXAm73a6xY8dq5cqVWr9+vZo1a2b1SFUSZ2CGmzBhgoYOHapOnTqpS5cueuWVV5STk6Phw4dbPRpcLDs7Wz/++GPx98eOHdOePXvUoEEDNW7c2MLJ4GpjxozRu+++q08//VR169Ytfk3b19dXtWvXtni6qoPL6K8C8+fP19y5c5WamqoOHTooPj5eXbt2tXosuNg333yj22+/vdT60KFDtWzZssofCG5js9kuub506VINGzascoepwggYAMBIvAYGADASAQMAGImAAQCMRMAAAEYiYAAAIxEwAICRCBgAwEgEDABgJAIGGKpp06Z65ZVXHN5+2bJlql+//hX/vjabTatWrbri/QBXioABFXDbbbdp/PjxVo8BVGsEDHATu92ugoICq8cArloEDHDSsGHDtGHDBsXFxclms8lms+n48eP65ptvZLPZ9NVXXyksLExeXl7atGmThg0bpoiIiBL7GD9+vG677bbi74uKihQTE6NmzZqpdu3aat++vT766COn5po3b57atWsnHx8fNWrUSKNHj1Z2dnap7VatWqUbb7xR3t7e6tu3r06cOFHi8U8//VShoaHy9vZW8+bNNX36dEKMKomAAU6Ki4tT9+7d9fjjjyslJUUpKSlq1KhR8ePR0dGaM2eODhw4oJtvvtmhfcbExGj58uV69dVXtW/fPkVFRenhhx/Whg0bHJ6rRo0aio+P1759+/Tmm29q/fr1evrpp0tsk5ubq1mzZmn58uVKSEjQmTNnNGjQoOLHN27cqCFDhmjcuHHav3+//vGPf2jZsmWaNWuWw3MAlcYOwGk9e/a0jxs3rsTa119/bZdkX7VqVYn1oUOH2v/0pz+VWBs3bpy9Z8+edrvdbj9//ry9Tp069s2bN5fYZsSIEfbBgweXOUOTJk3ssbGxZT7+4Ycf2q+77rri75cuXWqXZN+6dWvx2oEDB+yS7Nu2bbPb7XZ779697bNnzy6xn7feesvesGHD4u8l2VeuXFnm7wtUFj7QEnCxTp06ObX9jz/+qNzcXN1xxx0l1vPy8tSxY0eH9/Ovf/1LMTExOnjwoLKyslRQUKDz588rNzdXderUkSR5eHioc+fOxT/TunVr1a9fXwcOHFCXLl2UmJiohISEEmdchYWFpfYDVAUEDHAxHx+fEt/XqFFD9t997F5+fn7xr397nWr16tUKDg4usZ2Xl5dDv+fx48d17733atSoUZo1a5YaNGigTZs2acSIEcrLy3M4PNnZ2Zo+fbruu+++Uo95e3s7tA+gshAwoAJq1aqlwsJCh7b19/dXUlJSibU9e/bI09NTktS2bVt5eXkpOTlZPXv2rNA8u3btUlFRkV5++WXVqHHxpe0PPvig1HYFBQXauXOnunTpIkk6dOiQzpw5ozZt2kiSQkNDdejQIbVs2bJCcwCViYABFdC0aVNt27ZNx48f1zXXXKMGDRqUuW2vXr00d+5cLV++XN27d9fbb7+tpKSk4qcH69atq4kTJyoqKkpFRUXq0aOHMjMzlZCQoHr16mno0KHlztOyZUvl5+fr73//u/r166eEhAS9+uqrpbbz9PRUZGSk4uPj5eHhobFjx6pbt27FQZs6daruvfdeNW7cWA888IBq1KihxMREJSUl6fnnn6/gvy3APbgKEaiAiRMnqmbNmmrbtq38/f2VnJxc5rZ9+/bVlClT9PTTT6tz5846e/ashgwZUmKbmTNnasqUKYqJiVGbNm101113afXq1WrWrJlD87Rv317z5s3TCy+8oJCQEL3zzjuKiYkptV2dOnU0adIk/eUvf9Gtt96qa665RitWrCgx6xdffKE1a9aoc+fO6tatm2JjY9WkSRMH/80Alcdm//2T8wAAGIAzMACAkQgYAMBIBAwAYCQCBgAwEgEDABiJgAEAjETAAABGImAAACMRMACAkQgYAMBIBAwAYKT/BU7Y/c9h+NX2AAAAAElFTkSuQmCC", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "from sklearn.metrics import confusion_matrix\n", "import seaborn as sns\n", "mat = confusion_matrix(y_test, y_pred)\n", "sns.heatmap(mat.T, square=True, annot=True, fmt='d', cbar=False)\n", "plt.xlabel('true label')\n", "plt.ylabel('predicted label');" ] }, { "cell_type": "markdown", "id": "632ed73f-beab-4f4f-b59a-7511cb847035", "metadata": {}, "source": [ "# log-reg" ] }, { "cell_type": "code", "execution_count": 46, "id": "b26d104c-67b9-4946-9c6c-6592a5838a98", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Train Accuracy: 0.9758064516129032\n", "Test Accuracy: 0.9444444444444444\n" ] } ], "source": [ "from sklearn.linear_model import LogisticRegression\n", "logistic_mod = LogisticRegression(penalty='l2', solver='liblinear').fit(X_train, y_train)\n", "\n", "print(f'Train Accuracy: {accuracy_score(logistic_mod.predict(X_train), y_train)}')\n", "print(f'Test Accuracy: {accuracy_score(logistic_mod.predict(X_test), y_test)}')" ] }, { "cell_type": "markdown", "id": "a1f1c486-5370-4c03-a7fb-93355b605bfe", "metadata": {}, "source": [ "# svm" ] }, { "cell_type": "code", "execution_count": 61, "id": "3f8ad1b7-1865-4c7c-96a5-1346f9357d3e", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Linear SVC Train Accuracy: 1.0\n", "Linear SVC Test Accuracy: 0.9814814814814815\n", "Radial SVC Train Accuracy: 1.0\n", "Radial SVC Test Accuracy: 0.3333333333333333\n" ] } ], "source": [ "from sklearn import svm\n", "\n", "svc_lin_mod = svm.SVC(kernel=\"linear\").fit(X_train, y_train)\n", "svc_rbf_mod = svm.SVC(kernel=\"rbf\", gamma=0.7).fit(X_train, y_train)\n", "\n", "print(f'Linear SVC Train Accuracy: {accuracy_score(svc_lin_mod.predict(X_train), y_train)}')\n", "print(f'Linear SVC Test Accuracy: {accuracy_score(svc_lin_mod.predict(X_test), y_test)}')\n", "\n", "print(f'Radial SVC Train Accuracy: {accuracy_score(svc_rbf_mod.predict(X_train), y_train)}')\n", "print(f'Radial SVC Test Accuracy: {accuracy_score(svc_rbf_mod.predict(X_test), y_test)}')\n" ] }, { "cell_type": "code", "execution_count": 53, "id": "7d22e1c7-a314-4260-a3d2-1d4be750a300", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Polynomial SVC Train Accuracy: 0.7258064516129032\n", "Polynomial SVC Test Accuracy: 0.5555555555555556\n" ] } ], "source": [ "svc_poly_mod = svm.SVC(kernel=\"poly\", degree=1).fit(X_train, y_train)\n", "\n", "print(f'Polynomial SVC Train Accuracy: {accuracy_score(svc_poly_mod.predict(X_train), y_train)}')\n", "print(f'Polynomial SVC Test Accuracy: {accuracy_score(svc_poly_mod.predict(X_test), y_test)}')" ] }, { "cell_type": "code", "execution_count": 57, "id": "20d4a6e3-f78c-4b13-82ba-1a0c53256671", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0.9919354838709677" ] }, "execution_count": 57, "metadata": {}, "output_type": "execute_result" } ], "source": [ "accuracy_score(svm.LinearSVC(max_iter=10000).fit(X_train, y_train).predict(X_train), y_train)" ] }, { "cell_type": "code", "execution_count": 62, "id": "39fab5b7-1efe-4b04-8efc-f29bb3fa0b66", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0.9259259259259259" ] }, "execution_count": 62, "metadata": {}, "output_type": "execute_result" } ], "source": [ "accuracy_score(svm.LinearSVC(max_iter=10000).fit(X_train, y_train).predict(X_test), y_test)" ] }, { "cell_type": "code", "execution_count": null, "id": "d74ea0aa-14cf-402e-9574-317392a9e426", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.9" } }, "nbformat": 4, "nbformat_minor": 5 }