{ "cells": [ { "cell_type": "markdown", "id": "14c03cda", "metadata": { "origin_pos": 0 }, "source": [ "# The Dataset for Pretraining BERT\n", ":label:`sec_bert-dataset`\n", "\n", "To pretrain the BERT model as implemented in :numref:`sec_bert`,\n", "we need to generate the dataset in the ideal format to facilitate\n", "the two pretraining tasks:\n", "masked language modeling and next sentence prediction.\n", "On the one hand,\n", "the original BERT model is pretrained on the concatenation of\n", "two huge corpora BookCorpus and English Wikipedia (see :numref:`subsec_bert_pretraining_tasks`),\n", "making it hard to run for most readers of this book.\n", "On the other hand,\n", "the off-the-shelf pretrained BERT model\n", "may not fit for applications from specific domains like medicine.\n", "Thus, it is getting popular to pretrain BERT on a customized dataset.\n", "To facilitate the demonstration of BERT pretraining,\n", "we use a smaller corpus WikiText-2 :cite:`Merity.Xiong.Bradbury.ea.2016`.\n", "\n", "Comparing with the PTB dataset used for pretraining word2vec in :numref:`sec_word2vec_data`,\n", "WikiText-2 (i) retains the original punctuation, making it suitable for next sentence prediction; (ii) retains the original case and numbers; (iii) is over twice larger.\n" ] }, { "cell_type": "code", "execution_count": 1, "id": "b349c114", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:27:51.349012Z", "iopub.status.busy": "2023-08-18T19:27:51.348380Z", "iopub.status.idle": "2023-08-18T19:27:54.176466Z", "shell.execute_reply": "2023-08-18T19:27:54.175190Z" }, "origin_pos": 2, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "import os\n", "import random\n", "import torch\n", "from d2l import torch as d2l" ] }, { "cell_type": "markdown", "id": "00639076", "metadata": { "origin_pos": 3 }, "source": [ "In [**the WikiText-2 dataset**],\n", "each line represents a paragraph where\n", "space is inserted between any punctuation and its preceding token.\n", "Paragraphs with at least two sentences are retained.\n", "To split sentences, we only use the period as the delimiter for simplicity.\n", "We leave discussions of more complex sentence splitting techniques in the exercises\n", "at the end of this section.\n" ] }, { "cell_type": "code", "execution_count": 2, "id": "ea9beb6f", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:27:54.181061Z", "iopub.status.busy": "2023-08-18T19:27:54.180516Z", "iopub.status.idle": "2023-08-18T19:27:54.187506Z", "shell.execute_reply": "2023-08-18T19:27:54.186224Z" }, "origin_pos": 4, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "#@save\n", "d2l.DATA_HUB['wikitext-2'] = (\n", " 'https://s3.amazonaws.com/research.metamind.io/wikitext/'\n", " 'wikitext-2-v1.zip', '3c914d17d80b1459be871a5039ac23e752a53cbe')\n", "\n", "#@save\n", "def _read_wiki(data_dir):\n", " file_name = os.path.join(data_dir, 'wiki.train.tokens')\n", " with open(file_name, 'r') as f:\n", " lines = f.readlines()\n", " # Uppercase letters are converted to lowercase ones\n", " paragraphs = [line.strip().lower().split(' . ')\n", " for line in lines if len(line.split(' . ')) >= 2]\n", " random.shuffle(paragraphs)\n", " return paragraphs" ] }, { "cell_type": "markdown", "id": "686f3b50", "metadata": { "origin_pos": 5 }, "source": [ "## Defining Helper Functions for Pretraining Tasks\n", "\n", "In the following,\n", "we begin by implementing helper functions for the two BERT pretraining tasks:\n", "next sentence prediction and masked language modeling.\n", "These helper functions will be invoked later\n", "when transforming the raw text corpus\n", "into the dataset of the ideal format to pretrain BERT.\n", "\n", "### [**Generating the Next Sentence Prediction Task**]\n", "\n", "According to descriptions of :numref:`subsec_nsp`,\n", "the `_get_next_sentence` function generates a training example\n", "for the binary classification task.\n" ] }, { "cell_type": "code", "execution_count": 3, "id": "bb7772df", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:27:54.193280Z", "iopub.status.busy": "2023-08-18T19:27:54.192815Z", "iopub.status.idle": "2023-08-18T19:27:54.200587Z", "shell.execute_reply": "2023-08-18T19:27:54.199221Z" }, "origin_pos": 6, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "#@save\n", "def _get_next_sentence(sentence, next_sentence, paragraphs):\n", " if random.random() < 0.5:\n", " is_next = True\n", " else:\n", " # `paragraphs` is a list of lists of lists\n", " next_sentence = random.choice(random.choice(paragraphs))\n", " is_next = False\n", " return sentence, next_sentence, is_next" ] }, { "cell_type": "markdown", "id": "a005849c", "metadata": { "origin_pos": 7 }, "source": [ "The following function generates training examples for next sentence prediction\n", "from the input `paragraph` by invoking the `_get_next_sentence` function.\n", "Here `paragraph` is a list of sentences, where each sentence is a list of tokens.\n", "The argument `max_len` specifies the maximum length of a BERT input sequence during pretraining.\n" ] }, { "cell_type": "code", "execution_count": 4, "id": "8a515271", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:27:54.204406Z", "iopub.status.busy": "2023-08-18T19:27:54.204109Z", "iopub.status.idle": "2023-08-18T19:27:54.210105Z", "shell.execute_reply": "2023-08-18T19:27:54.209309Z" }, "origin_pos": 8, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "#@save\n", "def _get_nsp_data_from_paragraph(paragraph, paragraphs, vocab, max_len):\n", " nsp_data_from_paragraph = []\n", " for i in range(len(paragraph) - 1):\n", " tokens_a, tokens_b, is_next = _get_next_sentence(\n", " paragraph[i], paragraph[i + 1], paragraphs)\n", " # Consider 1 '' token and 2 '' tokens\n", " if len(tokens_a) + len(tokens_b) + 3 > max_len:\n", " continue\n", " tokens, segments = d2l.get_tokens_and_segments(tokens_a, tokens_b)\n", " nsp_data_from_paragraph.append((tokens, segments, is_next))\n", " return nsp_data_from_paragraph" ] }, { "cell_type": "markdown", "id": "074b9a12", "metadata": { "origin_pos": 9 }, "source": [ "### [**Generating the Masked Language Modeling Task**]\n", ":label:`subsec_prepare_mlm_data`\n", "\n", "In order to generate training examples\n", "for the masked language modeling task\n", "from a BERT input sequence,\n", "we define the following `_replace_mlm_tokens` function.\n", "In its inputs, `tokens` is a list of tokens representing a BERT input sequence,\n", "`candidate_pred_positions` is a list of token indices of the BERT input sequence\n", "excluding those of special tokens (special tokens are not predicted in the masked language modeling task),\n", "and `num_mlm_preds` indicates the number of predictions (recall 15% random tokens to predict).\n", "Following the definition of the masked language modeling task in :numref:`subsec_mlm`,\n", "at each prediction position, the input may be replaced by\n", "a special “<mask>” token or a random token, or remain unchanged.\n", "In the end, the function returns the input tokens after possible replacement,\n", "the token indices where predictions take place and labels for these predictions.\n" ] }, { "cell_type": "code", "execution_count": 5, "id": "5b4c14fe", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:27:54.213353Z", "iopub.status.busy": "2023-08-18T19:27:54.213061Z", "iopub.status.idle": "2023-08-18T19:27:54.220268Z", "shell.execute_reply": "2023-08-18T19:27:54.219354Z" }, "origin_pos": 10, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "#@save\n", "def _replace_mlm_tokens(tokens, candidate_pred_positions, num_mlm_preds,\n", " vocab):\n", " # For the input of a masked language model, make a new copy of tokens and\n", " # replace some of them by '' or random tokens\n", " mlm_input_tokens = [token for token in tokens]\n", " pred_positions_and_labels = []\n", " # Shuffle for getting 15% random tokens for prediction in the masked\n", " # language modeling task\n", " random.shuffle(candidate_pred_positions)\n", " for mlm_pred_position in candidate_pred_positions:\n", " if len(pred_positions_and_labels) >= num_mlm_preds:\n", " break\n", " masked_token = None\n", " # 80% of the time: replace the word with the '' token\n", " if random.random() < 0.8:\n", " masked_token = ''\n", " else:\n", " # 10% of the time: keep the word unchanged\n", " if random.random() < 0.5:\n", " masked_token = tokens[mlm_pred_position]\n", " # 10% of the time: replace the word with a random word\n", " else:\n", " masked_token = random.choice(vocab.idx_to_token)\n", " mlm_input_tokens[mlm_pred_position] = masked_token\n", " pred_positions_and_labels.append(\n", " (mlm_pred_position, tokens[mlm_pred_position]))\n", " return mlm_input_tokens, pred_positions_and_labels" ] }, { "cell_type": "markdown", "id": "9cc4dc18", "metadata": { "origin_pos": 11 }, "source": [ "By invoking the aforementioned `_replace_mlm_tokens` function,\n", "the following function takes a BERT input sequence (`tokens`)\n", "as an input and returns indices of the input tokens\n", "(after possible token replacement as described in :numref:`subsec_mlm`),\n", "the token indices where predictions take place,\n", "and label indices for these predictions.\n" ] }, { "cell_type": "code", "execution_count": 6, "id": "c71c28c0", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:27:54.224660Z", "iopub.status.busy": "2023-08-18T19:27:54.224368Z", "iopub.status.idle": "2023-08-18T19:27:54.231144Z", "shell.execute_reply": "2023-08-18T19:27:54.230276Z" }, "origin_pos": 12, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "#@save\n", "def _get_mlm_data_from_tokens(tokens, vocab):\n", " candidate_pred_positions = []\n", " # `tokens` is a list of strings\n", " for i, token in enumerate(tokens):\n", " # Special tokens are not predicted in the masked language modeling\n", " # task\n", " if token in ['', '']:\n", " continue\n", " candidate_pred_positions.append(i)\n", " # 15% of random tokens are predicted in the masked language modeling task\n", " num_mlm_preds = max(1, round(len(tokens) * 0.15))\n", " mlm_input_tokens, pred_positions_and_labels = _replace_mlm_tokens(\n", " tokens, candidate_pred_positions, num_mlm_preds, vocab)\n", " pred_positions_and_labels = sorted(pred_positions_and_labels,\n", " key=lambda x: x[0])\n", " pred_positions = [v[0] for v in pred_positions_and_labels]\n", " mlm_pred_labels = [v[1] for v in pred_positions_and_labels]\n", " return vocab[mlm_input_tokens], pred_positions, vocab[mlm_pred_labels]" ] }, { "cell_type": "markdown", "id": "d4716fe1", "metadata": { "origin_pos": 13 }, "source": [ "## Transforming Text into the Pretraining Dataset\n", "\n", "Now we are almost ready to customize a `Dataset` class for pretraining BERT.\n", "Before that, \n", "we still need to define a helper function `_pad_bert_inputs`\n", "to [**append the special “<pad>” tokens to the inputs.**]\n", "Its argument `examples` contain the outputs from the helper functions `_get_nsp_data_from_paragraph` and `_get_mlm_data_from_tokens` for the two pretraining tasks.\n" ] }, { "cell_type": "code", "execution_count": 7, "id": "9b2d2247", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:27:54.234922Z", "iopub.status.busy": "2023-08-18T19:27:54.234060Z", "iopub.status.idle": "2023-08-18T19:27:54.243205Z", "shell.execute_reply": "2023-08-18T19:27:54.242138Z" }, "origin_pos": 15, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "#@save\n", "def _pad_bert_inputs(examples, max_len, vocab):\n", " max_num_mlm_preds = round(max_len * 0.15)\n", " all_token_ids, all_segments, valid_lens, = [], [], []\n", " all_pred_positions, all_mlm_weights, all_mlm_labels = [], [], []\n", " nsp_labels = []\n", " for (token_ids, pred_positions, mlm_pred_label_ids, segments,\n", " is_next) in examples:\n", " all_token_ids.append(torch.tensor(token_ids + [vocab['']] * (\n", " max_len - len(token_ids)), dtype=torch.long))\n", " all_segments.append(torch.tensor(segments + [0] * (\n", " max_len - len(segments)), dtype=torch.long))\n", " # `valid_lens` excludes count of '' tokens\n", " valid_lens.append(torch.tensor(len(token_ids), dtype=torch.float32))\n", " all_pred_positions.append(torch.tensor(pred_positions + [0] * (\n", " max_num_mlm_preds - len(pred_positions)), dtype=torch.long))\n", " # Predictions of padded tokens will be filtered out in the loss via\n", " # multiplication of 0 weights\n", " all_mlm_weights.append(\n", " torch.tensor([1.0] * len(mlm_pred_label_ids) + [0.0] * (\n", " max_num_mlm_preds - len(pred_positions)),\n", " dtype=torch.float32))\n", " all_mlm_labels.append(torch.tensor(mlm_pred_label_ids + [0] * (\n", " max_num_mlm_preds - len(mlm_pred_label_ids)), dtype=torch.long))\n", " nsp_labels.append(torch.tensor(is_next, dtype=torch.long))\n", " return (all_token_ids, all_segments, valid_lens, all_pred_positions,\n", " all_mlm_weights, all_mlm_labels, nsp_labels)" ] }, { "cell_type": "markdown", "id": "fdc8f394", "metadata": { "origin_pos": 16 }, "source": [ "Putting the helper functions for\n", "generating training examples of the two pretraining tasks,\n", "and the helper function for padding inputs together,\n", "we customize the following `_WikiTextDataset` class as [**the WikiText-2 dataset for pretraining BERT**].\n", "By implementing the `__getitem__ `function,\n", "we can arbitrarily access the pretraining (masked language modeling and next sentence prediction) examples \n", "generated from a pair of sentences from the WikiText-2 corpus.\n", "\n", "The original BERT model uses WordPiece embeddings whose vocabulary size is 30000 :cite:`Wu.Schuster.Chen.ea.2016`.\n", "The tokenization method of WordPiece is a slight modification of\n", "the original byte pair encoding algorithm in :numref:`subsec_Byte_Pair_Encoding`.\n", "For simplicity, we use the `d2l.tokenize` function for tokenization.\n", "Infrequent tokens that appear less than five times are filtered out.\n" ] }, { "cell_type": "code", "execution_count": 8, "id": "27a00351", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:27:54.247275Z", "iopub.status.busy": "2023-08-18T19:27:54.246496Z", "iopub.status.idle": "2023-08-18T19:27:54.255307Z", "shell.execute_reply": "2023-08-18T19:27:54.254197Z" }, "origin_pos": 18, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "#@save\n", "class _WikiTextDataset(torch.utils.data.Dataset):\n", " def __init__(self, paragraphs, max_len):\n", " # Input `paragraphs[i]` is a list of sentence strings representing a\n", " # paragraph; while output `paragraphs[i]` is a list of sentences\n", " # representing a paragraph, where each sentence is a list of tokens\n", " paragraphs = [d2l.tokenize(\n", " paragraph, token='word') for paragraph in paragraphs]\n", " sentences = [sentence for paragraph in paragraphs\n", " for sentence in paragraph]\n", " self.vocab = d2l.Vocab(sentences, min_freq=5, reserved_tokens=[\n", " '', '', '', ''])\n", " # Get data for the next sentence prediction task\n", " examples = []\n", " for paragraph in paragraphs:\n", " examples.extend(_get_nsp_data_from_paragraph(\n", " paragraph, paragraphs, self.vocab, max_len))\n", " # Get data for the masked language model task\n", " examples = [(_get_mlm_data_from_tokens(tokens, self.vocab)\n", " + (segments, is_next))\n", " for tokens, segments, is_next in examples]\n", " # Pad inputs\n", " (self.all_token_ids, self.all_segments, self.valid_lens,\n", " self.all_pred_positions, self.all_mlm_weights,\n", " self.all_mlm_labels, self.nsp_labels) = _pad_bert_inputs(\n", " examples, max_len, self.vocab)\n", "\n", " def __getitem__(self, idx):\n", " return (self.all_token_ids[idx], self.all_segments[idx],\n", " self.valid_lens[idx], self.all_pred_positions[idx],\n", " self.all_mlm_weights[idx], self.all_mlm_labels[idx],\n", " self.nsp_labels[idx])\n", "\n", " def __len__(self):\n", " return len(self.all_token_ids)" ] }, { "cell_type": "markdown", "id": "16d11c15", "metadata": { "origin_pos": 19 }, "source": [ "By using the `_read_wiki` function and the `_WikiTextDataset` class,\n", "we define the following `load_data_wiki` to [**download and WikiText-2 dataset\n", "and generate pretraining examples**] from it.\n" ] }, { "cell_type": "code", "execution_count": 9, "id": "6e8efaa8", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:27:54.258912Z", "iopub.status.busy": "2023-08-18T19:27:54.258166Z", "iopub.status.idle": "2023-08-18T19:27:54.264300Z", "shell.execute_reply": "2023-08-18T19:27:54.263141Z" }, "origin_pos": 21, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "#@save\n", "def load_data_wiki(batch_size, max_len):\n", " \"\"\"Load the WikiText-2 dataset.\"\"\"\n", " num_workers = d2l.get_dataloader_workers()\n", " data_dir = d2l.download_extract('wikitext-2', 'wikitext-2')\n", " paragraphs = _read_wiki(data_dir)\n", " train_set = _WikiTextDataset(paragraphs, max_len)\n", " train_iter = torch.utils.data.DataLoader(train_set, batch_size,\n", " shuffle=True, num_workers=num_workers)\n", " return train_iter, train_set.vocab" ] }, { "cell_type": "markdown", "id": "d2986f60", "metadata": { "origin_pos": 22 }, "source": [ "Setting the batch size to 512 and the maximum length of a BERT input sequence to be 64,\n", "we [**print out the shapes of a minibatch of BERT pretraining examples**].\n", "Note that in each BERT input sequence,\n", "$10$ ($64 \\times 0.15$) positions are predicted for the masked language modeling task.\n" ] }, { "cell_type": "code", "execution_count": 10, "id": "6b866b9c", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:27:54.267688Z", "iopub.status.busy": "2023-08-18T19:27:54.267106Z", "iopub.status.idle": "2023-08-18T19:28:02.924683Z", "shell.execute_reply": "2023-08-18T19:28:02.923277Z" }, "origin_pos": 23, "tab": [ "pytorch" ] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Downloading ../data/wikitext-2-v1.zip from https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-v1.zip...\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "torch.Size([512, 64]) torch.Size([512, 64]) torch.Size([512]) torch.Size([512, 10]) torch.Size([512, 10]) torch.Size([512, 10]) torch.Size([512])\n" ] } ], "source": [ "batch_size, max_len = 512, 64\n", "train_iter, vocab = load_data_wiki(batch_size, max_len)\n", "\n", "for (tokens_X, segments_X, valid_lens_x, pred_positions_X, mlm_weights_X,\n", " mlm_Y, nsp_y) in train_iter:\n", " print(tokens_X.shape, segments_X.shape, valid_lens_x.shape,\n", " pred_positions_X.shape, mlm_weights_X.shape, mlm_Y.shape,\n", " nsp_y.shape)\n", " break" ] }, { "cell_type": "markdown", "id": "e5d9fd56", "metadata": { "origin_pos": 24 }, "source": [ "In the end, let's take a look at the vocabulary size.\n", "Even after filtering out infrequent tokens,\n", "it is still over twice larger than that of the PTB dataset.\n" ] }, { "cell_type": "code", "execution_count": 11, "id": "76cf9172", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:28:02.929142Z", "iopub.status.busy": "2023-08-18T19:28:02.928158Z", "iopub.status.idle": "2023-08-18T19:28:02.936806Z", "shell.execute_reply": "2023-08-18T19:28:02.935612Z" }, "origin_pos": 25, "tab": [ "pytorch" ] }, "outputs": [ { "data": { "text/plain": [ "20256" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "len(vocab)" ] }, { "cell_type": "markdown", "id": "3b598fc5", "metadata": { "origin_pos": 26 }, "source": [ "## Summary\n", "\n", "* Comparing with the PTB dataset, the WikiText-2 dateset retains the original punctuation, case and numbers, and is over twice larger.\n", "* We can arbitrarily access the pretraining (masked language modeling and next sentence prediction) examples generated from a pair of sentences from the WikiText-2 corpus.\n", "\n", "\n", "## Exercises\n", "\n", "1. For simplicity, the period is used as the only delimiter for splitting sentences. Try other sentence splitting techniques, such as the spaCy and NLTK. Take NLTK as an example. You need to install NLTK first: `pip install nltk`. In the code, first `import nltk`. Then, download the Punkt sentence tokenizer: `nltk.download('punkt')`. To split sentences such as `sentences = 'This is great ! Why not ?'`, invoking `nltk.tokenize.sent_tokenize(sentences)` will return a list of two sentence strings: `['This is great !', 'Why not ?']`.\n", "1. What is the vocabulary size if we do not filter out any infrequent token?\n" ] }, { "cell_type": "markdown", "id": "aad35f4a", "metadata": { "origin_pos": 28, "tab": [ "pytorch" ] }, "source": [ "[Discussions](https://discuss.d2l.ai/t/1496)\n" ] } ], "metadata": { "language_info": { "name": "python" }, "required_libs": [] }, "nbformat": 4, "nbformat_minor": 5 }