import torch import torch.nn as nn import torch.nn.functional as F from torch.utils.data import DataLoader from torchvision import datasets, transforms import matplotlib.pyplot as plt import seaborn as sns from sklearn.metrics import confusion_matrix import numpy as np # Custom erf activation function class ErfActivation(nn.Module): def forward(self, x): return torch.erf(x) # Define network with erf activations class Net(nn.Module): def __init__(self): super(Net, self).__init__() self.fc1 = nn.Linear(28 * 28, 128) self.act1 = ErfActivation() self.fc2 = nn.Linear(128, 64) self.act2 = ErfActivation() self.fc3 = nn.Linear(64, 10) # output layer (no activation for CrossEntropy) def forward(self, x): x = x.view(-1, 28 * 28) x = self.act1(self.fc1(x)) x = self.act2(self.fc2(x)) x = self.fc3(x) return x # Use MPS if available device = torch.device("mps" if torch.backends.mps.is_available() else "cpu") print(f"Using device: {device}") # Load MNIST data transform = transforms.ToTensor() train_data = datasets.MNIST(root='./data', train=True, download=True, transform=transform) test_data = datasets.MNIST(root='./data', train=False, download=True, transform=transform) train_loader = DataLoader(train_data, batch_size=64, shuffle=True) test_loader = DataLoader(test_data, batch_size=1000, shuffle=False) # Initialize model model = Net().to(device) criterion = nn.CrossEntropyLoss() optimizer = torch.optim.Adam(model.parameters(), lr=0.001) # Training loop epochs = 10 losses = [] accuracies = [] for epoch in range(epochs): model.train() running_loss = 0 correct = 0 total = 0 for images, labels in train_loader: images, labels = images.to(device), labels.to(device) optimizer.zero_grad() outputs = model(images) loss = criterion(outputs, labels) loss.backward() optimizer.step() running_loss += loss.item() _, preds = torch.max(outputs, 1) correct += (preds == labels).sum().item() total += labels.size(0) avg_loss = running_loss / len(train_loader) acc = correct / total losses.append(avg_loss) accuracies.append(acc) print(f"Epoch {epoch + 1}: Loss = {avg_loss:.4f}, Accuracy = {acc:.4f}") # Plot training curves plt.figure(figsize=(10, 4)) plt.subplot(1, 2, 1) plt.plot(losses, label="Loss") plt.title("Training Loss") plt.xlabel("Epoch") plt.ylabel("Loss") plt.grid(True) plt.subplot(1, 2, 2) plt.plot(accuracies, label="Accuracy") plt.title("Training Accuracy") plt.xlabel("Epoch") plt.ylabel("Accuracy") plt.grid(True) plt.tight_layout() plt.show() # Evaluate on test set model.eval() all_preds = [] all_labels = [] with torch.no_grad(): for images, labels in test_loader: images = images.to(device) outputs = model(images) preds = outputs.argmax(dim=1).cpu().numpy() all_preds.extend(preds) all_labels.extend(labels.numpy()) # Confusion matrix cm = confusion_matrix(all_labels, all_preds) plt.figure(figsize=(10, 8)) sns.heatmap(cm, annot=True, fmt="d", cmap="Blues") plt.title("Confusion Matrix (Test Set)") plt.xlabel("Predicted") plt.ylabel("Actual") plt.tight_layout() plt.show()