{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "g_a9QvUFVCUR" }, "source": [ "

Chapter 4 - Text Classification

\n", "Classifying text with both representative and generative models\n", "\n", "\n", "\n", "\n", "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1wGBtHF-lTI4G4QMdE0eOgmOhRDJT044T?usp=sharing)\n", "\n", "---\n", "\n", "This notebook is for Chapter 4 of the [Hands-On Large Language Models](https://www.amazon.com/Hands-Large-Language-Models-Understanding/dp/1098150961) book by [Jay Alammar](https://www.linkedin.com/in/jalammar) and [Maarten Grootendorst](https://www.linkedin.com/in/mgrootendorst/).\n", "\n", "---\n", "\n", "\n", "" ] }, { "cell_type": "markdown", "metadata": { "id": "UBeVnXxQWy7-" }, "source": [ "# **Data**" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "executionInfo": { "elapsed": 9938, "status": "ok", "timestamp": 1709737297789, "user": { "displayName": "Maarten Grootendorst", "userId": "11015108362723620659" }, "user_tz": -60 }, "id": "5phRS_z2U_3T", "outputId": "27f79175-2ec3-4922-e0ba-5bffd56c82cd" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/usr/local/lib/python3.10/dist-packages/huggingface_hub/utils/_token.py:88: UserWarning: \n", "The secret `HF_TOKEN` does not exist in your Colab secrets.\n", "To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.\n", "You will be able to reuse this secret in all of your notebooks.\n", "Please note that authentication is recommended but still optional to access public models or datasets.\n", " warnings.warn(\n" ] }, { "data": { "text/plain": [ "DatasetDict({\n", " train: Dataset({\n", " features: ['text', 'label'],\n", " num_rows: 8530\n", " })\n", " validation: Dataset({\n", " features: ['text', 'label'],\n", " num_rows: 1066\n", " })\n", " test: Dataset({\n", " features: ['text', 'label'],\n", " num_rows: 1066\n", " })\n", "})" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from datasets import load_dataset\n", "\n", "# Load our data\n", "data = load_dataset(\"rotten_tomatoes\")\n", "data" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "executionInfo": { "elapsed": 15, "status": "ok", "timestamp": 1709737297790, "user": { "displayName": "Maarten Grootendorst", "userId": "11015108362723620659" }, "user_tz": -60 }, "id": "xJJmaJzHDLZv", "outputId": "04501032-aed3-425c-8d70-b069a34c280b" }, "outputs": [ { "data": { "text/plain": [ "{'text': ['the rock is destined to be the 21st century\\'s new \" conan \" and that he\\'s going to make a splash even greater than arnold schwarzenegger , jean-claud van damme or steven segal .',\n", " 'things really get weird , though not particularly scary : the movie is all portent and no content .'],\n", " 'label': [1, 0]}" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "data[\"train\"][0, -1]" ] }, { "cell_type": "markdown", "metadata": { "id": "xya5dfmVoR1R" }, "source": [ "# **Text Classification with Representation Models**" ] }, { "cell_type": "markdown", "metadata": { "id": "co68g-Eloknf" }, "source": [ "## **Using a Task-specific Model**" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "executionInfo": { "elapsed": 17052, "status": "ok", "timestamp": 1709737314828, "user": { "displayName": "Maarten Grootendorst", "userId": "11015108362723620659" }, "user_tz": -60 }, "id": "ph-3T3XJopdN", "outputId": "62abf01e-ba0f-42fc-a8f3-fb467b02583b" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/usr/local/lib/python3.10/dist-packages/torch/_utils.py:831: UserWarning: TypedStorage is deprecated. It will be removed in the future and UntypedStorage will be the only storage class. This should only matter to you if you are using storages directly. To access UntypedStorage directly, use tensor.untyped_storage() instead of tensor.storage()\n", " return self.fget.__get__(instance, owner)()\n", "Some weights of the model checkpoint at cardiffnlp/twitter-roberta-base-sentiment-latest were not used when initializing RobertaForSequenceClassification: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']\n", "- This IS expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", "- This IS NOT expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n", "/usr/local/lib/python3.10/dist-packages/transformers/pipelines/text_classification.py:104: UserWarning: `return_all_scores` is now deprecated, if want a similar functionality use `top_k=None` instead of `return_all_scores=True` or `top_k=1` instead of `return_all_scores=False`.\n", " warnings.warn(\n" ] } ], "source": [ "from transformers import pipeline\n", "\n", "# Path to our HF model\n", "model_path = \"cardiffnlp/twitter-roberta-base-sentiment-latest\"\n", "\n", "# Load model into pipeline\n", "pipe = pipeline(\n", " model=model_path,\n", " tokenizer=model_path,\n", " return_all_scores=True,\n", " device=\"cuda:0\"\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "executionInfo": { "elapsed": 37634, "status": "ok", "timestamp": 1709737352458, "user": { "displayName": "Maarten Grootendorst", "userId": "11015108362723620659" }, "user_tz": -60 }, "id": "B2gbnL5Q69Y5", "outputId": "11c80fd6-0609-429f-caa4-443ee7298bbe" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 1066/1066 [00:37<00:00, 28.25it/s]\n" ] } ], "source": [ "import numpy as np\n", "from tqdm import tqdm\n", "from transformers.pipelines.pt_utils import KeyDataset\n", "\n", "# Run inference\n", "y_pred = []\n", "for output in tqdm(pipe(KeyDataset(data[\"test\"], \"text\")), total=len(data[\"test\"])):\n", " negative_score = output[0][\"score\"]\n", " positive_score = output[2][\"score\"]\n", " assignment = np.argmax([negative_score, positive_score])\n", " y_pred.append(assignment)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "X0KyKHtqyjn3" }, "outputs": [], "source": [ "from sklearn.metrics import classification_report\n", "\n", "def evaluate_performance(y_true, y_pred):\n", " \"\"\"Create and print the classification report\"\"\"\n", " performance = classification_report(\n", " y_true, y_pred,\n", " target_names=[\"Negative Review\", \"Positive Review\"]\n", " )\n", " print(performance)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "executionInfo": { "elapsed": 16, "status": "ok", "timestamp": 1709737352459, "user": { "displayName": "Maarten Grootendorst", "userId": "11015108362723620659" }, "user_tz": -60 }, "id": "fum3MTSyymlW", "outputId": "3a04041d-6a9f-44d3-ca15-c458bbb947ac" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " precision recall f1-score support\n", "\n", "Negative Review 0.76 0.88 0.81 533\n", "Positive Review 0.86 0.72 0.78 533\n", "\n", " accuracy 0.80 1066\n", " macro avg 0.81 0.80 0.80 1066\n", " weighted avg 0.81 0.80 0.80 1066\n", "\n" ] } ], "source": [ "evaluate_performance(data[\"test\"][\"label\"], y_pred)" ] }, { "cell_type": "markdown", "metadata": { "id": "Wr3WT4jzoNZE" }, "source": [ "## **Classification Tasks that Leverage Embeddings**" ] }, { "cell_type": "markdown", "metadata": { "id": "l8yuSP3heMzT" }, "source": [ "### Supervised Classification" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 81, "referenced_widgets": [ "2b933a85712a4571b2a01b5838ca3a9f", "b45b64fe077247ada77a3cc3955353c2", "0289cd2d988748e29548821e78b286cc", "0ae0629c02c247e698b13a665690d2af", "918cf345f3da443691d1e61ed68790d4", "8a627f49d3b1401eafbce793d6ff974b", "c56eb48aeb314c2fa8e7ce996c37c90a", "9b93708354e643039877c072090b830f", "fede446ab0ab46f3b919d305980d70a4", "a95f1bb3b57548bb87b5895b997349e5", "f5c82f510d134762bbe2f88b732bf9cb", "1d7dce3fb68f457c82ce7792809e55b5", "42bbdc9bcc464622903e4d2efe803f3b", "6910f4cf7f9a4f4aafaa7944fbd85200", "42d72d490e484787a6206d008dbc26b8", "cf3ed2d1780e4e26b701e8b70038278c", "a02435a08f1d4ab2aa67ea174f40b485", "89d34ec3f9b14dd1a655500aa451ba6a", "fcd513c6053d4a5888a354de040e47d0", "3d91c4caf4ec40c5b429a298a9803cbc", "42ba1c78e1ae43a1aff0eeaa2ada6a97", "c3870de37eb64d1ea358d1887b4445fa" ] }, "executionInfo": { "elapsed": 26978, "status": "ok", "timestamp": 1709737379425, "user": { "displayName": "Maarten Grootendorst", "userId": "11015108362723620659" }, "user_tz": -60 }, "id": "jGV9VS4bhq7f", "outputId": "47fc54ba-27a7-4043-e5a1-c4b60fa8de93" }, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "2b933a85712a4571b2a01b5838ca3a9f", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Batches: 0%| | 0/267 [00:00#sk-container-id-1 {color: black;background-color: white;}#sk-container-id-1 pre{padding: 0;}#sk-container-id-1 div.sk-toggleable {background-color: white;}#sk-container-id-1 label.sk-toggleable__label {cursor: pointer;display: block;width: 100%;margin-bottom: 0;padding: 0.3em;box-sizing: border-box;text-align: center;}#sk-container-id-1 label.sk-toggleable__label-arrow:before {content: \"▸\";float: left;margin-right: 0.25em;color: #696969;}#sk-container-id-1 label.sk-toggleable__label-arrow:hover:before {color: black;}#sk-container-id-1 div.sk-estimator:hover label.sk-toggleable__label-arrow:before {color: black;}#sk-container-id-1 div.sk-toggleable__content {max-height: 0;max-width: 0;overflow: hidden;text-align: left;background-color: #f0f8ff;}#sk-container-id-1 div.sk-toggleable__content pre {margin: 0.2em;color: black;border-radius: 0.25em;background-color: #f0f8ff;}#sk-container-id-1 input.sk-toggleable__control:checked~div.sk-toggleable__content {max-height: 200px;max-width: 100%;overflow: auto;}#sk-container-id-1 input.sk-toggleable__control:checked~label.sk-toggleable__label-arrow:before {content: \"▾\";}#sk-container-id-1 div.sk-estimator input.sk-toggleable__control:checked~label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-1 div.sk-label input.sk-toggleable__control:checked~label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-1 input.sk-hidden--visually {border: 0;clip: rect(1px 1px 1px 1px);clip: rect(1px, 1px, 1px, 1px);height: 1px;margin: -1px;overflow: hidden;padding: 0;position: absolute;width: 1px;}#sk-container-id-1 div.sk-estimator {font-family: monospace;background-color: #f0f8ff;border: 1px dotted black;border-radius: 0.25em;box-sizing: border-box;margin-bottom: 0.5em;}#sk-container-id-1 div.sk-estimator:hover {background-color: #d4ebff;}#sk-container-id-1 div.sk-parallel-item::after {content: \"\";width: 100%;border-bottom: 1px solid gray;flex-grow: 1;}#sk-container-id-1 div.sk-label:hover label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-1 div.sk-serial::before {content: \"\";position: absolute;border-left: 1px solid gray;box-sizing: border-box;top: 0;bottom: 0;left: 50%;z-index: 0;}#sk-container-id-1 div.sk-serial {display: flex;flex-direction: column;align-items: center;background-color: white;padding-right: 0.2em;padding-left: 0.2em;position: relative;}#sk-container-id-1 div.sk-item {position: relative;z-index: 1;}#sk-container-id-1 div.sk-parallel {display: flex;align-items: stretch;justify-content: center;background-color: white;position: relative;}#sk-container-id-1 div.sk-item::before, #sk-container-id-1 div.sk-parallel-item::before {content: \"\";position: absolute;border-left: 1px solid gray;box-sizing: border-box;top: 0;bottom: 0;left: 50%;z-index: -1;}#sk-container-id-1 div.sk-parallel-item {display: flex;flex-direction: column;z-index: 1;position: relative;background-color: white;}#sk-container-id-1 div.sk-parallel-item:first-child::after {align-self: flex-end;width: 50%;}#sk-container-id-1 div.sk-parallel-item:last-child::after {align-self: flex-start;width: 50%;}#sk-container-id-1 div.sk-parallel-item:only-child::after {width: 0;}#sk-container-id-1 div.sk-dashed-wrapped {border: 1px dashed gray;margin: 0 0.4em 0.5em 0.4em;box-sizing: border-box;padding-bottom: 0.4em;background-color: white;}#sk-container-id-1 div.sk-label label {font-family: monospace;font-weight: bold;display: inline-block;line-height: 1.2em;}#sk-container-id-1 div.sk-label-container {text-align: center;}#sk-container-id-1 div.sk-container {/* jupyter's `normalize.less` sets `[hidden] { display: none; }` but bootstrap.min.css set `[hidden] { display: none !important; }` so we also need the `!important` here to be able to override the default hidden behavior on the sphinx rendered scikit-learn.org. See: https://github.com/scikit-learn/scikit-learn/issues/21755 */display: inline-block !important;position: relative;}#sk-container-id-1 div.sk-text-repr-fallback {display: none;}
LogisticRegression(random_state=42)
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
" ], "text/plain": [ "LogisticRegression(random_state=42)" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from sklearn.linear_model import LogisticRegression\n", "\n", "# Train a Logistic Regression on our train embeddings\n", "clf = LogisticRegression(random_state=42)\n", "clf.fit(train_embeddings, data[\"train\"][\"label\"])" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "executionInfo": { "elapsed": 15, "status": "ok", "timestamp": 1709737379426, "user": { "displayName": "Maarten Grootendorst", "userId": "11015108362723620659" }, "user_tz": -60 }, "id": "tFvO9KhMokF7", "outputId": "6b93052c-0c1a-4e55-dc44-2280a2a66f5c" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " precision recall f1-score support\n", "\n", "Negative Review 0.85 0.86 0.85 533\n", "Positive Review 0.86 0.85 0.85 533\n", "\n", " accuracy 0.85 1066\n", " macro avg 0.85 0.85 0.85 1066\n", " weighted avg 0.85 0.85 0.85 1066\n", "\n" ] } ], "source": [ "# Predict previously unseen instances\n", "y_pred = clf.predict(test_embeddings)\n", "evaluate_performance(data[\"test\"][\"label\"], y_pred)" ] }, { "cell_type": "markdown", "metadata": { "id": "dwGIHxXpJgrC" }, "source": [ "**Tip!** \n", "\n", "What would happen if we would not use a classifier at all? Instead, we can average the embeddings per class and apply cosine similarity to predict which classes match the documents best:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "executionInfo": { "elapsed": 499, "status": "ok", "timestamp": 1709737379911, "user": { "displayName": "Maarten Grootendorst", "userId": "11015108362723620659" }, "user_tz": -60 }, "id": "3f_DnG1uJ7Sk", "outputId": "077a413d-05a3-429a-c098-2f2c73b4b53f" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " precision recall f1-score support\n", "\n", "Negative Review 0.85 0.84 0.84 533\n", "Positive Review 0.84 0.85 0.84 533\n", "\n", " accuracy 0.84 1066\n", " macro avg 0.84 0.84 0.84 1066\n", " weighted avg 0.84 0.84 0.84 1066\n", "\n" ] } ], "source": [ "import numpy as np\n", "import pandas as pd\n", "from sklearn.metrics import classification_report\n", "from sklearn.metrics.pairwise import cosine_similarity\n", "\n", "# Average the embeddings of all documents in each target label\n", "df = pd.DataFrame(np.hstack([train_embeddings, np.array(data[\"train\"][\"label\"]).reshape(-1, 1)]))\n", "averaged_target_embeddings = df.groupby(768).mean().values\n", "\n", "# Find the best matching embeddings between evaluation documents and target embeddings\n", "sim_matrix = cosine_similarity(test_embeddings, averaged_target_embeddings)\n", "y_pred = np.argmax(sim_matrix, axis=1)\n", "\n", "# Evaluate the model\n", "evaluate_performance(data[\"test\"][\"label\"], y_pred)" ] }, { "cell_type": "markdown", "metadata": { "id": "wCWdzjMIjzx0" }, "source": [ "### Zero-shot Classification" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "YSj6CdAetsNp" }, "outputs": [], "source": [ "# Create embeddings for our labels\n", "label_embeddings = model.encode([\"A negative review\", \"A positive review\"])" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "ZEIN7XnbtsQJ" }, "outputs": [], "source": [ "from sklearn.metrics.pairwise import cosine_similarity\n", "\n", "# Find the best matching label for each document\n", "sim_matrix = cosine_similarity(test_embeddings, label_embeddings)\n", "y_pred = np.argmax(sim_matrix, axis=1)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "executionInfo": { "elapsed": 5, "status": "ok", "timestamp": 1709737379912, "user": { "displayName": "Maarten Grootendorst", "userId": "11015108362723620659" }, "user_tz": -60 }, "id": "u6LyeuEUxIbW", "outputId": "7fd0aa01-6de0-449d-c2cb-2534d37fa6b9" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " precision recall f1-score support\n", "\n", "Negative Review 0.78 0.77 0.78 533\n", "Positive Review 0.77 0.79 0.78 533\n", "\n", " accuracy 0.78 1066\n", " macro avg 0.78 0.78 0.78 1066\n", " weighted avg 0.78 0.78 0.78 1066\n", "\n" ] } ], "source": [ "evaluate_performance(data[\"test\"][\"label\"], y_pred)" ] }, { "cell_type": "markdown", "metadata": { "id": "Ox27Rg71zclg" }, "source": [ "**Tip!** \n", "\n", "What would happen if you were to use different descriptions? Use **\"A very negative movie review\"** and **\"A very positive movie review\"** to see what happens!" ] }, { "cell_type": "markdown", "metadata": { "id": "4CC9iEGcuUit" }, "source": [ "## **Classification with Generative Models**" ] }, { "cell_type": "markdown", "metadata": { "id": "qFPPzUHoEESB" }, "source": [ "### Encoder-decoder Models" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "nVbTUMktEfJ3" }, "outputs": [], "source": [ "# Load our model\n", "pipe = pipeline(\n", " \"text2text-generation\",\n", " model=\"google/flan-t5-small\",\n", " device=\"cuda:0\"\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "executionInfo": { "elapsed": 7, "status": "ok", "timestamp": 1709737381485, "user": { "displayName": "Maarten Grootendorst", "userId": "11015108362723620659" }, "user_tz": -60 }, "id": "o5nWQORcFlNn", "outputId": "3705715d-3286-41b1-e07f-29792826c9f6" }, "outputs": [ { "data": { "text/plain": [ "DatasetDict({\n", " train: Dataset({\n", " features: ['text', 'label', 't5'],\n", " num_rows: 8530\n", " })\n", " validation: Dataset({\n", " features: ['text', 'label', 't5'],\n", " num_rows: 1066\n", " })\n", " test: Dataset({\n", " features: ['text', 'label', 't5'],\n", " num_rows: 1066\n", " })\n", "})" ] }, "execution_count": 18, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Prepare our data\n", "prompt = \"Is the following sentence positive or negative? \"\n", "data = data.map(lambda example: {\"t5\": prompt + example['text']})\n", "data" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "executionInfo": { "elapsed": 41739, "status": "ok", "timestamp": 1709737423219, "user": { "displayName": "Maarten Grootendorst", "userId": "11015108362723620659" }, "user_tz": -60 }, "id": "Nas574KFFSvR", "outputId": "b51b7694-3523-44a7-fcaa-48b2cfa2188a" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "\r", " 0%| | 0/1066 [00:00