{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "g_a9QvUFVCUR" }, "source": [ "

Chapter 11 - Fine-tuning Representation Models for Classification

\n", "Exploring the performance in classification of representation models.\n", "\n", "\n", "\n", "\n", "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/HandsOnLLM/Hands-On-Large-Language-Models/blob/main/chapter11/Chapter%2011%20-%20Fine-Tuning%20BERT.ipynb)\n", "\n", "---\n", "\n", "This notebook is for Chapter 11 of the [Hands-On Large Language Models](https://www.amazon.com/Hands-Large-Language-Models-Understanding/dp/1098150961) book by [Jay Alammar](https://www.linkedin.com/in/jalammar) and [Maarten Grootendorst](https://www.linkedin.com/in/mgrootendorst/).\n", "\n", "---\n", "\n", "\n", "\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### [OPTIONAL] - Installing Packages on \n", "\n", "If you are viewing this notebook on Google Colab (or any other cloud vendor), you need to **uncomment and run** the following codeblock to install the dependencies for this chapter:\n", "\n", "---\n", "\n", "💡 **NOTE**: We will want to use a GPU to run the examples in this notebook. In Google Colab, go to\n", "**Runtime > Change runtime type > Hardware accelerator > GPU > GPU type > T4**.\n", "\n", "---\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# %%capture\n", "# !pip install datasets>=2.18.0 transformers>=4.38.2 sentence-transformers>=2.5.1 setfit>=1.0.3 accelerate>=0.27.2 seqeval>=1.2.2" ] }, { "cell_type": "markdown", "metadata": { "id": "UBeVnXxQWy7-" }, "source": [ "## **Data**" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 365, "referenced_widgets": [ "e3c168a65bd34a64b2c4d13129f9d750", "5ad968464c934a15b19e3924663e5ebf", "31756f672e60432da8aa12b033a2ea20", "63adb81b029744fb82e31b3f1190bb9b", "77511fda45a940d1a6664438b07712b2", "677310bbba174f3a828d7f625a483f92", "bbb27e92d9e848e9ada7ba0c9866f87a", "d6299b347ade47bba93df3eff231be6f", "49541355ed224e748754ed4bd89792d8", "7cc8919e7f9e4db48a876619edff0952", "fdb9b32fdccf4156b537df8e000ec7d2", "0cd7314417ab48f99adf352f0cab863f", "bbb68b7249674351abfa8661b4645e2e", "285d6e493dcb462598190862b186b17e", "29419f20fad94db88b6da3b1ecd04725", "3d58f1665e0949d1baa1feaffd52aa29", "82c561fe82084baaa9eedfb78ec476da", "10165d819bea43f1ad531ee0a314d002", "ed2ca51dff8847648010e7b28f03d41e", "9f216034e74a4f0a9794a812addc82c5", "d3748aec876a4ddbb377f16a43b3c23e", "081b5fcad38b4627b1d553244a1bed0c", "3a4ca5b37d6f426a9d294ade901f6984", "673f06ed6138437680a01409cbce3a13", "625db697c4084a4fa7f78905c5c830b2", "a6c837a7a4fb45bcb035537cf41796e3", "0f53e4afed4a42a5a63d8a64f847c27e", "127978e95632464fb3b7396de66ff240", "a7fa1b4de3b6423783861a6d8213f430", "c185b6d9a8514c71b36918027a424204", "9fc58f20135e48b897cb6a2b8a67863d", "24b94b5fc6fd4aa7ae113d9eeb3a4d6f", "7c386e049403493997772cf4aca4a234", "09a021dba22548f3bd101e237c5ce51d", "36fdfe575de440458c4fe7cd572d2a22", "7ef0c97a8b774e57a041a210c3e23ce6", "d9379ba58d8e4de9b4ba3cc12be87b4c", "ee301aff657e405c8cdfd2e0f5d7d19e", "87d2c368a3d44c1994c8fe31fea59212", "0db4bd167da44ee7b9fc044d9bb5834d", "ebc2ddfef4cd4129997138c615cd15a1", "1b0fa0592da147fcab20f0c17f487542", "0a5befced7554f0a92523b5dca953ed4", "ec515b2a14ff4414baddca0cc672ef14", "2eaded0fbbd74c67b802fbb92973fa3a", "d1e9c32d66004c70b0e3fa96ccb1e4e3", "d27a323324cd45789808c22917fa5d36", "7d827c0ee22b47c4b46fd4f59ac3ec23", "aba270b810d6461887633f62cf35195c", "f43737300b254f2bb3aef90fabf652e2", "ff95c667621d46b89412c0638ce91673", "5f810926256a4a3aa5271c504cdad278", "117e9997a5db42e59e1f73eddd345b25", "c807a7269c2a4f8daf8022a177175213", "c123b5dae69741d6a10b0b783f5e54ac", "5a5d72b8d2484c54977e5ec692a00004", "fd5ff63e2c204836b1c6924633880077", "1c96b599ae6d466e94062f21e7d3a54b", "690d9fa381b74085af8025af1dcffdcf", "95ab40e005a9419aa7827dbcbd8a2415", "aeffa7363e144c0a9b830c7b532167cb", "59cd08daef304e78be3574b4345084b2", "1700973179db4c91a9c00fe5e84de1a9", "5a22333fd16d46429e16f8b1ed9276a1", "34e7d842cf364b9f8b377b17befa20b2", "84454ba3262b48d489215b874bd78842", "91552574464449e48e54a6939f0a5d18", "0f4c3ac40bf84e6293419447a7fbdd15", "c82ed37a975149d8895adc331aae34eb", "96252dca0e9049fea03228e7ceba066b", "25ed22abc1934a3393593d5b10331988", "b883ce44ce9e47d6b5ac70e36e049559", "17afecdfb20b4e49a79aea97d20a10dd", "c6b1edc0dece4354b7a43314c419cb71", "19b3c0979b3f45499faa57f341d6d53a", "ef4d27b4292240928439104f38c610b9", "ce07f53ddea3469282b20581a4f4e4da" ] }, "executionInfo": { "elapsed": 21399, "status": "ok", "timestamp": 1719386835780, "user": { "displayName": "Maarten Grootendorst", "userId": "11015108362723620659" }, "user_tz": -120 }, "id": "5phRS_z2U_3T", "outputId": "439a9900-6d9a-4a53-ea30-4774d3bf4de7" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/usr/local/lib/python3.10/dist-packages/huggingface_hub/utils/_token.py:89: UserWarning: \n", "The secret `HF_TOKEN` does not exist in your Colab secrets.\n", "To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.\n", "You will be able to reuse this secret in all of your notebooks.\n", "Please note that authentication is recommended but still optional to access public models or datasets.\n", " warnings.warn(\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "e3c168a65bd34a64b2c4d13129f9d750", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Downloading readme: 0%| | 0.00/7.46k [00:00\n", " \n", " \n", " [534/534 01:00, Epoch 1/1]\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
StepTraining Loss
5000.424000

" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "TrainOutput(global_step=534, training_loss=0.4183677966228585, metrics={'train_runtime': 61.6658, 'train_samples_per_second': 138.326, 'train_steps_per_second': 8.66, 'total_flos': 227605451772240.0, 'train_loss': 0.4183677966228585, 'epoch': 1.0})" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "trainer.train()" ] }, { "cell_type": "markdown", "metadata": { "id": "xkBUVlUYbUnn" }, "source": [ "Evaluate results." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 173, "referenced_widgets": [ "d656514d509441548af1dd6c7326d18a", "05ce4f6b47354c628c0a41544b9419f5", "d55ba50d84d7412daeb6ccc73be98694", "d5d54eddd21a4999aec6ea6a7782e223", "e8da2d90a93b4ce9a93c2ee93f75d59b", "b94e1b170ccc4ff7abbb9ed87ad2aedd", "3c38333cafb446bab3f3f75824075927", "960641d3e97f4a538a135861ab058bca", "fc4d3609cada4615924d5e92ee1ffdf4", "f914214414b14945a72ac57ac93dcb5e", "8cb0531c591c43f098564ce043eb036f" ] }, "executionInfo": { "elapsed": 3959, "status": "ok", "timestamp": 1719386999026, "user": { "displayName": "Maarten Grootendorst", "userId": "11015108362723620659" }, "user_tz": -120 }, "id": "wCI9uYDObWU8", "outputId": "5c2d54d2-0c81-4a10-d46d-9d7c4c9405d3" }, "outputs": [ { "data": { "text/html": [ "\n", "

\n", " \n", " \n", " [67/67 00:01]\n", "
\n", " " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "d656514d509441548af1dd6c7326d18a", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Downloading builder script: 0%| | 0.00/6.77k [00:00\n", " \n", " \n", " [534/534 00:15, Epoch 1/1]\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
StepTraining Loss
5000.697000

" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "TrainOutput(global_step=534, training_loss=0.6962381677234664, metrics={'train_runtime': 15.234, 'train_samples_per_second': 559.931, 'train_steps_per_second': 35.053, 'total_flos': 227605451772240.0, 'train_loss': 0.6962381677234664, 'epoch': 1.0})" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from transformers import TrainingArguments, Trainer\n", "\n", "# Trainer which executes the training process\n", "trainer = Trainer(\n", " model=model,\n", " args=training_args,\n", " train_dataset=tokenized_train,\n", " eval_dataset=tokenized_test,\n", " tokenizer=tokenizer,\n", " data_collator=data_collator,\n", " compute_metrics=compute_metrics,\n", ")\n", "trainer.train()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 141 }, "executionInfo": { "elapsed": 2623, "status": "ok", "timestamp": 1719387017125, "user": { "displayName": "Maarten Grootendorst", "userId": "11015108362723620659" }, "user_tz": -120 }, "id": "eCPpixB1HCsI", "outputId": "3d77ed38-0565-492b-b488-09eb525fc316" }, "outputs": [ { "data": { "text/html": [ "\n", "

\n", " \n", " \n", " [67/67 00:01]\n", "
\n", " " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "{'eval_loss': 0.6823198795318604,\n", " 'eval_f1': 0.637704918032787,\n", " 'eval_runtime': 2.7203,\n", " 'eval_samples_per_second': 391.865,\n", " 'eval_steps_per_second': 24.629,\n", " 'epoch': 1.0}" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "trainer.evaluate()" ] }, { "cell_type": "markdown", "metadata": { "id": "Uw729mLhIQL6" }, "source": [ "### Freeze blocks 1-5" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "executionInfo": { "elapsed": 640, "status": "ok", "timestamp": 1719387017762, "user": { "displayName": "Maarten Grootendorst", "userId": "11015108362723620659" }, "user_tz": -120 }, "id": "RbsLR561Kje-", "outputId": "ae1b3e9b-443d-4a06-94bb-6792d1751e27" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Parameter: 0bert.embeddings.word_embeddings.weight ----- False\n", "Parameter: 1bert.embeddings.position_embeddings.weight ----- False\n", "Parameter: 2bert.embeddings.token_type_embeddings.weight ----- False\n", "Parameter: 3bert.embeddings.LayerNorm.weight ----- False\n", "Parameter: 4bert.embeddings.LayerNorm.bias ----- False\n", "Parameter: 5bert.encoder.layer.0.attention.self.query.weight ----- False\n", "Parameter: 6bert.encoder.layer.0.attention.self.query.bias ----- False\n", "Parameter: 7bert.encoder.layer.0.attention.self.key.weight ----- False\n", "Parameter: 8bert.encoder.layer.0.attention.self.key.bias ----- False\n", "Parameter: 9bert.encoder.layer.0.attention.self.value.weight ----- False\n", "Parameter: 10bert.encoder.layer.0.attention.self.value.bias ----- False\n", "Parameter: 11bert.encoder.layer.0.attention.output.dense.weight ----- False\n", "Parameter: 12bert.encoder.layer.0.attention.output.dense.bias ----- False\n", "Parameter: 13bert.encoder.layer.0.attention.output.LayerNorm.weight ----- False\n", "Parameter: 14bert.encoder.layer.0.attention.output.LayerNorm.bias ----- False\n", "Parameter: 15bert.encoder.layer.0.intermediate.dense.weight ----- False\n", "Parameter: 16bert.encoder.layer.0.intermediate.dense.bias ----- False\n", "Parameter: 17bert.encoder.layer.0.output.dense.weight ----- False\n", "Parameter: 18bert.encoder.layer.0.output.dense.bias ----- False\n", "Parameter: 19bert.encoder.layer.0.output.LayerNorm.weight ----- False\n", "Parameter: 20bert.encoder.layer.0.output.LayerNorm.bias ----- False\n", "Parameter: 21bert.encoder.layer.1.attention.self.query.weight ----- False\n", "Parameter: 22bert.encoder.layer.1.attention.self.query.bias ----- False\n", "Parameter: 23bert.encoder.layer.1.attention.self.key.weight ----- False\n", "Parameter: 24bert.encoder.layer.1.attention.self.key.bias ----- False\n", "Parameter: 25bert.encoder.layer.1.attention.self.value.weight ----- False\n", "Parameter: 26bert.encoder.layer.1.attention.self.value.bias ----- False\n", "Parameter: 27bert.encoder.layer.1.attention.output.dense.weight ----- False\n", "Parameter: 28bert.encoder.layer.1.attention.output.dense.bias ----- False\n", "Parameter: 29bert.encoder.layer.1.attention.output.LayerNorm.weight ----- False\n", "Parameter: 30bert.encoder.layer.1.attention.output.LayerNorm.bias ----- False\n", "Parameter: 31bert.encoder.layer.1.intermediate.dense.weight ----- False\n", "Parameter: 32bert.encoder.layer.1.intermediate.dense.bias ----- False\n", "Parameter: 33bert.encoder.layer.1.output.dense.weight ----- False\n", "Parameter: 34bert.encoder.layer.1.output.dense.bias ----- False\n", "Parameter: 35bert.encoder.layer.1.output.LayerNorm.weight ----- False\n", "Parameter: 36bert.encoder.layer.1.output.LayerNorm.bias ----- False\n", "Parameter: 37bert.encoder.layer.2.attention.self.query.weight ----- False\n", "Parameter: 38bert.encoder.layer.2.attention.self.query.bias ----- False\n", "Parameter: 39bert.encoder.layer.2.attention.self.key.weight ----- False\n", "Parameter: 40bert.encoder.layer.2.attention.self.key.bias ----- False\n", "Parameter: 41bert.encoder.layer.2.attention.self.value.weight ----- False\n", "Parameter: 42bert.encoder.layer.2.attention.self.value.bias ----- False\n", "Parameter: 43bert.encoder.layer.2.attention.output.dense.weight ----- False\n", "Parameter: 44bert.encoder.layer.2.attention.output.dense.bias ----- False\n", "Parameter: 45bert.encoder.layer.2.attention.output.LayerNorm.weight ----- False\n", "Parameter: 46bert.encoder.layer.2.attention.output.LayerNorm.bias ----- False\n", "Parameter: 47bert.encoder.layer.2.intermediate.dense.weight ----- False\n", "Parameter: 48bert.encoder.layer.2.intermediate.dense.bias ----- False\n", "Parameter: 49bert.encoder.layer.2.output.dense.weight ----- False\n", "Parameter: 50bert.encoder.layer.2.output.dense.bias ----- False\n", "Parameter: 51bert.encoder.layer.2.output.LayerNorm.weight ----- False\n", "Parameter: 52bert.encoder.layer.2.output.LayerNorm.bias ----- False\n", "Parameter: 53bert.encoder.layer.3.attention.self.query.weight ----- False\n", "Parameter: 54bert.encoder.layer.3.attention.self.query.bias ----- False\n", "Parameter: 55bert.encoder.layer.3.attention.self.key.weight ----- False\n", "Parameter: 56bert.encoder.layer.3.attention.self.key.bias ----- False\n", "Parameter: 57bert.encoder.layer.3.attention.self.value.weight ----- False\n", "Parameter: 58bert.encoder.layer.3.attention.self.value.bias ----- False\n", "Parameter: 59bert.encoder.layer.3.attention.output.dense.weight ----- False\n", "Parameter: 60bert.encoder.layer.3.attention.output.dense.bias ----- False\n", "Parameter: 61bert.encoder.layer.3.attention.output.LayerNorm.weight ----- False\n", "Parameter: 62bert.encoder.layer.3.attention.output.LayerNorm.bias ----- False\n", "Parameter: 63bert.encoder.layer.3.intermediate.dense.weight ----- False\n", "Parameter: 64bert.encoder.layer.3.intermediate.dense.bias ----- False\n", "Parameter: 65bert.encoder.layer.3.output.dense.weight ----- False\n", "Parameter: 66bert.encoder.layer.3.output.dense.bias ----- False\n", "Parameter: 67bert.encoder.layer.3.output.LayerNorm.weight ----- False\n", "Parameter: 68bert.encoder.layer.3.output.LayerNorm.bias ----- False\n", "Parameter: 69bert.encoder.layer.4.attention.self.query.weight ----- False\n", "Parameter: 70bert.encoder.layer.4.attention.self.query.bias ----- False\n", "Parameter: 71bert.encoder.layer.4.attention.self.key.weight ----- False\n", "Parameter: 72bert.encoder.layer.4.attention.self.key.bias ----- False\n", "Parameter: 73bert.encoder.layer.4.attention.self.value.weight ----- False\n", "Parameter: 74bert.encoder.layer.4.attention.self.value.bias ----- False\n", "Parameter: 75bert.encoder.layer.4.attention.output.dense.weight ----- False\n", "Parameter: 76bert.encoder.layer.4.attention.output.dense.bias ----- False\n", "Parameter: 77bert.encoder.layer.4.attention.output.LayerNorm.weight ----- False\n", "Parameter: 78bert.encoder.layer.4.attention.output.LayerNorm.bias ----- False\n", "Parameter: 79bert.encoder.layer.4.intermediate.dense.weight ----- False\n", "Parameter: 80bert.encoder.layer.4.intermediate.dense.bias ----- False\n", "Parameter: 81bert.encoder.layer.4.output.dense.weight ----- False\n", "Parameter: 82bert.encoder.layer.4.output.dense.bias ----- False\n", "Parameter: 83bert.encoder.layer.4.output.LayerNorm.weight ----- False\n", "Parameter: 84bert.encoder.layer.4.output.LayerNorm.bias ----- False\n", "Parameter: 85bert.encoder.layer.5.attention.self.query.weight ----- False\n", "Parameter: 86bert.encoder.layer.5.attention.self.query.bias ----- False\n", "Parameter: 87bert.encoder.layer.5.attention.self.key.weight ----- False\n", "Parameter: 88bert.encoder.layer.5.attention.self.key.bias ----- False\n", "Parameter: 89bert.encoder.layer.5.attention.self.value.weight ----- False\n", "Parameter: 90bert.encoder.layer.5.attention.self.value.bias ----- False\n", "Parameter: 91bert.encoder.layer.5.attention.output.dense.weight ----- False\n", "Parameter: 92bert.encoder.layer.5.attention.output.dense.bias ----- False\n", "Parameter: 93bert.encoder.layer.5.attention.output.LayerNorm.weight ----- False\n", "Parameter: 94bert.encoder.layer.5.attention.output.LayerNorm.bias ----- False\n", "Parameter: 95bert.encoder.layer.5.intermediate.dense.weight ----- False\n", "Parameter: 96bert.encoder.layer.5.intermediate.dense.bias ----- False\n", "Parameter: 97bert.encoder.layer.5.output.dense.weight ----- False\n", "Parameter: 98bert.encoder.layer.5.output.dense.bias ----- False\n", "Parameter: 99bert.encoder.layer.5.output.LayerNorm.weight ----- False\n", "Parameter: 100bert.encoder.layer.5.output.LayerNorm.bias ----- False\n", "Parameter: 101bert.encoder.layer.6.attention.self.query.weight ----- False\n", "Parameter: 102bert.encoder.layer.6.attention.self.query.bias ----- False\n", "Parameter: 103bert.encoder.layer.6.attention.self.key.weight ----- False\n", "Parameter: 104bert.encoder.layer.6.attention.self.key.bias ----- False\n", "Parameter: 105bert.encoder.layer.6.attention.self.value.weight ----- False\n", "Parameter: 106bert.encoder.layer.6.attention.self.value.bias ----- False\n", "Parameter: 107bert.encoder.layer.6.attention.output.dense.weight ----- False\n", "Parameter: 108bert.encoder.layer.6.attention.output.dense.bias ----- False\n", "Parameter: 109bert.encoder.layer.6.attention.output.LayerNorm.weight ----- False\n", "Parameter: 110bert.encoder.layer.6.attention.output.LayerNorm.bias ----- False\n", "Parameter: 111bert.encoder.layer.6.intermediate.dense.weight ----- False\n", "Parameter: 112bert.encoder.layer.6.intermediate.dense.bias ----- False\n", "Parameter: 113bert.encoder.layer.6.output.dense.weight ----- False\n", "Parameter: 114bert.encoder.layer.6.output.dense.bias ----- False\n", "Parameter: 115bert.encoder.layer.6.output.LayerNorm.weight ----- False\n", "Parameter: 116bert.encoder.layer.6.output.LayerNorm.bias ----- False\n", "Parameter: 117bert.encoder.layer.7.attention.self.query.weight ----- False\n", "Parameter: 118bert.encoder.layer.7.attention.self.query.bias ----- False\n", "Parameter: 119bert.encoder.layer.7.attention.self.key.weight ----- False\n", "Parameter: 120bert.encoder.layer.7.attention.self.key.bias ----- False\n", "Parameter: 121bert.encoder.layer.7.attention.self.value.weight ----- False\n", "Parameter: 122bert.encoder.layer.7.attention.self.value.bias ----- False\n", "Parameter: 123bert.encoder.layer.7.attention.output.dense.weight ----- False\n", "Parameter: 124bert.encoder.layer.7.attention.output.dense.bias ----- False\n", "Parameter: 125bert.encoder.layer.7.attention.output.LayerNorm.weight ----- False\n", "Parameter: 126bert.encoder.layer.7.attention.output.LayerNorm.bias ----- False\n", "Parameter: 127bert.encoder.layer.7.intermediate.dense.weight ----- False\n", "Parameter: 128bert.encoder.layer.7.intermediate.dense.bias ----- False\n", "Parameter: 129bert.encoder.layer.7.output.dense.weight ----- False\n", "Parameter: 130bert.encoder.layer.7.output.dense.bias ----- False\n", "Parameter: 131bert.encoder.layer.7.output.LayerNorm.weight ----- False\n", "Parameter: 132bert.encoder.layer.7.output.LayerNorm.bias ----- False\n", "Parameter: 133bert.encoder.layer.8.attention.self.query.weight ----- False\n", "Parameter: 134bert.encoder.layer.8.attention.self.query.bias ----- False\n", "Parameter: 135bert.encoder.layer.8.attention.self.key.weight ----- False\n", "Parameter: 136bert.encoder.layer.8.attention.self.key.bias ----- False\n", "Parameter: 137bert.encoder.layer.8.attention.self.value.weight ----- False\n", "Parameter: 138bert.encoder.layer.8.attention.self.value.bias ----- False\n", "Parameter: 139bert.encoder.layer.8.attention.output.dense.weight ----- False\n", "Parameter: 140bert.encoder.layer.8.attention.output.dense.bias ----- False\n", "Parameter: 141bert.encoder.layer.8.attention.output.LayerNorm.weight ----- False\n", "Parameter: 142bert.encoder.layer.8.attention.output.LayerNorm.bias ----- False\n", "Parameter: 143bert.encoder.layer.8.intermediate.dense.weight ----- False\n", "Parameter: 144bert.encoder.layer.8.intermediate.dense.bias ----- False\n", "Parameter: 145bert.encoder.layer.8.output.dense.weight ----- False\n", "Parameter: 146bert.encoder.layer.8.output.dense.bias ----- False\n", "Parameter: 147bert.encoder.layer.8.output.LayerNorm.weight ----- False\n", "Parameter: 148bert.encoder.layer.8.output.LayerNorm.bias ----- False\n", "Parameter: 149bert.encoder.layer.9.attention.self.query.weight ----- False\n", "Parameter: 150bert.encoder.layer.9.attention.self.query.bias ----- False\n", "Parameter: 151bert.encoder.layer.9.attention.self.key.weight ----- False\n", "Parameter: 152bert.encoder.layer.9.attention.self.key.bias ----- False\n", "Parameter: 153bert.encoder.layer.9.attention.self.value.weight ----- False\n", "Parameter: 154bert.encoder.layer.9.attention.self.value.bias ----- False\n", "Parameter: 155bert.encoder.layer.9.attention.output.dense.weight ----- False\n", "Parameter: 156bert.encoder.layer.9.attention.output.dense.bias ----- False\n", "Parameter: 157bert.encoder.layer.9.attention.output.LayerNorm.weight ----- False\n", "Parameter: 158bert.encoder.layer.9.attention.output.LayerNorm.bias ----- False\n", "Parameter: 159bert.encoder.layer.9.intermediate.dense.weight ----- False\n", "Parameter: 160bert.encoder.layer.9.intermediate.dense.bias ----- False\n", "Parameter: 161bert.encoder.layer.9.output.dense.weight ----- False\n", "Parameter: 162bert.encoder.layer.9.output.dense.bias ----- False\n", "Parameter: 163bert.encoder.layer.9.output.LayerNorm.weight ----- False\n", "Parameter: 164bert.encoder.layer.9.output.LayerNorm.bias ----- False\n", "Parameter: 165bert.encoder.layer.10.attention.self.query.weight ----- False\n", "Parameter: 166bert.encoder.layer.10.attention.self.query.bias ----- False\n", "Parameter: 167bert.encoder.layer.10.attention.self.key.weight ----- False\n", "Parameter: 168bert.encoder.layer.10.attention.self.key.bias ----- False\n", "Parameter: 169bert.encoder.layer.10.attention.self.value.weight ----- False\n", "Parameter: 170bert.encoder.layer.10.attention.self.value.bias ----- False\n", "Parameter: 171bert.encoder.layer.10.attention.output.dense.weight ----- False\n", "Parameter: 172bert.encoder.layer.10.attention.output.dense.bias ----- False\n", "Parameter: 173bert.encoder.layer.10.attention.output.LayerNorm.weight ----- False\n", "Parameter: 174bert.encoder.layer.10.attention.output.LayerNorm.bias ----- False\n", "Parameter: 175bert.encoder.layer.10.intermediate.dense.weight ----- False\n", "Parameter: 176bert.encoder.layer.10.intermediate.dense.bias ----- False\n", "Parameter: 177bert.encoder.layer.10.output.dense.weight ----- False\n", "Parameter: 178bert.encoder.layer.10.output.dense.bias ----- False\n", "Parameter: 179bert.encoder.layer.10.output.LayerNorm.weight ----- False\n", "Parameter: 180bert.encoder.layer.10.output.LayerNorm.bias ----- False\n", "Parameter: 181bert.encoder.layer.11.attention.self.query.weight ----- False\n", "Parameter: 182bert.encoder.layer.11.attention.self.query.bias ----- False\n", "Parameter: 183bert.encoder.layer.11.attention.self.key.weight ----- False\n", "Parameter: 184bert.encoder.layer.11.attention.self.key.bias ----- False\n", "Parameter: 185bert.encoder.layer.11.attention.self.value.weight ----- False\n", "Parameter: 186bert.encoder.layer.11.attention.self.value.bias ----- False\n", "Parameter: 187bert.encoder.layer.11.attention.output.dense.weight ----- False\n", "Parameter: 188bert.encoder.layer.11.attention.output.dense.bias ----- False\n", "Parameter: 189bert.encoder.layer.11.attention.output.LayerNorm.weight ----- False\n", "Parameter: 190bert.encoder.layer.11.attention.output.LayerNorm.bias ----- False\n", "Parameter: 191bert.encoder.layer.11.intermediate.dense.weight ----- False\n", "Parameter: 192bert.encoder.layer.11.intermediate.dense.bias ----- False\n", "Parameter: 193bert.encoder.layer.11.output.dense.weight ----- False\n", "Parameter: 194bert.encoder.layer.11.output.dense.bias ----- False\n", "Parameter: 195bert.encoder.layer.11.output.LayerNorm.weight ----- False\n", "Parameter: 196bert.encoder.layer.11.output.LayerNorm.bias ----- False\n", "Parameter: 197bert.pooler.dense.weight ----- False\n", "Parameter: 198bert.pooler.dense.bias ----- False\n", "Parameter: 199classifier.weight ----- True\n", "Parameter: 200classifier.bias ----- True\n" ] } ], "source": [ "# We can check whether the model was correctly updated\n", "for index, (name, param) in enumerate(model.named_parameters()):\n", " print(f\"Parameter: {index}{name} ----- {param.requires_grad}\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 285 }, "executionInfo": { "elapsed": 24489, "status": "ok", "timestamp": 1719387042249, "user": { "displayName": "Maarten Grootendorst", "userId": "11015108362723620659" }, "user_tz": -120 }, "id": "QyleqOHICBjj", "outputId": "f2dc589c-1b8f-40e7-ef93-3ba0f3d3a3a3" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-cased and are newly initialized: ['classifier.bias', 'classifier.weight']\n", "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n" ] }, { "data": { "text/html": [ "\n", "
\n", " \n", " \n", " [534/534 00:21, Epoch 1/1]\n", "
\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
StepTraining Loss
5000.474600

" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n", "

\n", " \n", " \n", " [67/67 00:01]\n", "
\n", " " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "{'eval_loss': 0.4092540740966797,\n", " 'eval_f1': 0.8141086749285034,\n", " 'eval_runtime': 2.7437,\n", " 'eval_samples_per_second': 388.523,\n", " 'eval_steps_per_second': 24.419,\n", " 'epoch': 1.0}" ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Load model\n", "model_id = \"bert-base-cased\"\n", "model = AutoModelForSequenceClassification.from_pretrained(model_id, num_labels=2)\n", "tokenizer = AutoTokenizer.from_pretrained(model_id)\n", "\n", "# Encoder block 10 starts at index 165 and\n", "# we freeze everything before that block\n", "for index, (name, param) in enumerate(model.named_parameters()):\n", " if index < 165:\n", " param.requires_grad = False\n", "\n", "# Trainer which executes the training process\n", "trainer = Trainer(\n", " model=model,\n", " args=training_args,\n", " train_dataset=tokenized_train,\n", " eval_dataset=tokenized_test,\n", " tokenizer=tokenizer,\n", " data_collator=data_collator,\n", " compute_metrics=compute_metrics,\n", ")\n", "trainer.train()\n", "trainer.evaluate()" ] }, { "cell_type": "markdown", "metadata": { "id": "HJRMWsLdA913" }, "source": [ "### [BONUS] Freeze blocks" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "ADvKJjNaFAot" }, "outputs": [], "source": [ "# scores = []\n", "# for index in range(12):\n", "# # Re-load model\n", "# model = AutoModelForSequenceClassification.from_pretrained(\"bert-base-cased\", num_labels=2)\n", "# tokenizer = AutoTokenizer.from_pretrained(\"bert-base-cased\")\n", "\n", "# # Freeze encoder blocks 0-index\n", "# for name, param in model.named_parameters():\n", "# if \"layer\" in name:\n", "# layer_nr = int(name.split(\"layer\")[1].split(\".\")[1])\n", "# if layer_nr <= index:\n", "# param.requires_grad = False\n", "# else:\n", "# param.requires_grad = True\n", "\n", "# # Train\n", "# trainer = Trainer(\n", "# model=model,\n", "# args=training_args,\n", "# train_dataset=tokenized_train,\n", "# eval_dataset=tokenized_test,\n", "# tokenizer=tokenizer,\n", "# data_collator=data_collator,\n", "# compute_metrics=compute_metrics,\n", "# )\n", "# trainer.train()\n", "\n", "# # Evaluate\n", "# score = trainer.evaluate()[\"eval_f1\"]\n", "# scores.append(score)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "executionInfo": { "elapsed": 212, "status": "ok", "timestamp": 1712321357732, "user": { "displayName": "Maarten Grootendorst", "userId": "11015108362723620659" }, "user_tz": -120 }, "id": "dWYlHFNdLQtk", "outputId": "a28d645b-7633-4bdf-cc51-b7a937d36508" }, "outputs": [ { "data": { "text/plain": [ "[0.8541862652869239,\n", " 0.8525519848771267,\n", " 0.8514664143803217,\n", " 0.8506616257088847,\n", " 0.8398104265402844,\n", " 0.8391345249294448,\n", " 0.8377358490566037,\n", " 0.8433962264150944,\n", " 0.8258801141769743,\n", " 0.816247582205029,\n", " 0.7917485265225934,\n", " 0.7019400352733686]" ] }, "execution_count": 18, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# scores" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 410 }, "executionInfo": { "elapsed": 1383, "status": "ok", "timestamp": 1712388601684, "user": { "displayName": "Maarten Grootendorst", "userId": "11015108362723620659" }, "user_tz": -120 }, "id": "bf3PIvKhOBJ-", "outputId": "aadf0ff7-e4ab-4d62-87b5-cbf1ecdea8da" }, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# import matplotlib.pyplot as plt\n", "# import numpy as np\n", "\n", "# # Create Figure\n", "# plt.figure(figsize=(8,4))\n", "\n", "# # Prepare Data\n", "# x = [f\"0-{index}\" for index in range(12)]\n", "# x[0] = \"None\"\n", "# x[-1] = \"All\"\n", "# y = [\n", "# 0.8541862652869239,\n", "# 0.8525519848771267,\n", "# 0.8514664143803217,\n", "# 0.8506616257088847,\n", "# 0.8398104265402844,\n", "# 0.8391345249294448,\n", "# 0.8377358490566037,\n", "# 0.8433962264150944,\n", "# 0.8258801141769743,\n", "# 0.816247582205029,\n", "# 0.7917485265225934,\n", "# 0.7019400352733686\n", "# ][::-1]\n", "\n", "# # Stylize Figure\n", "# plt.grid(color='#ECEFF1')\n", "# plt.axvline(x=4, color=\"#EC407A\", linestyle=\"--\")\n", "# plt.title(\"Effect of Frozen Encoder Blocks on Training Performance\")\n", "# plt.ylabel(\"F1-score\")\n", "# plt.xlabel(\"Trainable encoder blocks\")\n", "\n", "# # Plot Data\n", "# plt.plot(x, y, color=\"black\")\n", "\n", "# # Additional Annotation\n", "# plt.annotate(\n", "# 'Performance stabilizing',\n", "# xy=(4, y[4]),\n", "# xytext=(4.5, y[4]-.05),\n", "# arrowprops=dict(\n", "# arrowstyle=\"-|>\",\n", "# connectionstyle=\"arc3\",\n", "# color=\"#00ACC1\")\n", "# )\n", "# plt.savefig(\"multiple_frozen_blocks.png\", dpi=300, bbox_inches='tight')" ] }, { "cell_type": "markdown", "metadata": { "id": "sf785lzMjwiy" }, "source": [ "## Few-shot Classification" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "8ybeQ3j6kOk4" }, "outputs": [], "source": [ "from setfit import sample_dataset\n", "\n", "# We simulate a few-shot setting by sampling 16 examples per class\n", "sampled_train_data = sample_dataset(tomatoes[\"train\"], num_samples=16)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "executionInfo": { "elapsed": 3418, "status": "ok", "timestamp": 1719390247822, "user": { "displayName": "Maarten Grootendorst", "userId": "11015108362723620659" }, "user_tz": -120 }, "id": "1Y55TDrmSqHm", "outputId": "84ec7a9c-92ed-4277-dfff-f0cb4cc918d6" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/usr/local/lib/python3.10/dist-packages/ipykernel/ipkernel.py:283: DeprecationWarning: `should_run_async` will not call `transform_cell` automatically in the future. Please pass the result to `transformed_cell` argument and any exception that happen during thetransform in `preprocessing_exc_tuple` in IPython 7.17 and above.\n", " and should_run_async(code)\n", "/usr/local/lib/python3.10/dist-packages/huggingface_hub/file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n", " warnings.warn(\n", "model_head.pkl not found on HuggingFace Hub, initialising classification head with random weights. You should TRAIN this model on a downstream task to use it for predictions and inference.\n" ] } ], "source": [ "from setfit import SetFitModel\n", "\n", "# Load a pre-trained SentenceTransformer model\n", "model = SetFitModel.from_pretrained(\"sentence-transformers/all-mpnet-base-v2\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 49, "referenced_widgets": [ "630de7830edb4104938bef1c304129c0", "736c3004e2ec48b2b3a8974e465eb838", "54c6443169d84a3c9983d8c5b125c3e7", "bbee47201aac4b8f96aa772cc751d1d8", "9ad6e36197894003809d3ae6fa1b6a2c", "4708f67e67194377a94ec08aa07fe688", "c03a1bf3d2214c3680ab0b983dd582ef", "e777041b66fa484090714094d9516f35", "c5d50176a0a049769731d1c6fdb37f0f", "aa6d83c964d741178534cd0fae6ccec9", "318f91cb3cdd4f6e88681121de392207" ] }, "executionInfo": { "elapsed": 2, "status": "ok", "timestamp": 1719390248689, "user": { "displayName": "Maarten Grootendorst", "userId": "11015108362723620659" }, "user_tz": -120 }, "id": "zZ10SpAXdNvC", "outputId": "f9930b86-e077-4233-f7d7-16c5ec1e640d" }, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "630de7830edb4104938bef1c304129c0", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Map: 0%| | 0/32 [00:00\n", " \n", " \n", " [240/240 00:37, Epoch 3/0]\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
StepTraining Loss

" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# Training loop\n", "trainer.train()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "executionInfo": { "elapsed": 2455, "status": "ok", "timestamp": 1719390291303, "user": { "displayName": "Maarten Grootendorst", "userId": "11015108362723620659" }, "user_tz": -120 }, "id": "PyRxiY32R3Jd", "outputId": "51d00a55-8086-4456-b3c2-d8720342bcc8" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "***** Running evaluation *****\n" ] }, { "data": { "text/plain": [ "{'f1': 0.8363988383349468}" ] }, "execution_count": 75, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Evaluate the model on our test data\n", "trainer.evaluate()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 129 }, "executionInfo": { "elapsed": 10, "status": "ok", "timestamp": 1719390291304, "user": { "displayName": "Maarten Grootendorst", "userId": "11015108362723620659" }, "user_tz": -120 }, "id": "-aKIHJpCQdAm", "outputId": "e2f786a4-78a4-4b92-d464-e3336d92edb2" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/usr/local/lib/python3.10/dist-packages/ipykernel/ipkernel.py:283: DeprecationWarning: `should_run_async` will not call `transform_cell` automatically in the future. Please pass the result to `transformed_cell` argument and any exception that happen during thetransform in `preprocessing_exc_tuple` in IPython 7.17 and above.\n", " and should_run_async(code)\n" ] }, { "data": { "text/html": [ "

LogisticRegression()
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
" ], "text/plain": [ "LogisticRegression()" ] }, "execution_count": 76, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model.model_head" ] }, { "cell_type": "markdown", "metadata": { "id": "L7NbUeSn-QSe" }, "source": [ "## MLM" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "executionInfo": { "elapsed": 1601, "status": "ok", "timestamp": 1719387317479, "user": { "displayName": "Maarten Grootendorst", "userId": "11015108362723620659" }, "user_tz": -120 }, "id": "Z35PD47AXXnv", "outputId": "45667c3b-9421-4406-e6e1-a406d803b8fd" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Some weights of the model checkpoint at bert-base-cased were not used when initializing BertForMaskedLM: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight']\n", "- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", "- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n" ] } ], "source": [ "from transformers import AutoTokenizer, AutoModelForMaskedLM\n", "\n", "# Load model for Masked Language Modeling (MLM)\n", "model = AutoModelForMaskedLM.from_pretrained(\"bert-base-cased\")\n", "tokenizer = AutoTokenizer.from_pretrained(\"bert-base-cased\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 81, "referenced_widgets": [ "5333a40e5cbc4bb5b89628ecaa6608b2", "c7a67733b49e48df88a34329bd090349", "40f95e431f264d3eb034f09f3f2e9e09", "827d0cdd85c14d13b78da5c0d7f054c7", "e8403bc0d63b4ec0aa679505a962dc34", "bb38873892e847c6b646c9e3cc4c753d", "496cfba4c0d04e399c6fde17772fa62c", "e21c06a4d4ac4f5987b8e4f34c4b9ac8", "c39ac7f5c60e4bf98e64174f21959d9a", "1caf892f43c94d50bb31f12b830462c6", "526eb0a0a0d246578aa2fb726c7385e4", "83160c120ff64c42af6d38bb3fb932b5", "aa3c27c5a1254b19a7cfd04f28aee5d0", "c350f4d70d5646fc973aa9e68711f830", "1ca86de4b7174993942d43b0222f698f", "cac994c5cc274c14ba2163c73fa4ef8c", "4a1c759888cf472aab34c96be48e95b1", "76aeb09a52f04f97aa006927b60bad84", "60f082d7f4234bfa985c9241a7e1e9c0", "8d84f304aca74bd99d7d624b29630495", "71ea5effe952447fa40b578e64e1b868", "933e2bdff863421abba4ba875a817dc0" ] }, "executionInfo": { "elapsed": 3, "status": "ok", "timestamp": 1719387318933, "user": { "displayName": "Maarten Grootendorst", "userId": "11015108362723620659" }, "user_tz": -120 }, "id": "zgLardIvEFTG", "outputId": "f40b1e95-9a82-437e-ffd8-5e0fdb4be138" }, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "5333a40e5cbc4bb5b89628ecaa6608b2", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Map: 0%| | 0/8530 [00:00\n", " \n", " \n", " [5340/5340 12:10, Epoch 10/10]\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
StepTraining Loss
5002.601700
10002.377500
15002.313100
20002.187500
25002.150400
30002.096100
35002.059500
40001.990300
45001.986100
50001.958500

" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# Save pre-trained tokenizer\n", "tokenizer.save_pretrained(\"mlm\")\n", "\n", "# Train model\n", "trainer.train()\n", "\n", "# Save updated model\n", "model.save_pretrained(\"mlm\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "executionInfo": { "elapsed": 1511, "status": "ok", "timestamp": 1719388054975, "user": { "displayName": "Maarten Grootendorst", "userId": "11015108362723620659" }, "user_tz": -120 }, "id": "HfxN1p8TOg2v", "outputId": "e9bcca59-cebb-4297-e675-b3d2a555d36e" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/usr/local/lib/python3.10/dist-packages/ipykernel/ipkernel.py:283: DeprecationWarning: `should_run_async` will not call `transform_cell` automatically in the future. Please pass the result to `transformed_cell` argument and any exception that happen during thetransform in `preprocessing_exc_tuple` in IPython 7.17 and above.\n", " and should_run_async(code)\n", "Some weights of the model checkpoint at bert-base-cased were not used when initializing BertForMaskedLM: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight']\n", "- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", "- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ ">>> What a horrible idea!\n", ">>> What a horrible dream!\n", ">>> What a horrible thing!\n", ">>> What a horrible day!\n", ">>> What a horrible thought!\n" ] } ], "source": [ "from transformers import pipeline\n", "\n", "# Load and create predictions\n", "mask_filler = pipeline(\"fill-mask\", model=\"bert-base-cased\")\n", "preds = mask_filler(\"What a horrible [MASK]!\")\n", "\n", "# Print results\n", "for pred in preds:\n", " print(f\">>> {pred['sequence']}\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "executionInfo": { "elapsed": 3, "status": "ok", "timestamp": 1719388054975, "user": { "displayName": "Maarten Grootendorst", "userId": "11015108362723620659" }, "user_tz": -120 }, "id": "ogk1hJ4zOlAU", "outputId": "316b16d0-065b-41e8-c35c-3ce3be481469" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ ">>> What a horrible movie!\n", ">>> What a horrible film!\n", ">>> What a horrible mess!\n", ">>> What a horrible comedy!\n", ">>> What a horrible story!\n" ] } ], "source": [ "# Load and create predictions\n", "mask_filler = pipeline(\"fill-mask\", model=\"mlm\")\n", "preds = mask_filler(\"What a horrible [MASK]!\")\n", "\n", "# Print results\n", "for pred in preds:\n", " print(f\">>> {pred['sequence']}\")" ] }, { "cell_type": "markdown", "metadata": { "id": "1sDqoG2NyeJO" }, "source": [ "## Named Entity Recognition\n", "\n", "Here are a number of interesting datasets you can also explore for NER:\n", "* tner/mit_movie_trivia\n", "* tner/mit_restaurant\n", "* wnut_17\n", "* conll2003" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "KGDvXU-ZzJ3J" }, "outputs": [], "source": [ "from transformers import AutoModelForTokenClassification, AutoTokenizer\n", "from transformers import DataCollatorWithPadding\n", "from transformers import TrainingArguments, Trainer\n", "import numpy as np" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 298, "referenced_widgets": [ "e8567cb28aff4974acf098b8fc4184ea", "a0960ca186e9491fbff044109a197da2", "1a56e5d0d26e49ad9bd6a919515de403", "2fe1e77e27014f62964c94b62b66f049", "2df3bbada30f48068cf8bd414f40d895", "4457b446302442e4bcab06fc10eb15ae", "23db3d47ec284ebfa6da24d9a58d0881", "a2795e523b5a436d9f94b059e7252597", "027f9ea43b6647ab91915020a2e25771", "15c3f91714a74762989e35950d6c12a5", "19b10fadedbd48288026ff8b21148c8f", "e4c076fd3de24dba84f4d566d4fca48a", "79c5f352cf7f4bf7ace576348e3c55d6", "690b323dd52e48b4b130e63d377b4831", "dbb5c772fc674bedaee90064d8e4163b", "4fe4a9b1182e4383b9ffa8b0b0c0f321", "6d12abf106954c2f9008c8505ed6fce3", "ff5ebeb9078d4f8b975a8aa59977fae3", "befc22d4789e49eb81793e4ce1210f41", "18c3194ae3b8497fb325a424dfb95901", "848c37b4a0224b4e994896ab75d695e4", "bc720c483d74483790aa184e927bd742", "7c9beb946034455e8f4e68cf9d3ed451", "18f06f815052464aa05c788246507c1b", "b58c221705ed490f801fe1cea0a50b60", "c51e7b34c1c049c99753132ddf06a8bf", "ca05b822e80640728b7b98d75fc0a569", "e70e0557dffe4ac4883151f464d88fb3", "caab4f8ab71343dc9d949a4b387ff97d", "114e90e1dc0b49149e94d9352a389352", "95784022cda24c588e9986f5267e3922", "1309c3d2eb8d4bd7b8aa8d7176b225d3", "67d416d3edb747598eefad5de3969330", "ec5ff2eceeaf425bb25a1580b1c86f35", "322b02e7719b4795a0afcbbe99074678", "376dbc09d30742cf9ea209d51b10bb46", "453b539e6093481dabbd4bacc7b746e6", "bb87ce6727e349df960b276b31477d42", "ddf0589272894ee8bc63371aebae6c47", "ce8a5082960345b28883dc1128550675", "69a2c842c8e04a00bccfd9534ead0a3d", "e4b0a87f816c4762b01e6952a4fb91f3", "4956e10177f74c30a4dcec1aee5d84d1", "930cc9ef722e4ad09fc5aec2bdb58f51", "d133524c37da40a688ae41d0be4fb608", "5068f7a72b444b368e3064b523d37316", "20fc8f2be16b48a3a330cb07558c4a7d", "b6871ea8d9ee464ca3d68d62c0264f65", "2ba02b04fbee4db093abeffdc880ce1b", "f4bce89187a74fda9f6b8412a886be56", "2118f618461a4f45b864b8ca51561ecd", "ca9ae83245b3466d961f8741b5b5ca22", "c9aa5057f4c64210a50738ab0ae8d92f", "bbc7a519dee04caca31b77c03b599cc1", "5f33680650354398bd53b210141050ad", "ed3c3a380c3a46e9b6eabe8e0df7f7a1", "a1cfaa8574e34145a95b170954a6af0f", "c94da6e530da46b284bee469474890cf", "9e0395ec942f4b42b7983147701806f8", "4730ed0d0d9a43d0b8b6a46f3351a91e", "6eaad3c842a14b5b86fa908a18d86f20", "9437374ca4e44034a7852cc05a345c18", "31d743d1493746cdae66dff187dee185", "4a2604352c6946e5b6ce02da0e82f00c", "243f45f71e3a4385ac447112e2778915", "20c7a8103eb44b61b2f221cc5ab5527b" ] }, "executionInfo": { "elapsed": 19287, "status": "ok", "timestamp": 1719388740030, "user": { "displayName": "Maarten Grootendorst", "userId": "11015108362723620659" }, "user_tz": -120 }, "id": "YOQlMioIGY33", "outputId": "3f5e52d8-e736-411c-e468-d63b48828109" }, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "e8567cb28aff4974acf098b8fc4184ea", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Downloading builder script: 0%| | 0.00/9.57k [00:00\n", " \n", " \n", " [878/878 02:49, Epoch 1/1]\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
StepTraining Loss
5000.047500

" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "TrainOutput(global_step=878, training_loss=0.04094860494001037, metrics={'train_runtime': 169.4752, 'train_samples_per_second': 82.85, 'train_steps_per_second': 5.181, 'total_flos': 351240792638148.0, 'train_loss': 0.04094860494001037, 'epoch': 1.0})" ] }, "execution_count": 54, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Training arguments for parameter tuning\n", "training_args = TrainingArguments(\n", " \"model\",\n", " learning_rate=2e-5,\n", " per_device_train_batch_size=16,\n", " per_device_eval_batch_size=16,\n", " num_train_epochs=1,\n", " weight_decay=0.01,\n", " save_strategy=\"epoch\",\n", " report_to=\"none\"\n", ")\n", "\n", "# Initialize Trainer\n", "trainer = Trainer(\n", " model=model,\n", " args=training_args,\n", " train_dataset=tokenized[\"train\"],\n", " eval_dataset=tokenized[\"test\"],\n", " tokenizer=tokenizer,\n", " data_collator=data_collator,\n", " compute_metrics=compute_metrics,\n", ")\n", "trainer.train()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 141 }, "executionInfo": { "elapsed": 14521, "status": "ok", "timestamp": 1712753153056, "user": { "displayName": "Maarten Grootendorst", "userId": "11015108362723620659" }, "user_tz": -120 }, "id": "ds5osPr9T0pq", "outputId": "28e6836f-5eae-43c2-a37f-b92acb9a42b6" }, "outputs": [ { "data": { "text/html": [ "\n", "

\n", " \n", " \n", " [216/216 00:09]\n", "
\n", " " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "{'eval_loss': 0.16888542473316193,\n", " 'eval_f1': 0.9180087380808113,\n", " 'eval_runtime': 14.5731,\n", " 'eval_samples_per_second': 236.943,\n", " 'eval_steps_per_second': 14.822,\n", " 'epoch': 1.0}" ] }, "execution_count": 55, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Evaluate the model on our test data\n", "trainer.evaluate()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "executionInfo": { "elapsed": 2674, "status": "ok", "timestamp": 1712753204334, "user": { "displayName": "Maarten Grootendorst", "userId": "11015108362723620659" }, "user_tz": -120 }, "id": "Q0PAXyzT-N45", "outputId": "d29825c0-2be2-48dd-c685-fad114afb369" }, "outputs": [ { "data": { "text/plain": [ "[{'entity': 'B-PER',\n", " 'score': 0.99534035,\n", " 'index': 4,\n", " 'word': 'Ma',\n", " 'start': 11,\n", " 'end': 13},\n", " {'entity': 'I-PER',\n", " 'score': 0.9928328,\n", " 'index': 5,\n", " 'word': '##arte',\n", " 'start': 13,\n", " 'end': 17},\n", " {'entity': 'I-PER',\n", " 'score': 0.9954301,\n", " 'index': 6,\n", " 'word': '##n',\n", " 'start': 17,\n", " 'end': 18}]" ] }, "execution_count": 58, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from transformers import pipeline\n", "\n", "# Save our fine-tuned model\n", "trainer.save_model(\"ner_model\")\n", "\n", "# Run inference on the fine-tuned model\n", "token_classifier = pipeline(\n", " \"token-classification\",\n", " model=\"ner_model\",\n", ")\n", "token_classifier(\"My name is Maarten.\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "MAfBhqVwC61e" }, "outputs": [], "source": [] } ], "metadata": { "accelerator": "GPU", "colab": { "authorship_tag": "ABX9TyOzKODTW4KPEpkvmaAMgqlD", "gpuType": "T4", "provenance": [], "toc_visible": true }, "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.14" } }, "nbformat": 4, "nbformat_minor": 4 }