Skip to content

pkulwj1994/easy_meanflow

Repository files navigation

Easy MeanFlow (Pytorch) 🌊

A clean PyTorch implementation of the paper "Mean Flows for One-step Generative Modeling" by Geng et al, with on-the-fly FID evaluation.

Our goal is to provide a straightforward and clean PyTorch implementation of Mean Flow models for CIFAR-10 and MNIST, such that researchers can conduct experiments with minimal costs.

A good way to start is to play with our Colab notebook. We will walk you through the details of mean flow and train a toy model on MNIST.


🚀 Features

  • 🧹 Clean implementation - Well-structured PyTorch codebase
  • 📊 Real-time FID - Evaluation during training
  • Optimized - Multi-GPU training support
  • 📝 Documented - Comprehensive docstrings and comments
  • 🧠 Detailed explanation - A Jupyter notebook that walks you through every detail in Mean Flows.

👨‍💻 Core Contributors

Researcher Affiliation Contact
Weijian Luo Humane Inellegence (hi) Lab, Xiaohongshu Inc && Peking University 📧
Yifei Wang Rice University 📧

💌 Call for Feedback

We welcome your input! Please reach out if you:

  • Find any issues running the code
  • Have suggestions for improvements
  • Want to collaborate on extensions

Email Email

Environment Setup

conda env create -f environment.yml
conda activate easy_meanflow

git clone https://github.com/pkulwj1994/easy_meanflow.git
cd easy_meanflow

Preparing datasets

We prepared our dataset following the instructions in StyleGAN.

CIFAR10 dataset can be simply downloaded through

wget https://huggingface.co/datasets/william94/useful_public_data/resolve/main/cifar10-32x32.zip

To calculate FID score, you will also need to compare the generated images against the same dataset that the model was originally trained with. To facilitate evaluation, use the exact reference statistics of EDM, which can be found at: https://nvlabs-fi-cdn.nvidia.com/edm/fid-refs/.

Getting started

A good way to start is to play with our Colab notebook. We will walk you through the details of mean flow and train a toy model on MNIST.

After that, if you want to train a meanflow model on CIFAR10, simply run:

sh ./exps/MF00/train_script.sh

or directly run

export PYTORCH_ENABLE_FUNC_IMPL=1 && \
export PYTORCH_DDP_NO_REBUILD_BUCKETS=1 && \
export TORCH_NCCL_IB_TIMEOUT=23 && \
export NCCL_TIMEOUT=3600 && \
export SETUPTOOLS_USE_DISTUTILS=local && \
torchrun --standalone --nproc_per_node=8 train_mf.py \
    --detach_tgt=1 \
    --outdir=logs/mf/MF00 \
    --data=cifar10-32x32.zip \
    --cond=0 --arch=ddpmpp --lr 10e-4 --batch 8

FID score is computed on the fly.

Calculating FID

We also provide scripts for computing Fréchet inception distance (FID), simply run:

sh cal_fid.sh

Results on CIFAR10

CIFAR fid

Acknowledgements

We are thankful to the authors of the meanflow, as well as their Jax implementation.

We extend our gratitude to the authors of the EDM paper for sharing their code, which served as the foundational framework for developing this repository. The repository can be found here: NVlabs/edm. We also refer to some basic logics of the Diff-Instruct repo pkulwj1994/diff_instruct. Additionally, we thank Deepseek for helping us resolve some DDP bugs.

About

A clean Pytorch Implementation of Mean Flow, with FID evaluation on the fly

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published