This is a PyTorch+GPU re-implementation for the CIFAR-10 experiments in Mean Flows for One-step Generative Modeling. The original experiments were done in JAX+TPU.
This repo was tested in PyTorch 2.7.1 and uses torch.compile. Compilation may depend on PyTorch versions.
conda env create -f environment.yml
conda activate meanflow
Run demo.ipynb for a demo of 1-step generation and FID evaluation. This demo should produce <2.9 FID.
Run the script cifar10_v1.sh to train from scratch with 8 GPUs.
It is an improved configuration that can approach ~2.9 FID at 16000 epochs (800k iterations with batch 128x8). It takes 0.21s/iter in 8x H200 GPUs. The checkpoint in demo.ipynb (~2.80 FID) is from this script.
The original configuration used in the paper is in cifar10_v0.sh.
Users may be unfamiliar with the JVP (Jacobian-vector product) operation, which MeanFlow is based on. While JVP is straightforward to implement in JAX, its correct implementation in PyTorch is worth a closer look.
The op torch.func.jvp does not support a DDP (DistributedDataParallel) object. In your code, you may need to replace model with model.module to allow torch.func.jvp to run. However, doing so may bypass the gradient synchronization normally handled by DDP, with no error reported.
In our code, we handle this by synchronize_gradients(model), with a sanity check gradient_sanity_check.
The memory and speed of JVP can greatly benefit from compilation, in both JAX and PyTorch. In our code, this is done by:
compiled_train_step = torch.compile(
train_step,
disable=not args.compile,
)
where train_step is:
def train_step(model_without_ddp, *args, **kwargs):
loss = model_without_ddp.forward_with_loss(*args, **kwargs)
loss.backward(create_graph=False)
return loss
Optionally, we also put update_ema() into train_step for compilation.
If you don't want to compile (for example, some of your ops are not supported), we recommend to compute dudt by torch.func.jvp under torch.no_grad():
u_pred = u_func(z, t, r)
with torch.no_grad():
_, dudt = torch.func.jvp(u_func, (z, t, r), (v, dtdt, drdt))
The function prediction u_pred is computed separately. In this way, computing dudt does not introduce substantial additional memory usage, and its time cost is roughly equivalent to a forward and backward pass. If you want u_func to share the dropout masks, consider backing up rng states by cpu_rng_state = torch.get_rng_state(); cuda_rng_state = torch.cuda.get_rng_state() and restoring by torch.set_rng_state(cpu_rng_state); torch.cuda.set_rng_state(cuda_rng_state) before and after the call of u_func.
This repo is based on the following repos:
See also:
- Our MeanFlow JAX repo with ImageNet experiments.
- A third-party MeanFlow PyTorch repo with reproduced ImageNet results.

