DATA_HUB = dict() DATA_URL = 'http://d2l-data.s3-accelerate.amazonaws.com/' import jax import flax from jax import numpy as jnp from flax import linen as nn import random get_seed = lambda: random.randint(0, 1e6) get_key = lambda: jax.random.PRNGKey(get_seed()) nn_Module = nn.Module ################# WARNING ################ # The below part is generated automatically through: # d2lbook build lib # Don't edit it directly import collections import hashlib import inspect import math import os import random import re import shutil import sys import tarfile import time import zipfile from collections import defaultdict import pandas as pd import requests from IPython import display from matplotlib import pyplot as plt from matplotlib_inline import backend_inline d2l = sys.modules[__name__] from dataclasses import field from functools import partial from types import FunctionType from typing import Any import flax import jax import numpy as np import optax import tensorflow as tf import tensorflow_datasets as tfds from flax import linen as nn from flax.training import train_state from jax import grad from jax import numpy as jnp from jax import vmap def use_svg_display(): """Use the svg format to display a plot in Jupyter. Defined in :numref:`sec_calculus`""" backend_inline.set_matplotlib_formats('svg') def set_figsize(figsize=(3.5, 2.5)): """Set the figure size for matplotlib. Defined in :numref:`sec_calculus`""" use_svg_display() d2l.plt.rcParams['figure.figsize'] = figsize def set_axes(axes, xlabel, ylabel, xlim, ylim, xscale, yscale, legend): """Set the axes for matplotlib. Defined in :numref:`sec_calculus`""" axes.set_xlabel(xlabel), axes.set_ylabel(ylabel) axes.set_xscale(xscale), axes.set_yscale(yscale) axes.set_xlim(xlim), axes.set_ylim(ylim) if legend: axes.legend(legend) axes.grid() def plot(X, Y=None, xlabel=None, ylabel=None, legend=[], xlim=None, ylim=None, xscale='linear', yscale='linear', fmts=('-', 'm--', 'g-.', 'r:'), figsize=(3.5, 2.5), axes=None): """Plot data points. Defined in :numref:`sec_calculus`""" def has_one_axis(X): # True if X (tensor or list) has 1 axis return (hasattr(X, "ndim") and X.ndim == 1 or isinstance(X, list) and not hasattr(X[0], "__len__")) if has_one_axis(X): X = [X] if Y is None: X, Y = [[]] * len(X), X elif has_one_axis(Y): Y = [Y] if len(X) != len(Y): X = X * len(Y) set_figsize(figsize) if axes is None: axes = d2l.plt.gca() axes.cla() for x, y, fmt in zip(X, Y, fmts): axes.plot(x,y,fmt) if len(x) else axes.plot(y,fmt) set_axes(axes, xlabel, ylabel, xlim, ylim, xscale, yscale, legend) def add_to_class(Class): """Register functions as methods in created class. Defined in :numref:`sec_oo-design`""" def wrapper(obj): setattr(Class, obj.__name__, obj) return wrapper class HyperParameters: """The base class of hyperparameters.""" def save_hyperparameters(self, ignore=[]): """Defined in :numref:`sec_oo-design`""" raise NotImplemented def save_hyperparameters(self, ignore=[]): """Save function arguments into class attributes. Defined in :numref:`sec_utils`""" frame = inspect.currentframe().f_back _, _, _, local_vars = inspect.getargvalues(frame) self.hparams = {k:v for k, v in local_vars.items() if k not in set(ignore+['self']) and not k.startswith('_')} for k, v in self.hparams.items(): setattr(self, k, v) class ProgressBoard(d2l.HyperParameters): """The board that plots data points in animation. Defined in :numref:`sec_oo-design`""" def __init__(self, xlabel=None, ylabel=None, xlim=None, ylim=None, xscale='linear', yscale='linear', ls=['-', '--', '-.', ':'], colors=['C0', 'C1', 'C2', 'C3'], fig=None, axes=None, figsize=(3.5, 2.5), display=True): self.save_hyperparameters() def draw(self, x, y, label, every_n=1): raise NotImplemented def draw(self, x, y, label, every_n=1): """Defined in :numref:`sec_utils`""" Point = collections.namedtuple('Point', ['x', 'y']) if not hasattr(self, 'raw_points'): self.raw_points = collections.OrderedDict() self.data = collections.OrderedDict() if label not in self.raw_points: self.raw_points[label] = [] self.data[label] = [] points = self.raw_points[label] line = self.data[label] points.append(Point(x, y)) if len(points) != every_n: return mean = lambda x: sum(x) / len(x) line.append(Point(mean([p.x for p in points]), mean([p.y for p in points]))) points.clear() if not self.display: return d2l.use_svg_display() if self.fig is None: self.fig = d2l.plt.figure(figsize=self.figsize) plt_lines, labels = [], [] for (k, v), ls, color in zip(self.data.items(), self.ls, self.colors): plt_lines.append(d2l.plt.plot([p.x for p in v], [p.y for p in v], linestyle=ls, color=color)[0]) labels.append(k) axes = self.axes if self.axes else d2l.plt.gca() if self.xlim: axes.set_xlim(self.xlim) if self.ylim: axes.set_ylim(self.ylim) if not self.xlabel: self.xlabel = self.x axes.set_xlabel(self.xlabel) axes.set_ylabel(self.ylabel) axes.set_xscale(self.xscale) axes.set_yscale(self.yscale) axes.legend(plt_lines, labels) display.display(self.fig) display.clear_output(wait=True) class Module(d2l.nn_Module, d2l.HyperParameters): """The base class of models. Defined in :numref:`sec_oo-design`""" # No need for save_hyperparam when using Python dataclass plot_train_per_epoch: int = field(default=2, init=False) plot_valid_per_epoch: int = field(default=1, init=False) # Use default_factory to make sure new plots are generated on each run board: ProgressBoard = field(default_factory=lambda: ProgressBoard(), init=False) def loss(self, y_hat, y): raise NotImplementedError # JAX & Flax do not have a forward-method-like syntax. Flax uses setup # and built-in __call__ magic methods for forward pass. Adding here # for consistency def forward(self, X, *args, **kwargs): assert hasattr(self, 'net'), 'Neural network is defined' return self.net(X, *args, **kwargs) def __call__(self, X, *args, **kwargs): return self.forward(X, *args, **kwargs) def plot(self, key, value, train): """Plot a point in animation.""" assert hasattr(self, 'trainer'), 'Trainer is not inited' self.board.xlabel = 'epoch' if train: x = self.trainer.train_batch_idx / \ self.trainer.num_train_batches n = self.trainer.num_train_batches / \ self.plot_train_per_epoch else: x = self.trainer.epoch + 1 n = self.trainer.num_val_batches / \ self.plot_valid_per_epoch self.board.draw(x, d2l.to(value, d2l.cpu()), ('train_' if train else 'val_') + key, every_n=int(n)) def training_step(self, params, batch, state): l, grads = jax.value_and_grad(self.loss)(params, batch[:-1], batch[-1], state) self.plot("loss", l, train=True) return l, grads def validation_step(self, params, batch, state): l = self.loss(params, batch[:-1], batch[-1], state) self.plot('loss', l, train=False) def apply_init(self, dummy_input, key): """To be defined later in :numref:`sec_lazy_init`""" raise NotImplementedError def configure_optimizers(self): raise NotImplementedError def configure_optimizers(self): """Defined in :numref:`sec_classification`""" return optax.sgd(self.lr) def apply_init(self, dummy_input, key): """Defined in :numref:`sec_lazy_init`""" params = self.init(key, *dummy_input) # dummy_input tuple unpacked return params class DataModule(d2l.HyperParameters): """The base class of data. Defined in :numref:`subsec_oo-design-models`""" def __init__(self, root='../data'): self.save_hyperparameters() def get_dataloader(self, train): raise NotImplementedError def train_dataloader(self): return self.get_dataloader(train=True) def val_dataloader(self): return self.get_dataloader(train=False) def get_tensorloader(self, tensors, train, indices=slice(0, None)): """Defined in :numref:`sec_synthetic-regression-data`""" tensors = tuple(a[indices] for a in tensors) # Use Tensorflow Datasets & Dataloader. JAX or Flax do not provide # any dataloading functionality shuffle_buffer = tensors[0].shape[0] if train else 1 return tfds.as_numpy( tf.data.Dataset.from_tensor_slices(tensors).shuffle( buffer_size=shuffle_buffer).batch(self.batch_size)) class Trainer(d2l.HyperParameters): """The base class for training models with data. Defined in :numref:`subsec_oo-design-models`""" def __init__(self, max_epochs, num_gpus=0, gradient_clip_val=0): self.save_hyperparameters() assert num_gpus == 0, 'No GPU support yet' def prepare_data(self, data): self.train_dataloader = data.train_dataloader() self.val_dataloader = data.val_dataloader() self.num_train_batches = len(self.train_dataloader) self.num_val_batches = (len(self.val_dataloader) if self.val_dataloader is not None else 0) def prepare_model(self, model): model.trainer = self model.board.xlim = [0, self.max_epochs] self.model = model def fit(self, model, data, key=None): self.prepare_data(data) self.prepare_model(model) self.optim = model.configure_optimizers() if key is None: root_key = d2l.get_key() else: root_key = key params_key, dropout_key = jax.random.split(root_key) key = {'params': params_key, 'dropout': dropout_key} dummy_input = next(iter(self.train_dataloader))[:-1] variables = model.apply_init(dummy_input, key=key) params = variables['params'] if 'batch_stats' in variables.keys(): # Here batch_stats will be used later (e.g., for batch norm) batch_stats = variables['batch_stats'] else: batch_stats = {} # Flax uses optax under the hood for a single state obj TrainState. # More will be discussed later in the dropout and batch # normalization section class TrainState(train_state.TrainState): batch_stats: Any dropout_rng: jax.random.PRNGKeyArray self.state = TrainState.create(apply_fn=model.apply, params=params, batch_stats=batch_stats, dropout_rng=dropout_key, tx=model.configure_optimizers()) self.epoch = 0 self.train_batch_idx = 0 self.val_batch_idx = 0 for self.epoch in range(self.max_epochs): self.fit_epoch() def fit_epoch(self): raise NotImplementedError def prepare_batch(self, batch): """Defined in :numref:`sec_linear_scratch`""" return batch def fit_epoch(self): """Defined in :numref:`sec_linear_scratch`""" self.model.training = True if self.state.batch_stats: # Mutable states will be used later (e.g., for batch norm) for batch in self.train_dataloader: (_, mutated_vars), grads = self.model.training_step(self.state.params, self.prepare_batch(batch), self.state) self.state = self.state.apply_gradients(grads=grads) # Can be ignored for models without Dropout Layers self.state = self.state.replace( dropout_rng=jax.random.split(self.state.dropout_rng)[0]) self.state = self.state.replace(batch_stats=mutated_vars['batch_stats']) self.train_batch_idx += 1 else: for batch in self.train_dataloader: _, grads = self.model.training_step(self.state.params, self.prepare_batch(batch), self.state) self.state = self.state.apply_gradients(grads=grads) # Can be ignored for models without Dropout Layers self.state = self.state.replace( dropout_rng=jax.random.split(self.state.dropout_rng)[0]) self.train_batch_idx += 1 if self.val_dataloader is None: return self.model.training = False for batch in self.val_dataloader: self.model.validation_step(self.state.params, self.prepare_batch(batch), self.state) self.val_batch_idx += 1 def __init__(self, max_epochs, num_gpus=0, gradient_clip_val=0): """Defined in :numref:`sec_use_gpu`""" self.save_hyperparameters() self.gpus = [d2l.gpu(i) for i in range(min(num_gpus, d2l.num_gpus()))] def prepare_batch(self, batch): """Defined in :numref:`sec_use_gpu`""" if self.gpus: batch = [d2l.to(a, self.gpus[0]) for a in batch] return batch def clip_gradients(self, grad_clip_val, grads): """Defined in :numref:`sec_rnn-scratch`""" grad_leaves, _ = jax.tree_util.tree_flatten(grads) norm = jnp.sqrt(sum(jnp.vdot(x, x) for x in grad_leaves)) clip = lambda grad: jnp.where(norm < grad_clip_val, grad, grad * (grad_clip_val / norm)) return jax.tree_util.tree_map(clip, grads) class SyntheticRegressionData(d2l.DataModule): """Synthetic data for linear regression. Defined in :numref:`sec_synthetic-regression-data`""" def __init__(self, w, b, noise=0.01, num_train=1000, num_val=1000, batch_size=32): super().__init__() self.save_hyperparameters() n = num_train + num_val key = jax.random.PRNGKey(0) key1, key2 = jax.random.split(key) self.X = jax.random.normal(key1, (n, w.shape[0])) noise = jax.random.normal(key2, (n, 1)) * noise self.y = d2l.matmul(self.X, d2l.reshape(w, (-1, 1))) + b + noise def get_dataloader(self, train): """Defined in :numref:`sec_synthetic-regression-data`""" i = slice(0, self.num_train) if train else slice(self.num_train, None) return self.get_tensorloader((self.X, self.y), train, i) class LinearRegressionScratch(d2l.Module): """The linear regression model implemented from scratch. Defined in :numref:`sec_linear_scratch`""" num_inputs: int lr: float sigma: float = 0.01 def setup(self): self.w = self.param('w', nn.initializers.normal(self.sigma), (self.num_inputs, 1)) self.b = self.param('b', nn.initializers.zeros, (1)) def forward(self, X): """Defined in :numref:`sec_linear_scratch`""" return d2l.matmul(X, self.w) + self.b def loss(self, params, X, y, state): """Defined in :numref:`sec_linear_scratch`""" y_hat = state.apply_fn({'params': params}, *X) # X unpacked from a tuple l = (y_hat - d2l.reshape(y, y_hat.shape)) ** 2 / 2 return d2l.reduce_mean(l) def configure_optimizers(self): """Defined in :numref:`sec_linear_scratch`""" return SGD(self.lr) class SGD(d2l.HyperParameters): """Minibatch stochastic gradient descent. Defined in :numref:`sec_linear_scratch`""" # The key transformation of Optax is the GradientTransformation # defined by two methods, the init and the update. # The init initializes the state and the update transforms the gradients. # https://github.com/deepmind/optax/blob/master/optax/_src/transform.py def __init__(self, lr): self.save_hyperparameters() def init(self, params): # Delete unused params del params return optax.EmptyState def update(self, updates, state, params=None): del params # When state.apply_gradients method is called to update flax's # train_state object, it internally calls optax.apply_updates method # adding the params to the update equation defined below. updates = jax.tree_util.tree_map(lambda g: -self.lr * g, updates) return updates, state def __call__(): return optax.GradientTransformation(self.init, self.update) class LinearRegression(d2l.Module): """The linear regression model implemented with high-level APIs. Defined in :numref:`sec_linear_concise`""" lr: float def setup(self): self.net = nn.Dense(1, kernel_init=nn.initializers.normal(0.01)) def forward(self, X): """Defined in :numref:`sec_linear_concise`""" return self.net(X) def loss(self, params, X, y, state): """Defined in :numref:`sec_linear_concise`""" y_hat = state.apply_fn({'params': params}, *X) return d2l.reduce_mean(optax.l2_loss(y_hat, y)) def configure_optimizers(self): """Defined in :numref:`sec_linear_concise`""" return optax.sgd(self.lr) def get_w_b(self, state): """Defined in :numref:`sec_linear_concise`""" net = state.params['net'] return net['kernel'], net['bias'] class FashionMNIST(d2l.DataModule): """The Fashion-MNIST dataset. Defined in :numref:`sec_fashion_mnist`""" def __init__(self, batch_size=64, resize=(28, 28)): super().__init__() self.save_hyperparameters() self.train, self.val = tf.keras.datasets.fashion_mnist.load_data() def text_labels(self, indices): """Return text labels. Defined in :numref:`sec_fashion_mnist`""" labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat', 'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot'] return [labels[int(i)] for i in indices] def get_dataloader(self, train): """Defined in :numref:`sec_fashion_mnist`""" data = self.train if train else self.val process = lambda X, y: (tf.expand_dims(X, axis=3) / 255, tf.cast(y, dtype='int32')) resize_fn = lambda X, y: (tf.image.resize_with_pad(X, *self.resize), y) shuffle_buf = len(data[0]) if train else 1 return tfds.as_numpy( tf.data.Dataset.from_tensor_slices(process(*data)).batch( self.batch_size).map(resize_fn).shuffle(shuffle_buf)) def visualize(self, batch, nrows=1, ncols=8, labels=[]): """Defined in :numref:`sec_fashion_mnist`""" X, y = batch if not labels: labels = self.text_labels(y) d2l.show_images(jnp.squeeze(X), nrows, ncols, titles=labels) def show_images(imgs, num_rows, num_cols, titles=None, scale=1.5): """Plot a list of images. Defined in :numref:`sec_fashion_mnist`""" raise NotImplementedError class Classifier(d2l.Module): """The base class of classification models. Defined in :numref:`sec_classification`""" def training_step(self, params, batch, state): # Here value is a tuple since models with BatchNorm layers require # the loss to return auxiliary data value, grads = jax.value_and_grad( self.loss, has_aux=True)(params, batch[:-1], batch[-1], state) l, _ = value self.plot("loss", l, train=True) return value, grads def validation_step(self, params, batch, state): # Discard the second returned value. It is used for training models # with BatchNorm layers since loss also returns auxiliary data l, _ = self.loss(params, batch[:-1], batch[-1], state) self.plot('loss', l, train=False) self.plot('acc', self.accuracy(params, batch[:-1], batch[-1], state), train=False) @partial(jax.jit, static_argnums=(0, 5)) def accuracy(self, params, X, Y, state, averaged=True): """Compute the number of correct predictions. Defined in :numref:`sec_classification`""" Y_hat = state.apply_fn({'params': params, 'batch_stats': state.batch_stats}, # BatchNorm Only *X) Y_hat = d2l.reshape(Y_hat, (-1, Y_hat.shape[-1])) preds = d2l.astype(d2l.argmax(Y_hat, axis=1), Y.dtype) compare = d2l.astype(preds == d2l.reshape(Y, -1), d2l.float32) return d2l.reduce_mean(compare) if averaged else compare @partial(jax.jit, static_argnums=(0, 5)) def loss(self, params, X, Y, state, averaged=True): """Defined in :numref:`sec_softmax_concise`""" # To be used later (e.g., for batch norm) Y_hat = state.apply_fn({'params': params}, *X, mutable=False, rngs=None) Y_hat = d2l.reshape(Y_hat, (-1, Y_hat.shape[-1])) Y = d2l.reshape(Y, (-1,)) fn = optax.softmax_cross_entropy_with_integer_labels # The returned empty dictionary is a placeholder for auxiliary data, # which will be used later (e.g., for batch norm) return (fn(Y_hat, Y).mean(), {}) if averaged else (fn(Y_hat, Y), {}) @partial(jax.jit, static_argnums=(0, 5)) def loss(self, params, X, Y, state, averaged=True): """Defined in :numref:`sec_dropout`""" Y_hat = state.apply_fn({'params': params}, *X, mutable=False, # To be used later (e.g., batch norm) rngs={'dropout': state.dropout_rng}) Y_hat = d2l.reshape(Y_hat, (-1, Y_hat.shape[-1])) Y = d2l.reshape(Y, (-1,)) fn = optax.softmax_cross_entropy_with_integer_labels # The returned empty dictionary is a placeholder for auxiliary data, # which will be used later (e.g., for batch norm) return (fn(Y_hat, Y).mean(), {}) if averaged else (fn(Y_hat, Y), {}) def layer_summary(self, X_shape, key=d2l.get_key()): """Defined in :numref:`sec_lenet`""" X = jnp.zeros(X_shape) params = self.init(key, X) bound_model = self.clone().bind(params, mutable=['batch_stats']) _ = bound_model(X) for layer in bound_model.net.layers: X = layer(X) print(layer.__class__.__name__, 'output shape:\t', X.shape) @partial(jax.jit, static_argnums=(0, 5)) def loss(self, params, X, Y, state, averaged=True): """Defined in :numref:`subsec_layer-normalization-in-bn`""" Y_hat, updates = state.apply_fn({'params': params, 'batch_stats': state.batch_stats}, *X, mutable=['batch_stats'], rngs={'dropout': state.dropout_rng}) Y_hat = d2l.reshape(Y_hat, (-1, Y_hat.shape[-1])) Y = d2l.reshape(Y, (-1,)) fn = optax.softmax_cross_entropy_with_integer_labels return (fn(Y_hat, Y).mean(), updates) if averaged else (fn(Y_hat, Y), updates) class SoftmaxRegression(d2l.Classifier): """Defined in :numref:`sec_softmax_concise`""" num_outputs: int lr: float @nn.compact def __call__(self, X): X = X.reshape((X.shape[0], -1)) # Flatten X = nn.Dense(self.num_outputs)(X) return X def cpu(): """Get the CPU device. Defined in :numref:`sec_use_gpu`""" return jax.devices('cpu')[0] def gpu(i=0): """Get a GPU device. Defined in :numref:`sec_use_gpu`""" return jax.devices('gpu')[i] def num_gpus(): """Get the number of available GPUs. Defined in :numref:`sec_use_gpu`""" try: return jax.device_count('gpu') except: return 0 # No GPU backend found def try_gpu(i=0): """Return gpu(i) if exists, otherwise return cpu(). Defined in :numref:`sec_use_gpu`""" if num_gpus() >= i + 1: return gpu(i) return cpu() def try_all_gpus(): """Return all available GPUs, or [cpu(),] if no GPU exists. Defined in :numref:`sec_use_gpu`""" return [gpu(i) for i in range(num_gpus())] def corr2d(X, K): """Compute 2D cross-correlation. Defined in :numref:`sec_conv_layer`""" h, w = K.shape Y = jnp.zeros((X.shape[0] - h + 1, X.shape[1] - w + 1)) for i in range(Y.shape[0]): for j in range(Y.shape[1]): Y = Y.at[i, j].set((X[i:i + h, j:j + w] * K).sum()) return Y class LeNet(d2l.Classifier): """The LeNet-5 model. Defined in :numref:`sec_lenet`""" lr: float = 0.1 num_classes: int = 10 kernel_init: FunctionType = nn.initializers.xavier_uniform def setup(self): self.net = nn.Sequential([ nn.Conv(features=6, kernel_size=(5, 5), padding='SAME', kernel_init=self.kernel_init()), nn.sigmoid, lambda x: nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2)), nn.Conv(features=16, kernel_size=(5, 5), padding='VALID', kernel_init=self.kernel_init()), nn.sigmoid, lambda x: nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2)), lambda x: x.reshape((x.shape[0], -1)), # flatten nn.Dense(features=120, kernel_init=self.kernel_init()), nn.sigmoid, nn.Dense(features=84, kernel_init=self.kernel_init()), nn.sigmoid, nn.Dense(features=self.num_classes, kernel_init=self.kernel_init()) ]) class Residual(nn.Module): """The Residual block of ResNet models. Defined in :numref:`sec_resnet`""" num_channels: int use_1x1conv: bool = False strides: tuple = (1, 1) training: bool = True def setup(self): self.conv1 = nn.Conv(self.num_channels, kernel_size=(3, 3), padding='same', strides=self.strides) self.conv2 = nn.Conv(self.num_channels, kernel_size=(3, 3), padding='same') if self.use_1x1conv: self.conv3 = nn.Conv(self.num_channels, kernel_size=(1, 1), strides=self.strides) else: self.conv3 = None self.bn1 = nn.BatchNorm(not self.training) self.bn2 = nn.BatchNorm(not self.training) def __call__(self, X): Y = nn.relu(self.bn1(self.conv1(X))) Y = self.bn2(self.conv2(Y)) if self.conv3: X = self.conv3(X) Y += X return nn.relu(Y) class ResNeXtBlock(nn.Module): """The ResNeXt block. Defined in :numref:`subsec_residual-blks`""" num_channels: int groups: int bot_mul: int use_1x1conv: bool = False strides: tuple = (1, 1) training: bool = True def setup(self): bot_channels = int(round(self.num_channels * self.bot_mul)) self.conv1 = nn.Conv(bot_channels, kernel_size=(1, 1), strides=(1, 1)) self.conv2 = nn.Conv(bot_channels, kernel_size=(3, 3), strides=self.strides, padding='same', feature_group_count=bot_channels//self.groups) self.conv3 = nn.Conv(self.num_channels, kernel_size=(1, 1), strides=(1, 1)) self.bn1 = nn.BatchNorm(not self.training) self.bn2 = nn.BatchNorm(not self.training) self.bn3 = nn.BatchNorm(not self.training) if self.use_1x1conv: self.conv4 = nn.Conv(self.num_channels, kernel_size=(1, 1), strides=self.strides) self.bn4 = nn.BatchNorm(not self.training) else: self.conv4 = None def __call__(self, X): Y = nn.relu(self.bn1(self.conv1(X))) Y = nn.relu(self.bn2(self.conv2(Y))) Y = self.bn3(self.conv3(Y)) if self.conv4: X = self.bn4(self.conv4(X)) return nn.relu(Y + X) class TimeMachine(d2l.DataModule): """The Time Machine dataset. Defined in :numref:`sec_text-sequence`""" def _download(self): fname = d2l.download(d2l.DATA_URL + 'timemachine.txt', self.root, '090b5e7e70c295757f55df93cb0a180b9691891a') with open(fname) as f: return f.read() def _preprocess(self, text): """Defined in :numref:`sec_text-sequence`""" return re.sub('[^A-Za-z]+', ' ', text).lower() def _tokenize(self, text): """Defined in :numref:`sec_text-sequence`""" return list(text) def build(self, raw_text, vocab=None): """Defined in :numref:`sec_text-sequence`""" tokens = self._tokenize(self._preprocess(raw_text)) if vocab is None: vocab = Vocab(tokens) corpus = [vocab[token] for token in tokens] return corpus, vocab def __init__(self, batch_size, num_steps, num_train=10000, num_val=5000): """Defined in :numref:`sec_language-model`""" super(d2l.TimeMachine, self).__init__() self.save_hyperparameters() corpus, self.vocab = self.build(self._download()) array = d2l.tensor([corpus[i:i+num_steps+1] for i in range(len(corpus)-num_steps)]) self.X, self.Y = array[:,:-1], array[:,1:] def get_dataloader(self, train): """Defined in :numref:`subsec_partitioning-seqs`""" idx = slice(0, self.num_train) if train else slice( self.num_train, self.num_train + self.num_val) return self.get_tensorloader([self.X, self.Y], train, idx) class Vocab: """Vocabulary for text.""" def __init__(self, tokens=[], min_freq=0, reserved_tokens=[]): """Defined in :numref:`sec_text-sequence`""" # Flatten a 2D list if needed if tokens and isinstance(tokens[0], list): tokens = [token for line in tokens for token in line] # Count token frequencies counter = collections.Counter(tokens) self.token_freqs = sorted(counter.items(), key=lambda x: x[1], reverse=True) # The list of unique tokens self.idx_to_token = list(sorted(set([''] + reserved_tokens + [ token for token, freq in self.token_freqs if freq >= min_freq]))) self.token_to_idx = {token: idx for idx, token in enumerate(self.idx_to_token)} def __len__(self): return len(self.idx_to_token) def __getitem__(self, tokens): if not isinstance(tokens, (list, tuple)): return self.token_to_idx.get(tokens, self.unk) return [self.__getitem__(token) for token in tokens] def to_tokens(self, indices): if hasattr(indices, '__len__') and len(indices) > 1: return [self.idx_to_token[int(index)] for index in indices] return self.idx_to_token[indices] @property def unk(self): # Index for the unknown token return self.token_to_idx[''] class RNNScratch(nn.Module): """The RNN model implemented from scratch. Defined in :numref:`sec_rnn-scratch`""" num_inputs: int num_hiddens: int sigma: float = 0.01 def setup(self): self.W_xh = self.param('W_xh', nn.initializers.normal(self.sigma), (self.num_inputs, self.num_hiddens)) self.W_hh = self.param('W_hh', nn.initializers.normal(self.sigma), (self.num_hiddens, self.num_hiddens)) self.b_h = self.param('b_h', nn.initializers.zeros, (self.num_hiddens)) def __call__(self, inputs, state=None): """Defined in :numref:`sec_rnn-scratch`""" if state is not None: state, = state outputs = [] for X in inputs: # Shape of inputs: (num_steps, batch_size, num_inputs) state = d2l.tanh(d2l.matmul(X, self.W_xh) + ( d2l.matmul(state, self.W_hh) if state is not None else 0) + self.b_h) outputs.append(state) return outputs, state def check_len(a, n): """Check the length of a list. Defined in :numref:`sec_rnn-scratch`""" assert len(a) == n, f'list\'s length {len(a)} != expected length {n}' def check_shape(a, shape): """Check the shape of a tensor. Defined in :numref:`sec_rnn-scratch`""" assert a.shape == shape, \ f'tensor\'s shape {a.shape} != expected shape {shape}' class RNNLMScratch(d2l.Classifier): """The RNN-based language model implemented from scratch. Defined in :numref:`sec_rnn-scratch`""" rnn: nn.Module vocab_size: int lr: float = 0.01 def setup(self): self.W_hq = self.param('W_hq', nn.initializers.normal(self.rnn.sigma), (self.rnn.num_hiddens, self.vocab_size)) self.b_q = self.param('b_q', nn.initializers.zeros, (self.vocab_size)) def training_step(self, params, batch, state): value, grads = jax.value_and_grad( self.loss, has_aux=True)(params, batch[:-1], batch[-1], state) l, _ = value self.plot('ppl', d2l.exp(l), train=True) return value, grads def validation_step(self, params, batch, state): l, _ = self.loss(params, batch[:-1], batch[-1], state) self.plot('ppl', d2l.exp(l), train=False) def one_hot(self, X): """Defined in :numref:`sec_rnn-scratch`""" # Output shape: (num_steps, batch_size, vocab_size) return jax.nn.one_hot(X.T, self.vocab_size) def output_layer(self, rnn_outputs): """Defined in :numref:`sec_rnn-scratch`""" outputs = [d2l.matmul(H, self.W_hq) + self.b_q for H in rnn_outputs] return d2l.stack(outputs, 1) def forward(self, X, state=None): """Defined in :numref:`sec_rnn-scratch`""" embs = self.one_hot(X) rnn_outputs, _ = self.rnn(embs, state) return self.output_layer(rnn_outputs) def predict(self, prefix, num_preds, vocab, params): """Defined in :numref:`sec_rnn-scratch`""" state, outputs = None, [vocab[prefix[0]]] for i in range(len(prefix) + num_preds - 1): X = d2l.tensor([[outputs[-1]]]) embs = self.one_hot(X) rnn_outputs, state = self.rnn.apply({'params': params['rnn']}, embs, state) if i < len(prefix) - 1: # Warm-up period outputs.append(vocab[prefix[i + 1]]) else: # Predict num_preds steps Y = self.apply({'params': params}, rnn_outputs, method=self.output_layer) outputs.append(int(d2l.reshape(d2l.argmax(Y, axis=2), 1))) return ''.join([vocab.idx_to_token[i] for i in outputs]) class RNN(nn.Module): """The RNN model implemented with high-level APIs. Defined in :numref:`sec_rnn-concise`""" num_hiddens: int @nn.compact def __call__(self, inputs, H=None): raise NotImplementedError class RNNLM(d2l.RNNLMScratch): """The RNN-based language model implemented with high-level APIs. Defined in :numref:`sec_rnn-concise`""" training: bool = True def setup(self): self.linear = nn.Dense(self.vocab_size) def output_layer(self, hiddens): return d2l.swapaxes(self.linear(hiddens), 0, 1) def forward(self, X, state=None): embs = self.one_hot(X) rnn_outputs, _ = self.rnn(embs, state, self.training) return self.output_layer(rnn_outputs) class GRU(d2l.RNN): """The multilayer GRU model. Defined in :numref:`sec_deep_rnn`""" num_hiddens: int num_layers: int dropout: float = 0 @nn.compact def __call__(self, X, state=None, training=False): outputs = X new_state = [] if state is None: batch_size = X.shape[1] state = [nn.GRUCell.initialize_carry(jax.random.PRNGKey(0), (batch_size,), self.num_hiddens)] * self.num_layers GRU = nn.scan(nn.GRUCell, variable_broadcast="params", in_axes=0, out_axes=0, split_rngs={"params": False}) # Introduce a dropout layer after every GRU layer except last for i in range(self.num_layers - 1): layer_i_state, X = GRU()(state[i], outputs) new_state.append(layer_i_state) X = nn.Dropout(self.dropout, deterministic=not training)(X) # Final GRU layer without dropout out_state, X = GRU()(state[-1], X) new_state.append(out_state) return X, jnp.array(new_state) class MTFraEng(d2l.DataModule): """The English-French dataset. Defined in :numref:`sec_machine_translation`""" def _download(self): d2l.extract(d2l.download( d2l.DATA_URL+'fra-eng.zip', self.root, '94646ad1522d915e7b0f9296181140edcf86a4f5')) with open(self.root + '/fra-eng/fra.txt', encoding='utf-8') as f: return f.read() def _preprocess(self, text): """Defined in :numref:`sec_machine_translation`""" # Replace non-breaking space with space text = text.replace('\u202f', ' ').replace('\xa0', ' ') # Insert space between words and punctuation marks no_space = lambda char, prev_char: char in ',.!?' and prev_char != ' ' out = [' ' + char if i > 0 and no_space(char, text[i - 1]) else char for i, char in enumerate(text.lower())] return ''.join(out) def _tokenize(self, text, max_examples=None): """Defined in :numref:`sec_machine_translation`""" src, tgt = [], [] for i, line in enumerate(text.split('\n')): if max_examples and i > max_examples: break parts = line.split('\t') if len(parts) == 2: # Skip empty tokens src.append([t for t in f'{parts[0]} '.split(' ') if t]) tgt.append([t for t in f'{parts[1]} '.split(' ') if t]) return src, tgt def __init__(self, batch_size, num_steps=9, num_train=512, num_val=128): """Defined in :numref:`sec_machine_translation`""" super(MTFraEng, self).__init__() self.save_hyperparameters() self.arrays, self.src_vocab, self.tgt_vocab = self._build_arrays( self._download()) def _build_arrays(self, raw_text, src_vocab=None, tgt_vocab=None): """Defined in :numref:`subsec_loading-seq-fixed-len`""" def _build_array(sentences, vocab, is_tgt=False): pad_or_trim = lambda seq, t: ( seq[:t] if len(seq) > t else seq + [''] * (t - len(seq))) sentences = [pad_or_trim(s, self.num_steps) for s in sentences] if is_tgt: sentences = [[''] + s for s in sentences] if vocab is None: vocab = d2l.Vocab(sentences, min_freq=2) array = d2l.tensor([vocab[s] for s in sentences]) valid_len = d2l.reduce_sum( d2l.astype(array != vocab[''], d2l.int32), 1) return array, vocab, valid_len src, tgt = self._tokenize(self._preprocess(raw_text), self.num_train + self.num_val) src_array, src_vocab, src_valid_len = _build_array(src, src_vocab) tgt_array, tgt_vocab, _ = _build_array(tgt, tgt_vocab, True) return ((src_array, tgt_array[:,:-1], src_valid_len, tgt_array[:,1:]), src_vocab, tgt_vocab) def get_dataloader(self, train): """Defined in :numref:`subsec_loading-seq-fixed-len`""" idx = slice(0, self.num_train) if train else slice(self.num_train, None) return self.get_tensorloader(self.arrays, train, idx) def build(self, src_sentences, tgt_sentences): """Defined in :numref:`subsec_loading-seq-fixed-len`""" raw_text = '\n'.join([src + '\t' + tgt for src, tgt in zip( src_sentences, tgt_sentences)]) arrays, _, _ = self._build_arrays( raw_text, self.src_vocab, self.tgt_vocab) return arrays def show_list_len_pair_hist(legend, xlabel, ylabel, xlist, ylist): """Plot the histogram for list length pairs. Defined in :numref:`sec_machine_translation`""" d2l.set_figsize() _, _, patches = d2l.plt.hist( [[len(l) for l in xlist], [len(l) for l in ylist]]) d2l.plt.xlabel(xlabel) d2l.plt.ylabel(ylabel) for patch in patches[1].patches: patch.set_hatch('/') d2l.plt.legend(legend) class Encoder(nn.Module): """The base encoder interface for the encoder--decoder architecture. Defined in :numref:`sec_encoder-decoder`""" def setup(self): raise NotImplementedError # Later there can be additional arguments (e.g., length excluding padding) def __call__(self, X, *args): raise NotImplementedError class Decoder(nn.Module): """The base decoder interface for the encoder--decoder architecture. Defined in :numref:`sec_encoder-decoder`""" def setup(self): raise NotImplementedError # Later there can be additional arguments (e.g., length excluding padding) def init_state(self, enc_all_outputs, *args): raise NotImplementedError def __call__(self, X, state): raise NotImplementedError class EncoderDecoder(d2l.Classifier): """The base class for the encoder--decoder architecture. Defined in :numref:`sec_encoder-decoder`""" encoder: nn.Module decoder: nn.Module training: bool def __call__(self, enc_X, dec_X, *args): enc_all_outputs = self.encoder(enc_X, *args, training=self.training) dec_state = self.decoder.init_state(enc_all_outputs, *args) # Return decoder output only return self.decoder(dec_X, dec_state, training=self.training)[0] def predict_step(self, params, batch, num_steps, save_attention_weights=False): """Defined in :numref:`sec_seq2seq_training`""" src, tgt, src_valid_len, _ = batch enc_all_outputs, inter_enc_vars = self.encoder.apply( {'params': params['encoder']}, src, src_valid_len, training=False, mutable='intermediates') # Save encoder attention weights if inter_enc_vars containing encoder # attention weights is not empty. (to be covered later) enc_attention_weights = [] if bool(inter_enc_vars) and save_attention_weights: # Encoder Attention Weights saved in the intermediates collection enc_attention_weights = inter_enc_vars[ 'intermediates']['enc_attention_weights'][0] dec_state = self.decoder.init_state(enc_all_outputs, src_valid_len) outputs, attention_weights = [d2l.expand_dims(tgt[:,0], 1), ], [] for _ in range(num_steps): (Y, dec_state), inter_dec_vars = self.decoder.apply( {'params': params['decoder']}, outputs[-1], dec_state, training=False, mutable='intermediates') outputs.append(d2l.argmax(Y, 2)) # Save attention weights (to be covered later) if save_attention_weights: # Decoder Attention Weights saved in the intermediates collection dec_attention_weights = inter_dec_vars[ 'intermediates']['dec_attention_weights'][0] attention_weights.append(dec_attention_weights) return d2l.concat(outputs[1:], 1), (attention_weights, enc_attention_weights) class Seq2SeqEncoder(d2l.Encoder): """The RNN encoder for sequence-to-sequence learning. Defined in :numref:`sec_seq2seq`""" vocab_size: int embed_size: int num_hiddens: int num_layers: int dropout: float = 0 def setup(self): self.embedding = nn.Embed(self.vocab_size, self.embed_size) self.rnn = d2l.GRU(self.num_hiddens, self.num_layers, self.dropout) def __call__(self, X, *args, training=False): # X shape: (batch_size, num_steps) embs = self.embedding(d2l.astype(d2l.transpose(X), d2l.int32)) # embs shape: (num_steps, batch_size, embed_size) outputs, state = self.rnn(embs, training=training) # outputs shape: (num_steps, batch_size, num_hiddens) # state shape: (num_layers, batch_size, num_hiddens) return outputs, state class Seq2Seq(d2l.EncoderDecoder): """The RNN encoder--decoder for sequence to sequence learning. Defined in :numref:`sec_seq2seq_decoder`""" encoder: nn.Module decoder: nn.Module tgt_pad: int lr: float def validation_step(self, params, batch, state): l, _ = self.loss(params, batch[:-1], batch[-1], state) self.plot('loss', l, train=False) def configure_optimizers(self): # Adam optimizer is used here return optax.adam(learning_rate=self.lr) def bleu(pred_seq, label_seq, k): """Compute the BLEU. Defined in :numref:`sec_seq2seq_training`""" pred_tokens, label_tokens = pred_seq.split(' '), label_seq.split(' ') len_pred, len_label = len(pred_tokens), len(label_tokens) score = math.exp(min(0, 1 - len_label / len_pred)) for n in range(1, min(k, len_pred) + 1): num_matches, label_subs = 0, collections.defaultdict(int) for i in range(len_label - n + 1): label_subs[' '.join(label_tokens[i: i + n])] += 1 for i in range(len_pred - n + 1): if label_subs[' '.join(pred_tokens[i: i + n])] > 0: num_matches += 1 label_subs[' '.join(pred_tokens[i: i + n])] -= 1 score *= math.pow(num_matches / (len_pred - n + 1), math.pow(0.5, n)) return score def show_heatmaps(matrices, xlabel, ylabel, titles=None, figsize=(2.5, 2.5), cmap='Reds'): """Show heatmaps of matrices. Defined in :numref:`sec_queries-keys-values`""" d2l.use_svg_display() num_rows, num_cols, _, _ = matrices.shape fig, axes = d2l.plt.subplots(num_rows, num_cols, figsize=figsize, sharex=True, sharey=True, squeeze=False) for i, (row_axes, row_matrices) in enumerate(zip(axes, matrices)): for j, (ax, matrix) in enumerate(zip(row_axes, row_matrices)): pcm = ax.imshow(matrix, cmap=cmap) if i == num_rows - 1: ax.set_xlabel(xlabel) if j == 0: ax.set_ylabel(ylabel) if titles: ax.set_title(titles[j]) fig.colorbar(pcm, ax=axes, shrink=0.6); def masked_softmax(X, valid_lens): """Perform softmax operation by masking elements on the last axis. Defined in :numref:`sec_attention-scoring-functions`""" # X: 3D tensor, valid_lens: 1D or 2D tensor def _sequence_mask(X, valid_len, value=0): maxlen = X.shape[1] mask = jnp.arange((maxlen), dtype=jnp.float32)[None, :] < valid_len[:, None] return jnp.where(mask, X, value) if valid_lens is None: return nn.softmax(X, axis=-1) else: shape = X.shape if valid_lens.ndim == 1: valid_lens = jnp.repeat(valid_lens, shape[1]) else: valid_lens = valid_lens.reshape(-1) # On the last axis, replace masked elements with a very large negative # value, whose exponentiation outputs 0 X = _sequence_mask(X.reshape(-1, shape[-1]), valid_lens, value=-1e6) return nn.softmax(X.reshape(shape), axis=-1) class DotProductAttention(nn.Module): """Scaled dot product attention. Defined in :numref:`subsec_batch_dot`""" dropout: float # Shape of queries: (batch_size, no. of queries, d) # Shape of keys: (batch_size, no. of key-value pairs, d) # Shape of values: (batch_size, no. of key-value pairs, value dimension) # Shape of valid_lens: (batch_size,) or (batch_size, no. of queries) @nn.compact def __call__(self, queries, keys, values, valid_lens=None, training=False): d = queries.shape[-1] # Swap the last two dimensions of keys with keys.swapaxes(1, 2) scores = queries@(keys.swapaxes(1, 2)) / math.sqrt(d) attention_weights = masked_softmax(scores, valid_lens) dropout_layer = nn.Dropout(self.dropout, deterministic=not training) return dropout_layer(attention_weights)@values, attention_weights class AdditiveAttention(nn.Module): """Defined in :numref:`subsec_batch_dot`""" num_hiddens: int dropout: float def setup(self): self.W_k = nn.Dense(self.num_hiddens, use_bias=False) self.W_q = nn.Dense(self.num_hiddens, use_bias=False) self.w_v = nn.Dense(1, use_bias=False) @nn.compact def __call__(self, queries, keys, values, valid_lens, training=False): queries, keys = self.W_q(queries), self.W_k(keys) # After dimension expansion, shape of queries: (batch_size, no. of # queries, 1, num_hiddens) and shape of keys: (batch_size, 1, no. of # key-value pairs, num_hiddens). Sum them up with broadcasting features = jnp.expand_dims(queries, axis=2) + jnp.expand_dims(keys, axis=1) features = nn.tanh(features) # There is only one output of self.w_v, so we remove the last # one-dimensional entry from the shape. Shape of scores: (batch_size, # no. of queries, no. of key-value pairs) scores = self.w_v(features).squeeze(-1) attention_weights = masked_softmax(scores, valid_lens) dropout_layer = nn.Dropout(self.dropout, deterministic=not training) # Shape of values: (batch_size, no. of key-value pairs, value # dimension) return dropout_layer(attention_weights)@values, attention_weights class MultiHeadAttention(nn.Module): """Defined in :numref:`sec_multihead-attention`""" num_hiddens: int num_heads: int dropout: float bias: bool = False def setup(self): self.attention = d2l.DotProductAttention(self.dropout) self.W_q = nn.Dense(self.num_hiddens, use_bias=self.bias) self.W_k = nn.Dense(self.num_hiddens, use_bias=self.bias) self.W_v = nn.Dense(self.num_hiddens, use_bias=self.bias) self.W_o = nn.Dense(self.num_hiddens, use_bias=self.bias) @nn.compact def __call__(self, queries, keys, values, valid_lens, training=False): # Shape of queries, keys, or values: # (batch_size, no. of queries or key-value pairs, num_hiddens) # Shape of valid_lens: (batch_size,) or (batch_size, no. of queries) # After transposing, shape of output queries, keys, or values: # (batch_size * num_heads, no. of queries or key-value pairs, # num_hiddens / num_heads) queries = self.transpose_qkv(self.W_q(queries)) keys = self.transpose_qkv(self.W_k(keys)) values = self.transpose_qkv(self.W_v(values)) if valid_lens is not None: # On axis 0, copy the first item (scalar or vector) for num_heads # times, then copy the next item, and so on valid_lens = jnp.repeat(valid_lens, self.num_heads, axis=0) # Shape of output: (batch_size * num_heads, no. of queries, # num_hiddens / num_heads) output, attention_weights = self.attention( queries, keys, values, valid_lens, training=training) # Shape of output_concat: (batch_size, no. of queries, num_hiddens) output_concat = self.transpose_output(output) return self.W_o(output_concat), attention_weights def transpose_qkv(self, X): """Transposition for parallel computation of multiple attention heads. Defined in :numref:`sec_multihead-attention`""" # Shape of input X: (batch_size, no. of queries or key-value pairs, # num_hiddens). Shape of output X: (batch_size, no. of queries or # key-value pairs, num_heads, num_hiddens / num_heads) X = X.reshape((X.shape[0], X.shape[1], self.num_heads, -1)) # Shape of output X: (batch_size, num_heads, no. of queries or key-value # pairs, num_hiddens / num_heads) X = jnp.transpose(X, (0, 2, 1, 3)) # Shape of output: (batch_size * num_heads, no. of queries or key-value # pairs, num_hiddens / num_heads) return X.reshape((-1, X.shape[2], X.shape[3])) def transpose_output(self, X): """Reverse the operation of transpose_qkv. Defined in :numref:`sec_multihead-attention`""" X = X.reshape((-1, self.num_heads, X.shape[1], X.shape[2])) X = jnp.transpose(X, (0, 2, 1, 3)) return X.reshape((X.shape[0], X.shape[1], -1)) class PositionalEncoding(nn.Module): """Positional encoding. Defined in :numref:`sec_self-attention-and-positional-encoding`""" num_hiddens: int dropout: float max_len: int = 1000 def setup(self): # Create a long enough P self.P = d2l.zeros((1, self.max_len, self.num_hiddens)) X = d2l.arange(self.max_len, dtype=jnp.float32).reshape( -1, 1) / jnp.power(10000, jnp.arange( 0, self.num_hiddens, 2, dtype=jnp.float32) / self.num_hiddens) self.P = self.P.at[:, :, 0::2].set(jnp.sin(X)) self.P = self.P.at[:, :, 1::2].set(jnp.cos(X)) @nn.compact def __call__(self, X, training=False): # Flax sow API is used to capture intermediate variables self.sow('intermediates', 'P', self.P) X = X + self.P[:, :X.shape[1], :] return nn.Dropout(self.dropout)(X, deterministic=not training) class PositionWiseFFN(nn.Module): """The positionwise feed-forward network. Defined in :numref:`sec_transformer`""" ffn_num_hiddens: int ffn_num_outputs: int def setup(self): self.dense1 = nn.Dense(self.ffn_num_hiddens) self.dense2 = nn.Dense(self.ffn_num_outputs) def __call__(self, X): return self.dense2(nn.relu(self.dense1(X))) class AddNorm(nn.Module): """The residual connection followed by layer normalization. Defined in :numref:`subsec_positionwise-ffn`""" dropout: int @nn.compact def __call__(self, X, Y, training=False): return nn.LayerNorm()( nn.Dropout(self.dropout)(Y, deterministic=not training) + X) class TransformerEncoderBlock(nn.Module): """The Transformer encoder block. Defined in :numref:`subsec_positionwise-ffn`""" num_hiddens: int ffn_num_hiddens: int num_heads: int dropout: float use_bias: bool = False def setup(self): self.attention = d2l.MultiHeadAttention(self.num_hiddens, self.num_heads, self.dropout, self.use_bias) self.addnorm1 = AddNorm(self.dropout) self.ffn = PositionWiseFFN(self.ffn_num_hiddens, self.num_hiddens) self.addnorm2 = AddNorm(self.dropout) def __call__(self, X, valid_lens, training=False): output, attention_weights = self.attention(X, X, X, valid_lens, training=training) Y = self.addnorm1(X, output, training=training) return self.addnorm2(Y, self.ffn(Y), training=training), attention_weights class TransformerEncoder(d2l.Encoder): """The Transformer encoder. Defined in :numref:`subsec_transformer-encoder`""" vocab_size: int num_hiddens:int ffn_num_hiddens: int num_heads: int num_blks: int dropout: float use_bias: bool = False def setup(self): self.embedding = nn.Embed(self.vocab_size, self.num_hiddens) self.pos_encoding = d2l.PositionalEncoding(self.num_hiddens, self.dropout) self.blks = [TransformerEncoderBlock(self.num_hiddens, self.ffn_num_hiddens, self.num_heads, self.dropout, self.use_bias) for _ in range(self.num_blks)] def __call__(self, X, valid_lens, training=False): # Since positional encoding values are between -1 and 1, the embedding # values are multiplied by the square root of the embedding dimension # to rescale before they are summed up X = self.embedding(X) * math.sqrt(self.num_hiddens) X = self.pos_encoding(X, training=training) attention_weights = [None] * len(self.blks) for i, blk in enumerate(self.blks): X, attention_w = blk(X, valid_lens, training=training) attention_weights[i] = attention_w # Flax sow API is used to capture intermediate variables self.sow('intermediates', 'enc_attention_weights', attention_weights) return X def show_images(imgs, num_rows, num_cols, titles=None, scale=1.5): """Plot a list of images. Defined in :numref:`sec_utils`""" figsize = (num_cols * scale, num_rows * scale) _, axes = d2l.plt.subplots(num_rows, num_cols, figsize=figsize) axes = axes.flatten() for i, (ax, img) in enumerate(zip(axes, imgs)): try: img = d2l.numpy(img) except: pass ax.imshow(img) ax.axes.get_xaxis().set_visible(False) ax.axes.get_yaxis().set_visible(False) if titles: ax.set_title(titles[i]) return axes def download(url, folder='../data', sha1_hash=None): """Download a file to folder and return the local filepath. Defined in :numref:`sec_utils`""" if not url.startswith('http'): # For back compatability url, sha1_hash = DATA_HUB[url] os.makedirs(folder, exist_ok=True) fname = os.path.join(folder, url.split('/')[-1]) # Check if hit cache if os.path.exists(fname) and sha1_hash: sha1 = hashlib.sha1() with open(fname, 'rb') as f: while True: data = f.read(1048576) if not data: break sha1.update(data) if sha1.hexdigest() == sha1_hash: return fname # Download print(f'Downloading {fname} from {url}...') r = requests.get(url, stream=True, verify=True) with open(fname, 'wb') as f: f.write(r.content) return fname def extract(filename, folder=None): """Extract a zip/tar file into folder. Defined in :numref:`sec_utils`""" base_dir = os.path.dirname(filename) _, ext = os.path.splitext(filename) assert ext in ('.zip', '.tar', '.gz'), 'Only support zip/tar files.' if ext == '.zip': fp = zipfile.ZipFile(filename, 'r') else: fp = tarfile.open(filename, 'r') if folder is None: folder = base_dir fp.extractall(folder) # Alias defined in config.ini nn_Module = nn.Module to = jax.device_put numpy = np.asarray transpose = lambda a: a.T sigmoid = jax.nn.sigmoid ones_like = jnp.ones_like ones = jnp.ones zeros_like = jnp.zeros_like zeros = jnp.zeros arange = jnp.arange meshgrid = jnp.meshgrid sin = jnp.sin sinh = jnp.sinh cos = jnp.cos cosh = jnp.cosh tanh = jnp.tanh linspace = jnp.linspace exp = jnp.exp log = jnp.log tensor = jnp.array expand_dims = jnp.expand_dims matmul = jnp.matmul int32 = jnp.int32 int64 = jnp.int64 float32 = jnp.float32 concat = jnp.concatenate stack = jnp.stack abs = jnp.abs eye = jnp.eye reshape = lambda x, *args, **kwargs: x.reshape(*args, **kwargs) reduce_sum = lambda x, *args, **kwargs: x.sum(*args, **kwargs) argmax = lambda x, *args, **kwargs: x.argmax(*args, **kwargs) astype = lambda x, *args, **kwargs: x.astype(*args, **kwargs) reduce_mean = lambda x, *args, **kwargs: x.mean(*args, **kwargs) swapaxes = lambda x, *args, **kwargs: x.swapaxes(*args, **kwargs) repeat = lambda x, *args, **kwargs: x.repeat(*args, **kwargs)