Blockwise SFT for Diffusion Language Models: Reconciling Bidirectional Attention and Autoregressive Decoding
A lightweight recipe to (1) preprocess an instruction dataset into a compact .pt file and (2) fine-tune with either blockwise masked SFT (discrete-diffusion-friendly) or full-sequence SFT using LoRA + DeepSpeed.
-
pt.py– preprocesses MetaMathQA intoMathQA_256.ptcontaining token IDs and prompt lengths ready for training. -
train_blockwise.py– Random-Block Masked SFT with prefix/suffix ablations and automatic one-block-per-sample sampling. -
train_fullseq.py– Full-sequence SFT with random masking over the entire answer span. -
Both trainers:
- No validation loop (by design).
- Write
(step, loss)to CSV for easy plotting. - Save fully merged Hugging Face checkpoints (LoRA merged into the base).
Model & dataset defaults
- Base model:
GSAI-ML/LLaDA-8B-Instruct- Dataset:
meta-math/MetaMathQA(trainsplit)
# Python 3.10+ recommended
pip install -U torch deepspeed transformers peft datasets tqdm
# (Optional) if you use bf16 on Ampere+ GPUs
# Ensure your CUDA/driver stack supports bf16Tip: If you plan to push/download models from Hugging Face, make sure your environment has access (e.g.,
huggingface-cli login).
pt.py renders conversations using the model’s chat template, computes
prompt_len (token count of the user-only prompt), and stores
full_ids (prompt + answer) as token IDs. Samples longer than --max_len are
skipped.
python pt.py \
--output MathQA_256.pt \
--model GSAI-ML/LLaDA-8B-Instruct \
--dataset meta-math/MetaMathQA \
--split train \
--max_len 256 \
--max_samples 20000 \
--filter_type GSMArguments (most used)
--output(str): output path for the.ptfile. Default:MathQA_256.pt--model(str): tokenizer/model for the chat template. Default:GSAI-ML/LLaDA-8B-Instruct--dataset(str): HF dataset ID. Default:meta-math/MetaMathQA--split(str): dataset split. Default:train--max_len(int): drop samples withlen(full_ids) > max_len. Default:256--max_samples(int): cap total kept samples (0 = unlimited). Default:20000--filter_type(str): skip examples whoseex["type"]contains this substring (empty string disables). Default:GSM
The file is a list of dictionaries, saved with torch.save:
[
{"prompt_len": <int>, "full_ids": torch.LongTensor([...])},
...
]prompt_len≤len(full_ids)full_idsalready matches the tokenizer used during preprocessing
Important: The tokenizer’s chat template may append small trailing text. The script applies a conservative workaround (strip 47 tail characters) to match the training pipeline. Adjust if you change models/templates.
Both trainers expect the .pt file (default: MathQA_256.pt).
They create a DeepSpeed config file (ds_zero2.json) automatically.
We recommend using the DeepSpeed launcher:
deepspeed --num_gpus 8 train_blockwise.py
# or
deepspeed --num_gpus 8 train_fullseq.pyFor single-GPU experiments you can also try:
deepspeed --num_gpus 1 train_blockwise.pyThe scripts call
deepspeed.init_distributed()internally, so launching with the DeepSpeed/Torch distributed runner is preferred.
What it does
- Splits the answer into fixed-size blocks (
BLOCK_SIZE). - At each training step, picks one block per sample and masks it with a
random rate
p ~ U(eps, 1). Only the active block contributes to the loss. - Prefix ablation (
PREFIX_MASK_RATE): randomly mask the answer prefix (between prompt end and active block), excluded from loss. - Suffix ablation (
SUFFIX_MASK_PROB): randomly mask tokens after the active block, excluded from loss.1.0reproduces “mask all suffix”.
Key knobs (edit in the script)
BLOCK_SIZE(default:32)PREFIX_MASK_RATE(default:0.0)SUFFIX_MASK_PROB(default:1.0)MICRO_BSIZE,GRAD_ACC,LRTOTAL_TRAIN_STEPS(computed fromMAX_SAMPLES,EPOCHS,MICRO_BSIZE)SAVE_EVERY(defaults to saving once at the end; lower it to save more often)CKPT_DIR(checkpoint root directory)
Artifacts
- Checkpoints:
ckpts_stagewise_sft_32_stepXXXX/(directory name is uniquified if it already exists) - Logs:
ckpts_stagewise_sft_32_logs/train_loss.csv - Ablation config snapshot:
ckpts_stagewise_sft_32_logs/ablation_config.json
Merged checkpoints: Saving uses LoRA-merged weights (full base + LoRA), serialized in safetensors, so disk usage is similar to a full model.
What it does
- Randomly masks the whole answer span each step with a rate
p ~ U(eps, 1); masked tokens contribute to the loss.
Key knobs (edit in the script)
MICRO_BSIZE,GRAD_ACC,LRTOTAL_TRAIN_STEPSandSAVE_EVERY(same semantics as above)CKPT_DIR(default:ckpts_fullseq_sft)
Artifacts
- Checkpoints:
ckpts_fullseq_sft_stepXXXX/ - Logs:
ckpts_fullseq_sft_logs/train_loss.csv
- Seeds: set via
set_seed(SEED)(defaultSEED=42). - Loss curve:
(step, loss)CSV under*_logs/train_loss.csv. - Progress:
tqdmprogress bar on rank 0. - Early stop (blockwise only): if every
(sample, block)has been used, training stops early.
-
Memory / OOM
- Decrease
MICRO_BSIZE, increaseGRAD_ACC. - BF16 is enabled by default in DeepSpeed config; ensure your GPU supports it.
- Decrease
-
Save frequency
- By default
SAVE_EVERY == TOTAL_TRAIN_STEPS(save at the end). - Lower
SAVE_EVERYto get mid-training snapshots.
- By default
-
Tokenizer alignment
- Keep the same tokenizer for preprocessing and training.
- If you switch models/templates, regenerate the
.ptfile.
-
MASK_ID
- Both trainers assume
MASK_ID=126336(the special<mask>used in the dataset/pipeline). Adjust consistently if your tokenizer uses a different id.
- Both trainers assume
# 1) Preprocess (creates MathQA_256.pt)
python pt.py --output MathQA_256.pt --max_len 256 --max_samples 20000
# 2) Train (blockwise)
deepspeed --num_gpus 8 train_blockwise.py
# or full-seq SFT
deepspeed --num_gpus 8 train_fullseq.pyAfter training, you’ll find merged Hugging Face checkpoints under
ckpts_*_stepXXXX/. You can load them with:
from transformers import AutoModelForCausalLM, AutoTokenizer
tok = AutoTokenizer.from_pretrained("path/to/ckpts_stagewise_sft_32_step0001")
mdl = AutoModelForCausalLM.from_pretrained("path/to/ckpts_stagewise_sft_32_step0001", torch_dtype="auto")-
Code is intended for research use. Check the licenses of:
- The base model
GSAI-ML/LLaDA-8B-Instruct. - The dataset
meta-math/MetaMathQA.
- The base model
-
Respect dataset and model terms when redistributing derivatives.