{ "cells": [ { "cell_type": "code", "execution_count": 229, "id": "d56bdde4-ef09-4c8a-8d23-f708cc5e23b5", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([0.92203867, 0.93767056])" ] }, "execution_count": 229, "metadata": {}, "output_type": "execute_result" } ], "source": [ "'''\n", "algorithm:\n", "\n", "1. randomly initialise k clusters\n", "2. for each datapoint in the dataset, compare the euclidean distance against all clusters\n", "3. assign that datapoint to the closest cluster\n", "\n", "1. update cluster locations\n", "\n", "'''\n", "\n", "import numpy as np\n", "from dataclasses import dataclass\n", "\n", "\n", "\"\"\"\n", "implements the euclidean distance.\n", "\"\"\"\n", "def k_dist(X: np.ndarray, Y: np.ndarray):\n", " return np.linalg.norm(X-Y)\n", "\n", "@dataclass\n", "class Centroid:\n", " location: np.ndarray\n", " vectors: np.ndarray # dimension ( None:\n", " self.k = k\n", " self.centroids = [Centroid(\n", " location = sum(X[i] for i in range(X.shape[0]))/(j+1),\n", " vectors = []\n", " ) for j in range(k)]\n", " self.X = X\n", "\n", " \"\"\"\n", " assign vectors to each centroid and then recompute the centroid location. occurs n iter times\n", " \"\"\"\n", " def fit(self, iters: int, viz: bool = False) -> None:\n", " for i in range(iters):\n", "\n", " # reset centroids for every iteration\n", " self.centroids = [Centroid(\n", " location = sum(self.X[i] for i in range(self.X.shape[0]))/(j+1),\n", " vectors = []\n", " ) for j in range(k)]\n", " \n", " for X_i in self.X:\n", " distances = []\n", " for C in self.centroids:\n", " distances.append(k_dist(X_i, C.location))\n", " # add the vector to the closest centroid:\n", " self.centroids[np.argmin(distances)].vectors.append(X_i)\n", "\n", " for C in self.centroids:\n", " if len(C.vectors) > 0:\n", " C.location = sum(C.vectors) / len(C.vectors)\n", " \n", " \"\"\"\n", " compares the given datapoint against all centroids and returns the class label of that centroid.\n", " \"\"\"\n", " def predict(self, X_i: np.ndarray) -> np.ndarray:\n", " return self.centroids[np.argmin(\n", " [k_dist(X_i, C.location) for C in self.centroids]\n", " )].location\n", "\n", "from sklearn.datasets import make_blobs\n", "X_train, y = make_blobs(n_samples=20, random_state=123, cluster_std=0.6, centers=[[1,1],[3,3]])\n", "kmeans = KMeans(k=20, X=X_train)\n", "kmeans.fit(5)\n", "kmeans.predict([1.2,1.2])" ] }, { "cell_type": "code", "execution_count": 231, "id": "e9c9ebee-48ed-4bf7-8fb4-6b01f9dbabcf", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "20" ] }, "execution_count": 231, "metadata": {}, "output_type": "execute_result" } ], "source": [ "len(kmeans.centroids)" ] }, { "cell_type": "code", "execution_count": 233, "id": "b84f20fb-cb95-4fdb-aee1-affdbecdc8f3", "metadata": {}, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "import matplotlib.pyplot as plt\n", "import matplotlib\n", "# colour for each centroid\n", "k = len(kmeans.centroids) \n", "from matplotlib import colormaps\n", "cmap = colormaps['hsv'].resampled(k)\n", "cluster_colours = cmap(range(k)) \n", "\n", "for i, C in enumerate(kmeans.centroids):\n", "\n", " # nothing was assigned to this centroid → skip it\n", " if not C.vectors: # empty list == False\n", " continue\n", " # turn [[x, y], [x, y], …] into an (N, 2) float array\n", " vecs = np.stack(C.vectors, axis=0) # safe because we know it’s non-empty\n", " xs, ys = vecs[:, 0], vecs[:, 1]\n", "\n", " # draw the points that belong to this centroid\n", " plt.scatter(xs, ys,\n", " color=cluster_colours[i],\n", " s=20, alpha=0.7, # tweak size / transparency at will\n", " label=f\"cluster {i}\")\n", "\n", " # (optional) also plot the centroid itself\n", " plt.scatter(*C.location, # or whatever attribute stores the mean\n", " marker='x', s=80, linewidths=2,\n", " color=cluster_colours[i])\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": null, "id": "22f0a904-4c23-442a-aeb7-b1aa186b34ff", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "metal", "language": "python", "name": "metal" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.11" } }, "nbformat": 4, "nbformat_minor": 5 }