A clean PyTorch implementation of the paper "Mean Flows for One-step Generative Modeling" by Geng et al, with on-the-fly FID evaluation.
Our goal is to provide a straightforward and clean PyTorch implementation of Mean Flow models for CIFAR-10 and MNIST, such that researchers can conduct experiments with minimal costs.
A good way to start is to play with our Colab notebook. We will walk you through the details of mean flow and train a toy model on MNIST.
- 🧹 Clean implementation - Well-structured PyTorch codebase
- 📊 Real-time FID - Evaluation during training
- ⚡ Optimized - Multi-GPU training support
- 📝 Documented - Comprehensive docstrings and comments
- 🧠 Detailed explanation - A Jupyter notebook that walks you through every detail in Mean Flows.
| Researcher | Affiliation | Contact |
|---|---|---|
| Weijian Luo | Humane Inellegence (hi) Lab, Xiaohongshu Inc && Peking University | 📧 |
| Yifei Wang | Rice University | 📧 |
We welcome your input! Please reach out if you:
- Find any issues running the code
- Have suggestions for improvements
- Want to collaborate on extensions
conda env create -f environment.yml
conda activate easy_meanflow
git clone https://github.com/pkulwj1994/easy_meanflow.git
cd easy_meanflow
We prepared our dataset following the instructions in StyleGAN.
CIFAR10 dataset can be simply downloaded through
wget https://huggingface.co/datasets/william94/useful_public_data/resolve/main/cifar10-32x32.zip
To calculate FID score, you will also need to compare the generated images against the same dataset that the model was originally trained with. To facilitate evaluation, use the exact reference statistics of EDM, which can be found at: https://nvlabs-fi-cdn.nvidia.com/edm/fid-refs/.
A good way to start is to play with our Colab notebook. We will walk you through the details of mean flow and train a toy model on MNIST.
After that, if you want to train a meanflow model on CIFAR10, simply run:
sh ./exps/MF00/train_script.sh
or directly run
export PYTORCH_ENABLE_FUNC_IMPL=1 && \
export PYTORCH_DDP_NO_REBUILD_BUCKETS=1 && \
export TORCH_NCCL_IB_TIMEOUT=23 && \
export NCCL_TIMEOUT=3600 && \
export SETUPTOOLS_USE_DISTUTILS=local && \
torchrun --standalone --nproc_per_node=8 train_mf.py \
--detach_tgt=1 \
--outdir=logs/mf/MF00 \
--data=cifar10-32x32.zip \
--cond=0 --arch=ddpmpp --lr 10e-4 --batch 8
FID score is computed on the fly.
We also provide scripts for computing Fréchet inception distance (FID), simply run:
sh cal_fid.sh
We are thankful to the authors of the meanflow, as well as their Jax implementation.
We extend our gratitude to the authors of the EDM paper for sharing their code, which served as the foundational framework for developing this repository. The repository can be found here: NVlabs/edm. We also refer to some basic logics of the Diff-Instruct repo pkulwj1994/diff_instruct. Additionally, we thank Deepseek for helping us resolve some DDP bugs.

