{ "cells": [ { "cell_type": "markdown", "id": "105653cd", "metadata": { "origin_pos": 0 }, "source": [ "# Pretraining word2vec\n", ":label:`sec_word2vec_pretraining`\n", "\n", "\n", "We go on to implement the skip-gram\n", "model defined in\n", ":numref:`sec_word2vec`.\n", "Then\n", "we will pretrain word2vec using negative sampling\n", "on the PTB dataset.\n", "First of all,\n", "let's obtain the data iterator\n", "and the vocabulary for this dataset\n", "by calling the `d2l.load_data_ptb`\n", "function, which was described in :numref:`sec_word2vec_data`\n" ] }, { "cell_type": "code", "execution_count": 1, "id": "c74744cb", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:46:56.397062Z", "iopub.status.busy": "2023-08-18T19:46:56.396303Z", "iopub.status.idle": "2023-08-18T19:47:13.849836Z", "shell.execute_reply": "2023-08-18T19:47:13.848880Z" }, "origin_pos": 2, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "import math\n", "import torch\n", "from torch import nn\n", "from d2l import torch as d2l\n", "\n", "batch_size, max_window_size, num_noise_words = 512, 5, 5\n", "data_iter, vocab = d2l.load_data_ptb(batch_size, max_window_size,\n", " num_noise_words)" ] }, { "cell_type": "markdown", "id": "edfdfb2d", "metadata": { "origin_pos": 3 }, "source": [ "## The Skip-Gram Model\n", "\n", "We implement the skip-gram model\n", "by using embedding layers and batch matrix multiplications.\n", "First, let's review\n", "how embedding layers work.\n", "\n", "\n", "### Embedding Layer\n", "\n", "As described in :numref:`sec_seq2seq`,\n", "an embedding layer\n", "maps a token's index to its feature vector.\n", "The weight of this layer\n", "is a matrix whose number of rows equals to\n", "the dictionary size (`input_dim`) and\n", "number of columns equals to\n", "the vector dimension for each token (`output_dim`).\n", "After a word embedding model is trained,\n", "this weight is what we need.\n" ] }, { "cell_type": "code", "execution_count": 2, "id": "f3bbd097", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:47:13.854261Z", "iopub.status.busy": "2023-08-18T19:47:13.853539Z", "iopub.status.idle": "2023-08-18T19:47:13.879890Z", "shell.execute_reply": "2023-08-18T19:47:13.878774Z" }, "origin_pos": 5, "tab": [ "pytorch" ] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Parameter embedding_weight (torch.Size([20, 4]), dtype=torch.float32)\n" ] } ], "source": [ "embed = nn.Embedding(num_embeddings=20, embedding_dim=4)\n", "print(f'Parameter embedding_weight ({embed.weight.shape}, '\n", " f'dtype={embed.weight.dtype})')" ] }, { "cell_type": "markdown", "id": "801ec8da", "metadata": { "origin_pos": 6 }, "source": [ "The input of an embedding layer is the\n", "index of a token (word).\n", "For any token index $i$,\n", "its vector representation\n", "can be obtained from\n", "the $i^\\textrm{th}$ row of the weight matrix\n", "in the embedding layer.\n", "Since the vector dimension (`output_dim`)\n", "was set to 4,\n", "the embedding layer\n", "returns vectors with shape (2, 3, 4)\n", "for a minibatch of token indices with shape\n", "(2, 3).\n" ] }, { "cell_type": "code", "execution_count": 3, "id": "5f462eec", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:47:13.884582Z", "iopub.status.busy": "2023-08-18T19:47:13.883631Z", "iopub.status.idle": "2023-08-18T19:47:13.896267Z", "shell.execute_reply": "2023-08-18T19:47:13.894870Z" }, "origin_pos": 7, "tab": [ "pytorch" ] }, "outputs": [ { "data": { "text/plain": [ "tensor([[[ 0.7606, 0.3872, -0.1864, 1.1732],\n", " [ 1.5035, 2.3623, -1.7542, -1.4990],\n", " [-1.2639, -1.5313, 2.1719, 0.4151]],\n", "\n", " [[-1.9079, 0.2434, 1.5395, 1.2990],\n", " [ 0.7470, 1.0129, 0.4039, 0.0591],\n", " [-0.6293, -0.1814, -0.4782, -0.5289]]], grad_fn=)" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "x = torch.tensor([[1, 2, 3], [4, 5, 6]])\n", "embed(x)" ] }, { "cell_type": "markdown", "id": "61ffafa8", "metadata": { "origin_pos": 8 }, "source": [ "### Defining the Forward Propagation\n", "\n", "In the forward propagation,\n", "the input of the skip-gram model\n", "includes\n", "the center word indices `center`\n", "of shape (batch size, 1)\n", "and\n", "the concatenated context and noise word indices `contexts_and_negatives`\n", "of shape (batch size, `max_len`),\n", "where `max_len`\n", "is defined\n", "in :numref:`subsec_word2vec-minibatch-loading`.\n", "These two variables are first transformed from the\n", "token indices into vectors via the embedding layer,\n", "then their batch matrix multiplication\n", "(described in :numref:`subsec_batch_dot`)\n", "returns\n", "an output of shape (batch size, 1, `max_len`).\n", "Each element in the output is the dot product of\n", "a center word vector and a context or noise word vector.\n" ] }, { "cell_type": "code", "execution_count": 4, "id": "734187b9", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:47:13.902028Z", "iopub.status.busy": "2023-08-18T19:47:13.900944Z", "iopub.status.idle": "2023-08-18T19:47:13.907239Z", "shell.execute_reply": "2023-08-18T19:47:13.906216Z" }, "origin_pos": 10, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "def skip_gram(center, contexts_and_negatives, embed_v, embed_u):\n", " v = embed_v(center)\n", " u = embed_u(contexts_and_negatives)\n", " pred = torch.bmm(v, u.permute(0, 2, 1))\n", " return pred" ] }, { "cell_type": "markdown", "id": "5672ddf8", "metadata": { "origin_pos": 11 }, "source": [ "Let's print the output shape of this `skip_gram` function for some example inputs.\n" ] }, { "cell_type": "code", "execution_count": 5, "id": "265139ff", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:47:13.911481Z", "iopub.status.busy": "2023-08-18T19:47:13.910689Z", "iopub.status.idle": "2023-08-18T19:47:13.922107Z", "shell.execute_reply": "2023-08-18T19:47:13.920752Z" }, "origin_pos": 13, "tab": [ "pytorch" ] }, "outputs": [ { "data": { "text/plain": [ "torch.Size([2, 1, 4])" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "skip_gram(torch.ones((2, 1), dtype=torch.long),\n", " torch.ones((2, 4), dtype=torch.long), embed, embed).shape" ] }, { "cell_type": "markdown", "id": "08a1ef11", "metadata": { "origin_pos": 14 }, "source": [ "## Training\n", "\n", "Before training the skip-gram model with negative sampling,\n", "let's first define its loss function.\n", "\n", "\n", "### Binary Cross-Entropy Loss\n", "\n", "According to the definition of the loss function\n", "for negative sampling in :numref:`subsec_negative-sampling`, \n", "we will use \n", "the binary cross-entropy loss.\n" ] }, { "cell_type": "code", "execution_count": 6, "id": "1f152a65", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:47:13.926069Z", "iopub.status.busy": "2023-08-18T19:47:13.925631Z", "iopub.status.idle": "2023-08-18T19:47:13.933662Z", "shell.execute_reply": "2023-08-18T19:47:13.932267Z" }, "origin_pos": 16, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "class SigmoidBCELoss(nn.Module):\n", " # Binary cross-entropy loss with masking\n", " def __init__(self):\n", " super().__init__()\n", "\n", " def forward(self, inputs, target, mask=None):\n", " out = nn.functional.binary_cross_entropy_with_logits(\n", " inputs, target, weight=mask, reduction=\"none\")\n", " return out.mean(dim=1)\n", "\n", "loss = SigmoidBCELoss()" ] }, { "cell_type": "markdown", "id": "ab567fa0", "metadata": { "origin_pos": 17 }, "source": [ "Recall our descriptions\n", "of the mask variable\n", "and the label variable in\n", ":numref:`subsec_word2vec-minibatch-loading`.\n", "The following\n", "calculates the \n", "binary cross-entropy loss\n", "for the given variables.\n" ] }, { "cell_type": "code", "execution_count": 7, "id": "41c67449", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:47:13.938563Z", "iopub.status.busy": "2023-08-18T19:47:13.937748Z", "iopub.status.idle": "2023-08-18T19:47:13.954225Z", "shell.execute_reply": "2023-08-18T19:47:13.952929Z" }, "origin_pos": 18, "tab": [ "pytorch" ] }, "outputs": [ { "data": { "text/plain": [ "tensor([0.9352, 1.8462])" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "pred = torch.tensor([[1.1, -2.2, 3.3, -4.4]] * 2)\n", "label = torch.tensor([[1.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0]])\n", "mask = torch.tensor([[1, 1, 1, 1], [1, 1, 0, 0]])\n", "loss(pred, label, mask) * mask.shape[1] / mask.sum(axis=1)" ] }, { "cell_type": "markdown", "id": "7112ed9f", "metadata": { "origin_pos": 19 }, "source": [ "Below shows\n", "how the above results are calculated\n", "(in a less efficient way)\n", "using the\n", "sigmoid activation function\n", "in the binary cross-entropy loss.\n", "We can consider \n", "the two outputs as\n", "two normalized losses\n", "that are averaged over non-masked predictions.\n" ] }, { "cell_type": "code", "execution_count": 8, "id": "8bf49999", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:47:13.959559Z", "iopub.status.busy": "2023-08-18T19:47:13.958780Z", "iopub.status.idle": "2023-08-18T19:47:13.965184Z", "shell.execute_reply": "2023-08-18T19:47:13.964337Z" }, "origin_pos": 20, "tab": [ "pytorch" ] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0.9352\n", "1.8462\n" ] } ], "source": [ "def sigmd(x):\n", " return -math.log(1 / (1 + math.exp(-x)))\n", "\n", "print(f'{(sigmd(1.1) + sigmd(2.2) + sigmd(-3.3) + sigmd(4.4)) / 4:.4f}')\n", "print(f'{(sigmd(-1.1) + sigmd(-2.2)) / 2:.4f}')" ] }, { "cell_type": "markdown", "id": "a8d4adb2", "metadata": { "origin_pos": 21 }, "source": [ "### Initializing Model Parameters\n", "\n", "We define two embedding layers\n", "for all the words in the vocabulary\n", "when they are used as center words\n", "and context words, respectively.\n", "The word vector dimension\n", "`embed_size` is set to 100.\n" ] }, { "cell_type": "code", "execution_count": 9, "id": "53bf9f52", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:47:13.970338Z", "iopub.status.busy": "2023-08-18T19:47:13.969732Z", "iopub.status.idle": "2023-08-18T19:47:13.986034Z", "shell.execute_reply": "2023-08-18T19:47:13.985180Z" }, "origin_pos": 23, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "embed_size = 100\n", "net = nn.Sequential(nn.Embedding(num_embeddings=len(vocab),\n", " embedding_dim=embed_size),\n", " nn.Embedding(num_embeddings=len(vocab),\n", " embedding_dim=embed_size))" ] }, { "cell_type": "markdown", "id": "58b580f4", "metadata": { "origin_pos": 24 }, "source": [ "### Defining the Training Loop\n", "\n", "The training loop is defined below. Because of the existence of padding, the calculation of the loss function is slightly different compared to the previous training functions.\n" ] }, { "cell_type": "code", "execution_count": 10, "id": "4bc84232", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:47:13.989704Z", "iopub.status.busy": "2023-08-18T19:47:13.989147Z", "iopub.status.idle": "2023-08-18T19:47:13.997948Z", "shell.execute_reply": "2023-08-18T19:47:13.997164Z" }, "origin_pos": 26, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "def train(net, data_iter, lr, num_epochs, device=d2l.try_gpu()):\n", " def init_weights(module):\n", " if type(module) == nn.Embedding:\n", " nn.init.xavier_uniform_(module.weight)\n", " net.apply(init_weights)\n", " net = net.to(device)\n", " optimizer = torch.optim.Adam(net.parameters(), lr=lr)\n", " animator = d2l.Animator(xlabel='epoch', ylabel='loss',\n", " xlim=[1, num_epochs])\n", " # Sum of normalized losses, no. of normalized losses\n", " metric = d2l.Accumulator(2)\n", " for epoch in range(num_epochs):\n", " timer, num_batches = d2l.Timer(), len(data_iter)\n", " for i, batch in enumerate(data_iter):\n", " optimizer.zero_grad()\n", " center, context_negative, mask, label = [\n", " data.to(device) for data in batch]\n", "\n", " pred = skip_gram(center, context_negative, net[0], net[1])\n", " l = (loss(pred.reshape(label.shape).float(), label.float(), mask)\n", " / mask.sum(axis=1) * mask.shape[1])\n", " l.sum().backward()\n", " optimizer.step()\n", " metric.add(l.sum(), l.numel())\n", " if (i + 1) % (num_batches // 5) == 0 or i == num_batches - 1:\n", " animator.add(epoch + (i + 1) / num_batches,\n", " (metric[0] / metric[1],))\n", " print(f'loss {metric[0] / metric[1]:.3f}, '\n", " f'{metric[1] / timer.stop():.1f} tokens/sec on {str(device)}')" ] }, { "cell_type": "markdown", "id": "33f9b060", "metadata": { "origin_pos": 27 }, "source": [ "Now we can train a skip-gram model using negative sampling.\n" ] }, { "cell_type": "code", "execution_count": 11, "id": "5e4a73b4", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:47:14.001235Z", "iopub.status.busy": "2023-08-18T19:47:14.000704Z", "iopub.status.idle": "2023-08-18T19:47:53.656586Z", "shell.execute_reply": "2023-08-18T19:47:53.655104Z" }, "origin_pos": 28, "tab": [ "pytorch" ] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "loss 0.410, 223485.0 tokens/sec on cuda:0\n" ] }, { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-08-18T19:47:53.616066\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.7.2, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "lr, num_epochs = 0.002, 5\n", "train(net, data_iter, lr, num_epochs)" ] }, { "cell_type": "markdown", "id": "c47dd8e2", "metadata": { "origin_pos": 29 }, "source": [ "## Applying Word Embeddings\n", ":label:`subsec_apply-word-embed`\n", "\n", "\n", "After training the word2vec model,\n", "we can use the cosine similarity\n", "of word vectors from the trained model\n", "to \n", "find words from the dictionary\n", "that are most semantically similar\n", "to an input word.\n" ] }, { "cell_type": "code", "execution_count": 12, "id": "81f4d03b", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:47:53.660681Z", "iopub.status.busy": "2023-08-18T19:47:53.660377Z", "iopub.status.idle": "2023-08-18T19:47:53.728687Z", "shell.execute_reply": "2023-08-18T19:47:53.727729Z" }, "origin_pos": 31, "tab": [ "pytorch" ] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "cosine sim=0.702: microprocessor\n", "cosine sim=0.649: mips\n", "cosine sim=0.643: intel\n" ] } ], "source": [ "def get_similar_tokens(query_token, k, embed):\n", " W = embed.weight.data\n", " x = W[vocab[query_token]]\n", " # Compute the cosine similarity. Add 1e-9 for numerical stability\n", " cos = torch.mv(W, x) / torch.sqrt(torch.sum(W * W, dim=1) *\n", " torch.sum(x * x) + 1e-9)\n", " topk = torch.topk(cos, k=k+1)[1].cpu().numpy().astype('int32')\n", " for i in topk[1:]: # Remove the input words\n", " print(f'cosine sim={float(cos[i]):.3f}: {vocab.to_tokens(i)}')\n", "\n", "get_similar_tokens('chip', 3, net[0])" ] }, { "cell_type": "markdown", "id": "554c2d35", "metadata": { "origin_pos": 32 }, "source": [ "## Summary\n", "\n", "* We can train a skip-gram model with negative sampling using embedding layers and the binary cross-entropy loss.\n", "* Applications of word embeddings include finding semantically similar words for a given word based on the cosine similarity of word vectors.\n", "\n", "\n", "## Exercises\n", "\n", "1. Using the trained model, find semantically similar words for other input words. Can you improve the results by tuning hyperparameters?\n", "1. When a training corpus is huge, we often sample context words and noise words for the center words in the current minibatch *when updating model parameters*. In other words, the same center word may have different context words or noise words in different training epochs. What are the benefits of this method? Try to implement this training method.\n" ] }, { "cell_type": "markdown", "id": "7742f77d", "metadata": { "origin_pos": 34, "tab": [ "pytorch" ] }, "source": [ "[Discussions](https://discuss.d2l.ai/t/1335)\n" ] } ], "metadata": { "language_info": { "name": "python" }, "required_libs": [] }, "nbformat": 4, "nbformat_minor": 5 }