{ "cells": [ { "cell_type": "markdown", "id": "f25f38c1", "metadata": { "origin_pos": 0 }, "source": [ "# Natural Language Inference: Fine-Tuning BERT\n", ":label:`sec_natural-language-inference-bert`\n", "\n", "In earlier sections of this chapter,\n", "we have designed an attention-based architecture\n", "(in :numref:`sec_natural-language-inference-attention`)\n", "for the natural language inference task\n", "on the SNLI dataset (as described in :numref:`sec_natural-language-inference-and-dataset`).\n", "Now we revisit this task by fine-tuning BERT.\n", "As discussed in :numref:`sec_finetuning-bert`,\n", "natural language inference is a sequence-level text pair classification problem,\n", "and fine-tuning BERT only requires an additional MLP-based architecture,\n", "as illustrated in :numref:`fig_nlp-map-nli-bert`.\n", "\n", "![This section feeds pretrained BERT to an MLP-based architecture for natural language inference.](../img/nlp-map-nli-bert.svg)\n", ":label:`fig_nlp-map-nli-bert`\n", "\n", "In this section,\n", "we will download a pretrained small version of BERT,\n", "then fine-tune it\n", "for natural language inference on the SNLI dataset.\n" ] }, { "cell_type": "code", "execution_count": 1, "id": "9f088de5", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:39:22.440128Z", "iopub.status.busy": "2023-08-18T19:39:22.439451Z", "iopub.status.idle": "2023-08-18T19:39:25.853436Z", "shell.execute_reply": "2023-08-18T19:39:25.852139Z" }, "origin_pos": 2, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "import json\n", "import multiprocessing\n", "import os\n", "import torch\n", "from torch import nn\n", "from d2l import torch as d2l" ] }, { "cell_type": "markdown", "id": "caaf2d0d", "metadata": { "origin_pos": 3 }, "source": [ "## [**Loading Pretrained BERT**]\n", "\n", "We have explained how to pretrain BERT on the WikiText-2 dataset in\n", ":numref:`sec_bert-dataset` and :numref:`sec_bert-pretraining`\n", "(note that the original BERT model is pretrained on much bigger corpora).\n", "As discussed in :numref:`sec_bert-pretraining`,\n", "the original BERT model has hundreds of millions of parameters.\n", "In the following,\n", "we provide two versions of pretrained BERT:\n", "\"bert.base\" is about as big as the original BERT base model that requires a lot of computational resources to fine-tune,\n", "while \"bert.small\" is a small version to facilitate demonstration.\n" ] }, { "cell_type": "code", "execution_count": 2, "id": "fdd9ca6e", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:39:25.859081Z", "iopub.status.busy": "2023-08-18T19:39:25.858390Z", "iopub.status.idle": "2023-08-18T19:39:25.863351Z", "shell.execute_reply": "2023-08-18T19:39:25.862573Z" }, "origin_pos": 5, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "d2l.DATA_HUB['bert.base'] = (d2l.DATA_URL + 'bert.base.torch.zip',\n", " '225d66f04cae318b841a13d32af3acc165f253ac')\n", "d2l.DATA_HUB['bert.small'] = (d2l.DATA_URL + 'bert.small.torch.zip',\n", " 'c72329e68a732bef0452e4b96a1c341c8910f81f')" ] }, { "cell_type": "markdown", "id": "6e761acc", "metadata": { "origin_pos": 6 }, "source": [ "Either pretrained BERT model contains a \"vocab.json\" file that defines the vocabulary set\n", "and a \"pretrained.params\" file of the pretrained parameters.\n", "We implement the following `load_pretrained_model` function to [**load pretrained BERT parameters**].\n" ] }, { "cell_type": "code", "execution_count": 3, "id": "9ca530a2", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:39:25.867663Z", "iopub.status.busy": "2023-08-18T19:39:25.867063Z", "iopub.status.idle": "2023-08-18T19:39:25.874533Z", "shell.execute_reply": "2023-08-18T19:39:25.873621Z" }, "origin_pos": 8, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "def load_pretrained_model(pretrained_model, num_hiddens, ffn_num_hiddens,\n", " num_heads, num_blks, dropout, max_len, devices):\n", " data_dir = d2l.download_extract(pretrained_model)\n", " # Define an empty vocabulary to load the predefined vocabulary\n", " vocab = d2l.Vocab()\n", " vocab.idx_to_token = json.load(open(os.path.join(data_dir, 'vocab.json')))\n", " vocab.token_to_idx = {token: idx for idx, token in enumerate(\n", " vocab.idx_to_token)}\n", " bert = d2l.BERTModel(\n", " len(vocab), num_hiddens, ffn_num_hiddens=ffn_num_hiddens, num_heads=4,\n", " num_blks=2, dropout=0.2, max_len=max_len)\n", " # Load pretrained BERT parameters\n", " bert.load_state_dict(torch.load(os.path.join(data_dir,\n", " 'pretrained.params')))\n", " return bert, vocab" ] }, { "cell_type": "markdown", "id": "527319d5", "metadata": { "origin_pos": 9 }, "source": [ "To facilitate demonstration on most of machines,\n", "we will load and fine-tune the small version (\"bert.small\") of the pretrained BERT in this section.\n", "In the exercise, we will show how to fine-tune the much larger \"bert.base\" to significantly improve the testing accuracy.\n" ] }, { "cell_type": "code", "execution_count": 4, "id": "b4d73006", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:39:25.878388Z", "iopub.status.busy": "2023-08-18T19:39:25.877761Z", "iopub.status.idle": "2023-08-18T19:39:29.552585Z", "shell.execute_reply": "2023-08-18T19:39:29.550325Z" }, "origin_pos": 10, "tab": [ "pytorch" ] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Downloading ../data/bert.small.torch.zip from http://d2l-data.s3-accelerate.amazonaws.com/bert.small.torch.zip...\n" ] } ], "source": [ "devices = d2l.try_all_gpus()\n", "bert, vocab = load_pretrained_model(\n", " 'bert.small', num_hiddens=256, ffn_num_hiddens=512, num_heads=4,\n", " num_blks=2, dropout=0.1, max_len=512, devices=devices)" ] }, { "cell_type": "markdown", "id": "9599a861", "metadata": { "origin_pos": 11 }, "source": [ "## [**The Dataset for Fine-Tuning BERT**]\n", "\n", "For the downstream task natural language inference on the SNLI dataset,\n", "we define a customized dataset class `SNLIBERTDataset`.\n", "In each example,\n", "the premise and hypothesis form a pair of text sequence\n", "and is packed into one BERT input sequence as depicted in :numref:`fig_bert-two-seqs`.\n", "Recall :numref:`subsec_bert_input_rep` that segment IDs\n", "are used to distinguish the premise and the hypothesis in a BERT input sequence.\n", "With the predefined maximum length of a BERT input sequence (`max_len`),\n", "the last token of the longer of the input text pair keeps getting removed until\n", "`max_len` is met.\n", "To accelerate generation of the SNLI dataset\n", "for fine-tuning BERT,\n", "we use 4 worker processes to generate training or testing examples in parallel.\n" ] }, { "cell_type": "code", "execution_count": 5, "id": "0ef1ad4c", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:39:29.556910Z", "iopub.status.busy": "2023-08-18T19:39:29.556532Z", "iopub.status.idle": "2023-08-18T19:39:29.575941Z", "shell.execute_reply": "2023-08-18T19:39:29.571719Z" }, "origin_pos": 13, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "class SNLIBERTDataset(torch.utils.data.Dataset):\n", " def __init__(self, dataset, max_len, vocab=None):\n", " all_premise_hypothesis_tokens = [[\n", " p_tokens, h_tokens] for p_tokens, h_tokens in zip(\n", " *[d2l.tokenize([s.lower() for s in sentences])\n", " for sentences in dataset[:2]])]\n", "\n", " self.labels = torch.tensor(dataset[2])\n", " self.vocab = vocab\n", " self.max_len = max_len\n", " (self.all_token_ids, self.all_segments,\n", " self.valid_lens) = self._preprocess(all_premise_hypothesis_tokens)\n", " print('read ' + str(len(self.all_token_ids)) + ' examples')\n", "\n", " def _preprocess(self, all_premise_hypothesis_tokens):\n", " pool = multiprocessing.Pool(4) # Use 4 worker processes\n", " out = pool.map(self._mp_worker, all_premise_hypothesis_tokens)\n", " all_token_ids = [\n", " token_ids for token_ids, segments, valid_len in out]\n", " all_segments = [segments for token_ids, segments, valid_len in out]\n", " valid_lens = [valid_len for token_ids, segments, valid_len in out]\n", " return (torch.tensor(all_token_ids, dtype=torch.long),\n", " torch.tensor(all_segments, dtype=torch.long),\n", " torch.tensor(valid_lens))\n", "\n", " def _mp_worker(self, premise_hypothesis_tokens):\n", " p_tokens, h_tokens = premise_hypothesis_tokens\n", " self._truncate_pair_of_tokens(p_tokens, h_tokens)\n", " tokens, segments = d2l.get_tokens_and_segments(p_tokens, h_tokens)\n", " token_ids = self.vocab[tokens] + [self.vocab['']] \\\n", " * (self.max_len - len(tokens))\n", " segments = segments + [0] * (self.max_len - len(segments))\n", " valid_len = len(tokens)\n", " return token_ids, segments, valid_len\n", "\n", " def _truncate_pair_of_tokens(self, p_tokens, h_tokens):\n", " # Reserve slots for '', '', and '' tokens for the BERT\n", " # input\n", " while len(p_tokens) + len(h_tokens) > self.max_len - 3:\n", " if len(p_tokens) > len(h_tokens):\n", " p_tokens.pop()\n", " else:\n", " h_tokens.pop()\n", "\n", " def __getitem__(self, idx):\n", " return (self.all_token_ids[idx], self.all_segments[idx],\n", " self.valid_lens[idx]), self.labels[idx]\n", "\n", " def __len__(self):\n", " return len(self.all_token_ids)" ] }, { "cell_type": "markdown", "id": "972e4031", "metadata": { "origin_pos": 14 }, "source": [ "After downloading the SNLI dataset,\n", "we [**generate training and testing examples**]\n", "by instantiating the `SNLIBERTDataset` class.\n", "Such examples will be read in minibatches during training and testing\n", "of natural language inference.\n" ] }, { "cell_type": "code", "execution_count": 6, "id": "ba8fa6e9", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:39:29.579600Z", "iopub.status.busy": "2023-08-18T19:39:29.579246Z", "iopub.status.idle": "2023-08-18T19:40:35.629014Z", "shell.execute_reply": "2023-08-18T19:40:35.626314Z" }, "origin_pos": 16, "tab": [ "pytorch" ] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "read 549367 examples\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "read 9824 examples\n" ] } ], "source": [ "# Reduce `batch_size` if there is an out of memory error. In the original BERT\n", "# model, `max_len` = 512\n", "batch_size, max_len, num_workers = 512, 128, d2l.get_dataloader_workers()\n", "data_dir = d2l.download_extract('SNLI')\n", "train_set = SNLIBERTDataset(d2l.read_snli(data_dir, True), max_len, vocab)\n", "test_set = SNLIBERTDataset(d2l.read_snli(data_dir, False), max_len, vocab)\n", "train_iter = torch.utils.data.DataLoader(train_set, batch_size, shuffle=True,\n", " num_workers=num_workers)\n", "test_iter = torch.utils.data.DataLoader(test_set, batch_size,\n", " num_workers=num_workers)" ] }, { "cell_type": "markdown", "id": "1072ac02", "metadata": { "origin_pos": 17 }, "source": [ "## Fine-Tuning BERT\n", "\n", "As :numref:`fig_bert-two-seqs` indicates,\n", "fine-tuning BERT for natural language inference\n", "requires only an extra MLP consisting of two fully connected layers\n", "(see `self.hidden` and `self.output` in the following `BERTClassifier` class).\n", "[**This MLP transforms the\n", "BERT representation of the special “<cls>” token**],\n", "which encodes the information of both the premise and the hypothesis,\n", "(**into three outputs of natural language inference**):\n", "entailment, contradiction, and neutral.\n" ] }, { "cell_type": "code", "execution_count": 7, "id": "a65b8d66", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:40:35.635313Z", "iopub.status.busy": "2023-08-18T19:40:35.634393Z", "iopub.status.idle": "2023-08-18T19:40:35.643634Z", "shell.execute_reply": "2023-08-18T19:40:35.642590Z" }, "origin_pos": 19, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "class BERTClassifier(nn.Module):\n", " def __init__(self, bert):\n", " super(BERTClassifier, self).__init__()\n", " self.encoder = bert.encoder\n", " self.hidden = bert.hidden\n", " self.output = nn.LazyLinear(3)\n", "\n", " def forward(self, inputs):\n", " tokens_X, segments_X, valid_lens_x = inputs\n", " encoded_X = self.encoder(tokens_X, segments_X, valid_lens_x)\n", " return self.output(self.hidden(encoded_X[:, 0, :]))" ] }, { "cell_type": "markdown", "id": "02a14374", "metadata": { "origin_pos": 20 }, "source": [ "In the following,\n", "the pretrained BERT model `bert` is fed into the `BERTClassifier` instance `net` for\n", "the downstream application.\n", "In common implementations of BERT fine-tuning,\n", "only the parameters of the output layer of the additional MLP (`net.output`) will be learned from scratch.\n", "All the parameters of the pretrained BERT encoder (`net.encoder`) and the hidden layer of the additional MLP (`net.hidden`) will be fine-tuned.\n" ] }, { "cell_type": "code", "execution_count": 8, "id": "b92a26a4", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:40:35.647266Z", "iopub.status.busy": "2023-08-18T19:40:35.646712Z", "iopub.status.idle": "2023-08-18T19:40:35.653291Z", "shell.execute_reply": "2023-08-18T19:40:35.652150Z" }, "origin_pos": 22, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "net = BERTClassifier(bert)" ] }, { "cell_type": "markdown", "id": "1b303e5c", "metadata": { "origin_pos": 23 }, "source": [ "Recall that\n", "in :numref:`sec_bert`\n", "both the `MaskLM` class and the `NextSentencePred` class\n", "have parameters in their employed MLPs.\n", "These parameters are part of those in the pretrained BERT model\n", "`bert`, and thus part of parameters in `net`.\n", "However, such parameters are only for computing\n", "the masked language modeling loss\n", "and the next sentence prediction loss\n", "during pretraining.\n", "These two loss functions are irrelevant to fine-tuning downstream applications,\n", "thus the parameters of the employed MLPs in \n", "`MaskLM` and `NextSentencePred` are not updated (staled) when BERT is fine-tuned.\n", "\n", "To allow parameters with stale gradients,\n", "the flag `ignore_stale_grad=True` is set in the `step` function of `d2l.train_batch_ch13`.\n", "We use this function to train and evaluate the model `net` using the training set\n", "(`train_iter`) and the testing set (`test_iter`) of SNLI.\n", "Due to the limited computational resources, [**the training**] and testing accuracy\n", "can be further improved: we leave its discussions in the exercises.\n" ] }, { "cell_type": "code", "execution_count": 9, "id": "70669ab9", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:40:35.658377Z", "iopub.status.busy": "2023-08-18T19:40:35.657672Z", "iopub.status.idle": "2023-08-18T19:45:31.971648Z", "shell.execute_reply": "2023-08-18T19:45:31.970642Z" }, "origin_pos": 25, "tab": [ "pytorch" ] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "loss 0.520, train acc 0.791, test acc 0.786\n", "10588.8 examples/sec on [device(type='cuda', index=0), device(type='cuda', index=1)]\n" ] }, { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-08-18T19:45:31.911818\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.7.2, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "lr, num_epochs = 1e-4, 5\n", "trainer = torch.optim.Adam(net.parameters(), lr=lr)\n", "loss = nn.CrossEntropyLoss(reduction='none')\n", "net(next(iter(train_iter))[0])\n", "d2l.train_ch13(net, train_iter, test_iter, loss, trainer, num_epochs, devices)" ] }, { "cell_type": "markdown", "id": "26196099", "metadata": { "origin_pos": 26 }, "source": [ "## Summary\n", "\n", "* We can fine-tune the pretrained BERT model for downstream applications, such as natural language inference on the SNLI dataset.\n", "* During fine-tuning, the BERT model becomes part of the model for the downstream application. Parameters that are only related to pretraining loss will not be updated during fine-tuning. \n", "\n", "\n", "\n", "## Exercises\n", "\n", "1. Fine-tune a much larger pretrained BERT model that is about as big as the original BERT base model if your computational resource allows. Set arguments in the `load_pretrained_model` function as: replacing 'bert.small' with 'bert.base', increasing values of `num_hiddens=256`, `ffn_num_hiddens=512`, `num_heads=4`, and `num_blks=2` to 768, 3072, 12, and 12, respectively. By increasing fine-tuning epochs (and possibly tuning other hyperparameters), can you get a testing accuracy higher than 0.86?\n", "1. How to truncate a pair of sequences according to their ratio of length? Compare this pair truncation method and the one used in the `SNLIBERTDataset` class. What are their pros and cons?\n" ] }, { "cell_type": "markdown", "id": "e4f408d0", "metadata": { "origin_pos": 28, "tab": [ "pytorch" ] }, "source": [ "[Discussions](https://discuss.d2l.ai/t/1526)\n" ] } ], "metadata": { "language_info": { "name": "python" }, "required_libs": [] }, "nbformat": 4, "nbformat_minor": 5 }