Fine-Tuning a Language Model for Legal Question Answering March, 2024

The legal domain is filled with complex terminology, nuanced interpretations, and the need for precise answers. This makes it a perfect challenge for the power of natural language processing (NLP). However, generic language models often struggle to grasp the subtleties of legal text. To satisfy my curiosity, in this blog post, I will dive into the process of fine-tuning a large language model (LLM) specifically for the task of legal question answering. You can find the raw dataset and the complete notebook in legal-gemma repository.

The Toolkit

  • Google Colab: The experimentation playground for GPU resources.
  • LegalBench: A curated dataset of legal questions and answers.
  • HuggingFace: Its libraries streamline working with LLMs.
  • BitsAndBytes: For memory-efficient quantization.
  • PEFT and LORA: Techniques to selectively fine-tune parts of the model.

The Dataset: LegalBench

Before diving into code, let's understand the structure of LegalBench. According to the owners of the repository, LegalBench is "a benchmark consisting of different legal reasoning tasks. Each task has an associated dataset, consisting of input-output pairs."

Step-by-Step Guide

Project Setup

First, head to Google Colab and start a new notebook. Next, make sure you have all the following installed on the notebook:


!pip install --upgrade pandas datasets transformers bitsandbytes peft trl huggingface_hub

Load and Transform the Data

I do not use the LegalBench dataset in the originally provided format but instead transform it from a multi-task, multi-dataset form to a question/answering dataset where the answers could be Yes/No, multiple-choice, or open-ended. You can find the [long and ugly] transformation code here. Basically, it takes one row from the original dataset and prepares a context, a question, and answer depending on the task type. I decided to move the data transformation code to a separate Python file to improve code readability.

So, in the implementation, I start by reading the provided dataset and transforming it using the data transformation that I built.


  from Data import transform_data
  import pandas as pd

  df = pd.read_json('raw_data_sample.json')

  data = [transform_data(row) for row in df.itertuples()]
  indexes, contexts, questions, answers = zip(*data)

  assert len(contexts) == len(questions) == len(answers) and len(contexts) > 0
  input_texts = [{"text": f"Answer the Question based on the given Context.\nContext: {c}\nQuestion: {q}\nAnswer: {a}"} for c, q, a in zip(contexts, questions, answers)]

Next, I use scikit to split the data into two sets: train (95%), and test (5%).


  from sklearn.model_selection import train_test_split
  from datasets import Dataset

  train_texts, test_texts = train_test_split(input_texts, test_size=0.05, random_state=7)

  train_dataset = Dataset.from_list(train_texts)
  test_dataset = Dataset.from_list(test_texts)

Prepare for Loading the Language Model

Then, I import the necessary libraries for fine-tuning gemma-2b model. I use huggingface for this purpose.


  import torch

  from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments
  from transformers import BitsAndBytesConfig, set_seed

  from peft import LoraConfig, PeftModel, prepare_model_for_kbit_training, get_peft_model

  from trl import SFTTrainer

  set_seed(7)
  model_name = "google/gemma-2b-it"

Since loading gemma-2b requires accepting particular terms and conditions, one needs to accept these terms and login into huggingface hub via an access token. These terms and conditions can be accepted by logging into huggingface. Then, login using your access token.


  from huggingface_hub import login
  login(token='YOUR_TOKEN')

Fine-tuning the Model

The next step is loading the model and the corresponding tokenizer.

Fitting the LLM on my GPU!

Since fine-tuning gemma-2b required more memory than I had available (15GB), I needed to take advantage of quantization to reduce the memory usage. Basically, BitsAndBytes quantization uses lower-precision data types to enable loading larger models.

My configuration loads the linear layers of the model with 4-bit integer precision.


  bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16
  )

Loading the Model

I load the model and its corresponding tokenizer using huggingface. Then, I pass it to prepare_model_for_kbit_training so that the BitsAndBytes quantization that I configured earlier is applied to the loaded model.


  tokenizer = AutoTokenizer.from_pretrained(model_name, truncation=True, truncation_side = "left")
  tokenizer.pad_token = tokenizer.eos_token
  tokenizer.padding_value = tokenizer.eos_token_id
  model = AutoModelForCausalLM.from_pretrained(model_name, quantization_config=bnb_config)
  model = prepare_model_for_kbit_training(model)

The next configuration that I found necessary for successfully fine-tuning gemma-2b within my time and hardware constraint was reducing the number of trainable parameters. Recent work has shown that even training a very small percentage of pretrained parameters can be beneficial. That is where LoraConfig comes in. LoRa technique freezes values of all the parameters in the pretrained model and introduces a pair of matrices that can be trained into each layer of the Transformer to be trained instead.


  peft_config = LoraConfig(
    r=8,
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM",
    target_modules=['q_proj', 'up_proj', 'v_proj', 'k_proj', 'o_proj', 'down_proj', 'gate_proj']
  )
	model = get_peft_model(model, peft_config)

Then, I set the training arguments. Due to the time constraints, I experimented with only 3 different learning rates to choose the best one, and kept the rest of the hyperparameters fixed. Since the fine-tuning data is relatively small, I decided to only do a small number of training steps.


  training_arguments = TrainingArguments(
    output_dir="./results",
    do_eval=True,
    per_device_train_batch_size=1,
    per_device_eval_batch_size=1,
    gradient_accumulation_steps=4,
    optim="paged_adamw_8bit",
    save_steps=100,
    logging_steps=100,
    learning_rate=1e-5,
    eval_steps=100,
    num_train_epochs=1,
    warmup_ratio=0.02,
    lr_scheduler_type="linear",
    load_best_model_at_end=True,
    save_strategy="steps",
    evaluation_strategy="steps"
  )

Next, I set up the fine-tuning trainer so that it uses the training_arguments defined above.


  trainer = SFTTrainer(
    model=model,
    train_dataset=train_dataset,
    eval_dataset=test_dataset,
    dataset_text_field="text",
    tokenizer=tokenizer,
    args=training_arguments,
    peft_config=peft_config,
    max_seq_length=512
  )

Measuring the Preplexity before Fine-tuning

Before fine-tuning, I will measure the preplexity of the original (non-fine-tuned) model on the legal test set.


  import math

  eval_results = trainer.evaluate()
  print(f"Perplexity: {math.exp(eval_results['eval_loss']):.2f}")

  Perplexity: 67.86

Let's keep the model's answers to 100 randomly chosen samples from the test set. I will use them to manually evaluate the models' responses.


  from random import choices

  subset_test_dataset = list(test_dataset.select(choices(range(len(test_dataset)), k=100)))

  with open('ground_truth_subset_test.txt', 'w') as f:
    for i, entry in enumerate(subset_test_dataset):
      f.write(str(i) + "\n")
      f.write(entry['text'])
      f.write('\n====================\n')

  with open('original_subset_test.txt', 'w') as f:
    for i, entry in enumerate(subset_test_dataset):
      text = entry["text"][:entry["text"].find('\nAnswer: ')] + '\nAnswer: '
      inputs = tokenizer(text, return_tensors="pt", add_special_tokens=True)
      outputs = model.generate(**inputs, max_new_tokens=20, pad_token_id=tokenizer.eos_token_id)
      result = tokenizer.decode(outputs[0], skip_special_tokens=True)
      f.write(str(i) + "\n")
      f.write(result)
      f.write('\n====================\n')

Fine-tuning

Now, we can finally start the fine-tuning using the configurations and hyperparameters set earlier.


  torch.cuda.empty_cache()
  model.config.use_cache = False
  trainer.train()
  model.config.use_cache = True

Save the Model


  trainer.model.save_pretrained('legal_'+model_name)

Evaluating the Preplexity of the New Model

Let's evaluate the preplexity of the fine-tuned model.


  new_eval_results = trainer.evaluate()
  print(f"Perplexity: {math.exp(new_eval_results['eval_loss']):.2f}")
  assert new_eval_results['eval_loss'] < eval_results['eval_loss']
Output:

	Perplexity: 4.56

As expected (and desired!), the preplexity of the fine-tuned model is lower than that of the original model.

Manual Evaluation

Let's record the model's answers to a subset of the test set.


  with open('fine-tuned_subset_test.txt', 'w') as f:
    for i, entry in enumerate(subset_test_dataset):
      text = entry["text"][:entry["text"].find('\nAnswer: ')] + '\nAnswer: '
      inputs = tokenizer(text, return_tensors="pt", add_special_tokens=True)
      outputs = model.generate(**inputs, max_new_tokens=20)
      result = tokenizer.decode(outputs[0], skip_special_tokens=True)
      f.write(str(i) + "\n")
      f.write(result)
      f.write('\n====================\n')

I used the ground truth, the answers generated by the original model, and the answers generated by the fine-tuned model to perform a manual evaluation on a subset of the test set. The original model, answers the given question correctly in 36% of the test cases whereas the fine-tuned model correctly answers 62% of the test cases. You can find the model's behavior for individual samples in a separate file. What surprises me is that this significant improvement is a result of very little amount of effort.

Conclusion and Considerations

As we saw in this simple exercise, fine-tuning an LLM can significantly boost its performance on domain-specific tasks. Keep in mind that:
  1. The results are only as good as the dataset used.
  2. Legal language changes; so we will need to update the model over time.
  3. We should be mindful of the potential biases and fairness issues in the model's output.