[NeurIPS 2025] Adaptive Stochastic Coefficients for Accelerating Diffusion Sampling
Official implementation of the NeurIPS 2025 paper
Abstract: Diffusion-based generative processes, formulated as differential equation solving, frequently balance computational speed with sample quality. Our theoretical investigation of ODE- and SDE-based solvers reveals complementary weaknesses: ODE solvers accumulate irreducible gradient error along deterministic trajectories, while SDE methods suffer from amplified discretization errors when the step budget is limited. Building upon this insight, we introduce AdaSDE, a novel single-step SDE solver that aims to unify the efficiency of ODEs with the error resilience of SDEs. Specifically, we introduce a single per-step learnable coefficient, estimated via lightweight distillation, which dynamically regulates the error correction strength to accelerate diffusion sampling. Notably, our framework can be integrated with existing solvers to enhance their capabilities. Extensive experiments demonstrate state-of-the-art performance: at 5 NFE, AdaSDE achieves FID scores of 4.18 on CIFAR-10, 8.05 on FFHQ and 6.96 on LSUN Bedroom.
- This codebase mainly refers to the codebase of EDM. To install the required packages, please refer to the EDM codebase.
- This codebase supports the pre-trained diffusion models from EDM, ADM, Consistency models, LDM and Stable Diffusion. When you want to load the pre-trained diffusion models from these codebases, please refer to the corresponding codebases for package installation.
- Run the commands in launch.sh for training, sampling and evaluation using our recommended configurations.
- All commands support multi-GPU parallelization by adjusting the
--nproc_per_nodeparameter. - Complete parameter descriptions are available in the next section.
- Required models will be automatically downloaded to
"./src/dataset_name" - Default configurations use 2 4090GPU for CIFAR10, FFHQ, ImageNet, and 4 A100 GPUs for LSUN Bedroom and Stable Diffusion
- Adjust batch size according to your hardware capabilities
Important Note:
The num_steps parameter specifies the number of original timestamps. Adasde inserts new timestamps between existing ones, so:
- When
num_steps=4, the total becomes 7 timestamps (6 sampling steps) - NFE = 5 if
afs==True, otherwise 6
# Adasde Configuration
SOLVER_FLAGS="--sampler_stu=adasde --sampler_tea=dpm --num_steps=3 --M=3 --afs=True --scale_dir=0.05 --scale_time=0.2 --gamma=0.02 --seed=0 --lr=0.2 --coslr"
SCHEDULE_FLAGS="--schedule_type=time_uniform --schedule_rho=1"
torchrun --standalone --nproc_per_node=2 --master_port=11111 \
train.py --dataset_name="cifar10" --batch=32 --total_kimg=10 $SOLVER_FLAGS $SCHEDULE_FLAGS# Stable Diffusion
SOLVER_FLAGS="--sampler_stu=adasde --sampler_tea=dpm --num_steps=4 --M=3 --afs=True --scale_dir=0.05 --scale_time=0.2 --seed=0 --lr 0.01"
SCHEDULE_FLAGS="--schedule_type=discrete --schedule_rho=1"
ADDITIONAL_FLAGS="--max_order=2 --predict_x0=False --lower_order_final=True"
GUIDANCE_FLAGS="--guidance_type=cfg --guidance_rate=7.5"
torchrun --standalone --nproc_per_node=4 --master_port=11111 \
train.py --dataset_name="ms_coco" --batch=32 --total_kimg=5 $SOLVER_FLAGS $SCHEDULE_FLAGS $ADDITIONAL_FLAGS $GUIDANCE_FLAGS# Generate 50K samples for FID evaluation
torchrun --standalone --nproc_per_node=1 --master_port=22222 \
sample.py --predictor_path=0 --batch=128 --seeds="0-49999"The generated images will be stored at "./samples" by default. To compute Fréchet inception distance (FID) for a given model and sampler, compare the generated 50k images against the dataset reference statistics using fid.py:
# FID evaluation
python fid.py calc --images=path/to/images --ref=path/to/fid/stat| Category | Parameter | Default | Description |
|---|---|---|---|
| General Options | dataset_name |
None | Supported datasets: ['cifar10', 'ffhq', 'afhqv2', 'imagenet64', 'lsun_bedroom', 'imagenet256', 'lsun_bedroom_ldm', 'ms_coco'] |
predictor_path |
None | Path or experiment number of trained AdaSDE predictor | |
batch |
64 | Total batch size | |
seeds |
"0-63" | Random seed range for image generation | |
grid |
False | Organize output images in grid layout | |
total_kimg |
10 | Training duration (in thousands of images) | |
scale_dir |
0.05 | Gradient direction scale (c_n in paper). Range: [1-scale_dir, 1+scale_dir] |
|
scale_time |
0.2 | Input time scale (a_n in paper). Range: [1-scale_time, 1+scale_time] |
|
| Solver Flags | sampler_stu |
'adasde' | Student solver: ['adasde'] |
sampler_tea |
'dpm' | Teacher solver type | |
num_steps |
4 | Initial timestamps for student solver. Final steps = 2*(num_steps-1) |
|
M |
3 | Intermediate steps inserted between teacher solver steps | |
afs |
True | Enable Accelerated First Step (saves initial model evaluation) | |
| Schedule Flags | schedule_type |
'polynomial' | Time discretization: ['polynomial', 'logsnr', 'time_uniform', 'discrete'] |
schedule_rho |
7 | Time step exponent (required for polynomial, time_uniform, discrete) |
|
| Additional Flags | predict_x0 |
True | DPM-Solver++: Use data prediction formulation |
lower_order_final |
True | DPM-Solver++: Reduce order at final sampling stages | |
| Guidance Flags | guidance_type |
None | Guidance method: ['cg' (classifier), 'cfg' (classifier-free), 'uncond' (unconditional), None] |
guidance_rate |
None | Guidance strength parameter | |
prompt |
None | Text prompt for Stable Diffusion sampling |
Key Metric: All numbers are measured at 5 NFE (ultra-low latency).
At equal NFE, AdaSDE consistently achieves the best FID and outperforms the best non-AdaSDE method on every dataset.
| Dataset | AdaSDE @5 NFE (FID ↓) | Best non-AdaSDE (method) | FID ↓ | Absolute gain (ΔFID) | Relative gain |
|---|---|---|---|---|---|
| CIFAR-10 (32×32) | 4.18 | AMED-Solver | 7.59 | −3.41 | −44.9% |
| FFHQ-64 | 8.05 | AMED-Solver | 14.80 | −6.75 | −45.6% |
| ImageNet-64 | 6.90 | AMED-Solver | 10.74 | −3.84 | −35.8% |
| LSUN Bedroom (256×256) | 6.96 | AMED-Solver | 13.20 | −6.24 | −47.3% |
We provide pre-trained AdaSDE predictors for:
- CIFAR-10 main result
- CIFAR-10 ablation result
The pre-trained models are available in ./exp/.
Folder naming format:
<exp_num>-<dataset_name>-<num_steps>-<NFE>-<student>-<schedule>-<afs>
Usage:
Run sampling with:
# Generate 50K samples for FID evaluation
torchrun --standalone --nproc_per_node=1 --master_port=22222 \
sample.py --predictor_path=EXP_NUM --batch=128 --seeds="0-49999"We perform sampling on a variaty of pre-trained diffusion models from different codebases including EDM, ADM, Consistency models, LDM and Stable Diffusion. The tested pre-trained models are listed below:
| Codebase | dataset_name | Resolusion | Pre-trained Models | Description |
|---|---|---|---|---|
| EDM | cifar10 | 32 | edm-cifar10-32x32-uncond-vp.pkl | |
| EDM | ffhq | 64 | edm-ffhq-64x64-uncond-vp.pkl | |
| EDM | afhqv2 | 64 | edm-afhqv2-64x64-uncond-vp.pkl | |
| EDM | imagenet64 | 64 | edm-imagenet-64x64-cond-adm.pkl | |
| Consistency Models | lsun_bedroom | 256 | edm_bedroom256_ema.pt | Pixel-space |
| ADM | imagenet256 | 256 | 256x256_diffusion.pt and 256x256_classifier.pt | Classifier-guidance. |
| LDM | lsun_bedroom_ldm | 256 | lsun_bedrooms.zip | Latent-space |
| Stable Diffusion | ms_coco | 512 | stable-diffusion-v1-5 | Classifier-free-guidance |
For facilitating the FID evaluation of diffusion models, we provide our FID statistics of various datasets. They are collected on the Internet or made by ourselves with the guidance of the EDM codebase.
You can compute the reference statistics for your own datasets as follows:
python fid.py ref --data=path/to/my-dataset.zip --dest=path/to/save/my-dataset.npz
If you find this repository useful, please consider citing the following paper:
@inproceedings{wang2025adaptivestochasticcoefficientsaccelerating,
title={Adaptive Stochastic Coefficients for Accelerating Diffusion Sampling},
author={Ruoyu Wang and Beier Zhu and Junzhi Li and Liangyu Yuan and Chi Zhang},
year={2025},
booktitle={Proc. NeurIPS}
}


