Train Your Own ChatGPT-like LLM with FlanT5 and Replicate

Written by shanglun | Published Invalid Date
Tech Story Tags: ai | machine-learning | llm | chatgpt | flant5 | replicate | how-to-train-your-own-llm | hackernoon-top-story

TLDRWe train an open-source LLM to distinguish between William Shakespeare and Anton Chekhov. A proof of concept for natural language classifiers based on a small, cost-efficient but powerful competitor to ChatGPT. via the TL;DR App

With the growth of LLM models like ChatGPT, there has been a rush by firms to commercialize language-based deep learning applications. Companies like Duolingo and Blinkist are building educational chat applications, firms like Cocounsel are building document analysis models, and some, like MedGPT, are even building specialist models that can do things like medical diagnosis. In a previous article, I wrote about how someone can use ChatGPT and prompt engineering to build a power document analyzer.

In order to support more powerful and domain-specific LLM applications, technology providers have made many cloud solutions available. OpenAI, the company behind ChatGPT, for example, has made a simple but powerful fine-tuning API available to users, allowing users to build their own language models based on GPT3 technology. Google, not to be outdone, made their bison-text model, widely considered a capable competitor to GPT 3 and GPT 3.5, available for fine-tuning through the Google Cloud platform. In a previous article, I wrote about how to use the fine-tuning API to create a domain expert LLM.

As powerful as these services can be, a company considering a serious investment in LLM technology will want to learn to train their own models from open-source technologies. Compared to using these vendor-provided endpoints, training your own model gives the following advantages:

  • You gain the flexibility of choosing and changing your deployment infrastructure. This can lead to cost savings, closer integration, and perhaps most importantly in medical and financial applications, more privacy.
  • You get more control over the underlying technology, giving you a choice of open-source models to use. Different open source models are built with different use cases in mind, and you can choose the best tool for the job.
  • Your applications become more future-proof. By using open source technology, you can set your own pace of development. You can use the state-of-the-art technology, and you won’t have to worry about vendor deprecations and service outages.

In this article, we will take a popular and capable open-source LLM model, train it on our own data similar to what we did in a previous article, and validate the results. While the example we’re tackling is non-commercial and based on public information, the techniques can be easily cross-applied to commercial endeavors. We will delve into specific suggestions on what commercial applications can be built using this technique in the “Expert LLM Model” section where we define the problem we’ll be solving in this article.

Underlying Technologies

Flan-T5

For today’s experiment, we will be relying on Flan-T5 Large, which is a large language model released by Google. While this is not the technology that underlies Bard, this model is widely considered to be competitive with GPT-based technologies. What’s impressive about the Flan T5 models, however, is that they achieve satisfactory results using far fewer parameters than GPT based models. Even the XL version of the model, for example, only has 3 billion parameters, compared to GPT3, which has 175 billion.

As a result of this compactness, it is relatively cheap to train and store these models on cloud computing assets. Additionally, the Flan-T5 family of models are released with the Apache license, which allows for commercial use, reducing potential license headaches that accompany some of the other open source LLMs. Facebooks’ LLaMa, for example, is still available only for research and non-commercial purposes.

To write this article, I experimented with a few different classes of tasks to test the effectiveness of the technology. Generally, Flan-T5, especially the XL variant, seems to have great natural language understanding capabilities similar to some of the GPT models on the market. However, the model falls short somewhat when drawing abstract connections, and has some trouble generating long outputs. Therefore, one should take care to select the right model for the right task.

Replicate

Replicate is a Platform-as-a-Service company that allows people to rent GPUs for training and running large AI models at an affordable price. Their suite of AI model management tools allows users to focus on working with data instead of managing server resources.

In order to write this article, I tried several AI training PaaS offerings including AWS SageMaker, Google Colab, and PaperSpace Gradient. Replicate was by far the easiest platform to get started with, and offered very competitive pricing relative to the other services mentioned.

Python

Python is the lingua franca of data engineering. The extensive ecosystem allows programmers to quickly ingest, analyze, and process data. Most major AI training platforms have first-class support for Python, which makes our job much easier. Because of Replicate’s excellent integration, we will be writing all of our code today in Python.

Expert LLM Model

Playwright Classifier

Because the Flan-T5 family of models is much better at understanding text than generating text, we want to choose a task that is heavy on input but light on output. Natural language classifiers are a perfect use-case for this type of scenario, so today we will be building a playwright identifier. Specifically, we will be giving the model passages from either William Shakespeare or Anton Chekhov, and see if we can teach the model to identify the playwright based on the writing style and word choice.

Of course, because this is a public tutorial, we are intentionally choosing a model with public and easily accessible data. However, this can easily be adapted to a commercial context. Here are some examples where natural language classifiers may be useful:

  • Sorting customer reviews and complaints into various categories such as shipping problems, product quality, customer service, etc.
  • Performing sentiment analysis on a sales call transcript to see if the prospect had a change in their mood during the call.
  • Analyzing a large number of earnings calls transcripts to determine if CEOs are generally bullish or bearish.

Building the Training Data

To create the training data, we can download some Anton Chekhov and William Shakespeare plays from Project Gutenberg. To set up the data ingestion, we can run the following Python script.

import requests
import openai
import replicate
import os
import pandas as pd
import random

texts = {
    'chekhov': 'https://www.gutenberg.org/files/7986/7986-0.txt',
    'chekhov_2': 'https://www.gutenberg.org/cache/epub/1755/pg1755.txt',
    'shakespeare_midsummer': 'https://www.gutenberg.org/cache/epub/1514/pg1514.txt',
    'shakespeare_romeo_juliet': 'https://www.gutenberg.org/cache/epub/1112/pg1112.txt',
    'shakespeare_macbeth': 'https://www.gutenberg.org/cache/epub/2264/pg2264.txt',
    'shakespeare_hamlet': 'https://www.gutenberg.org/cache/epub/2265/pg2265.txt',
}

Now we create the training data folder and download the texts:

if not os.path.exists('training_text'):
    os.mkdir('training_text')
for name, url in texts.items():
    print(name)
    res = requests.get(url)
    with open(os.path.join('training_text', '%s.txt' % name), 'w') as fp_write:
        fp_write.write(res.text)

You should see some outputs like this to show you it succeeded:

chekhov
chekhov_2
shakespeare_midsummer
shakespeare_romeo_juliet
shakespeare_macbeth
Shakespeare_hamlet

You can also check the training_text folder to see that the files were properly downloaded.

Now we want to read these files back into memory and split them into a list of lines. While we’re at it, we’ll count the number of lines in each file.

lines_by_file = {}
for fn in os.listdir('training_text'):
    if not fn.endswith('.txt'):
        continue
    with open(os.path.join('training_text', fn)) as fp_file:
        lines_by_file[fn.split('.')[0]] = '\n'.join(fp_file.readlines())
        print(fn, len(lines_by_file[fn.split('.')[0]]))

You should see output like the below:

shakespeare_midsummer.txt 120198
shakespeare_romeo_juliet.txt 179726
shakespeare_macbeth.txt 140022
shakespeare_hamlet.txt 204169
chekhov.txt 419063
chekhov_2.txt 148324

Now comes the fun part. We want to split the lines into real training data. To do so, we first remove the first and last 1000 lines, which is taken up by introduction, header and footer content. Then, we will grab the remaining text 50 lines at a time. We will then turn the 50 lines into a prompt-and-completion pair.

train_data = []
for k in lines_by_file:
    is_chekhov = 'chekhov' in k
    useful_lines = lines_by_file[k].split('\n')[1000:-1000]
    
    prompt_fmt = "Which playwright wrote the following passage? \n ==== \n %s \n ====" 
    for i in range(0, len(useful_lines), 50):
        training_set = useful_lines[i: i+50]
        train_data.append({
            'prompt': prompt_fmt % '\n'.join(training_set),
            'completion': 'Anton Chekhov' if is_chekhov else 'William Shakespeare'
        })

So now we have clearly defined the problem - given 50 lines of text from a play, determine if the playwright is Anton Chekov or William Shakespeare. We’re not done just yet. We need to write the data into jsonl (JSON Lines) format for training, and we also want to reserve a few samples for testing purposes. Run the following code like so:

df = pd.DataFrame(train_data)
df_chekhov = df[df['completion'] == 'Anton Chekhov']
df_shakespeare = df[df['completion'] == 'William Shakespeare']
chekhov_test_indices = random.sample(df_chekhov.index.tolist(), 15)
shakespeare_test_indices = random.sample(df_shakespeare.index.tolist(), 15)
df_chekhov_test = df_chekhov.loc[chekhov_test_indices]
df_shakespeare_test = df_shakespeare.loc[shakespeare_test_indices]
df_chekhov_train = df_chekhov.loc[[i for i in df_chekhov.index if i not in chekhov_test_indices]]
df_shakespeare_train = df_shakespeare.loc[[i for i in df_shakespeare.index if i not in shakespeare_test_indices]]

pd.concat([df_chekhov_train, df_shakespeare_train]).to_json('chekhov_shakespeare_train.jsonl', orient='records', lines=True)
pd.concat([df_chekhov_test, df_shakespeare_test]).to_json('chekhov_shakespeare_test.jsonl', orient='records', lines=True)

Of course, if you want to use the entire corpus for training, you can simply run

pd.DataFrame(train_data).to_json(‘output.jsonl’, orient=’records’, lines=True).

Training with Replicate

We need to do two things before we can invoke the training - first, we need to upload the training data onto somewhere accessible by replicate. One very easy way to do this would be to upload the file to a google cloud bucket, make the bucket and the file public, and supply the url in the format of https://storage.googleapis.com/<bucket_name>/<file_name>.

Next, we need to create a destination. To do this, simply log into Replicate (which you can do through Github OAuth), and create a new model. Once the model is created and named, you will be able to push your model to this space.

Once everything is set up, you can kick off the training like so:

training = replicate.trainings.create(
  version="[flant5-large location]",
  input={
    "train_data": "[Data Location]",
  },
  destination="[Destination]"
)

print(training)

You should see some output that tells you the training is starting. Wait a few minutes and check back with the training by running the following code:

training.reload()
print(training)

You can also monitor training progress on the Replicate website. Once the training is done, you can reload the training object to get the output name and proceed to the next step.

Be warned that there are time periods when GPU resources are scarce, and you may get a “training failed” error. If this happens to you, just wait a few hours and try again. There is a GPU shortage, and PaaS providers are not immune!

Testing the Model

All right! Now that we have our fine-tuned model, we need to test it. Remember that we reserved 15 Chekhov and Shakespeare passages for testing. We can use them here like so:

for _, row in df_chekhov_test.iterrows():
    output = replicate.run(
      training.output["version"],
      input={"prompt": row['prompt']}
    )
    for s in output:
        print(s, end="", flush=True)
    print('')

After a short start-up period, you should see the output being printed to the console. The model should be extremely accurate and return “Anton Chekhov” each time. Let’s try this with Shakespeare:

for _, row in df_shakespeare_test.iterrows():
    output = replicate.run(
      training.output["version"],
      input={"prompt": row['prompt']}
    )
    for s in output:
        print(s, end="", flush=True)
    print('')

Similar to the Chekhov example, you should see that the model is able to identify Shakespeare every time.

For good measure, let’s see if the base model is able to identify Shakespeare or Chekhov:

for _, row in df_shakespeare_test.iterrows():
    output = replicate.run(
      "[base flant5-large location]",
      input={"prompt": row['prompt']}
    )
    for s in output:
        print(s, end="", flush=True)
    print('')

for _, row in df_chekhov_test.iterrows():
    output = replicate.run(
      "[base flant5-large location]",
      input={"prompt": row['prompt']}
    )
    for s in output:
        print(s, end="", flush=True)
    print('')

You should see that the base model is unable to reliably identify the playwright for the same passages. This shows that our fine tuning reliably gave the model new information, and we have built ourselves a natural language playwright classifier!

Conclusion

In today’s article, we trained a simple natural language classifier based on Flan-T5, a large language model provided by Google. Because of its compact size and permissive license, Flan-T5 can be trained and deployed on private infrastructure, which sets it apart from many of the other popular models on the market, such as ChatGPT.

While the example today was based on public data and was decidedly non-commercial, this proof of concept can be easily adapted to many other commercial applications as outlined above. If you have an idea with LLMs that you like to see turned into reality, feel free to start a conversation by visiting my GitHub or LinkedIn page. Also, feel free to read my previous LLM articles including one about Building a Document Analyzer using ChatGPT, and Creating a Domain Expert LLM using OpenAI’s Fine Tuning API.

Happy hacking!


Written by shanglun | Quant, technologist, occasional economist, cat lover, and tango organizer.
Published by HackerNoon on Invalid Date