Shuchen Xue, Chongjian Ge, Shilong Zhang, Yichen Li, Zhi-Ming Ma
University of Chinese Academy of Sciences · Adobe Research · HKU · MIT
Efficient reinforcement learning for text-to-image diffusion models. This implementation supports Stable Diffusion 3.5 with various reward functions.
# Automated environment setup.
# The script `setup/02_setup_env_reward_server.sh` installs MMCV and MMDetection.
# For GPUs with Ampere architecture (e.g., A100, RTX 30xx),
# please follow the comments in the script to adjust the MMCV/MMDetection installation accordingly.
bash setup_env.shexport HUGGINGFACE_TOKEN="your_hf_token"
export WANDB_API_KEY="your_wandb_key"# Geneval on SD3.5
bash train_awm_scripts_geneval/train_single_node.sh
# PickScore on SD3.5
bash train_awm_scripts_pickscore/train_single_node.sh
# OCR on SD3.5
bash train_awm_scripts_ocr/train_single_node.sh# Master node
bash train_awm_scripts_geneval/train_multi_nodes.sh # node 0
# Worker nodes
bash train_awm_scripts_geneval/train_multi_nodes.sh # node 1
bash train_awm_scripts_geneval/train_multi_nodes.sh # node 2
bash train_awm_scripts_geneval/train_multi_nodes.sh # node 3Configurations in advantage_weighted_matching/config/dgx_awm.py:
geneval_sd3_no_cfg_1node- GenEvalgeneval_sd3_no_cfg_4nodes- GenEvalocr_sd3_no_cfg_1node- OCRocr_sd3_no_cfg_4nodes- OCRpickscore_sd3_no_cfg_1node- PickScorepickscore_sd3_no_cfg_4nodes- PickScore
# Model
config.pretrained.model = "stabilityai/stable-diffusion-3.5-medium" # Backbone diffusion model (SD3.5 Medium)
config.resolution = 512 # Training / sampling resolution
# Dataset & prompting
config.dataset = "dataset/{geneval|ocr|pickscore}" # Dataset used for training and evaluation
config.prompt_fn = "geneval | general_ocr" # Prompt generation function
# Sampling (data collection)
gpu_numbers = 8 # * config.sample.num_batches_per_epoch = constant
config.sample.num_steps = 14 # Number of diffusion steps during rollout
config.sample.eval_num_steps = 40 # Diffusion steps used for evaluation
config.sample.guidance_scale = 1.0 # CFG scale in rollout (1.0 = no CFG; 3.0/4.5 for CFG)
config.sample.train_batch_size = 6 # Prompts per GPU per rollout batch
config.sample.test_batch_size = 6 # Batch size for evaluation sampling
config.sample.num_image_per_prompt = 24 # Images sampled per prompt (group size)
config.sample.num_batches_per_epoch = = 24 * 24 // gpu_numbers // 2 # Rollout batches per epoch
config.sample.noise_level = 0.4 # Noise level during sampling (ODE=0.0; SDE > 0.0)
config.sample.sde_frac = 0.0 # Fraction of SDE steps in DDIM(0 = deterministic ODE)
# 0.5: first 50% stochastic DDIM, remaining 50% DDIM
config.sample.global_std = True # Use global std for advantage normalization
config.train.use_sa_solver = True # Use SA-Solver for rollout; if False: DDIM sampling
# Training (gradient updates)
config.train.batch_size = 1 # for current imple, actual micro-bs is config.train.
# batch_size * config.train.train_timesteps
config.train.gradient_accumulation_steps = config.sample.train_batch_size // config.train.batch_size * config.sample.num_batches_per_epoch # Set to maintain on policy
config.train.num_batches_per_epoch = config.sample.train_batch_size // config.train.batch_size * config.sample.num_batches_per_epoch
config.train.train_timesteps = 6 # Number of diffusion timesteps trained
config.train.learning_rate = 3e-4 # AdamW learning rate
# Advantage / KL control
config.train.beta = 0.003 # KL regularization coefficient
config.train.clip_range = 1 # Advantage clipping epsilon; currently not work (on policy)
config.train.advantage_max = 1 # Max advantage abs value
config.train.kl_weight = "Uniform" # Time weighting for KL term; Uniform or ELBO
# Loss shaping
config.train.loss_type = "exp_first" # Loss formulation; exp_first or sum_first; gradient is equivalent for on-policy
config.train.weighting = "ghuber" # Robust loss weighting scheme
config.train.ghuber_power = 0.25 # Power for generalized Huber weighting
config.train.timestep_fraction = 0.9 # Fraction of timesteps used for loss
# KL-EMA (mimic TRPO)
config.train.kl_ema_weight = "Uniform" # Time weighting for KL-EMA
config.train.kl_ema_decay = 0.3 # EMA decay rate
config.train.kl_ema_decay_type = "linear" # EMA decay schedule
config.train.ema_beta = 1 # Scaling factor for EMA KL
# EMA & optimization tricks
config.train.ema = True # Enable EMA for model weights
config.train.ema_decay = 0.99 # EMA decay rate
config.train.ema_update_step_interval = 1 # EMA update frequency
# Time sampling
config.time_type = "discrete_wo_init" # Time sampling strategy
# Rewards
config.reward_fn = { # Reward function(s) and weights
"geneval": 1.0, # or "ocr": 1.0 / "pickscore": 1.0
}
# Logging & bookkeeping
config.per_prompt_stat_tracking = True # Track per-prompt statistics
config.num_epochs = 100000 # Total training epochs
config.save_freq = 20–50 # Checkpoint saving frequency (epochs)
config.eval_freq = 20–50 # Evaluation frequency (epochs)Training Batch Relationship:
config.train.gradient_accumulation_steps = config.sample.train_batch_size // config.train.batch_size * config.sample.num_batches_per_epoch # Set to maintain on policy
config.train.num_batches_per_epoch = config.sample.train_batch_size // config.train.batch_size * config.sample.num_batches_per_epochOur code is based on DDPO-Pytorch, Flow-GRPO
If you find this repo helpful, please consider citing:
@article{xue2025advantage,
title={Advantage weighted matching: Aligning rl with pretraining in diffusion models},
author={Xue, Shuchen and Ge, Chongjian and Zhang, Shilong and Li, Yichen and Ma, Zhi-Ming},
journal={arXiv preprint arXiv:2509.25050},
year={2025}
}