Fine-tune an Instruct model over raw text data

Purpose
Getting a modern chatbot to uphold it's capabilities on your own data remains a complex task. Context window sizes are increasing rapidly with leading products like Gemini 1.5 Pro's and Claude 3's big leap to a 1 million token capacity. However, a company like The Guardian, where I currently work, has countless code repositories containing hundreds of millions of tokens worth of data.
The recently announced Devin by Cognition Labs likely uses clever RAG techniques to complete it's tasks, but relying on injecting all information into the context window can be problematic. The consensus in the community seems to be that GPT-4 128k can retain great performance for up to around 60K tokens, which isn't a lot. Even then, retaining the great performance requires better and trickier prompting as the amount of tokens grow. Because of these limitations, it seems likely that the most capable models in the near future will use a combination of good prompting, RAG and fine-tuning. For example, for a code assistant tool, the most recent code could be retrieved through a RAG pipeline. A fine-tuned model could then analyse and reason about this code more effectively than a non fine-tuned model, pointing out any edge cases and risks it may have learned from elsewhere. Additionally, the fine-tuned model would adopt the organisation's coding conventions and best practices, allowing it to provide more insightful guidance to employees.
I found limited resources online about high-performing Chatbots fine-tuned on smaller datasets. Instead, most research introduces models like BioMistral, which achieve success using large 3 billion token datasets, requiring significant budget and expertise.
This experiment seeks to discover a lighter approach that navigates between the constraints of a 128K context window and the complexities of a model fine-tuned on billions of tokens, perhaps more in the realm of tens of millions of tokens. For a smaller-scale test, I'll fine-tune Mistral's 7B Instruct v0.2 model on The Guardian's manage-frontend repository (the dataset being 1.6 million tokens).
The goal of this article was to create a reproducible set of instructions for cost-effective model fine-tuning using easily accessible hardware. Emphasis was placed on ease of use, minimizing trial and error, and maximizing the use of raw text data over labeled conversational data. Hopefully any software developer, with zero experience in deep learning engineering, can pick up the notebook and train their own model with ease.
I'll outline the data used, highlight the best hyperparameters and their results, then conclude with a technical explanation for their effectiveness.
Training
A100 40GB
I used a Nvidia A100 40GB from Colab for all training except for one run where I used an H100 80GB.
Unsloth
I used the Unsloth library for faster and more memory efficient training. This blog post gives a good summary on how the Unsloth library works under the hood and shows benchmarks for training speed increases and memory saving.
Differences in training approach to start of the art fine-tuned models
Modern examples of fine-tuning to teach a model new domain-specific knowledge include BioMistral and xFinance. xFinance continues the pre-training of the Llama 7B base model, i.e.: the non-instruct version. It uses LoRA. The model is first trained on over 216,626 documents, totalling 236 million tokens. It is then further fine-tuned on 25,000 samples of finance-based conversational data. Similar to standard chatbot training, this approach begins with training on raw text data, lacking instruction tokens or structured conversational elements, and then transitions to training over exclusively conversational data. BioMistral takes a similar approach, though interestingly it starts fine-tuning off the Mistral 7B Instruct v0.2 model.
My approach combines both the raw dataset and the annotated dataset in the same training run as this approach produced the best results. Only one training run is done.
TRL's SFTtrainer
I used the [SFTtrainer](https://huggingface.co/docs/[trl](https://huggingface.co/docs/trl/en/index)/en/sft_trainer)
from the trl
library. I saw it was used in this Unsloth demo notebook with good results. This is a wrapper over the default HuggingFace trainer. I couldn't find much documentation on how the SFTtrainer extends it, and the code suggests minimal changes. It appears to prepare the dataset for training by setting target labels identical to input_ids (see these lines of code). It sets the target labels
to be the same as the input_ids
. Here's an example of a notebook doing the same thing with the default HuggingFace trainer. This just boils down to next token prediction with cross-entropy loss using the default trainer provided by HuggingFace, nothing fancy. The only difference in training between the "raw text data" and conversational data are the addition of the special instruction tokens "[INST]" and "[/INST]" that Mistral Instruct has been trained to recognise. Refer to the cell outputs in the notebook to see what the dataset looks like.
Creating the raw dataset
My raw dataset consists of the repo's wiki, a snapshot of the main branch from December, and the last 100 pull requests including comments and code changes. I chunked it so each sample was max 8192 tokens.
Scraping the wiki
I just copied and pasted each page into a text file for this
Scraping the codebase
I wrote a Python script that ran locally and wrote all files to a text file in the following format:
- File: productSwitchTypes.ts
Content:
export type ProductSwitchType =
| 'to-recurring-contribution'
| 'recurring-contribution-to-supporter-plus';
export interface PreviewResponse {
amountPayableToday: number;
supporterPlusPurchaseAmount: number;
contributionRefundAmount: number;
nextPaymentDate: string;
checkChargeAmountBeforeUpdate: boolean;
}
- File: productTypes.ts
Content:
...
...
...
Scraping PR data
The corresponding cell in the Colab notebook will produce an output like so for this PR:
PR #2989: Create devcontainer.json
URL: https://github.com/octocat/Hello-World/pull/2989
Description: None
Created at: 2024-02-26T11:39:03Z
Merged at: None
File: .devcontainer/devcontainer.json, Status: added
Changes: @@ -0,0 +1,5 @@
+{
+ "image": "mcr.microsoft.com/devcontainers/universal:2",
+ "features": {
+ }
+}
Generating conversational data
Despite the title of this article, I did use a bit of labeled conversational data, but it is synthetically and easily generated. This doesn't match the quality of carefully curated datasets, but synthetic data is becoming common (I read somewhere it amounted for around 50% of the datasets on HuggingFace). While it won't lead to amazing chatbot performance, the intuition is it may help mitigate any catastrophic forgetting and performance dips, and it's also an easy way of augmenting our dataset. I used 3 methods of generating the synthetic data:
- For each Wiki page, I used the GPT-4 Turbo API to generate a few QA samples based on the provided text. This resulted in roughly 300 QA pairs.
- For each Wiki page, I created a specific instruction or question. For instance, on the ‘Fastly & Caching‘ page, the instruction might be ‘Walk me through how Fastly is used in
manage-frontend
.' The response is then simply the contents of that Wiki page. - Similar to the previous step, for each file in the codebase, I created a question for it. E.g.: "What does the
package.json
file look like in themanage-frontend
repo?" I then prefix each code file with the date of the codebase snapshot used for training, i.e.: "As of December 2023, thepackage.json
file looks like so: "
The QA data was exported to a JSONL file, the following format is recommended as many tokenizers have a function called apply_chat_template
which takes in the list inside the messages
property in each line. Here is an example format below:
{"messages":[{"role":"user","content":"What is the capital of France?"},{"role":"assistant","content":"The capital of France is Paris."}]}
{"messages":[{"role":"user","content":"What is the capital of England?"},{"role":"assistant","content":"The capital of England is London."}]}
I'm using 10% of this conversational data for the validation dataset.
Training the model
Hyperparameter sweeps
I used a manual search. My intuition was that the LoRA rank, batch size and learning rate would affect model performance the most. I therefore started with a wide range of these hyperparameters and then iteratively narrowed down the search space based on the performance of the initial sweeps. A learning rate of 2e-5 appeared optimal, which seems to be standard for fine-tuning Mistral. BioMistral continued fine-tuning the instruct model v0.2 with 0 warm up, a cosine scheduler and a learning rate of 2e-5. As I upped the rank and lowered the batch size the eval loss improved. However, it's important to note that just lowering eval batch size can naturally improve validation loss due to less samples being validated at once, so it's always good to check your model manually after it's done training!
The sweeps in the image below all use a rank of either 512 or 768, with varying alphas; either 1x, 1.5x or 2x the rank. The batch sizes are either 1, 2 or 4. You can see the final hyperparameters I used in here.
Once I found the optimal hyperparameters, I re-ran the training to include all data to make the most of the little data I had, as is common practice. These runs are noted by the All-Data
tag on the end of the sweep name.
Each sweep took under 3 hours, only a few pounds in Colab. All sweeps probably cost me somewhere between £40 and £50.
Note: I accidentally included my Q&A validation data in my raw text data (I forgot I copied and pasted it into one of my text files