{ "cells": [ { "cell_type": "markdown", "id": "24a5f600", "metadata": { "origin_pos": 0 }, "source": [ "# The Dataset for Pretraining Word Embeddings\n", ":label:`sec_word2vec_data`\n", "\n", "Now that we know the technical details of \n", "the word2vec models and approximate training methods,\n", "let's walk through their implementations. \n", "Specifically,\n", "we will take the skip-gram model in :numref:`sec_word2vec`\n", "and negative sampling in :numref:`sec_approx_train`\n", "as an example.\n", "In this section,\n", "we begin with the dataset\n", "for pretraining the word embedding model:\n", "the original format of the data\n", "will be transformed\n", "into minibatches\n", "that can be iterated over during training.\n" ] }, { "cell_type": "code", "execution_count": 1, "id": "1dc00574", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:37:02.153563Z", "iopub.status.busy": "2023-08-18T19:37:02.152664Z", "iopub.status.idle": "2023-08-18T19:37:05.599873Z", "shell.execute_reply": "2023-08-18T19:37:05.597956Z" }, "origin_pos": 2, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "import collections\n", "import math\n", "import os\n", "import random\n", "import torch\n", "from d2l import torch as d2l" ] }, { "cell_type": "markdown", "id": "21d7aa3d", "metadata": { "origin_pos": 3 }, "source": [ "## Reading the Dataset\n", "\n", "The dataset that we use here\n", "is [Penn Tree Bank (PTB)]( https://catalog.ldc.upenn.edu/LDC99T42). \n", "This corpus is sampled\n", "from Wall Street Journal articles,\n", "split into training, validation, and test sets.\n", "In the original format,\n", "each line of the text file\n", "represents a sentence of words that are separated by spaces.\n", "Here we treat each word as a token.\n" ] }, { "cell_type": "code", "execution_count": 2, "id": "5ec3dd39", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:37:05.606025Z", "iopub.status.busy": "2023-08-18T19:37:05.605205Z", "iopub.status.idle": "2023-08-18T19:37:06.158652Z", "shell.execute_reply": "2023-08-18T19:37:06.157572Z" }, "origin_pos": 4, "tab": [ "pytorch" ] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Downloading ../data/ptb.zip from http://d2l-data.s3-accelerate.amazonaws.com/ptb.zip...\n" ] }, { "data": { "text/plain": [ "'# sentences: 42069'" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "#@save\n", "d2l.DATA_HUB['ptb'] = (d2l.DATA_URL + 'ptb.zip',\n", " '319d85e578af0cdc590547f26231e4e31cdf1e42')\n", "\n", "#@save\n", "def read_ptb():\n", " \"\"\"Load the PTB dataset into a list of text lines.\"\"\"\n", " data_dir = d2l.download_extract('ptb')\n", " # Read the training set\n", " with open(os.path.join(data_dir, 'ptb.train.txt')) as f:\n", " raw_text = f.read()\n", " return [line.split() for line in raw_text.split('\\n')]\n", "\n", "sentences = read_ptb()\n", "f'# sentences: {len(sentences)}'" ] }, { "cell_type": "markdown", "id": "c96688b4", "metadata": { "origin_pos": 5 }, "source": [ "After reading the training set,\n", "we build a vocabulary for the corpus,\n", "where any word that appears \n", "less than 10 times is replaced by \n", "the \"<unk>\" token.\n", "Note that the original dataset\n", "also contains \"<unk>\" tokens that represent rare (unknown) words.\n" ] }, { "cell_type": "code", "execution_count": 3, "id": "2e0980e1", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:37:06.163173Z", "iopub.status.busy": "2023-08-18T19:37:06.162536Z", "iopub.status.idle": "2023-08-18T19:37:06.422758Z", "shell.execute_reply": "2023-08-18T19:37:06.421550Z" }, "origin_pos": 6, "tab": [ "pytorch" ] }, "outputs": [ { "data": { "text/plain": [ "'vocab size: 6719'" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "vocab = d2l.Vocab(sentences, min_freq=10)\n", "f'vocab size: {len(vocab)}'" ] }, { "cell_type": "markdown", "id": "fad600ee", "metadata": { "origin_pos": 7 }, "source": [ "## Subsampling\n", "\n", "Text data\n", "typically have high-frequency words\n", "such as \"the\", \"a\", and \"in\":\n", "they may even occur billions of times in\n", "very large corpora.\n", "However,\n", "these words often co-occur\n", "with many different words in\n", "context windows, providing little useful signals.\n", "For instance,\n", "consider the word \"chip\" in a context window:\n", "intuitively\n", "its co-occurrence with a low-frequency word \"intel\"\n", "is more useful in training\n", "than \n", "the co-occurrence with a high-frequency word \"a\".\n", "Moreover, training with vast amounts of (high-frequency) words\n", "is slow.\n", "Thus, when training word embedding models, \n", "high-frequency words can be *subsampled* :cite:`Mikolov.Sutskever.Chen.ea.2013`.\n", "Specifically, \n", "each indexed word $w_i$ \n", "in the dataset will be discarded with probability\n", "\n", "\n", "$$ P(w_i) = \\max\\left(1 - \\sqrt{\\frac{t}{f(w_i)}}, 0\\right),$$\n", "\n", "where $f(w_i)$ is the ratio of \n", "the number of words $w_i$\n", "to the total number of words in the dataset, \n", "and the constant $t$ is a hyperparameter\n", "($10^{-4}$ in the experiment). \n", "We can see that only when\n", "the relative frequency\n", "$f(w_i) > t$ can the (high-frequency) word $w_i$ be discarded, \n", "and the higher the relative frequency of the word, \n", "the greater the probability of being discarded.\n" ] }, { "cell_type": "code", "execution_count": 4, "id": "8a996abd", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:37:06.430864Z", "iopub.status.busy": "2023-08-18T19:37:06.426528Z", "iopub.status.idle": "2023-08-18T19:37:08.573132Z", "shell.execute_reply": "2023-08-18T19:37:08.571407Z" }, "origin_pos": 8, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "#@save\n", "def subsample(sentences, vocab):\n", " \"\"\"Subsample high-frequency words.\"\"\"\n", " # Exclude unknown tokens ('')\n", " sentences = [[token for token in line if vocab[token] != vocab.unk]\n", " for line in sentences]\n", " counter = collections.Counter([\n", " token for line in sentences for token in line])\n", " num_tokens = sum(counter.values())\n", "\n", " # Return True if `token` is kept during subsampling\n", " def keep(token):\n", " return(random.uniform(0, 1) <\n", " math.sqrt(1e-4 / counter[token] * num_tokens))\n", "\n", " return ([[token for token in line if keep(token)] for line in sentences],\n", " counter)\n", "\n", "subsampled, counter = subsample(sentences, vocab)" ] }, { "cell_type": "markdown", "id": "0171226e", "metadata": { "origin_pos": 9 }, "source": [ "The following code snippet \n", "plots the histogram of\n", "the number of tokens per sentence\n", "before and after subsampling.\n", "As expected, \n", "subsampling significantly shortens sentences\n", "by dropping high-frequency words,\n", "which will lead to training speedup.\n" ] }, { "cell_type": "code", "execution_count": 5, "id": "5d169993", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:37:08.578324Z", "iopub.status.busy": "2023-08-18T19:37:08.577516Z", "iopub.status.idle": "2023-08-18T19:37:08.949175Z", "shell.execute_reply": "2023-08-18T19:37:08.947979Z" }, "origin_pos": 10, "tab": [ "pytorch" ] }, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-08-18T19:37:08.889695\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.7.2, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "d2l.show_list_len_pair_hist(['origin', 'subsampled'], '# tokens per sentence',\n", " 'count', sentences, subsampled);" ] }, { "cell_type": "markdown", "id": "b91678bc", "metadata": { "origin_pos": 11 }, "source": [ "For individual tokens, the sampling rate of the high-frequency word \"the\" is less than 1/20.\n" ] }, { "cell_type": "code", "execution_count": 6, "id": "f4969244", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:37:08.955491Z", "iopub.status.busy": "2023-08-18T19:37:08.954520Z", "iopub.status.idle": "2023-08-18T19:37:08.996260Z", "shell.execute_reply": "2023-08-18T19:37:08.995416Z" }, "origin_pos": 12, "tab": [ "pytorch" ] }, "outputs": [ { "data": { "text/plain": [ "'# of \"the\": before=50770, after=2010'" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "def compare_counts(token):\n", " return (f'# of \"{token}\": '\n", " f'before={sum([l.count(token) for l in sentences])}, '\n", " f'after={sum([l.count(token) for l in subsampled])}')\n", "\n", "compare_counts('the')" ] }, { "cell_type": "markdown", "id": "0a65171a", "metadata": { "origin_pos": 13 }, "source": [ "In contrast, \n", "low-frequency words \"join\" are completely kept.\n" ] }, { "cell_type": "code", "execution_count": 7, "id": "499e2d66", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:37:08.999868Z", "iopub.status.busy": "2023-08-18T19:37:08.999197Z", "iopub.status.idle": "2023-08-18T19:37:09.052157Z", "shell.execute_reply": "2023-08-18T19:37:09.051044Z" }, "origin_pos": 14, "tab": [ "pytorch" ] }, "outputs": [ { "data": { "text/plain": [ "'# of \"join\": before=45, after=45'" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "compare_counts('join')" ] }, { "cell_type": "markdown", "id": "39f616bd", "metadata": { "origin_pos": 15 }, "source": [ "After subsampling, we map tokens to their indices for the corpus.\n" ] }, { "cell_type": "code", "execution_count": 8, "id": "074161dc", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:37:09.056059Z", "iopub.status.busy": "2023-08-18T19:37:09.055315Z", "iopub.status.idle": "2023-08-18T19:37:09.375685Z", "shell.execute_reply": "2023-08-18T19:37:09.374502Z" }, "origin_pos": 16, "tab": [ "pytorch" ] }, "outputs": [ { "data": { "text/plain": [ "[[], [4127, 3228, 1773], [3922, 1922, 4743, 2696]]" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "corpus = [vocab[line] for line in subsampled]\n", "corpus[:3]" ] }, { "cell_type": "markdown", "id": "52de792c", "metadata": { "origin_pos": 17 }, "source": [ "## Extracting Center Words and Context Words\n", "\n", "\n", "The following `get_centers_and_contexts`\n", "function extracts all the \n", "center words and their context words\n", "from `corpus`.\n", "It uniformly samples an integer between 1 and `max_window_size`\n", "at random as the context window size.\n", "For any center word,\n", "those words \n", "whose distance from it\n", "does not exceed the sampled\n", "context window size\n", "are its context words.\n" ] }, { "cell_type": "code", "execution_count": 9, "id": "25925d85", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:37:09.380090Z", "iopub.status.busy": "2023-08-18T19:37:09.379512Z", "iopub.status.idle": "2023-08-18T19:37:09.386244Z", "shell.execute_reply": "2023-08-18T19:37:09.385233Z" }, "origin_pos": 18, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "#@save\n", "def get_centers_and_contexts(corpus, max_window_size):\n", " \"\"\"Return center words and context words in skip-gram.\"\"\"\n", " centers, contexts = [], []\n", " for line in corpus:\n", " # To form a \"center word--context word\" pair, each sentence needs to\n", " # have at least 2 words\n", " if len(line) < 2:\n", " continue\n", " centers += line\n", " for i in range(len(line)): # Context window centered at `i`\n", " window_size = random.randint(1, max_window_size)\n", " indices = list(range(max(0, i - window_size),\n", " min(len(line), i + 1 + window_size)))\n", " # Exclude the center word from the context words\n", " indices.remove(i)\n", " contexts.append([line[idx] for idx in indices])\n", " return centers, contexts" ] }, { "cell_type": "markdown", "id": "0dbb9f6d", "metadata": { "origin_pos": 19 }, "source": [ "Next, we create an artificial dataset containing two sentences of 7 and 3 words, respectively. \n", "Let the maximum context window size be 2 \n", "and print all the center words and their context words.\n" ] }, { "cell_type": "code", "execution_count": 10, "id": "bc0a75c3", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:37:09.390499Z", "iopub.status.busy": "2023-08-18T19:37:09.389680Z", "iopub.status.idle": "2023-08-18T19:37:09.396152Z", "shell.execute_reply": "2023-08-18T19:37:09.395286Z" }, "origin_pos": 20, "tab": [ "pytorch" ] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "dataset [[0, 1, 2, 3, 4, 5, 6], [7, 8, 9]]\n", "center 0 has contexts [1]\n", "center 1 has contexts [0, 2]\n", "center 2 has contexts [0, 1, 3, 4]\n", "center 3 has contexts [1, 2, 4, 5]\n", "center 4 has contexts [2, 3, 5, 6]\n", "center 5 has contexts [3, 4, 6]\n", "center 6 has contexts [5]\n", "center 7 has contexts [8, 9]\n", "center 8 has contexts [7, 9]\n", "center 9 has contexts [7, 8]\n" ] } ], "source": [ "tiny_dataset = [list(range(7)), list(range(7, 10))]\n", "print('dataset', tiny_dataset)\n", "for center, context in zip(*get_centers_and_contexts(tiny_dataset, 2)):\n", " print('center', center, 'has contexts', context)" ] }, { "cell_type": "markdown", "id": "ffb6b952", "metadata": { "origin_pos": 21 }, "source": [ "When training on the PTB dataset,\n", "we set the maximum context window size to 5. \n", "The following extracts all the center words and their context words in the dataset.\n" ] }, { "cell_type": "code", "execution_count": 11, "id": "c98f0160", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:37:09.400516Z", "iopub.status.busy": "2023-08-18T19:37:09.399775Z", "iopub.status.idle": "2023-08-18T19:37:11.117700Z", "shell.execute_reply": "2023-08-18T19:37:11.116857Z" }, "origin_pos": 22, "tab": [ "pytorch" ] }, "outputs": [ { "data": { "text/plain": [ "'# center-context pairs: 1503420'" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "all_centers, all_contexts = get_centers_and_contexts(corpus, 5)\n", "f'# center-context pairs: {sum([len(contexts) for contexts in all_contexts])}'" ] }, { "cell_type": "markdown", "id": "97541501", "metadata": { "origin_pos": 23 }, "source": [ "## Negative Sampling\n", "\n", "We use negative sampling for approximate training. \n", "To sample noise words according to \n", "a predefined distribution,\n", "we define the following `RandomGenerator` class,\n", "where the (possibly unnormalized) sampling distribution is passed\n", "via the argument `sampling_weights`.\n" ] }, { "cell_type": "code", "execution_count": 12, "id": "77ae8ed2", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:37:11.121021Z", "iopub.status.busy": "2023-08-18T19:37:11.120728Z", "iopub.status.idle": "2023-08-18T19:37:11.127039Z", "shell.execute_reply": "2023-08-18T19:37:11.125995Z" }, "origin_pos": 24, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "#@save\n", "class RandomGenerator:\n", " \"\"\"Randomly draw among {1, ..., n} according to n sampling weights.\"\"\"\n", " def __init__(self, sampling_weights):\n", " # Exclude\n", " self.population = list(range(1, len(sampling_weights) + 1))\n", " self.sampling_weights = sampling_weights\n", " self.candidates = []\n", " self.i = 0\n", "\n", " def draw(self):\n", " if self.i == len(self.candidates):\n", " # Cache `k` random sampling results\n", " self.candidates = random.choices(\n", " self.population, self.sampling_weights, k=10000)\n", " self.i = 0\n", " self.i += 1\n", " return self.candidates[self.i - 1]" ] }, { "cell_type": "markdown", "id": "b13c7e0f", "metadata": { "origin_pos": 25 }, "source": [ "For example, \n", "we can draw 10 random variables $X$\n", "among indices 1, 2, and 3\n", "with sampling probabilities $P(X=1)=2/9, P(X=2)=3/9$, and $P(X=3)=4/9$ as follows.\n" ] }, { "cell_type": "markdown", "id": "9e301f88", "metadata": { "origin_pos": 27 }, "source": [ "For a pair of center word and context word, \n", "we randomly sample `K` (5 in the experiment) noise words. According to the suggestions in the word2vec paper,\n", "the sampling probability $P(w)$ of \n", "a noise word $w$\n", "is \n", "set to its relative frequency \n", "in the dictionary\n", "raised to \n", "the power of 0.75 :cite:`Mikolov.Sutskever.Chen.ea.2013`.\n" ] }, { "cell_type": "code", "execution_count": 13, "id": "d6f7d4e5", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:37:11.130327Z", "iopub.status.busy": "2023-08-18T19:37:11.130041Z", "iopub.status.idle": "2023-08-18T19:37:26.129919Z", "shell.execute_reply": "2023-08-18T19:37:26.128681Z" }, "origin_pos": 28, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "#@save\n", "def get_negatives(all_contexts, vocab, counter, K):\n", " \"\"\"Return noise words in negative sampling.\"\"\"\n", " # Sampling weights for words with indices 1, 2, ... (index 0 is the\n", " # excluded unknown token) in the vocabulary\n", " sampling_weights = [counter[vocab.to_tokens(i)]**0.75\n", " for i in range(1, len(vocab))]\n", " all_negatives, generator = [], RandomGenerator(sampling_weights)\n", " for contexts in all_contexts:\n", " negatives = []\n", " while len(negatives) < len(contexts) * K:\n", " neg = generator.draw()\n", " # Noise words cannot be context words\n", " if neg not in contexts:\n", " negatives.append(neg)\n", " all_negatives.append(negatives)\n", " return all_negatives\n", "\n", "all_negatives = get_negatives(all_contexts, vocab, counter, 5)" ] }, { "cell_type": "markdown", "id": "c2c00d37", "metadata": { "origin_pos": 29 }, "source": [ "## Loading Training Examples in Minibatches\n", ":label:`subsec_word2vec-minibatch-loading`\n", "\n", "After\n", "all the center words\n", "together with their\n", "context words and sampled noise words are extracted,\n", "they will be transformed into \n", "minibatches of examples\n", "that can be iteratively loaded\n", "during training.\n", "\n", "\n", "\n", "In a minibatch,\n", "the $i^\\textrm{th}$ example includes a center word\n", "and its $n_i$ context words and $m_i$ noise words. \n", "Due to varying context window sizes,\n", "$n_i+m_i$ varies for different $i$.\n", "Thus,\n", "for each example\n", "we concatenate its context words and noise words in \n", "the `contexts_negatives` variable,\n", "and pad zeros until the concatenation length\n", "reaches $\\max_i n_i+m_i$ (`max_len`).\n", "To exclude paddings\n", "in the calculation of the loss,\n", "we define a mask variable `masks`.\n", "There is a one-to-one correspondence\n", "between elements in `masks` and elements in `contexts_negatives`,\n", "where zeros (otherwise ones) in `masks` correspond to paddings in `contexts_negatives`.\n", "\n", "\n", "To distinguish between positive and negative examples,\n", "we separate context words from noise words in `contexts_negatives` via a `labels` variable. \n", "Similar to `masks`,\n", "there is also a one-to-one correspondence\n", "between elements in `labels` and elements in `contexts_negatives`,\n", "where ones (otherwise zeros) in `labels` correspond to context words (positive examples) in `contexts_negatives`.\n", "\n", "\n", "The above idea is implemented in the following `batchify` function.\n", "Its input `data` is a list with length\n", "equal to the batch size,\n", "where each element is an example\n", "consisting of\n", "the center word `center`, its context words `context`, and its noise words `negative`.\n", "This function returns \n", "a minibatch that can be loaded for calculations \n", "during training,\n", "such as including the mask variable.\n" ] }, { "cell_type": "code", "execution_count": 14, "id": "be9e9c90", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:37:26.135084Z", "iopub.status.busy": "2023-08-18T19:37:26.134724Z", "iopub.status.idle": "2023-08-18T19:37:26.142942Z", "shell.execute_reply": "2023-08-18T19:37:26.142046Z" }, "origin_pos": 30, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "#@save\n", "def batchify(data):\n", " \"\"\"Return a minibatch of examples for skip-gram with negative sampling.\"\"\"\n", " max_len = max(len(c) + len(n) for _, c, n in data)\n", " centers, contexts_negatives, masks, labels = [], [], [], []\n", " for center, context, negative in data:\n", " cur_len = len(context) + len(negative)\n", " centers += [center]\n", " contexts_negatives += [context + negative + [0] * (max_len - cur_len)]\n", " masks += [[1] * cur_len + [0] * (max_len - cur_len)]\n", " labels += [[1] * len(context) + [0] * (max_len - len(context))]\n", " return (torch.tensor(centers).reshape((-1, 1)), torch.tensor(\n", " contexts_negatives), torch.tensor(masks), torch.tensor(labels))" ] }, { "cell_type": "markdown", "id": "c200c24d", "metadata": { "origin_pos": 31 }, "source": [ "Let's test this function using a minibatch of two examples.\n" ] }, { "cell_type": "code", "execution_count": 15, "id": "79be2d26", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:37:26.149204Z", "iopub.status.busy": "2023-08-18T19:37:26.147743Z", "iopub.status.idle": "2023-08-18T19:37:26.162562Z", "shell.execute_reply": "2023-08-18T19:37:26.157787Z" }, "origin_pos": 32, "tab": [ "pytorch" ] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "centers = tensor([[1],\n", " [1]])\n", "contexts_negatives = tensor([[2, 2, 3, 3, 3, 3],\n", " [2, 2, 2, 3, 3, 0]])\n", "masks = tensor([[1, 1, 1, 1, 1, 1],\n", " [1, 1, 1, 1, 1, 0]])\n", "labels = tensor([[1, 1, 0, 0, 0, 0],\n", " [1, 1, 1, 0, 0, 0]])\n" ] } ], "source": [ "x_1 = (1, [2, 2], [3, 3, 3, 3])\n", "x_2 = (1, [2, 2, 2], [3, 3])\n", "batch = batchify((x_1, x_2))\n", "\n", "names = ['centers', 'contexts_negatives', 'masks', 'labels']\n", "for name, data in zip(names, batch):\n", " print(name, '=', data)" ] }, { "cell_type": "markdown", "id": "ba41d95f", "metadata": { "origin_pos": 33 }, "source": [ "## Putting It All Together\n", "\n", "Last, we define the `load_data_ptb` function that reads the PTB dataset and returns the data iterator and the vocabulary.\n" ] }, { "cell_type": "code", "execution_count": 16, "id": "3220f70b", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:37:26.167245Z", "iopub.status.busy": "2023-08-18T19:37:26.166697Z", "iopub.status.idle": "2023-08-18T19:37:26.183618Z", "shell.execute_reply": "2023-08-18T19:37:26.178896Z" }, "origin_pos": 35, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "#@save\n", "def load_data_ptb(batch_size, max_window_size, num_noise_words):\n", " \"\"\"Download the PTB dataset and then load it into memory.\"\"\"\n", " num_workers = d2l.get_dataloader_workers()\n", " sentences = read_ptb()\n", " vocab = d2l.Vocab(sentences, min_freq=10)\n", " subsampled, counter = subsample(sentences, vocab)\n", " corpus = [vocab[line] for line in subsampled]\n", " all_centers, all_contexts = get_centers_and_contexts(\n", " corpus, max_window_size)\n", " all_negatives = get_negatives(\n", " all_contexts, vocab, counter, num_noise_words)\n", "\n", " class PTBDataset(torch.utils.data.Dataset):\n", " def __init__(self, centers, contexts, negatives):\n", " assert len(centers) == len(contexts) == len(negatives)\n", " self.centers = centers\n", " self.contexts = contexts\n", " self.negatives = negatives\n", "\n", " def __getitem__(self, index):\n", " return (self.centers[index], self.contexts[index],\n", " self.negatives[index])\n", "\n", " def __len__(self):\n", " return len(self.centers)\n", "\n", " dataset = PTBDataset(all_centers, all_contexts, all_negatives)\n", "\n", " data_iter = torch.utils.data.DataLoader(dataset, batch_size, shuffle=True,\n", " collate_fn=batchify,\n", " num_workers=num_workers)\n", " return data_iter, vocab" ] }, { "cell_type": "markdown", "id": "05e4e4b4", "metadata": { "origin_pos": 36 }, "source": [ "Let's print the first minibatch of the data iterator.\n" ] }, { "cell_type": "code", "execution_count": 17, "id": "7d42be08", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:37:26.188797Z", "iopub.status.busy": "2023-08-18T19:37:26.186965Z", "iopub.status.idle": "2023-08-18T19:37:45.999655Z", "shell.execute_reply": "2023-08-18T19:37:45.998399Z" }, "origin_pos": 37, "tab": [ "pytorch" ] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "centers shape: torch.Size([512, 1])\n", "contexts_negatives shape: torch.Size([512, 60])\n", "masks shape: torch.Size([512, 60])\n", "labels shape: torch.Size([512, 60])\n" ] } ], "source": [ "data_iter, vocab = load_data_ptb(512, 5, 5)\n", "for batch in data_iter:\n", " for name, data in zip(names, batch):\n", " print(name, 'shape:', data.shape)\n", " break" ] }, { "cell_type": "markdown", "id": "b0ea8d0f", "metadata": { "origin_pos": 38 }, "source": [ "## Summary\n", "\n", "* High-frequency words may not be so useful in training. We can subsample them for speedup in training.\n", "* For computational efficiency, we load examples in minibatches. We can define other variables to distinguish paddings from non-paddings, and positive examples from negative ones.\n", "\n", "\n", "\n", "## Exercises\n", "\n", "1. How does the running time of code in this section changes if not using subsampling?\n", "1. The `RandomGenerator` class caches `k` random sampling results. Set `k` to other values and see how it affects the data loading speed.\n", "1. What other hyperparameters in the code of this section may affect the data loading speed?\n" ] }, { "cell_type": "markdown", "id": "a1293d2b", "metadata": { "origin_pos": 40, "tab": [ "pytorch" ] }, "source": [ "[Discussions](https://discuss.d2l.ai/t/1330)\n" ] } ], "metadata": { "language_info": { "name": "python" }, "required_libs": [] }, "nbformat": 4, "nbformat_minor": 5 }