""" Generate a README-ready animation that VISUALLY proves the Universal Approximation Theorem. * A tiny ReLU MLP learns to approximate **three** aesthetically-pleasing 3-D functions in succession. The network’s prediction surface (right) morphs until it matches the ground-truth surface (left), then the scene wipes to the next function. * Output: `universal.gif` (≈ 4 MB @ 512×288, 45 fps, 100 frames). Perfect for GitHub 😎. * Dependencies: `pip install torch matplotlib imageio tqdm` (CPU-only is fine). Usage ----- ```bash python universal_approx_animation.py # creates universal.gif in the cwd ``` Drop this in your repo and embed in `README.md` with: ```markdown ![UA-Theorem in action](universal.gif) ``` """ from __future__ import annotations import math import os import shutil from pathlib import Path from typing import Callable, List import imageio.v3 as iio import matplotlib.pyplot as plt import numpy as np import torch import torch.nn as nn from mpl_toolkits.mplot3d import Axes3D # noqa: F401 (needed for 3-D projection) from tqdm import tqdm # ---------------------------- CONFIGURABLE PARAMS --------------------------- # GIF_PATH = Path("universal.gif") FRAME_DIR = Path("frames_tmp") GRID_N = 50 # 50×50 grid -> 2 500 points per function SNAPSHOT_EVERY = 20 # add a frame every N training steps EPOCHS_PER_FUNC = 400 # so ~20 frames per function DEVICE = torch.device("cpu") # change to "cuda" if you want SEED = 42 # Fancy but lightweight colour map for surfaces CMAP_PRED = plt.cm.viridis CMAP_TRUE = plt.cm.plasma # --------------------------------------------------------------------------- # torch.manual_seed(SEED) np.random.seed(SEED) # --------------------------- THE TEST FUNCTIONS ---------------------------- # def mexican_hat(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: r2 = x ** 2 + y ** 2 return torch.sin(r2 * math.pi) / (1 + 5 * r2) def ripple(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: r = torch.sqrt(x ** 2 + y ** 2) + 1e-9 return torch.sin(8 * r) / (8 * r) def saddle(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: return x ** 2 - y ** 2 FUNCS: List[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = [ mexican_hat, ripple, saddle, ] FUNC_NAMES = ["Mexican Hat", "Rippling Sinc", "Hyperbolic Saddle"] # --------------------------- THE NETWORK MODEL ----------------------------- # def make_mlp(in_dim: int = 2, hidden: int = 64, depth: int = 2, out_dim: int = 1) -> nn.Module: layers: List[nn.Module] = [] for _ in range(depth): layers += [nn.Linear(in_dim if not layers else hidden, hidden), nn.ReLU()] layers.append(nn.Linear(hidden, out_dim)) return nn.Sequential(*layers) # ---------------------------- TRAIN + CAPTURE ------------------------------ # def surface_plot(ax: Axes3D, X: np.ndarray, Y: np.ndarray, Z: np.ndarray, *, cmap, title: str): ax.plot_surface(X, Y, Z, cmap=cmap, rstride=1, cstride=1, linewidth=0, antialiased=True) ax.set_xticks([]) ax.set_yticks([]) ax.set_zticks([]) ax.set_title(title, fontsize=10, pad=10) def capture_frame(epoch: int, func_idx: int, X: np.ndarray, Y: np.ndarray, Z_true: np.ndarray, Z_pred: np.ndarray, loss: float): fig = plt.figure(figsize=(5.12, 2.88), dpi=100) # 512×288 px gs = fig.add_gridspec(1, 2, wspace=0.05) # Truth on left, prediction on right ax1 = fig.add_subplot(gs[0, 0], projection="3d") surface_plot(ax1, X, Y, Z_true, cmap=CMAP_TRUE, title=f"Ground Truth: {FUNC_NAMES[func_idx]}") ax2 = fig.add_subplot(gs[0, 1], projection="3d") surface_plot(ax2, X, Y, Z_pred, cmap=CMAP_PRED, title=f"MLP Approx | epoch {epoch} | loss {loss:.3e}") fig.text(0.01, 0.01, "Aayush Bajaj — Universal Approximation Demo", fontsize=6, color="#666") frame_path = FRAME_DIR / f"frame_{func_idx:02d}_{epoch:04d}.png" fig.savefig(frame_path, bbox_inches="tight", pad_inches=0) plt.close(fig) # ---------------------------- MAIN WORKFLOW -------------------------------- # def main(): FRAME_DIR.mkdir(exist_ok=True) # Build static grid once grid = torch.linspace(-1, 1, GRID_N) X_grid, Y_grid = torch.meshgrid(grid, grid, indexing="xy") XY = torch.stack([X_grid.reshape(-1), Y_grid.reshape(-1)], dim=1).to(DEVICE) model = make_mlp().to(DEVICE) criterion = nn.MSELoss() opt = torch.optim.Adam(model.parameters(), lr=1e-2) global_epoch = 0 for f_idx, f in enumerate(FUNCS): # Compute ground truth once per function with torch.no_grad(): Z_true = f(X_grid, Y_grid).cpu().numpy() # Reset network parameters between functions to avoid bias model.apply(lambda m: m.reset_parameters() if hasattr(m, "reset_parameters") else None) opt = torch.optim.Adam(model.parameters(), lr=2e-2) for local_epoch in tqdm(range(1, EPOCHS_PER_FUNC + 1), desc=f"Training {FUNC_NAMES[f_idx]}"): model.train() preds = model(XY).squeeze() loss = criterion(preds, f(XY[:, 0], XY[:, 1])) opt.zero_grad() loss.backward() opt.step() # Capture frame if local_epoch % SNAPSHOT_EVERY == 0 or local_epoch == 1 or local_epoch == EPOCHS_PER_FUNC: with torch.no_grad(): Z_pred = ( model(XY).detach().cpu().numpy().reshape(GRID_N, GRID_N) ) capture_frame( epoch=local_epoch, func_idx=f_idx, X=X_grid.cpu().numpy(), Y=Y_grid.cpu().numpy(), Z_true=Z_true, Z_pred=Z_pred, loss=loss.item(), ) global_epoch += 1 # Assemble GIF print("\nEncoding GIF…") frames = [iio.imread(fp) for fp in sorted(FRAME_DIR.glob("frame_*.png"))] iio.imwrite(GIF_PATH, frames, duration=100) # duration in ms per frame # Clean-up temp frames to keep repo tidy shutil.rmtree(FRAME_DIR) print(f"Done. ✨ {GIF_PATH} ready for your README!") if __name__ == "__main__": main()