{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "WpBVeU0XX8Uk" }, "source": [ "

Chapter 12 - Fine-tuning Generation Models

\n", "Exploring a two-step approach for fine-tuning generative LLMs.\n", "\n", "\n", "\n", "\n", "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/HandsOnLLM/Hands-On-Large-Language-Models/blob/main/chapter12/Chapter%2012%20-%20Fine-tuning%20Generation%20Models.ipynb)\n", "\n", "---\n", "\n", "This notebook is for Chapter 12 of the [Hands-On Large Language Models](https://www.amazon.com/Hands-Large-Language-Models-Understanding/dp/1098150961) book by [Jay Alammar](https://www.linkedin.com/in/jalammar) and [Maarten Grootendorst](https://www.linkedin.com/in/mgrootendorst/).\n", "\n", "---\n", "\n", "\n", "" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### [OPTIONAL] - Installing Packages on \n", "\n", "If you are viewing this notebook on Google Colab (or any other cloud vendor), you need to **uncomment and run** the following codeblock to install the dependencies for this chapter:\n", "\n", "---\n", "\n", "💡 **NOTE**: We will want to use a GPU to run the examples in this notebook. In Google Colab, go to\n", "**Runtime > Change runtime type > Hardware accelerator > GPU > GPU type > T4**.\n", "\n", "---\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# %%capture\n", "# !pip install -q accelerate==0.31.0 peft==0.11.1 bitsandbytes==0.43.1 transformers==4.41.2 trl==0.9.4 sentencepiece==0.2.0 triton==3.1.0" ] }, { "cell_type": "markdown", "metadata": { "id": "v5luSSUAu_6d" }, "source": [ "# Supervised Fine-Tuning (SFT)" ] }, { "cell_type": "markdown", "metadata": { "id": "VPtcbw38_hVi" }, "source": [ "## Data Preprocessing" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 717, "referenced_widgets": [ "0c0285e0913b46638191933995384e81", "dcf360a0f47a49e4a45077254414c5a3", "15c6aefae79544f9976aaf48c76724f0", "14dfbe9d7ea64d948115cbfc419f088c", "27f6c79febad4975bad1c6826f56bb3a", "e41a3acb92354b39883d0a593cfd134f", "5580337a0b914e39997d375ab566d320", "0cd0b4ceab9d437f853687736c038724", "502d9bd6b8214cd5947b69581d16f43f", "a17a283f2efa46a39b5b08a3c9a3b354", "5632acd62d3f45a38317099f9d22ebb6", "9122ddf6e4c441aa888d237ab95f3db5", "7ef9099272e74f3ba85ccff9bec40fdd", "4bef0d8577984f5984779e4e80f541df", "e6fbcce852524d81af28fc0b308c416a", "997a77c0b71b45e785d07824116698ee", "cf32f48abc8a4b17a9ba384a50a04990", "b342ae0e84234300a0ce4c665ca34f0e", "2be641c16747464fa4030e8adfdc3d46", "cf001da1deae4f1e97454ea275c0f41a", "daa634044dd240efa142477a69643b81", "7697e11d19b54b0a895b89db5f1cab1c", "9aa1d03488e44364889b9b87db48369e", "d8c2e6bc0e4e4dff85217b4d6b294a4c", "e353f2f8d5f844658f90c252f35ab205", "1c6db3d796204b0587475bab3d3accc4", "0e0b34b158f745f0bf80c2b5d9f8047a", "a6394bf93d9847019bc69951a1343314", "3b4c717b24f04a1481b5ef9316abfda0", "72f3a7db32824b94987384879da3547b", "8c6d3817f69f433184e0554e66334202", "f322425b5f9842c39b083986866ddad9", "10962d2bcb2f442881e6be4cc822fb23", "ec75adaddca64c2d95ddcba24a3611f6", "571991d082184f39b8e75fcac8b40da9", "00303d6943e94fa3a18e180e119825e3", "564fd5d03ddc4db898de20119b596488", "6339b2becd1243eaa4bcc1dec80c70ce", "aefa2e2b6a3f4f75bc3ec84138a58272", "737fbc0ba9684ec68eeed5e7b4958a42", "01c18865664e471982c22916a8cad324", "7c2bfb3fc0a74868baae251eb02c1b5e", "e967b96a23994543ac1c57e86a32a72f", "3a7583cae9aa45b4b51a3630b07f176a", "a7b18257b8d445eab00a2ad90dde1935", "e01822a863864a878d7fc8308b3524b9", "c146636d840f4f81bb7650855621e779", "fbed7ecb20094e9a8cf0d35713561cd2", "0bd9340be6ba4edb9eb571ec25bf659c", "3f3ca64e868c41e18d07d47485d55b89", "315d7aa759cb4fc7a5234fe71fcb2f20", "b799a5db12964557b723e296b8682683", "cf2f19683db94c549f819ac419ba6cd6", "7f5806d80d484a83b7fb2fa21941b4b4", "a109942f24ed47ad9f7be28c56cdb1f3", "4c27c2ebc6c04157910d64034bfce031", "87f9cb64141c4afabedbc608ec00843e", "f558f81464b64db0bdc7f212ae39f2f7", "d29ea76455374ea1b16400c0108389cf", "a327cfa279b941309db9d501da4ab103", "7995c22adaba4f73a94da0dc49e06253", "21c143577ecc462b97251df3fdeee5c6", "aaa7114d60674c0b9439a9bb15d878ae", "728edb90e05e48419efde0fbc4b1854c", "0640421a36dc490f8ba89f550a913148", "0e72b56b079d4e20910bdc9fe1ea9192", "ebfdda7e72694e2e96dc5b894d624959", "98caa4ec23d44d08bb5c72756556eae2", "0df385f54f2a4b289b8fb453bcc05c89", "775f135d2ae3404fa3aa31bc3e137205", "40da08ae319e406c988ec665a40c7017", "7f9c6d434da14849b3d1965aa62083bf", "3cf04c5fa94f4400b8b71657abd1105a", "717649a83f6a440d89e49d675f2b035c", "3b3362cef58443e19c88cba029895229", "813b2c775ef34c20b9b2471b40d189b4", "1b19e64c2b3b444b93a3f50adb3ecbef", "8c2509bb3a3244f8bee5cd03b3a63b01", "5b44bf29b5f24f79aef39b189403bc4d", "ceb0cb33a74e464683351014cb2777b9", "75255917e2b8415a91e06eaf9261a432", "0f4374fe6fb946959bad87e95dc641b5", "67386e8a10e8437dba09bcd15b9ca95e", "f3db9002e1db4f1095c8d256d88b77bd", "6d288ede18de4c7daa8c24bbe3e734cb", "e514be007a6e413695acb9ab7541e1e7", "67bcaf2825584c40b56c22000b1ef813", "dd6840b67b064e998dbcc5a66a1924b1", "593b914a05834c71a737480e32e080af", "20bbacc57f0944efbc53beb5162e3949", "5295e8e550324fef87dbfd6c7e10d960", "14e82108cbc9491dbed6b426ca3238d6", "1a7bd5986d3949318b8be938be47cd75", "371426080f734157ba6c90fcb5b06d32", "64b9301aa9774da28ee49bed3e3a0c8e", "a97d7b4039f54cfd853c03ad1dc62e7f", "d98b4f9684cb46b5be80154f76a1695d", "1b0cfe5b09d549fa98b3a4abcbd0be42", "25de49bee947408e970a951b78fdf818", "2cd7a2846de944cb8f3f046042b1c259", "3644a994605e42b6885b41b6ad9d4039", "52278a1aa6344c798d8cca868ae72aa0", "65c42da7b6334852ab88e93d366c6d76", "3d3f29b0df5a479083856f3962dc15a8", "4461deb4578b414681111aa897e4cd6c", "2275c7f8c6424a479cb241213c5cee86", "8a6cb467bb3e4bd985948490bfe5a131", "12740f7ada4742e2b6d80b5bd6b74607", "f20d3572b6eb40f88436c28025b81bf7", "ee9a1521598640eba1a6ddec51fc2684", "c0270f68d92347cc8a0be3c40166cbe6", "754aefe850fd41b58695eaa2b11c2642", "1674cc41b1dd460aab86339c96ea25dc", "d0e398214f1041ad92f693b8d27b6aa2", "f9698c8009694677a3001f7d26df3282", "cd68b532faaf4339b38ee45f05067376", "0146e7107826418cbcf12224d1d3eff9", "abd89fffda5a4caf80ea321921699eed", "f583bba54fd44c9994b837a1d5a1f4cc", "1af514f472e241deb1d3110dee38ff1d", "ffc03d8ef2a247c1a0761c7180c7b0eb", "7068dee2697343a88b82236f35dbd325", "b4df107982fa482bb4951a3e1ea5ae3e", "ba287f57e76d404dbee0d5384d5eae9f", "893c0cf2bf71484d9f9a476ed7b7060b", "b5fade77a4c741ccae482a3085a79254", "be12786eefed4d72babd91fbef0703f1", "4432f74694924808b4d77fd80d925121", "18a2b862f3364f2a97adf727b9997830", "bdaedbfa2fe143b881de9ac928f260f1", "0541ddcf2e6d415aa11ed9be9518409e", "78270dd565a347dfb00cbb1ee54b474d", "473de04017ba465982c4e9e6a15d7ad9", "d2b3525bdd3944db9b00bd26addb2046", "2ddc6ec3271d4417a3d9b0bf43a66f60", "dfb5602808f64aeb901fe73f9d0ae3dd", "007a8d7830d345f19f044c377de3034f", "dc07ee7aa1e24c09aca0c8837fea0e2b", "4ac1d624ec9a4a6eab38801b1e8605ea", "3e1df78695db4902b88be100a5880d64", "88235119e9774ebfb91c16632420501d", "c9af312b93ef4a54a1bba402966891ed", "f3e1492ca4fc46c8b3e209fc53edc55b", "3de01e557b784bd387bd6985b05c58d2", "531aa8d491394608ad78374c96e1f182", "266bf17916d14b309031ef704ebf8b16", "034fe1a3c54c4aa5afab13b26946c390", "8fb7ab95392f41aba2884c67b4970b93", "5600fc1f67c64444b22f26b040e53bad", "603774211a8f44409f626aec41739f8b", "000129269bda41fc95af5d86e32037e1", "0d7ed05129774701a0e1fd5d9dee938e", "4fa7fbd7f10c43239619efe8c94891c2", "a77b8d97a3ca47fdabf6d3f154440332", "f78630d608e646a985d1af14bbe9108d", "89f8d133c512401caa43f4adc82923db", "38ed39fcb5854a5693cea58d2e7df3da", "5c22e529a1d345f58b2a35b932434ace", "ddd79e89dbd54530905cf3c11fedd4e6", "8db41e735e454cf187047c2453e6097d", "381293f50a7c41518b46e5dba81162b2", "658c9da8d4994da9bd6fc900959f79aa", "8d3012ca9c4a4b75bce993cf668fd062", "475082257c81404c8b1fd270f6dbf895", "ea68d2393dae4db6b7947d9589176c66", "f6768e31065f48598a625ed8831d69ac", "f782508c19df4dfba8a1e3bf14c50131", "9ea8370dddd1469daf46c6e9c9784c59", "ed09b731dfae40feb977fd982e92d237", "550634ce52b24efb9b29231ddd2bcdd8", "9d3cc2f714be4b8e9605f92682d7ac73", "a00545a641274994955018848dba8622", "c2928952588648fa824632bb43e75712", "6ae43605a1fb4dfaa38214c118201269", "57318597fd344d0a8bd0cef8855efe53", "5369f8381c0b43afa2bcb7193db21a39", "5994759ae74240b396b00df3d2e87caf", "6c2124867ce14599917c4f10a211911b", "4a575c47233e4722bc9252f8d650d014", "48869b2592504cc7a0d4dc5a8eb009ae", "3fa9beddb8304d3e8610923802ff6e25", "588bd78207204c74a4b9f45cd660f452", "af85f9be14e040ad9a38390d6629646d", "7b54eabb0cbe4966a0fdab4b3aab652d", "c93c5cb2419847c5a305c5376e3bcac1", "fac0d8b63fb44fd195c1b80d91037055", "504065bf78444cedaab7c54a4d6a888f", "79cb1e5865cb4cf1a3ec6592b2b70a87", "f659f53ffde94033b805b8166fc7b861", "61a6e7784a0d43d086ee4ac57fb138d4", "abd923ac9c324e76a8305bb624934ecc", "c5a657ab85c14fc9a8fe0d53e51edac9", "d285bec3cce74bfab2e54d764192cf94", "faf7e7fea65246a788edb30cc32f430e", "3301916017cc47ccacecdec36b37833d", "da88cae9dab740548ca0593208ff8e41", "05db7dee35e8409791c0b12711e92d4f", "a23266458ea94362bf5f3696429327ce" ] }, "executionInfo": { "elapsed": 31435, "status": "ok", "timestamp": 1719390601273, "user": { "displayName": "Maarten Grootendorst", "userId": "11015108362723620659" }, "user_tz": -120 }, "id": "SqeZchJiOXdd", "outputId": "516dbb28-1771-4e35-bc2a-3317a56960d8" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/usr/local/lib/python3.10/dist-packages/huggingface_hub/utils/_token.py:89: UserWarning: \n", "The secret `HF_TOKEN` does not exist in your Colab secrets.\n", "To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.\n", "You will be able to reuse this secret in all of your notebooks.\n", "Please note that authentication is recommended but still optional to access public models or datasets.\n", " warnings.warn(\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "0c0285e0913b46638191933995384e81", "version_major": 2, "version_minor": 0 }, "text/plain": [ "tokenizer_config.json: 0%| | 0.00/1.29k [00:00 template TinyLLama is using\"\"\"\n", "\n", " # Format answers\n", " chat = example[\"messages\"]\n", " prompt = template_tokenizer.apply_chat_template(chat, tokenize=False)\n", "\n", " return {\"text\": prompt}\n", "\n", "# Load and format the data using the template TinyLLama is using\n", "dataset = (\n", " load_dataset(\"HuggingFaceH4/ultrachat_200k\", split=\"test_sft\")\n", " .shuffle(seed=42)\n", " .select(range(3_000))\n", ")\n", "dataset = dataset.map(format_prompt)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "executionInfo": { "elapsed": 3, "status": "ok", "timestamp": 1719390601273, "user": { "displayName": "Maarten Grootendorst", "userId": "11015108362723620659" }, "user_tz": -120 }, "id": "dtl2xZptgyDf", "outputId": "304c49f2-16c8-47ad-f8fb-4d975012e6d3" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "<|user|>\n", "Given the text: Knock, knock. Who’s there? Hike.\n", "Can you continue the joke based on the given text material \"Knock, knock. Who’s there? Hike\"?\n", "<|assistant|>\n", "Sure! Knock, knock. Who's there? Hike. Hike who? Hike up your pants, it's cold outside!\n", "<|user|>\n", "Can you tell me another knock-knock joke based on the same text material \"Knock, knock. Who's there? Hike\"?\n", "<|assistant|>\n", "Of course! Knock, knock. Who's there? Hike. Hike who? Hike your way over here and let's go for a walk!\n", "\n" ] } ], "source": [ "# Example of formatted prompt\n", "print(dataset[\"text\"][2576])" ] }, { "cell_type": "markdown", "metadata": { "id": "CyuLZGizDqUB" }, "source": [ "## Models - Quantization" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 241, "referenced_widgets": [ "1b4de2592d454a7ba7560ea6849dc6ba", "16f636aa216e44a1b9ebb66961b45361", "d55c8156644740c48012383eb903b7c5", "b40ba0141f194f36b95882644fb2a41e", "e8499c18c273492d8fb55fdd50f8f9d2", "db0c43c383d3479989e3a307a8c635fd", "2e5ccfc7cf4e45cbbdf74add6bf2fdf2", "b66b520fbdf84c78adc2f16910ed7a2a", "bb4131ce7126465ba7f0107f51d4bd5a", "865041ec26e3499cb66da96354bc9a7a", "32a2bd764cae40d59467ea71dccea11d", "b226cae1a8fc4b8384e0f5fc67d94d89", "220334bb55104a01b4e112a05a79db77", "a956b9fa522e4d0fafc8fb415ac109ec", "a6ae124d59814210913da1e50e82a90b", "c03dd90be661408c8d94238ce901081d", "23399b989c4449e0ae28c9932b104ba4", "16a053e1fc61408a97c650b58fd56913", "373a8dedb2b6498c8831f01912e87f10", "ec9bf18f5baa45d6a90c1fd2be0e941f", "88e2d8f90d814e61b2b8993f12f24d0c", "38e7e69833ec4298872afb91e3ba76a0", "da07fb8f4a204a4aa9125ddb89526992", "3071cd29353c4137a104b974efb903fb", "e650428153f441c19a9ce60c7f36b726", "6f6d884827dd403280cc2bd44febf8f7", "505f2e0ae82f4d5190e7f0a61c693c7d", "3ef9ff15982048bcbd9667af97be7a91", "f7e0a4c846264a3ebf8bec09a6520674", "ac64efb662ca4a2593609ee2797cb188", "2211c8ccb3eb4d0d8e5d7ce14e09c843", "dea87b44934648cb86d95f4f081005b1", "cf4f55ef9e2f450b891149a5cf3d190d", "4f88583ce59b467cb9791d09f1d57afd", "825cf142c3a842ccb3f7001138a11930", "12959d7e56374d8ca49ff18544b4db7e", "f301f0072ff649f3a75dea4ed687c294", "37273f57135141deb8cb53d858669834", "3c0da0f585f64a489a28fb6afa1e6f5f", "5b4f672bcb7546eca0e2dbce43900798", "e918450ee4b74d03bab029dd7230728e", "57040592ffb3433fbdfe90ea1a52d1ab", "0ab39d8c544a43bcade2ab94f2a61a0a", "fedef8865ad740d0a5cf4167b0067bf6", "7ab5f2a459764205ae4263b22f64a7aa", "f6092670ea49450e93fbf62eab75d996", "6c042632a5b94c5bb9b387ad17920908", "7ddca642a1bb4605ad508b1d56c3a61f", "1d44d180cee1440ca280e7004aa9bbda", "33fd53bddde646cf99019f98e04cb1d5", "744008ceeebd4809969b054d3a09d6c5", "1deb1756e91941af91bc2f404b5c52c9", "9e2a150555434cf39df5da21932a57bf", "397231517f4e46fbaef4408d36ecb1c8", "1c5d2c4d4d74406fbd4a3fb5b528eed9", "a641c51801964d749c11be9ecdaf8749", "4375456a00da42c4af307d3b0680ec02", "51778001c9214474870749b4af4e2d23", "63ae62bab5394d848265983975840f53", "e1337ad30ae74bf28c4f88ccecb24c06", "a80691a3b01a4c409c49833aa4e2c4ea", "f51f84cc3c9e41ea82720a9ed41cfbe4", "ec18963535394ba2acef196aff2603b9", "9666d1392eb5438b84b76dcfd3222992", "cd70eac32fea43a3bfb08cf52d74ae4b", "4f4660a0cec2426e8d7ff0b673de28be", "5f59b68201d94411b978417aa902e747", "e45f18c8c9b447eda9c64a47df8d4881", "179ee7d876784842b278c92ce0ac4f7d", "e664c66f54e9465c8bd404b2ae5aaf4e", "a55c2efc970b481a8fb1ad771c6fc1ee", "686140ba3e5b428088de8d94b6e3c707", "601b905edc4d410184ee34f83d993971", "b1738e5723ca4a4ca82f9038897ce547", "aa3dc6d6d4d44b2b8d25ab7fbb4bf224", "ad68d7d2d6ef4d90aef6d7536d61bd5b", "9fe1ce51c081490bbbfc9c3688bc3860" ] }, "executionInfo": { "elapsed": 22033, "status": "ok", "timestamp": 1719390623304, "user": { "displayName": "Maarten Grootendorst", "userId": "11015108362723620659" }, "user_tz": -120 }, "id": "M95Y207T7wSp", "outputId": "5ed465a8-9fa9-4c8e-8b50-bdc3630bbef1" }, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "1b4de2592d454a7ba7560ea6849dc6ba", "version_major": 2, "version_minor": 0 }, "text/plain": [ "config.json: 0%| | 0.00/560 [00:00\"\n", "tokenizer.padding_side = \"left\"" ] }, { "cell_type": "markdown", "metadata": { "id": "t1iGIch-sAMC" }, "source": [ "## Configuration" ] }, { "cell_type": "markdown", "metadata": { "id": "86o1T5n4DziD" }, "source": [ "### LoRA Configuration" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "0tYs1ZhYDyw9" }, "outputs": [], "source": [ "from peft import LoraConfig, prepare_model_for_kbit_training, get_peft_model\n", "\n", "# Prepare LoRA Configuration\n", "peft_config = LoraConfig(\n", " lora_alpha=32, # LoRA Scaling\n", " lora_dropout=0.1, # Dropout for LoRA Layers\n", " r=64, # Rank\n", " bias=\"none\",\n", " task_type=\"CAUSAL_LM\",\n", " target_modules= # Layers to target\n", " ['k_proj', 'gate_proj', 'v_proj', 'up_proj', 'q_proj', 'o_proj', 'down_proj']\n", ")\n", "\n", "# prepare model for training\n", "model = prepare_model_for_kbit_training(model)\n", "model = get_peft_model(model, peft_config)" ] }, { "cell_type": "markdown", "metadata": { "id": "Zhbh7kKuD24o" }, "source": [ "### Training Configuration" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "TwxZkx80G6bO" }, "outputs": [], "source": [ "from transformers import TrainingArguments\n", "\n", "output_dir = \"./results\"\n", "\n", "# Training arguments\n", "training_arguments = TrainingArguments(\n", " output_dir=output_dir,\n", " per_device_train_batch_size=2,\n", " gradient_accumulation_steps=4,\n", " optim=\"paged_adamw_32bit\",\n", " learning_rate=2e-4,\n", " lr_scheduler_type=\"cosine\",\n", " num_train_epochs=1,\n", " logging_steps=10,\n", " fp16=True,\n", " gradient_checkpointing=True\n", ")" ] }, { "cell_type": "markdown", "metadata": { "id": "RtwIo5a0D6f1" }, "source": [ "## Training!" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 1000, "referenced_widgets": [ "1b741e85fcf1458f9875c12a9640dfee", "d034f23fd4df4e3296477d8dd76be5b1", "bee0b0ce2fb84c5eb67a04ced69752d1", "4d39d09e1e2648b1b5295f192e9ad356", "3f733eb54fe54d879c97dba7a5204ddd", "e9dab506c3b242d7b6228394ada6084b", "e7d4893b696c4941bf29d349eb2ceabb", "6d7ee17aa7024c8088c374781348f9f0", "7d8478e66e394f0fb077853e5319ee6a", "630e317f036f41f4a9852f7df81eef83", "eb35394940c74f60abe2daaeb243fa88" ] }, "executionInfo": { "elapsed": 774977, "status": "ok", "timestamp": 1719391399990, "user": { "displayName": "Maarten Grootendorst", "userId": "11015108362723620659" }, "user_tz": -120 }, "id": "B2D7RVihsE7Z", "outputId": "1a9f8125-6d39-410e-ff94-9a9ac493ff25" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/usr/local/lib/python3.10/dist-packages/huggingface_hub/utils/_deprecation.py:100: FutureWarning: Deprecated argument(s) used in '__init__': dataset_text_field, max_seq_length. Will not be supported from version '1.0.0'.\n", "\n", "Deprecated positional argument(s) used in SFTTrainer, please use the SFTConfig to set these arguments instead.\n", " warnings.warn(message, FutureWarning)\n", "/usr/local/lib/python3.10/dist-packages/transformers/training_args.py:1965: FutureWarning: `--push_to_hub_token` is deprecated and will be removed in version 5 of 🤗 Transformers. Use `--hub_token` instead.\n", " warnings.warn(\n", "/usr/local/lib/python3.10/dist-packages/trl/trainer/sft_trainer.py:269: UserWarning: You passed a `max_seq_length` argument to the SFTTrainer, the value you passed will override the one in the `SFTConfig`.\n", " warnings.warn(\n", "/usr/local/lib/python3.10/dist-packages/trl/trainer/sft_trainer.py:307: UserWarning: You passed a `dataset_text_field` argument to the SFTTrainer, the value you passed will override the one in the `SFTConfig`.\n", " warnings.warn(\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "1b741e85fcf1458f9875c12a9640dfee", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Map: 0%| | 0/3000 [00:00\n", " \n", " \n", " [375/375 12:45, Epoch 1/1]\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
StepTraining Loss
101.670600
201.475400
301.451400
401.487800
501.477900
601.390500
701.495200
801.450300
901.427900
1001.404400
1101.414400
1201.377500
1301.332100
1401.497000
1501.347000
1601.411500
1701.454000
1801.324500
1901.419300
2001.474900
2101.404600
2201.342100
2301.361100
2401.387300
2501.353700
2601.345800
2701.465400
2801.434000
2901.387600
3001.376200
3101.395000
3201.437900
3301.387200
3401.388100
3501.313600
3601.444300
3701.452000

" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stderr", "output_type": "stream", "text": [ "/usr/local/lib/python3.10/dist-packages/huggingface_hub/file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n", " warnings.warn(\n" ] } ], "source": [ "from trl import SFTTrainer\n", "\n", "# Set supervised fine-tuning parameters\n", "trainer = SFTTrainer(\n", " model=model,\n", " train_dataset=dataset,\n", " dataset_text_field=\"text\",\n", " tokenizer=tokenizer,\n", " args=training_arguments,\n", " max_seq_length=512,\n", "\n", " # Leave this out for regular SFT\n", " peft_config=peft_config,\n", ")\n", "\n", "# Train model\n", "trainer.train()\n", "\n", "# Save QLoRA weights\n", "trainer.model.save_pretrained(\"TinyLlama-1.1B-qlora\")" ] }, { "cell_type": "markdown", "metadata": { "id": "tsIBfv1PsId-" }, "source": [ "### Merge Adapter" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "M6cPdde4Z-ks" }, "outputs": [], "source": [ "from peft import AutoPeftModelForCausalLM\n", "\n", "model = AutoPeftModelForCausalLM.from_pretrained(\n", " \"TinyLlama-1.1B-qlora\",\n", " low_cpu_mem_usage=True,\n", " device_map=\"auto\",\n", ")\n", "\n", "# Merge LoRA and base model\n", "merged_model = model.merge_and_unload()" ] }, { "cell_type": "markdown", "metadata": { "id": "jPRYGimIsM2-" }, "source": [ "### Inference" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "executionInfo": { "elapsed": 6781, "status": "ok", "timestamp": 1719391410095, "user": { "displayName": "Maarten Grootendorst", "userId": "11015108362723620659" }, "user_tz": -120 }, "id": "15dJC3ZrdVnK", "outputId": "3095ed46-5bb8-4288-b3a0-05d7daeaefc4" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "<|user|>\n", "Tell me something about Large Language Models.\n", "<|assistant|>\n", "Large Language Models (LLMs) are a type of artificial intelligence (AI) that can generate human-like language. They are trained on large amounts of data, including text, audio, and video, and are capable of generating complex and nuanced language.\n", "\n", "LLMs are used in a variety of applications, including natural language processing (NLP), machine translation, and chatbots. They can be used to generate text, speech, or images, and can be trained to understand different languages and dialects.\n", "\n", "One of the most significant applications of LLMs is in the field of natural language generation (NLG). LLMs can be used to generate text in a variety of languages, including English, French, and German. They can also be used to generate speech, such as in chatbots or voice assistants.\n", "\n", "LLMs have the potential to revolutionize the way we communicate and interact with each other. They can help us create more engaging and personalized content, and they can also help us understand each other better.\n" ] } ], "source": [ "from transformers import pipeline\n", "\n", "# Use our predefined prompt template\n", "prompt = \"\"\"<|user|>\n", "Tell me something about Large Language Models.\n", "<|assistant|>\n", "\"\"\"\n", "\n", "# Run our instruction-tuned model\n", "pipe = pipeline(task=\"text-generation\", model=merged_model, tokenizer=tokenizer)\n", "print(pipe(prompt)[0][\"generated_text\"])" ] }, { "cell_type": "markdown", "metadata": { "id": "9JNfYZe9vCb8" }, "source": [ "# Preference Tuning (PPO/DPO)" ] }, { "cell_type": "markdown", "metadata": { "id": "ar2h9kZ9qmEG" }, "source": [ "## Data Preprocessing" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 246, "referenced_widgets": [ "d330d84ac98a4d14b51ffab13277e501", "2ed3f903f58e42b4a10c8937e7e8cdf5", "415333d658284e5f9566065f9bfc4808", "76038e4dce21442caa069273e8c22e42", "914814f0575948c688991533f85e59dc", "e78efc4e689745d69c6d91ac71d51f39", "f113f1d6bf9244c98c89f5610e371431", "67066b2143a74269bb15a97ace2f0ae2", "45b29b925a224dbebf61f67bf359f393", "f6a86cfcc1d347e5b182a9abf57b15b4", "9dedc5f35e9241a6b9aff057c1c6ef9b", "73e8b10b8ee349d1b91dff30f885e335", "79da0fb8fad247c1a5e6a5b8bde4d498", "c6ff9de93d92425481c29123ac68bf76", "c800d5aacd2f4844856c505f11e09e56", "0c5019fa5fc74a81a6b11afc49fe135c", "c9c3dad975c4436499c3ac3df06bad39", "78fb4c59298b4b10b3dd3d5707ae0d81", "a7ed257939254f11b1494bf7ae6d42f9", "7a856d774904424699aec1f2e9479016", "149367fc059e47f990dde648af05d17c", "3245f7bd4eb244b587eb15e3cd0b1d80", "063c3f5d0ed64ef4901efb5a4fd64149", "d571423cdf204acc8ce3e7a4da3e2526", "06220ed92658447fb48ade092c4bb36d", "6d05c08c79694bed8345d28bdfe19032", "54724f48c14b43f5ba4459459670334a", "f5d383934bbb4f758336f08b567f0824", "64d85e89ed7d46a981d2a00c47bdf8b2", "35a0f7cdae6f4a61acb0dffe5e698130", "850abb487dce4b1aa09f8094bc447a9b", "1aeb4f6ec0234ed3a6be1ac75244ae8e", "93156ea1f4a14de392c28e0b489b4290", "a29338f36ee34153a880c3f5fb985616", "50de9fd7ffd544559ebbd078ef12345f", "4036016f7aad43449ad70cd76f40c5eb", "9bfee60d3ab5416e983c42ae6ecdd0e0", "ff17fc6260e547268b518231cae50451", "80e75187ab95457795532aa6c7b00d76", "261b9119ab994265aac43d6b80ffc90d", "92ef80ab57e44151a0e9419639cb9a34", "bb0813f49acd457cadc27d2384e9274f", "bd41907a546e40059723a07c53f39339", "d30f87bdd81242e989c97a14ec2f98f3", "a98b46f3843b4724b6fad82fac16e219", "c1f7130a16d844ed91aa1c97c36f18a1", "4715e099e4454e9f9c54e0572d0508d4", "e9d5215839c44dbcb782855769ef3d0b", "c96ff8419847405889ef3b88a4684739", "fe77dd4438fb4de389bbf33ec23bf8b4", "3e34500a435a457c9cb321b4b258f76e", "30cfc3189acb4db7954e0b6efc9a5645", "530d30bd4d524415bec77c5d1725a4ac", "3804bdf7cf5f4d9c87ce42c47fadfe1a", "27c3bc5002ba49ef81d6c20b40d5719f" ] }, "executionInfo": { "elapsed": 4958, "status": "ok", "timestamp": 1719391415052, "user": { "displayName": "Maarten Grootendorst", "userId": "11015108362723620659" }, "user_tz": -120 }, "id": "UlbPVO_aac33", "outputId": "a2c446d6-e410-4d17-eb98-96b21264e0e9" }, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "d330d84ac98a4d14b51ffab13277e501", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Downloading readme: 0%| | 0.00/10.1k [00:00 template TinyLLama is using\"\"\"\n", "\n", " # Format answers\n", " system = \"<|system|>\\n\" + example['system'] + \"\\n\"\n", " prompt = \"<|user|>\\n\" + example['input'] + \"\\n<|assistant|>\\n\"\n", " chosen = example['chosen'] + \"\\n\"\n", " rejected = example['rejected'] + \"\\n\"\n", "\n", " return {\n", " \"prompt\": system + prompt,\n", " \"chosen\": chosen,\n", " \"rejected\": rejected,\n", " }\n", "\n", "# Apply formatting to the dataset and select relatively short answers\n", "dpo_dataset = load_dataset(\"argilla/distilabel-intel-orca-dpo-pairs\", split=\"train\")\n", "dpo_dataset = dpo_dataset.filter(\n", " lambda r:\n", " r[\"status\"] != \"tie\" and\n", " r[\"chosen_score\"] >= 8 and\n", " not r[\"in_gsm8k_train\"]\n", ")\n", "dpo_dataset = dpo_dataset.map(format_prompt, remove_columns=dpo_dataset.column_names)\n", "dpo_dataset" ] }, { "cell_type": "markdown", "metadata": { "id": "AkCJ4CO5sQG6" }, "source": [ "## Models - Quantization" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "executionInfo": { "elapsed": 5934, "status": "ok", "timestamp": 1719391420979, "user": { "displayName": "Maarten Grootendorst", "userId": "11015108362723620659" }, "user_tz": -120 }, "id": "7YMmilm7c1-P", "outputId": "fbf5e75b-cf63-4ac6-b1b5-514cddceb842" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/usr/local/lib/python3.10/dist-packages/peft/tuners/lora/bnb.py:325: UserWarning: Merge lora module to 4-bit linear may get different generations due to rounding errors.\n", " warnings.warn(\n" ] } ], "source": [ "from peft import AutoPeftModelForCausalLM\n", "from transformers import BitsAndBytesConfig, AutoTokenizer\n", "\n", "# 4-bit quantization configuration - Q in QLoRA\n", "bnb_config = BitsAndBytesConfig(\n", " load_in_4bit=True, # Use 4-bit precision model loading\n", " bnb_4bit_quant_type=\"nf4\", # Quantization type\n", " bnb_4bit_compute_dtype=\"float16\", # Compute dtype\n", " bnb_4bit_use_double_quant=True, # Apply nested quantization\n", ")\n", "\n", "# Merge LoRA and base model\n", "model = AutoPeftModelForCausalLM.from_pretrained(\n", " \"TinyLlama-1.1B-qlora\",\n", " low_cpu_mem_usage=True,\n", " device_map=\"auto\",\n", " quantization_config=bnb_config,\n", ")\n", "merged_model = model.merge_and_unload()\n", "\n", "# Load LLaMA tokenizer\n", "model_name = \"TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T\"\n", "tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=False)\n", "tokenizer.pad_token = \"\"\n", "tokenizer.padding_side = \"left\"" ] }, { "cell_type": "markdown", "metadata": { "id": "iidCbaXMs1O4" }, "source": [ "## Configuration" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "m6IfkvLkylVD" }, "outputs": [], "source": [ "from peft import LoraConfig, prepare_model_for_kbit_training, get_peft_model\n", "\n", "# Prepare LoRA Configuration\n", "peft_config = LoraConfig(\n", " lora_alpha=32, # LoRA Scaling\n", " lora_dropout=0.1, # Dropout for LoRA Layers\n", " r=64, # Rank\n", " bias=\"none\",\n", " task_type=\"CAUSAL_LM\",\n", " target_modules= # Layers to target\n", " ['k_proj', 'gate_proj', 'v_proj', 'up_proj', 'q_proj', 'o_proj', 'down_proj']\n", ")\n", "\n", "# prepare model for training\n", "model = prepare_model_for_kbit_training(model)\n", "model = get_peft_model(model, peft_config)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "lk-cEEd8nk27" }, "outputs": [], "source": [ "from trl import DPOConfig\n", "\n", "output_dir = \"./results\"\n", "\n", "# Training arguments\n", "training_arguments = DPOConfig(\n", " output_dir=output_dir,\n", " per_device_train_batch_size=2,\n", " gradient_accumulation_steps=4,\n", " optim=\"paged_adamw_32bit\",\n", " learning_rate=1e-5,\n", " lr_scheduler_type=\"cosine\",\n", " max_steps=200,\n", " logging_steps=10,\n", " fp16=True,\n", " gradient_checkpointing=True,\n", " warmup_ratio=0.1\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 1000, "referenced_widgets": [ "c98fc033e551443b94dc8dae31d590bf", "b20056d36ac942cc9f8c3a9efc020f36", "3f8a810d4e2549c9a4ba2f3f2ee017e8", "47768d322ac94e1494d7f4a2f01440b7", "656d76b25562479ba3b24f22800a675b", "eab02ca086f44759a9b1b00d8f1a1245", "d960c3e6a7a04fb8b525dec294da6815", "983f9c1d7e47494383099916eed69c0d", "0236bebb8dce4b5dade5728304ffb964", "9e1a61b6c5f8482eadd024280da208f3", "4380e0fb571e41c9a58b09a06a20b853" ] }, "executionInfo": { "elapsed": 805129, "status": "ok", "timestamp": 1719392226734, "user": { "displayName": "Maarten Grootendorst", "userId": "11015108362723620659" }, "user_tz": -120 }, "id": "Pp3tUXhWm0pE", "outputId": "29378dc8-bb8a-435b-8330-e45aa26548c7" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/usr/local/lib/python3.10/dist-packages/huggingface_hub/utils/_deprecation.py:100: FutureWarning: Deprecated argument(s) used in '__init__': max_prompt_length, max_length. Will not be supported from version '1.0.0'.\n", "\n", "Deprecated positional argument(s) used in DPOTrainer, please use the DPOConfig to set these arguments instead.\n", " warnings.warn(message, FutureWarning)\n", "/usr/local/lib/python3.10/dist-packages/peft/tuners/lora/bnb.py:325: UserWarning: Merge lora module to 4-bit linear may get different generations due to rounding errors.\n", " warnings.warn(\n", "/usr/local/lib/python3.10/dist-packages/trl/trainer/dpo_trainer.py:358: UserWarning: You passed `max_length` to the DPOTrainer, the value you passed will override the one in the `DPOConfig`.\n", " warnings.warn(\n", "/usr/local/lib/python3.10/dist-packages/trl/trainer/dpo_trainer.py:371: UserWarning: You passed `max_prompt_length` to the DPOTrainer, the value you passed will override the one in the `DPOConfig`.\n", " warnings.warn(\n", "/usr/local/lib/python3.10/dist-packages/trl/trainer/dpo_trainer.py:411: UserWarning: When using DPODataCollatorWithPadding, you should set `remove_unused_columns=False` in your TrainingArguments we have set it for you, but you should do it yourself in the future.\n", " warnings.warn(\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "c98fc033e551443b94dc8dae31d590bf", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Map: 0%| | 0/5922 [00:00\n", " \n", " \n", " [200/200 12:52, Epoch 0/1]\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
StepTraining Loss
100.692400
200.678200
300.646000
400.606300
500.595600
600.616800
700.593700
800.531900
900.559200
1000.639000
1100.496500
1200.586000
1300.630000
1400.590100
1500.577500
1600.591000
1700.606900
1800.627800
1900.668600
2000.555400

" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "from trl import DPOTrainer\n", "\n", "# Create DPO trainer\n", "dpo_trainer = DPOTrainer(\n", " model,\n", " args=training_arguments,\n", " train_dataset=dpo_dataset,\n", " tokenizer=tokenizer,\n", " peft_config=peft_config,\n", " beta=0.1,\n", " max_prompt_length=512,\n", " max_length=512,\n", ")\n", "\n", "# Fine-tune model with DPO\n", "dpo_trainer.train()\n", "\n", "# Save adapter\n", "dpo_trainer.model.save_pretrained(\"TinyLlama-1.1B-dpo-qlora\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "QFE4OKFvyLMe" }, "outputs": [], "source": [ "from peft import PeftModel\n", "\n", "# Merge LoRA and base model\n", "model = AutoPeftModelForCausalLM.from_pretrained(\n", " \"TinyLlama-1.1B-qlora\",\n", " low_cpu_mem_usage=True,\n", " device_map=\"auto\",\n", ")\n", "sft_model = model.merge_and_unload()\n", "\n", "# Merge DPO LoRA and SFT model\n", "dpo_model = PeftModel.from_pretrained(\n", " sft_model,\n", " \"TinyLlama-1.1B-dpo-qlora\",\n", " device_map=\"auto\",\n", ")\n", "dpo_model = dpo_model.merge_and_unload()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "executionInfo": { "elapsed": 6777, "status": "ok", "timestamp": 1719392237608, "user": { "displayName": "Maarten Grootendorst", "userId": "11015108362723620659" }, "user_tz": -120 }, "id": "zAkwJcHYmxr4", "outputId": "631aed7c-1e64-4e2c-db73-3e36ddee4e75" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "<|user|>\n", "Tell me something about Large Language Models.\n", "<|assistant|>\n", "Large Language Models (LLMs) are a type of artificial intelligence (AI) that can generate human-like language. They are trained on large amounts of data, including text, audio, and video, and are capable of generating complex and nuanced language.\n", "\n", "LLMs are used in a variety of applications, including natural language processing (NLP), machine translation, and chatbots. They can be used to generate text, speech, or images, and can be trained to understand different languages and dialects.\n", "\n", "One of the most significant applications of LLMs is in the field of natural language generation (NLG). LLMs can be used to generate text in a variety of languages, including English, French, and German. They can also be used to generate speech, such as in chatbots or voice assistants.\n", "\n", "LLMs have the potential to revolutionize the way we communicate and interact with each other. They can help us create more engaging and personalized content, and they can also help us understand each other better.\n" ] } ], "source": [ "from transformers import pipeline\n", "\n", "# Use our predefined prompt template\n", "prompt = \"\"\"<|user|>\n", "Tell me something about Large Language Models.\n", "<|assistant|>\n", "\"\"\"\n", "\n", "# Run our instruction-tuned model\n", "pipe = pipeline(task=\"text-generation\", model=dpo_model, tokenizer=tokenizer)\n", "print(pipe(prompt)[0][\"generated_text\"])" ] } ], "metadata": { "accelerator": "GPU", "colab": { "authorship_tag": "ABX9TyPaxPKtmt1gCzzuqYr6g2+g", "gpuType": "T4", "provenance": [], "toc_visible": true }, "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.14" } }, "nbformat": 4, "nbformat_minor": 4 }