Skip to content

14H034160212/gptcoaching_mi_training

Repository files navigation

GPTCoach MI Training Starter

This repo scaffolds data preparation + SFT + DPO training for a multi-turn Motivational Interviewing (MI) style coach model using Hugging Face Transformers + TRL.

What you get

  • data/ — schema + a tiny synthetic example dataset (JSONL) to verify the pipeline.
  • scripts/data_prep.py — normalize MI datasets (MI-TAGS / AnnoMI / MI-Dataset) into a common JSONL format.
  • scripts/sft_train.py — Supervised fine-tuning with TRL SFTTrainer. Supports LoRA + 4-bit.
  • scripts/dpo_train.py — Preference optimization with TRL DPOTrainer.
  • scripts/infer_demo.py — Run inference with memory + MI persona prompt.
  • scripts/metrics_mi.py — Simple MI-style behavioral metrics (coverage of open-question/reflect/affirm, etc.).
  • configs/*.yaml — Example hyper-parameters.

⚠️ You must provide your own dataset paths for MI-TAGS / AnnoMI etc. The included example_mi_dialogs.jsonl is just for smoke tests (not for real training).

Conda Environment Setup

cd /mnt
wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh
bash Miniconda3-latest-Linux-x86_64.sh -b -p /mnt/miniconda3
echo 'export PATH="/mnt/miniconda3/bin:$PATH"' >> ~/.bashrc
source ~/.bashrc
conda create -n gptcoach python=3.10
conda activate gptcoach
pip install -r /mnt/gptcoaching_mi_training/requirements.txt

Install

pip install -U transformers datasets accelerate trl peft bitsandbytes torch torchvision torchaudio
# If CUDA is not available, install CPU wheels for torch.

Data format (common JSONL)

Each line is a dict:

{
  "dialog_id": "string",
  "turn_id": 7,
  "user_utt": "string",
  "coach_utt": "string",
  "mi_tags": ["open_question","reflection_simple","affirm"],
  "state_before": {},   # optional structured state (goal, barriers, wearable stats, etc.)
  "state_after":  {}    # optional updated state
}

You can include additional fields; unknown keys are ignored by the loader.

Prepare data

Set your dataset file paths and run:

Using AnnoMI-full to generate training set.

python scripts/data_prep.py  --annomi_csv ./data/AnnoMI-full.csv --out_jsonl data/mi_unified_from_annomi_full.jsonl

Using AnnoMI-simple to generate validation set.

python scripts/data_prep.py  --annomi_csv ./data/AnnoMI-simple.csv --out_jsonl data/mi_unified_from_annomi_simple.jsonl

SFT (supervised fine-tuning)

python scripts/sft_train.py \
  --model_name_or_path Qwen/Qwen2.5-3B-Instruct \
  --train_file data/mi_unified_from_annomi_full.jsonl \
  --eval_file data/mi_unified_from_annomi_simple.jsonl \
  --output_dir outputs/qwen2p5-3b-mi-sft \
  --num_train_epochs 3 \
  --per_device_train_batch_size 20 \
  --lr 2e-5 \
  --eval_steps 200 \
  --save_steps 200 \
  --wandb --wandb_project mi-coach-sft --wandb_run_name qwen2p5_3b_sft \
  --bnb_4bit --lora

DPO (preference optimization) Dataset Generation

This part will generate both positive and negative samples (MI-style coaching response).

python scripts/make_dpo_prefs_v2.py \
  --sft_file data/mi_unified_from_annomi_full.jsonl \
  --out_file data/mi_prefs.jsonl \
  --seed 123 \
  --max_samples 5000

DPO (preference optimization)

Prepare a JSONL with pairs of responses (chosen vs rejected) per context/turn.

python scripts/dpo_train.py \
  --model_name_or_path outputs/qwen2p5-3b-mi-sft/checkpoint-510 \
  --pref_file data/mi_prefs.jsonl \
  --output_dir runs/qwen2p5-3b-mi-dpo \
  --num_train_epochs 3 \
  --per_device_train_batch_size 1 \
  --lr 2e-5 \
  --logging_steps 10 \
  --save_steps 200 \
  --lora --bnb_4bit \
  --wandb --wandb_project mi-coach-dpo --wandb_run_name qwen2p5_3b_dpo

Local demo (FastAPI)

export HF_HOME=/mnt/.cache/huggingface
export TRANSFORMERS_CACHE=/mnt/.cache/huggingface/transformers

# merged model dir from your DPO (or SFT) run
export MODEL_PATH=/mnt/gptcoaching_mi_training/runs/qwen2p5-3b-mi-dpo-merged
uvicorn scripts.app_demo:app --host 0.0.0.0 --port 8000 --reload
# POST http://localhost:8000/chat
# body:
# {
#   "history": [{"user":"I want to be more active.","coach":"What matters most about being active for you?"}],
#   "user_msg":"I'm too busy this week."
# }

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published