{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "-ETtu9CvVMDR" }, "source": [ "

Chapter 10 - Creating Text Embedding Models

\n", "Exploring methods for both training and fine-tuning embedding models.\n", "\n", "\n", "\n", "\n", "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1mvUcwC3XSYn-V4lypqf3ebxxNMeM9Gp6?usp=sharing)\n", "\n", "---\n", "\n", "This notebook is for Chapter 10 of the [Hands-On Large Language Models](https://www.amazon.com/Hands-Large-Language-Models-Understanding/dp/1098150961) book by [Jay Alammar](https://www.linkedin.com/in/jalammar) and [Maarten Grootendorst](https://www.linkedin.com/in/mgrootendorst/).\n", "\n", "---\n", "\n", "\n", "\n" ] }, { "cell_type": "markdown", "metadata": { "id": "2UrKluX5YNmu" }, "source": [ "# Creating an Embedding Model" ] }, { "cell_type": "markdown", "metadata": { "id": "ywsyZzm5VSER" }, "source": [ "## **Data**" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "executionInfo": { "elapsed": 6529, "status": "ok", "timestamp": 1717342944433, "user": { "displayName": "Maarten Grootendorst", "userId": "11015108362723620659" }, "user_tz": -120 }, "id": "Ahk0SJDKVy6F", "outputId": "497309ee-333a-4a6c-f008-dd6262a7a52f" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/usr/local/lib/python3.10/dist-packages/huggingface_hub/utils/_token.py:89: UserWarning: \n", "The secret `HF_TOKEN` does not exist in your Colab secrets.\n", "To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.\n", "You will be able to reuse this secret in all of your notebooks.\n", "Please note that authentication is recommended but still optional to access public models or datasets.\n", " warnings.warn(\n" ] } ], "source": [ "from datasets import load_dataset\n", "\n", "# Load MNLI dataset from GLUE\n", "# 0 = entailment, 1 = neutral, 2 = contradiction\n", "train_dataset = load_dataset(\"glue\", \"mnli\", split=\"train\").select(range(50_000))\n", "train_dataset = train_dataset.remove_columns(\"idx\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "executionInfo": { "elapsed": 3, "status": "ok", "timestamp": 1717342944434, "user": { "displayName": "Maarten Grootendorst", "userId": "11015108362723620659" }, "user_tz": -120 }, "id": "t-BHO4-qwMDO", "outputId": "f6671b92-7319-48bb-96a3-c848b45dee33" }, "outputs": [ { "data": { "text/plain": [ "{'premise': 'One of our number will carry out your instructions minutely.',\n", " 'hypothesis': 'A member of my team will execute your orders with immense precision.',\n", " 'label': 0}" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "train_dataset[2]" ] }, { "cell_type": "markdown", "metadata": { "id": "5wO23cXLXeFU" }, "source": [ "## **Model**" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "executionInfo": { "elapsed": 19919, "status": "ok", "timestamp": 1717342964351, "user": { "displayName": "Maarten Grootendorst", "userId": "11015108362723620659" }, "user_tz": -120 }, "id": "C4qLaPR6nrqC", "outputId": "76fa9f0a-9c99-4e67-be82-70f9c41ba1b8" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "WARNING:sentence_transformers.SentenceTransformer:No sentence-transformers model found with name bert-base-uncased. Creating a new one with mean pooling.\n", "/usr/local/lib/python3.10/dist-packages/huggingface_hub/file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n", " warnings.warn(\n" ] } ], "source": [ "from sentence_transformers import SentenceTransformer\n", "\n", "# Use a base model\n", "embedding_model = SentenceTransformer('bert-base-uncased')" ] }, { "cell_type": "markdown", "metadata": { "id": "pAiL21AuYKVI" }, "source": [ "## **Loss Function**" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "OgmtKckBXiK9" }, "outputs": [], "source": [ "from sentence_transformers import losses\n", "\n", "# Define the loss function. In soft-max loss, we will also need to explicitly set the number of labels.\n", "train_loss = losses.SoftmaxLoss(\n", " model=embedding_model,\n", " sentence_embedding_dimension=embedding_model.get_sentence_embedding_dimension(),\n", " num_labels=3\n", ")" ] }, { "cell_type": "markdown", "metadata": { "id": "tH0efspwlOX2" }, "source": [ "## Evaluation" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "f8ZsoY0AretV" }, "outputs": [], "source": [ "from sentence_transformers.evaluation import EmbeddingSimilarityEvaluator\n", "\n", "# Create an embedding similarity evaluator for stsb\n", "val_sts = load_dataset('glue', 'stsb', split='validation')\n", "evaluator = EmbeddingSimilarityEvaluator(\n", " sentences1=val_sts[\"sentence1\"],\n", " sentences2=val_sts[\"sentence2\"],\n", " scores=[score/5 for score in val_sts[\"label\"]],\n", " main_similarity=\"cosine\",\n", ")" ] }, { "cell_type": "markdown", "metadata": { "id": "umikSmoYIP07" }, "source": [ "## **Training**" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "8uAAhNs0ocoV" }, "outputs": [], "source": [ "from sentence_transformers.training_args import SentenceTransformerTrainingArguments\n", "\n", "# Define the training arguments\n", "args = SentenceTransformerTrainingArguments(\n", " output_dir=\"base_embedding_model\",\n", " num_train_epochs=1,\n", " per_device_train_batch_size=32,\n", " per_device_eval_batch_size=32,\n", " warmup_steps=100,\n", " fp16=True,\n", " eval_steps=100,\n", " logging_steps=100,\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 618, "referenced_widgets": [ "459f20375cf24b51bd89cf541e2b63fa", "c27d0f3225924807bc9f0dc18b926424", "739d1e36723d4ee09430c2bb98fd7766", "3f2b25e57f3a4bf69b18ba02d8128092", "56245fdec01446f38b956da92a296306", "5d471644fab242568d37184e9bbcb086", "ed4c445ced6a4ad3bd5b8e5a56ea23fb", "d26a9fc85ca849b99fbead500a295f4a", "e11b9173c7e34062be45e43737aa9eb5", "049ddce06bfc4f1e8c81d412814a4e74", "b4eb0759b0874712a9736da2c8203689", "3bad6bb0058f45e1aa59fc4fa44e5a6f", "f0756c62e13741f9b574e6a30c5e9aae", "61131cb0ea4e4d83a9a2debc6557435d", "bada4e4dedfa46d09cf544dc2647801a", "ca673a5e949f4eb79ce87c8b8aac954c", "1429dde1f00e4c74a6680739ed6355ee", "1428f068da84478faa6ad723744f78ec", "f347198f211849bd8cacef1b9094109d", "840dd9e2a63147ea8ce9c1a5423a2451", "7b8f7b0d76e9419a8677510d65f7dc1b", "b7a8360bf9204c5a9b7c0b87555cbdd3", "5446d5dc6ce047eca1b11b91bae71c81", "7d2fa3fb9bb143f3b5f31f74276651dd", "a11c52e4d7dd4c39995561c495ae5fa4", "98e9f40ea3274ef5bd4593e2cee97242", "cb0646d4e66246669a31359ca3e6476a", "3334b517bae2434588cf1fac55a57182", "c2affd1e287640a3967ef68344e86247", "406f19fb77b1492691851a0222bf5a28", "80ff5c0e8c3148fdaa1a0a9be16f812a", "e426aaf2f15d46b88d9ffb7855744045", "ee32bcff9df042a4a9cff5437ef61a46" ] }, "executionInfo": { "elapsed": 374122, "status": "ok", "timestamp": 1717343342445, "user": { "displayName": "Maarten Grootendorst", "userId": "11015108362723620659" }, "user_tz": -120 }, "id": "JKA_L39FpAoM", "outputId": "8091a7c2-585f-4668-9f83-3ca1a06acd79" }, "outputs": [ { "data": { "text/html": [ "\n", "
\n", " \n", " \n", " [1563/1563 06:10, Epoch 1/1]\n", "
\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
StepTraining Loss
1001.080700
2000.959400
3000.916200
4000.870200
5000.849100
6000.854200
7000.835200
8000.825200
9000.818100
10000.800300
11000.781600
12000.777100
13000.786600
14000.767900
15000.797100

" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "459f20375cf24b51bd89cf541e2b63fa", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Computing widget examples: 0%| | 0/5 [00:00───────────────────────────────────────────────── Selected tasks ─────────────────────────────────────────────────\n", "\n" ], "text/plain": [ "\u001b[38;5;235m───────────────────────────────────────────────── \u001b[0m\u001b[1mSelected tasks \u001b[0m\u001b[38;5;235m ─────────────────────────────────────────────────\u001b[0m\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "

Classification\n",
       "
\n" ], "text/plain": [ "\u001b[1mClassification\u001b[0m\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
    - Banking77Classification, s2s\n",
       "
\n" ], "text/plain": [ " - Banking77Classification, \u001b[3;38;5;241ms2s\u001b[0m\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
\n",
       "\n",
       "
\n" ], "text/plain": [ "\n", "\n" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stderr", "output_type": "stream", "text": [ "/usr/local/lib/python3.10/dist-packages/joblib/externals/loky/backend/fork_exec.py:38: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n", " pid = os.fork()\n" ] }, { "data": { "text/plain": [ "{'Banking77Classification': {'mteb_version': '1.1.2',\n", " 'dataset_revision': '0fd18e25b25c072e09e0d92ab615fda904d66300',\n", " 'mteb_dataset_name': 'Banking77Classification',\n", " 'test': {'accuracy': 0.46022727272727276,\n", " 'f1': 0.45802738001849663,\n", " 'accuracy_stderr': 0.009556987908238961,\n", " 'f1_stderr': 0.01072225943077292,\n", " 'main_score': 0.46022727272727276,\n", " 'evaluation_time': 29.63}}}" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from mteb import MTEB\n", "\n", "# Choose evaluation task\n", "evaluation = MTEB(tasks=[\"Banking77Classification\"])\n", "\n", "# Calculate results\n", "results = evaluation.run(embedding_model)\n", "results" ] }, { "cell_type": "markdown", "metadata": { "id": "56V2ma89uJwN" }, "source": [ "⚠️ **VRAM Clean-up** - You will need to run the code below to partially empty the VRAM (GPU RAM). If that does not work, it is advised to restart the notebook instead. You can check the resources on the right-hand side (if you are using Google Colab) to check whether the used VRAM is indeed low. You can also run `!nivia-smi` to check current usage." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "c3LX1G0_4QCv" }, "outputs": [], "source": [ "# # Empty and delete trainer/model\n", "# trainer.accelerator.clear()\n", "# del trainer, embedding_model\n", "\n", "# # Garbage collection and empty cache\n", "# import gc\n", "# import torch\n", "\n", "# gc.collect()\n", "# torch.cuda.empty_cache()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "6d0GcY8cnNs4" }, "outputs": [], "source": [ "import gc\n", "import torch\n", "\n", "gc.collect()\n", "torch.cuda.empty_cache()" ] }, { "cell_type": "markdown", "metadata": { "id": "jYnRRSDN06eB" }, "source": [ "# Loss Fuctions" ] }, { "cell_type": "markdown", "metadata": { "id": "vuSCWbFO7RRM" }, "source": [ "⚠️ **VRAM Clean-up**\n", "* `Restart` the notebook in order to clean-up memory if you move on to the next training example." ] }, { "cell_type": "markdown", "metadata": { "id": "Tq8Yb6IB2LFI" }, "source": [ "## Cosine Similarity Loss" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "qEmnjQQPuszQ" }, "outputs": [], "source": [ "from datasets import Dataset, load_dataset\n", "\n", "# Load MNLI dataset from GLUE\n", "# 0 = entailment, 1 = neutral, 2 = contradiction\n", "train_dataset = load_dataset(\"glue\", \"mnli\", split=\"train\").select(range(50_000))\n", "train_dataset = train_dataset.remove_columns(\"idx\")\n", "\n", "# (neutral/contradiction)=0 and (entailment)=1\n", "mapping = {2: 0, 1: 0, 0:1}\n", "train_dataset = Dataset.from_dict({\n", " \"sentence1\": train_dataset[\"premise\"],\n", " \"sentence2\": train_dataset[\"hypothesis\"],\n", " \"label\": [float(mapping[label]) for label in train_dataset[\"label\"]]\n", "})" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "np5bMwgO5y8g" }, "outputs": [], "source": [ "from sentence_transformers.evaluation import EmbeddingSimilarityEvaluator\n", "\n", "# Create an embedding similarity evaluator for stsb\n", "val_sts = load_dataset('glue', 'stsb', split='validation')\n", "evaluator = EmbeddingSimilarityEvaluator(\n", " sentences1=val_sts[\"sentence1\"],\n", " sentences2=val_sts[\"sentence2\"],\n", " scores=[score/5 for score in val_sts[\"label\"]],\n", " main_similarity=\"cosine\"\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 692, "referenced_widgets": [ "c16cc9c062f547ee938537a719d4f668", "94e4f08e49bc4043917ce262cfd17fb1", "afc5056a4b1a49579fc52faed7ce12b3", "cbb2ad7701bb4324907b30ac30c058bb", "f21767a805b24d3fa93c647517d1c1ee", "8478533cc376477e97a80376e0fb3678", "e70ed406b6c24fd2b22d518bcff8ff0c", "74b7ff4db1c64725b2bdb67e212bfac3", "3071516aab1a407eb4224e98caf675c6", "319644a2e9204cd08017353e867779f4", "97a00463893941be8e9685a990d2ad6a", "055c320e1c4540fb9f2b293dd28ed569", "6d194ebdd6334b998be781281a49fea3", "7ac2fbaa7741499cb0423c3b509f1656", "27f1913c7cea41858091d98eb8f1ea9b", "0fdc301b86594a7a81731fbecfa5f2ef", "0264b9ef5adf43cb8c3d4516b56e9a03", "2d9423c5a1594b4c853ebd363f14d917", "d31dd92c800246798b73035c1f81deeb", "c96b9494c274423bb7c869b16b4b3be1", "58bff1c2f9bf49f99d5d1715d1aaee01", "60d91c4725314ae29ad0019985e7c73a", "677333dacf8a41d08e0b7fc605da27f9", "cad3e9f01c98445e8fee42c396653813", "4eccf40bc6684d22ab26edf01290d2ae", "7868fb9e79ca4925afca2223088c1aa3", "f3ba5c1671304318977ce8f6c1b47391", "78e6d195ff224c66ae58ac6dad70430c", "ccb9de8bcfce439aaadf3732e1b2a5f8", "979b54240a164b70942340c79ab53777", "f7eeb347d6224eb48bbbc5879dbc80ae", "8e25d420a21b4649ac1ed6b4656098b3", "ccbe3bb6aabd4fbdaedad6755886560c" ] }, "executionInfo": { "elapsed": 366439, "status": "ok", "timestamp": 1717343750576, "user": { "displayName": "Maarten Grootendorst", "userId": "11015108362723620659" }, "user_tz": -120 }, "id": "Ikky866vdseY", "outputId": "803f2abe-002a-481c-9403-8ce39dfa5b47" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "WARNING:sentence_transformers.SentenceTransformer:No sentence-transformers model found with name bert-base-uncased. Creating a new one with mean pooling.\n", "/usr/local/lib/python3.10/dist-packages/huggingface_hub/file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n", " warnings.warn(\n" ] }, { "data": { "text/html": [ "\n", "
\n", " \n", " \n", " [1563/1563 06:04, Epoch 1/1]\n", "
\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
StepTraining Loss
1000.231900
2000.168900
3000.170900
4000.157800
5000.152900
6000.156100
7000.149300
8000.154500
9000.150900
10000.145600
11000.147800
12000.145600
13000.145100
14000.142000
15000.141600

" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "c16cc9c062f547ee938537a719d4f668", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Computing widget examples: 0%| | 0/5 [00:00\n", " \n", " \n", " [528/528 03:00, Epoch 1/1]\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
StepTraining Loss
1000.345200
2000.105500
3000.079000
4000.062200
5000.069000

" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "5913c2edc4de4e48990626d19af0ff2b", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Computing widget examples: 0%| | 0/5 [00:00\n", " \n", " \n", " [1563/1563 01:57, Epoch 1/1]\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
StepTraining Loss
1000.155500
2000.110000
3000.118600
4000.115300
5000.110700
6000.101000
7000.113100
8000.099800
9000.109600
10000.105800
11000.094900
12000.106400
13000.105300
14000.105200
15000.106600

" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "511480884e6145cda98eac03e3eaa214", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Computing widget examples: 0%| | 0/5 [00:00