{ "cells": [ { "cell_type": "code", "execution_count": 2, "id": "88d4bc0d-a30e-4a74-9cd1-707bb10647e7", "metadata": {}, "outputs": [ { "ename": "SyntaxError", "evalue": "invalid character '²' (U+00B2) (4264702997.py, line 15)", "output_type": "error", "traceback": [ " \u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[2]\u001b[39m\u001b[32m, line 15\u001b[39m\n\u001b[31m \u001b[39m\u001b[31mThe Fast Fourier Transform (FFT) is one of the most important algorithms in computational science. It efficiently computes the Discrete Fourier Transform (DFT) and its inverse, reducing the computational complexity from O(n²) to O(n log n).\u001b[39m\n ^\n\u001b[31mSyntaxError\u001b[39m\u001b[31m:\u001b[39m invalid character '²' (U+00B2)\n" ] } ], "source": [ "# Fast Fourier Transform (FFT) - Complete Guide\n", "\n", "## Table of Contents\n", "1. [Introduction](#introduction)\n", "2. [Mathematical Foundation](#mathematical-foundation)\n", "3. [The DFT vs FFT](#the-dft-vs-fft)\n", "4. [Cooley-Tukey Algorithm](#cooley-tukey-algorithm)\n", "5. [Implementation from Scratch](#implementation-from-scratch)\n", "6. [Practical Examples](#practical-examples)\n", "7. [Applications](#applications)\n", "8. [Advanced Topics](#advanced-topics)\n", "\n", "## Introduction\n", "\n", "The Fast Fourier Transform (FFT) is one of the most important algorithms in computational science. It efficiently computes the Discrete Fourier Transform (DFT) and its inverse, reducing the computational complexity from O(n²) to O(n log n).\n", "\n", "**Key Applications:**\n", "- Signal processing and filtering\n", "- Image and audio compression\n", "- Solving partial differential equations\n", "- Fast polynomial multiplication\n", "- Digital communications\n", "\n", "```python\n", "import numpy as np\n", "import matplotlib.pyplot as plt\n", "from scipy.fft import fft, ifft\n", "import time\n", "\n", "# Set up plotting\n", "plt.style.use('seaborn-v0_8')\n", "plt.rcParams['figure.figsize'] = (12, 8)\n", "```\n", "\n", "## Mathematical Foundation\n", "\n", "### The Discrete Fourier Transform (DFT)\n", "\n", "For a sequence of N complex numbers x₀, x₁, ..., x_{N-1}, the DFT is defined as:\n", "\n", "**X_k = Σ(n=0 to N-1) x_n * e^(-2πikn/N)**\n", "\n", "where:\n", "- X_k is the k-th frequency component\n", "- x_n is the n-th time-domain sample\n", "- e^(-2πikn/N) is the complex exponential (twiddle factor)\n", "\n", "```python\n", "def naive_dft(x):\n", " \"\"\"\n", " Compute DFT using the naive O(n²) algorithm\n", " \"\"\"\n", " N = len(x)\n", " X = np.zeros(N, dtype=complex)\n", " \n", " for k in range(N):\n", " for n in range(N):\n", " X[k] += x[n] * np.exp(-2j * np.pi * k * n / N)\n", " \n", " return X\n", "\n", "# Example with a simple signal\n", "t = np.linspace(0, 1, 8, endpoint=False)\n", "x = np.sin(2 * np.pi * 2 * t) + 0.5 * np.sin(2 * np.pi * 4 * t)\n", "\n", "print(\"Input signal:\", x)\n", "print(\"DFT result:\", naive_dft(x))\n", "```\n", "\n", "### Complex Exponentials and Twiddle Factors\n", "\n", "The key insight of the FFT is that the twiddle factors W_N^{kn} = e^(-2πikn/N) have special symmetry properties:\n", "\n", "```python\n", "def plot_twiddle_factors():\n", " \"\"\"Visualize twiddle factors on the unit circle\"\"\"\n", " N = 8\n", " fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))\n", " \n", " # Unit circle\n", " theta = np.linspace(0, 2*np.pi, 100)\n", " ax1.plot(np.cos(theta), np.sin(theta), 'k--', alpha=0.3)\n", " \n", " # Twiddle factors for N=8\n", " for k in range(N):\n", " w = np.exp(-2j * np.pi * k / N)\n", " ax1.plot(w.real, w.imag, 'ro', markersize=10)\n", " ax1.annotate(f'W₈^{k}', (w.real, w.imag), xytext=(5, 5), \n", " textcoords='offset points', fontsize=12)\n", " \n", " ax1.set_xlim(-1.5, 1.5)\n", " ax1.set_ylim(-1.5, 1.5)\n", " ax1.set_aspect('equal')\n", " ax1.grid(True, alpha=0.3)\n", " ax1.set_title('Twiddle Factors on Unit Circle (N=8)')\n", " ax1.set_xlabel('Real Part')\n", " ax1.set_ylabel('Imaginary Part')\n", " \n", " # Magnitude and phase\n", " k_values = np.arange(N)\n", " w_values = np.exp(-2j * np.pi * k_values / N)\n", " \n", " ax2.stem(k_values, np.abs(w_values), basefmt=' ')\n", " ax2.set_title('Magnitude of Twiddle Factors')\n", " ax2.set_xlabel('k')\n", " ax2.set_ylabel('|W₈^k|')\n", " ax2.grid(True, alpha=0.3)\n", " \n", " plt.tight_layout()\n", " plt.show()\n", "\n", "plot_twiddle_factors()\n", "```\n", "\n", "## The DFT vs FFT\n", "\n", "### Computational Complexity Comparison\n", "\n", "```python\n", "def complexity_comparison():\n", " \"\"\"Compare DFT and FFT computational complexity\"\"\"\n", " sizes = 2 ** np.arange(4, 12) # Powers of 2 from 16 to 2048\n", " dft_times = []\n", " fft_times = []\n", " \n", " for N in sizes:\n", " # Generate random signal\n", " x = np.random.random(N) + 1j * np.random.random(N)\n", " \n", " # Time naive DFT (only for smaller sizes)\n", " if N <= 256:\n", " start = time.time()\n", " _ = naive_dft(x)\n", " dft_times.append(time.time() - start)\n", " else:\n", " dft_times.append(np.nan)\n", " \n", " # Time FFT\n", " start = time.time()\n", " _ = fft(x)\n", " fft_times.append(time.time() - start)\n", " \n", " # Plot results\n", " plt.figure(figsize=(12, 6))\n", " \n", " plt.subplot(1, 2, 1)\n", " valid_mask = ~np.isnan(dft_times)\n", " plt.loglog(sizes[valid_mask], np.array(dft_times)[valid_mask], 'ro-', label='Naive DFT O(n²)')\n", " plt.loglog(sizes, fft_times, 'bo-', label='FFT O(n log n)')\n", " plt.xlabel('Signal Length (N)')\n", " plt.ylabel('Time (seconds)')\n", " plt.title('Execution Time Comparison')\n", " plt.legend()\n", " plt.grid(True, alpha=0.3)\n", " \n", " plt.subplot(1, 2, 2)\n", " theoretical_dft = sizes**2 / (sizes[0]**2) * dft_times[0] if not np.isnan(dft_times[0]) else None\n", " theoretical_fft = sizes * np.log2(sizes) / (sizes[0] * np.log2(sizes[0])) * fft_times[0]\n", " \n", " if theoretical_dft is not None:\n", " plt.loglog(sizes, theoretical_dft, 'r--', alpha=0.7, label='Theoretical O(n²)')\n", " plt.loglog(sizes, theoretical_fft, 'b--', alpha=0.7, label='Theoretical O(n log n)')\n", " plt.loglog(sizes, fft_times, 'bo-', label='Actual FFT')\n", " plt.xlabel('Signal Length (N)')\n", " plt.ylabel('Relative Time')\n", " plt.title('Theoretical vs Actual Complexity')\n", " plt.legend()\n", " plt.grid(True, alpha=0.3)\n", " \n", " plt.tight_layout()\n", " plt.show()\n", "\n", "complexity_comparison()\n", "```\n", "\n", "## Cooley-Tukey Algorithm\n", "\n", "The most common FFT algorithm uses a divide-and-conquer approach, recursively breaking down the DFT into smaller DFTs.\n", "\n", "### Decimation-in-Time (DIT) Approach\n", "\n", "```python\n", "def fft_recursive(x):\n", " \"\"\"\n", " Recursive implementation of Cooley-Tukey FFT\n", " Input length must be a power of 2\n", " \"\"\"\n", " N = len(x)\n", " \n", " # Base case\n", " if N == 1:\n", " return x\n", " \n", " # Divide\n", " even = fft_recursive(x[0::2]) # Even-indexed elements\n", " odd = fft_recursive(x[1::2]) # Odd-indexed elements\n", " \n", " # Combine\n", " T = np.exp(-2j * np.pi * np.arange(N // 2) / N) * odd\n", " \n", " return np.concatenate([even + T, even - T])\n", "\n", "def bit_reverse_permutation(x):\n", " \"\"\"\n", " Rearrange array elements according to bit-reversed indices\n", " \"\"\"\n", " N = len(x)\n", " j = 0\n", " for i in range(1, N):\n", " bit = N >> 1\n", " while j & bit:\n", " j ^= bit\n", " bit >>= 1\n", " j ^= bit\n", " if j > i:\n", " x[i], x[j] = x[j], x[i]\n", " return x\n", "\n", "def fft_iterative(x):\n", " \"\"\"\n", " Iterative implementation of Cooley-Tukey FFT (more efficient)\n", " \"\"\"\n", " x = x.copy() # Don't modify original\n", " N = len(x)\n", " \n", " # Bit-reverse permutation\n", " x = bit_reverse_permutation(x)\n", " \n", " # Iterative FFT\n", " length = 2\n", " while length <= N:\n", " # Twiddle factor for this stage\n", " w = np.exp(-2j * np.pi / length)\n", " \n", " for i in range(0, N, length):\n", " wn = 1\n", " for j in range(length // 2):\n", " u = x[i + j]\n", " v = x[i + j + length // 2] * wn\n", " x[i + j] = u + v\n", " x[i + j + length // 2] = u - v\n", " wn *= w\n", " \n", " length *= 2\n", " \n", " return x\n", "\n", "# Test our implementations\n", "N = 16\n", "t = np.linspace(0, 1, N, endpoint=False)\n", "test_signal = np.sin(2 * np.pi * 2 * t) + 0.5 * np.cos(2 * np.pi * 4 * t)\n", "\n", "print(\"Testing FFT implementations:\")\n", "print(\"NumPy FFT: \", np.abs(fft(test_signal)))\n", "print(\"Recursive FFT: \", np.abs(fft_recursive(test_signal)))\n", "print(\"Iterative FFT: \", np.abs(fft_iterative(test_signal)))\n", "print(\"Max difference: \", np.max(np.abs(fft(test_signal) - fft_iterative(test_signal))))\n", "```\n", "\n", "### Visualization of FFT Stages\n", "\n", "```python\n", "def visualize_fft_stages():\n", " \"\"\"Visualize how FFT breaks down the problem\"\"\"\n", " N = 8\n", " x = np.array([1, 1, 1, 1, 0, 0, 0, 0], dtype=complex) # Simple rectangular pulse\n", " \n", " fig, axes = plt.subplots(3, 1, figsize=(15, 12))\n", " \n", " # Original signal\n", " axes[0].stem(range(N), x.real, basefmt=' ')\n", " axes[0].set_title('Original Signal (Time Domain)')\n", " axes[0].set_xlabel('Sample Index')\n", " axes[0].set_ylabel('Amplitude')\n", " axes[0].grid(True, alpha=0.3)\n", " \n", " # FFT result\n", " X = fft(x)\n", " axes[1].stem(range(N), np.abs(X), basefmt=' ')\n", " axes[1].set_title('FFT Result (Frequency Domain)')\n", " axes[1].set_xlabel('Frequency Bin')\n", " axes[1].set_ylabel('Magnitude')\n", " axes[1].grid(True, alpha=0.3)\n", " \n", " # Phase\n", " axes[2].stem(range(N), np.angle(X), basefmt=' ')\n", " axes[2].set_title('Phase Spectrum')\n", " axes[2].set_xlabel('Frequency Bin')\n", " axes[2].set_ylabel('Phase (radians)')\n", " axes[2].grid(True, alpha=0.3)\n", " \n", " plt.tight_layout()\n", " plt.show()\n", "\n", "visualize_fft_stages()\n", "```\n", "\n", "## Implementation from Scratch\n", "\n", "### Complete FFT Implementation with Error Checking\n", "\n", "```python\n", "class FFT:\n", " \"\"\"\n", " Complete FFT implementation with various algorithms and utilities\n", " \"\"\"\n", " \n", " @staticmethod\n", " def is_power_of_2(n):\n", " \"\"\"Check if n is a power of 2\"\"\"\n", " return n > 0 and (n & (n - 1)) == 0\n", " \n", " @staticmethod\n", " def next_power_of_2(n):\n", " \"\"\"Find the next power of 2 greater than or equal to n\"\"\"\n", " return 1 << (n - 1).bit_length()\n", " \n", " @staticmethod\n", " def zero_pad_to_power_of_2(x):\n", " \"\"\"Zero-pad signal to next power of 2 length\"\"\"\n", " N = len(x)\n", " next_pow2 = FFT.next_power_of_2(N)\n", " if next_pow2 > N:\n", " x_padded = np.zeros(next_pow2, dtype=complex)\n", " x_padded[:N] = x\n", " return x_padded\n", " return x\n", " \n", " @staticmethod\n", " def cooley_tukey_fft(x):\n", " \"\"\"\n", " Cooley-Tukey FFT with automatic zero-padding\n", " \"\"\"\n", " x = np.array(x, dtype=complex)\n", " \n", " # Zero-pad to power of 2 if necessary\n", " if not FFT.is_power_of_2(len(x)):\n", " x = FFT.zero_pad_to_power_of_2(x)\n", " \n", " return fft_iterative(x)\n", " \n", " @staticmethod\n", " def inverse_fft(X):\n", " \"\"\"\n", " Compute inverse FFT\n", " \"\"\"\n", " N = len(X)\n", " # Conjugate, apply FFT, conjugate again, and scale\n", " return np.conj(FFT.cooley_tukey_fft(np.conj(X))) / N\n", " \n", " @staticmethod\n", " def fft_2d(image):\n", " \"\"\"\n", " 2D FFT for image processing\n", " \"\"\"\n", " # Apply 1D FFT to each row\n", " rows_fft = np.array([FFT.cooley_tukey_fft(row) for row in image])\n", " # Apply 1D FFT to each column\n", " return np.array([FFT.cooley_tukey_fft(col) for col in rows_fft.T]).T\n", "\n", "# Test the complete implementation\n", "print(\"Testing complete FFT implementation:\")\n", "test_data = np.random.random(15) + 1j * np.random.random(15) # Non-power of 2\n", "our_fft = FFT.cooley_tukey_fft(test_data)\n", "scipy_fft = fft(test_data, n=FFT.next_power_of_2(len(test_data)))\n", "print(f\"Max difference: {np.max(np.abs(our_fft - scipy_fft)):.2e}\")\n", "```\n", "\n", "## Practical Examples\n", "\n", "### 1. Signal Filtering\n", "\n", "```python\n", "def signal_filtering_example():\n", " \"\"\"Demonstrate signal filtering using FFT\"\"\"\n", " # Create a noisy signal\n", " t = np.linspace(0, 1, 1000, endpoint=False)\n", " clean_signal = np.sin(2 * np.pi * 50 * t) + np.sin(2 * np.pi * 120 * t)\n", " noise = 0.5 * np.random.randn(len(t))\n", " noisy_signal = clean_signal + noise\n", " \n", " # Compute FFT\n", " X = fft(noisy_signal)\n", " freqs = np.fft.fftfreq(len(t), t[1] - t[0])\n", " \n", " # Design a simple low-pass filter\n", " cutoff = 100 # Hz\n", " X_filtered = X.copy()\n", " X_filtered[np.abs(freqs) > cutoff] = 0 # Zero out high frequencies\n", " \n", " # Inverse FFT to get filtered signal\n", " filtered_signal = np.real(ifft(X_filtered))\n", " \n", " # Plot results\n", " fig, axes = plt.subplots(2, 2, figsize=(15, 10))\n", " \n", " # Time domain signals\n", " axes[0, 0].plot(t[:200], clean_signal[:200], 'g-', label='Clean', linewidth=2)\n", " axes[0, 0].plot(t[:200], noisy_signal[:200], 'r-', alpha=0.7, label='Noisy')\n", " axes[0, 0].plot(t[:200], filtered_signal[:200], 'b--', label='Filtered', linewidth=2)\n", " axes[0, 0].set_title('Time Domain Signals')\n", " axes[0, 0].set_xlabel('Time (s)')\n", " axes[0, 0].set_ylabel('Amplitude')\n", " axes[0, 0].legend()\n", " axes[0, 0].grid(True, alpha=0.3)\n", " \n", " # Frequency domain - original\n", " axes[0, 1].semilogy(freqs[:len(freqs)//2], np.abs(X[:len(X)//2]))\n", " axes[0, 1].axvline(cutoff, color='r', linestyle='--', label=f'Cutoff: {cutoff} Hz')\n", " axes[0, 1].set_title('Original Spectrum')\n", " axes[0, 1].set_xlabel('Frequency (Hz)')\n", " axes[0, 1].set_ylabel('Magnitude')\n", " axes[0, 1].legend()\n", " axes[0, 1].grid(True, alpha=0.3)\n", " \n", " # Frequency domain - filtered\n", " axes[1, 0].semilogy(freqs[:len(freqs)//2], np.abs(X_filtered[:len(X_filtered)//2]))\n", " axes[1, 0].axvline(cutoff, color='r', linestyle='--', label=f'Cutoff: {cutoff} Hz')\n", " axes[1, 0].set_title('Filtered Spectrum')\n", " axes[1, 0].set_xlabel('Frequency (Hz)')\n", " axes[1, 0].set_ylabel('Magnitude')\n", " axes[1, 0].legend()\n", " axes[1, 0].grid(True, alpha=0.3)\n", " \n", " # Error analysis\n", " error = np.abs(clean_signal - filtered_signal)\n", " axes[1, 1].plot(t[:200], error[:200])\n", " axes[1, 1].set_title(f'Reconstruction Error (RMS: {np.sqrt(np.mean(error**2)):.3f})')\n", " axes[1, 1].set_xlabel('Time (s)')\n", " axes[1, 1].set_ylabel('Error')\n", " axes[1, 1].grid(True, alpha=0.3)\n", " \n", " plt.tight_layout()\n", " plt.show()\n", "\n", "signal_filtering_example()\n", "```\n", "\n", "### 2. Image Processing with 2D FFT\n", "\n", "```python\n", "def image_processing_example():\n", " \"\"\"Demonstrate 2D FFT for image processing\"\"\"\n", " # Create a simple test image\n", " x = np.linspace(-5, 5, 128)\n", " y = np.linspace(-5, 5, 128)\n", " X, Y = np.meshgrid(x, y)\n", " \n", " # Create a pattern with different frequency components\n", " image = (np.sin(X) * np.cos(Y) + 0.5 * np.sin(4*X) * np.sin(4*Y) + \n", " 0.25 * np.sin(8*X) * np.cos(2*Y))\n", " \n", " # Add noise\n", " noisy_image = image + 0.3 * np.random.randn(*image.shape)\n", " \n", " # Compute 2D FFT\n", " F = np.fft.fft2(noisy_image)\n", " F_shifted = np.fft.fftshift(F) # Shift zero frequency to center\n", " \n", " # Create a circular low-pass filter\n", " rows, cols = image.shape\n", " crow, ccol = rows//2, cols//2\n", " r = 30 # Filter radius\n", " \n", " # Create mask\n", " y_mask, x_mask = np.ogrid[:rows, :cols]\n", " mask = (x_mask - ccol)**2 + (y_mask - crow)**2 <= r**2\n", " \n", " # Apply filter\n", " F_filtered = F_shifted.copy()\n", " F_filtered[~mask] = 0\n", " \n", " # Inverse FFT\n", " filtered_image = np.real(np.fft.ifft2(np.fft.ifftshift(F_filtered)))\n", " \n", " # Plot results\n", " fig, axes = plt.subplots(2, 3, figsize=(18, 12))\n", " \n", " # Original and noisy images\n", " im1 = axes[0, 0].imshow(image, cmap='gray')\n", " axes[0, 0].set_title('Original Image')\n", " axes[0, 0].axis('off')\n", " plt.colorbar(im1, ax=axes[0, 0])\n", " \n", " im2 = axes[0, 1].imshow(noisy_image, cmap='gray')\n", " axes[0, 1].set_title('Noisy Image')\n", " axes[0, 1].axis('off')\n", " plt.colorbar(im2, ax=axes[0, 1])\n", " \n", " im3 = axes[0, 2].imshow(filtered_image, cmap='gray')\n", " axes[0, 2].set_title('Filtered Image')\n", " axes[0, 2].axis('off')\n", " plt.colorbar(im3, ax=axes[0, 2])\n", " \n", " # Frequency domain representations\n", " im4 = axes[1, 0].imshow(np.log(1 + np.abs(F_shifted)), cmap='hot')\n", " axes[1, 0].set_title('Original Spectrum (log scale)')\n", " axes[1, 0].axis('off')\n", " plt.colorbar(im4, ax=axes[1, 0])\n", " \n", " im5 = axes[1, 1].imshow(mask, cmap='gray')\n", " axes[1, 1].set_title('Low-pass Filter Mask')\n", " axes[1, 1].axis('off')\n", " plt.colorbar(im5, ax=axes[1, 1])\n", " \n", " im6 = axes[1, 2].imshow(np.log(1 + np.abs(F_filtered)), cmap='hot')\n", " axes[1, 2].set_title('Filtered Spectrum (log scale)')\n", " axes[1, 2].axis('off')\n", " plt.colorbar(im6, ax=axes[1, 2])\n", " \n", " plt.tight_layout()\n", " plt.show()\n", "\n", "image_processing_example()\n", "```\n", "\n", "### 3. Fast Convolution\n", "\n", "```python\n", "def convolution_example():\n", " \"\"\"Demonstrate fast convolution using FFT\"\"\"\n", " # Create signals\n", " N = 1000\n", " t = np.linspace(0, 10, N)\n", " \n", " # Input signal: sum of sinusoids\n", " signal = np.sin(2 * np.pi * 1 * t) + 0.5 * np.sin(2 * np.pi * 3 * t)\n", " \n", " # Filter kernel: Gaussian\n", " kernel_size = 101\n", " kernel = np.exp(-0.5 * ((np.arange(kernel_size) - kernel_size//2) / 10)**2)\n", " kernel = kernel / np.sum(kernel) # Normalize\n", " \n", " # Direct convolution (slow)\n", " start_time = time.time()\n", " conv_direct = np.convolve(signal, kernel, mode='same')\n", " direct_time = time.time() - start_time\n", " \n", " # FFT-based convolution (fast)\n", " start_time = time.time()\n", " # Zero-pad both signals to avoid circular convolution artifacts\n", " n_conv = len(signal) + len(kernel) - 1\n", " n_fft = 2 ** int(np.ceil(np.log2(n_conv)))\n", " \n", " signal_padded = np.zeros(n_fft)\n", " kernel_padded = np.zeros(n_fft)\n", " signal_padded[:len(signal)] = signal\n", " kernel_padded[:len(kernel)] = kernel\n", " \n", " # Convolution in frequency domain\n", " conv_fft_full = np.real(ifft(fft(signal_padded) * fft(kernel_padded)))\n", " # Extract the 'same' portion\n", " start_idx = len(kernel) // 2\n", " conv_fft = conv_fft_full[start_idx:start_idx + len(signal)]\n", " fft_time = time.time() - start_time\n", " \n", " # Plot results\n", " fig, axes = plt.subplots(2, 2, figsize=(15, 10))\n", " \n", " # Original signal and kernel\n", " axes[0, 0].plot(t, signal, 'b-', label='Original Signal')\n", " axes[0, 0].set_title('Input Signal')\n", " axes[0, 0].set_xlabel('Time')\n", " axes[0, 0].set_ylabel('Amplitude')\n", " axes[0, 0].grid(True, alpha=0.3)\n", " \n", " axes[0, 1].plot(kernel, 'r-', linewidth=2)\n", " axes[0, 1].set_title('Convolution Kernel (Gaussian)')\n", " axes[0, 1].set_xlabel('Sample')\n", " axes[0, 1].set_ylabel('Amplitude')\n", " axes[0, 1].grid(True, alpha=0.3)\n", " \n", " # Convolution results\n", " axes[1, 0].plot(t, conv_direct, 'g-', label='Direct Convolution', linewidth=2)\n", " axes[1, 0].plot(t, conv_fft, 'r--', label='FFT Convolution', alpha=0.8)\n", " axes[1, 0].set_title(f'Convolution Results\\nDirect: {direct_time:.4f}s, FFT: {fft_time:.4f}s')\n", " axes[1, 0].set_xlabel('Time')\n", " axes[1, 0].set_ylabel('Amplitude')\n", " axes[1, 0].legend()\n", " axes[1, 0].grid(True, alpha=0.3)\n", " \n", " # Error analysis\n", " error = np.abs(conv_direct - conv_fft)\n", " axes[1, 1].plot(t, error)\n", " axes[1, 1].set_title(f'Absolute Error (Max: {np.max(error):.2e})')\n", " axes[1, 1].set_xlabel('Time')\n", " axes[1, 1].set_ylabel('Error')\n", " axes[1, 1].grid(True, alpha=0.3)\n", " \n", " plt.tight_layout()\n", " plt.show()\n", " \n", " print(f\"Speed improvement: {direct_time/fft_time:.1f}x faster\")\n", "\n", "convolution_example()\n", "```\n", "\n", "## Applications\n", "\n", "### Polynomial Multiplication\n", "\n", "```python\n", "def polynomial_multiplication():\n", " \"\"\"Demonstrate fast polynomial multiplication using FFT\"\"\"\n", " # Define two polynomials\n", " # p(x) = 1 + 2x + 3x² + 4x³\n", " # q(x) = 2 + 3x + x²\n", " \n", " p = [1, 2, 3, 4] # Coefficients in ascending order\n", " q = [2, 3, 1]\n", " \n", " # Direct multiplication\n", " def multiply_polynomials_direct(p, q):\n", " result = [0] * (len(p) + len(q) - 1)\n", " for i in range(len(p)):\n", " for j in range(len(q)):\n", " result[i + j] += p[i] * q[j]\n", " return result\n", " \n", " # FFT-based multiplication\n", " def multiply_polynomials_fft(p, q):\n", " n = len(p) + len(q) - 1\n", " n_fft = 2 ** int(np.ceil(np.log2(n)))\n", " \n", " # Zero-pad\n", " p_padded = p + [0] * (n_fft - len(p))\n", " q_padded = q + [0] * (n_fft - len(q))\n", " \n", " # FFT, multiply, inverse FFT\n", " p_fft = fft(p_padded)\n", " q_fft = fft(q_padded)\n", " result_fft = p_fft * q_fft\n", " result = np.real(ifft(result_fft))\n", " \n", " return result[:n]\n", " \n", " # Compare results\n", " direct_result = multiply_polynomials_direct(p, q)\n", " fft_result = multiply_polynomials_fft(p, q)\n", " \n", " print(\"Polynomial Multiplication:\")\n", " print(f\"p(x) coefficients: {p}\")\n", " print(f\"q(x) coefficients: {q}\")\n", " print(f\"Direct result: {direct_result}\")\n", " print(f\"FFT result: {[round(x) for x in fft_result]}\")\n", " print(f\"Max difference: {np.max(np.abs(np.array(direct_result) - fft_result[:len(direct_result)])):.2e}\")\n", "\n", "polynomial_multiplication()\n", "```\n", "\n", "## Advanced Topics\n", "\n", "### Window Functions and Spectral Leakage\n", "\n", "```python\n", "def window_functions_demo():\n", " \"\"\"Demonstrate the effect of window functions on spectral analysis\"\"\"\n", " # Create a signal with two close frequencies\n", " N = 512\n", " t = np.linspace(0, 1, N, endpoint=False)\n", " f1, f2 = 50, 55 # Close frequencies\n", " signal = np.sin(2 * np.pi * f1 * t) + np.sin(2 * np.pi * f2 * t)\n", " \n", " # Define window functions\n", " windows = {\n", " 'Rectangular': np.ones(N),\n", " 'Hanning': np.hanning(N),\n", " 'Hamming': np.hamming(N),\n", " 'Blackman': np.blackman(N)\n", " }\n", " \n", " fig, axes = plt.subplots(2, 2, figsize=(15, 10))\n", " axes = axes.flatten()\n", " \n", " freqs = np.fft.fftfreq(N, 1/N)\n", " freq_range = (freqs >= 0) & (freqs <= 100)\n", " \n", " for i, (name, window) in enumerate(windows.items()):\n", " # Apply window\n", " windowed_signal = signal * window\n", " \n", " # Compute FFT\n", " spectrum = fft(windowed_signal)\n", " magnitude = np.abs(spectrum)\n", " \n", " # Plot\n", " axes[i].plot(freqs[freq_range], magnitude[freq_range])\n", " axes[i].set_title(f'{name} Window')\n", " axes[i].set_xlabel('Frequency (Hz)')\n", " axes[i].set_ylabel('Magnitude')\n", " axes[i].axvline(f1, color='r', linestyle='--', alpha=0.7, label=f'f₁={f1} Hz')\n", " axes[i].axvline(f2, color='r', linestyle='--', alpha=0.7, label=f'f₂={f2} Hz')\n", " axes[i].legend()\n", " axes[i].grid(True, alpha=0.3)\n", " \n", " plt" ] }, { "cell_type": "code", "execution_count": null, "id": "dae15ff8-3dd3-4ad8-a88d-93d33ca454e1", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python [conda env:metal]", "language": "python", "name": "conda-env-metal-py" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.9" } }, "nbformat": 4, "nbformat_minor": 5 }