Fine-Tuning a Language Model for Legal Question Answering March, 2024
The legal domain is filled with complex terminology, nuanced interpretations, and the need for precise answers. This makes it a perfect challenge for the power of natural language processing (NLP). However, generic language models often struggle to grasp the subtleties of legal text. To satisfy my curiosity, in this blog post, I will dive into the process of fine-tuning a large language model (LLM) specifically for the task of legal question answering. You can find the raw dataset and the complete notebook in legal-gemma repository.
The Toolkit
- Google Colab: The experimentation playground for GPU resources.
- LegalBench: A curated dataset of legal questions and answers.
- HuggingFace: Its libraries streamline working with LLMs.
- BitsAndBytes: For memory-efficient quantization.
- PEFT and LORA: Techniques to selectively fine-tune parts of the model.
The Dataset: LegalBench
Before diving into code, let's understand the structure of LegalBench. According to the owners of the repository, LegalBench is "a benchmark consisting of different legal reasoning tasks. Each task has an associated dataset, consisting of input-output pairs."Step-by-Step Guide
Project Setup
First, head to Google Colab and start a new notebook. Next, make sure you have all the following installed on the notebook:
!pip install --upgrade pandas datasets transformers bitsandbytes peft trl huggingface_hub
Load and Transform the Data
I do not use the LegalBench dataset in the originally provided format but instead transform it from a multi-task, multi-dataset form to a question/answering dataset where the answers could be Yes/No, multiple-choice, or open-ended. You can find the [long and ugly] transformation code here. Basically, it takes one row from the original dataset and prepares a
So, in the implementation, I start by reading the provided dataset and transforming it using the data transformation that I built.
from Data import transform_data
import pandas as pd
df = pd.read_json('raw_data_sample.json')
data = [transform_data(row) for row in df.itertuples()]
indexes, contexts, questions, answers = zip(*data)
assert len(contexts) == len(questions) == len(answers) and len(contexts) > 0
input_texts = [{"text": f"Answer the Question based on the given Context.\nContext: {c}\nQuestion: {q}\nAnswer: {a}"} for c, q, a in zip(contexts, questions, answers)]
Next, I use scikit to split the data into two sets: train (95%), and test (5%).
from sklearn.model_selection import train_test_split
from datasets import Dataset
train_texts, test_texts = train_test_split(input_texts, test_size=0.05, random_state=7)
train_dataset = Dataset.from_list(train_texts)
test_dataset = Dataset.from_list(test_texts)
Prepare for Loading the Language Model
Then, I import the necessary libraries for fine-tuning gemma-2b model. I use huggingface for this purpose.
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments
from transformers import BitsAndBytesConfig, set_seed
from peft import LoraConfig, PeftModel, prepare_model_for_kbit_training, get_peft_model
from trl import SFTTrainer
set_seed(7)
model_name = "google/gemma-2b-it"
Since loading gemma-2b
requires accepting particular terms and conditions, one needs to accept these terms and login into huggingface hub via an access token. These terms and conditions can be accepted by logging into huggingface. Then, login using your access token.
from huggingface_hub import login
login(token='YOUR_TOKEN')
Fine-tuning the Model
The next step is loading the model and the corresponding tokenizer.
Fitting the LLM on my GPU!
Since fine-tuning gemma-2b required more memory than I had available (15GB), I needed to take advantage of quantization to reduce the memory usage. Basically, BitsAndBytes
quantization uses lower-precision data types to enable loading larger models.
My configuration loads the linear layers of the model with 4-bit integer precision.
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16
)
Loading the Model
I load the model and its corresponding tokenizer using huggingface
. Then, I pass it to prepare_model_for_kbit_training
so that the BitsAndBytes
quantization that I configured earlier is applied to the loaded model.
tokenizer = AutoTokenizer.from_pretrained(model_name, truncation=True, truncation_side = "left")
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_value = tokenizer.eos_token_id
model = AutoModelForCausalLM.from_pretrained(model_name, quantization_config=bnb_config)
model = prepare_model_for_kbit_training(model)
The next configuration that I found necessary for successfully fine-tuning gemma-2b
within my time and hardware constraint was reducing the number of trainable parameters. Recent work has shown that even training a very small percentage of pretrained parameters can be beneficial. That is where LoraConfig
comes in.
LoRa technique freezes values of all the parameters in the pretrained model and introduces a pair of matrices that can be trained into each layer of the Transformer to be trained instead.
peft_config = LoraConfig(
r=8,
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM",
target_modules=['q_proj', 'up_proj', 'v_proj', 'k_proj', 'o_proj', 'down_proj', 'gate_proj']
)
model = get_peft_model(model, peft_config)
Then, I set the training arguments. Due to the time constraints, I experimented with only 3 different learning rates to choose the best one, and kept the rest of the hyperparameters fixed. Since the fine-tuning data is relatively small, I decided to only do a small number of training steps.
training_arguments = TrainingArguments(
output_dir="./results",
do_eval=True,
per_device_train_batch_size=1,
per_device_eval_batch_size=1,
gradient_accumulation_steps=4,
optim="paged_adamw_8bit",
save_steps=100,
logging_steps=100,
learning_rate=1e-5,
eval_steps=100,
num_train_epochs=1,
warmup_ratio=0.02,
lr_scheduler_type="linear",
load_best_model_at_end=True,
save_strategy="steps",
evaluation_strategy="steps"
)
Next, I set up the fine-tuning trainer so that it uses the training_arguments
defined above.
trainer = SFTTrainer(
model=model,
train_dataset=train_dataset,
eval_dataset=test_dataset,
dataset_text_field="text",
tokenizer=tokenizer,
args=training_arguments,
peft_config=peft_config,
max_seq_length=512
)
Measuring the Preplexity before Fine-tuning
Before fine-tuning, I will measure the preplexity of the original (non-fine-tuned) model on the legal test set.
import math
eval_results = trainer.evaluate()
print(f"Perplexity: {math.exp(eval_results['eval_loss']):.2f}")
Perplexity: 67.86
Let's keep the model's answers to 100 randomly chosen samples from the test set. I will use them to manually evaluate the models' responses.
from random import choices
subset_test_dataset = list(test_dataset.select(choices(range(len(test_dataset)), k=100)))
with open('ground_truth_subset_test.txt', 'w') as f:
for i, entry in enumerate(subset_test_dataset):
f.write(str(i) + "\n")
f.write(entry['text'])
f.write('\n====================\n')
with open('original_subset_test.txt', 'w') as f:
for i, entry in enumerate(subset_test_dataset):
text = entry["text"][:entry["text"].find('\nAnswer: ')] + '\nAnswer: '
inputs = tokenizer(text, return_tensors="pt", add_special_tokens=True)
outputs = model.generate(**inputs, max_new_tokens=20, pad_token_id=tokenizer.eos_token_id)
result = tokenizer.decode(outputs[0], skip_special_tokens=True)
f.write(str(i) + "\n")
f.write(result)
f.write('\n====================\n')
Fine-tuning
Now, we can finally start the fine-tuning using the configurations and hyperparameters set earlier.
torch.cuda.empty_cache()
model.config.use_cache = False
trainer.train()
model.config.use_cache = True
Save the Model
trainer.model.save_pretrained('legal_'+model_name)
Evaluating the Preplexity of the New Model
Let's evaluate the preplexity of the fine-tuned model.
new_eval_results = trainer.evaluate()
print(f"Perplexity: {math.exp(new_eval_results['eval_loss']):.2f}")
assert new_eval_results['eval_loss'] < eval_results['eval_loss']
Output:
Perplexity: 4.56
As expected (and desired!), the preplexity of the fine-tuned model is lower than that of the original model.
Manual Evaluation
Let's record the model's answers to a subset of the test set.
with open('fine-tuned_subset_test.txt', 'w') as f:
for i, entry in enumerate(subset_test_dataset):
text = entry["text"][:entry["text"].find('\nAnswer: ')] + '\nAnswer: '
inputs = tokenizer(text, return_tensors="pt", add_special_tokens=True)
outputs = model.generate(**inputs, max_new_tokens=20)
result = tokenizer.decode(outputs[0], skip_special_tokens=True)
f.write(str(i) + "\n")
f.write(result)
f.write('\n====================\n')
I used the ground truth, the answers generated by the original model, and the answers generated by the fine-tuned model to perform a manual evaluation on a subset of the test set. The original model, answers the given question correctly in 36% of the test cases whereas the fine-tuned model correctly answers 62% of the test cases. You can find the model's behavior for individual samples in a separate file. What surprises me is that this significant improvement is a result of very little amount of effort.
Conclusion and Considerations
As we saw in this simple exercise, fine-tuning an LLM can significantly boost its performance on domain-specific tasks. Keep in mind that:- The results are only as good as the dataset used.
- Legal language changes; so we will need to update the model over time.
- We should be mindful of the potential biases and fairness issues in the model's output.