Summary: Decoupled MeanFlow is a converts flow transformers into flow maps for accelerated sampling. Through fine-tuning the flow maps to predict the average velocity, we obtain state-of-the-art 1-step FID=2.16 and FID=2.12 on ImageNet 256 and 512 dataset, respectively. By increasing number of steps, we achieve FID=1.51 and FID=1.68 on ImageNet 256 and 512 dataset, matching the performance of flow models while significantly accelerating the inference.
We provide our DMF models, base SiT-XL/2+REPA models, and dataset statistics (for evaluation) in our Huggingface repository. One can download the model checkpoints and dataset statistics through following:
pip install huggingface_hub
hf download kyungmnlee/DMF --local-dir ckptWhen using Hopper GPUs (e.g., H100, H200, H800), we recommend to use CUDA==12.8 for compatibility with Flash Attention v3. If using Ampere GPUs (e.g., A100), install Flash-Attention v2.
conda create -n dmf python=3.12 -y
conda activate dmf
pip install torch==2.8.0 torchvision==0.23.0
pip install -r requirements.txtTo install Flash Attention v3, we recommend to install from source, following the official github repository of Flash Attention v3.
Currently, we provide experiments for ImageNet. You can place the data that you want and can specifiy it via --data-dir arguments in training scripts. Please refer to our preprocessing guide.
We recommend fine-tuning from flow models, where we also provide model weights and training code below. Here, we provide a skeleton to launch the training. Please see scripts/train_dmf_xl_2_256.sh and scripts/train_dmf_xl_2_512.sh for details.
accelerate launch train_dmf.py \
--model="DMFT-XL/2" \
--attn-func "fa3" \
--dmf-depth 20 \
--g-type "mg" \
--omega 0.6 \
--g-min 0.0 \
--g-max 0.7 \
--P-mean 0.0 \
--P-mean-t 0.4 \
--P-mean-r -1.2Then this script will automatically create the folder in exps to save logs and checkpoints. You can adjust the following options:
--models:[DMFT-B/2, DMFT-L/2, DMFT-XL/2]--attn-func:["fa3", "fa2", "base"].fa3for Flash-Attention v3 with JVP support,fa2for Flash-Attention v2 with JVP support,basefor naive attention.--g-type:["default", "mg", "distill"].mgstands for training with model guidance,defaultfor non-cfg, anddistillfor cfg velocity from teacher flow model. Default choice ismg--dmf-depth: depth for DMF model (usually take 2/3 of total transformer depth)--omega: model guidance scale--g-min,--g-max: guidance interval formg--P-mean,--P-mean-t,--P-mean-r: time proposal distributions--qk-norm: For 512 experiments, we find that adding qk norm during DMF training helps stabilizing the training (which is due to instable JVP computation).
One can generate images through the following code, see scripts/eval_dmf_xl_2_256.sh and scripts/eval_dmf_xl_2_512.sh for details:
torchrun --nnodes=1 --nproc_per_node=8 generate.py \
--model "DMFT-XL/2" \
--num-fid-samples 50000 \
--ckpt YOUR_CHECKPOINT_PATH \
--dmf-depth=20 \
--per-proc-batch-size=64 \
--mode "euler" \
--num-steps 4 \
--shift 1.0--mode:["euler", "restart"]--shift: time distribution shifting (useshift=1.5for ImageNet 512)
After generating images, one can compute FID and FD-DINOv2 using following code below. We provide dataset statistics imgnet256.pkl and imgnet512.pkl in our huggingface repository.
python calculate_metrics.py calc \
--images "SAMPLE_DIR" \
--ref="./ckpt/imgnet256.pkl"We follow the setup in REPA, where we refactorize the code for simplicity. The training scripts can be found in scripts/train_sit_xl_2_256.sh and scripts/train_sit_xl_2_512.sh. We provide pretrained checkpoints in our huggingface repository
It's possible that this code may not accurately replicate the results outlined in the paper due to potential human errors during the preparation and cleaning of the code for release. If you encounter any difficulties in reproducing our findings, please don't hesitate to inform us. Additionally, we'll make an effort to carry out sanity-check experiments in the near future.