SFT on single-host TPUs#

Supervised fine-tuning (SFT) is a process where a pre-trained large language model is fine-tuned on a labeled dataset to adapt the model to perform better on specific tasks.

This tutorial demonstrates step-by-step instructions for setting up the environment and then training the model on a Hugging Face dataset using SFT.

We use Tunix, a JAX-based library designed for post-training tasks, to perform SFT.

In this tutorial we use a single host TPU VM such as v6e-8/v5p-8. Let’s get started!

Install MaxText and Post-Training dependencies#

For instructions on installing MaxText with post-training dependencies on your VM, please refer to the official documentation and use the maxtext[tpu-post-train] installation path to include all necessary post-training dependencies.

Note: If you have previously installed MaxText with a different option (e.g., maxtext[tpu]), we strongly recommend using a fresh virtual environment for maxtext[tpu-post-train] to avoid potential library version conflicts.

Setup environment variables#

Login to Hugging Face. Provide your access token when prompted:

hf auth login

Set up the following environment variables to configure your training run. Replace placeholders with your actual values.

# -- Model configuration --
# The MaxText model name. See `src/maxtext/configs/types.py` for `ModelName` for a
# full list of supported models.
export MODEL=<MaxText Model> # e.g., 'llama3.1-8b-Instruct'

# -- MaxText configuration --
# Use a GCS bucket you own to store logs and checkpoints. Ideally in the same
# region as your TPUs to minimize latency and costs.
# You can list your buckets and their locations in the
# [Cloud Console](https://console.cloud.google.com/storage/browser).
export BASE_OUTPUT_DIRECTORY=<gcs bucket path> # e.g., gs://my-bucket/maxtext-runs

# An arbitrary string to identify this specific run.
# We recommend to include the model, user, and timestamp.
# Note: Kubernetes requires workload names to be valid DNS labels (lowercase, no underscores or periods).
export RUN_NAME=<Name for this run>

export STEPS=<number of fine-tuning steps to run> # e.g., 1000
export PER_DEVICE_BATCH_SIZE=<batch size per device> # e.g., 1

# -- Dataset configuration --
export DATASET_NAME=<Hugging Face dataset name> # e.g., HuggingFaceH4/ultrachat_200k
export TRAIN_SPLIT=<data split for train> # e.g., train_sft
export TRAIN_DATA_COLUMNS=<data columns to train on> # e.g., ['messages']

Get your model checkpoint#

This section explains how to prepare your model checkpoint for use with MaxText. You have two options: using an existing MaxText checkpoint or converting a Hugging Face checkpoint.

Option 1: Using an existing MaxText checkpoint#

If you already have a MaxText-compatible model checkpoint, simply set the following environment variable and move on to the next section.

export MAXTEXT_CKPT_PATH=<gcs path for MaxText checkpoint> # e.g., gs://my-bucket/my-model-checkpoint/0/items

Option 2: Converting a Hugging Face checkpoint#

Refer the steps in Hugging Face to MaxText to convert a hugging face checkpoint to MaxText. Make sure you have correct checkpoint files converted and saved. Similar as Option 1, you can set the following environment and move on.

export MAXTEXT_CKPT_PATH=<gcs path for MaxText checkpoint> # e.g., gs://my-bucket/my-model-checkpoint/0/items

Run SFT on Hugging Face Dataset#

Now you are ready to run SFT using the following command:

python3 -m maxtext.trainers.post_train.sft.train_sft \
    run_name=${RUN_NAME?} \
    base_output_directory=${BASE_OUTPUT_DIRECTORY?} \
    model_name=${MODEL?} \
    load_parameters_path=${MAXTEXT_CKPT_PATH?} \
    per_device_batch_size=${PER_DEVICE_BATCH_SIZE?} \
    steps=${STEPS?} \
    hf_path=${DATASET_NAME?} \
    train_split=${TRAIN_SPLIT?} \
    train_data_columns=${TRAIN_DATA_COLUMNS?} \
    profiler=xplane

Your fine-tuned model checkpoints will be saved here: $BASE_OUTPUT_DIRECTORY/$RUN_NAME/checkpoints.