The offical repo for Token Pruning for Caching Better: 9× Acceleration on Stable Diffusion for Free.
This repository contains the core code for DaTo on Stable Diffusion 1.x, including:
- Dynamic token pruning and attention patching
- Multi-objective search for depth-wise cache strategies
- Multi-objective search for per-step pruning ratios
- Python 3.8+
- PyTorch with CUDA
pymoofor multi-objective optimizationpytorch-fidfor FID computation (optional)ImageRewardfor image–text reward scoring- Stable Diffusion 1.x checkpoint
Main DaTo-related components:
sd1/
scripts/
datosd/
__init__.py
dato.py # Core DaTo logic: token pruning, integration with SD
patch.py # Patching/hooking Stable Diffusion modules
utils.py # Helper utilities
search_depth.py # Search over cache strategies across timesteps
search_ratio.py # Search over per-step pruning ratios
-
Core implementation
The directorysd1/scripts/datosdcontains the main DaTo implementation: -
Search scripts
search_depth.py:
Uses NSGA-II to search over depth-wise cache strategies for each diffusion step. It jointly optimizes:- image quality (via ImageReward), and
- inference time (average generation time).
search_ratio.py:
Uses NSGA-II to search over per-step pruning ratios, balancing:- FID / ImageReward quality, and
- speed-up.
DaTo (Dynamic Token Optimization) reduces redundant tokens during the diffusion sampling process to accelerate inference while maintaining generation quality.
At a high level:
- DaTo applies dynamic token pruning on the relevant attention layers of Stable Diffusion.
- It can be combined with caching strategies to reuse intermediate features at certain timesteps.
- The actual token and cache strategies are obtained from the multi-objective search procedures implemented in the search scripts.
Core logic:
Script: sd1/scripts/search_depth.py
This script uses pymoo’s NSGA-II algorithm to search over a length‑steps (default 50) cache strategy vector. Each dimension encodes a cache behavior for a specific diffusion step.
For each candidate strategy, the script:
- Randomly samples a set of prompts from a text file.
- Calls
txt2img_gento generate images with:- deep cache enabled,
- a specific
cache_strategy.
- Computes the ImageReward score and average inference time.
- Uses NSGA-II to evolve the population and obtain a Pareto front over:
- maximizing ImageReward, and
- minimizing inference time.
Example usage (adapt root_dir, dataset paths, etc., inside the script):
CUDA_VISIBLE_DEVICES=0 python sd1/scripts/search_depth.pyThe final result contains:
- The best (or Pareto) strategies
res.X - Corresponding objective values
res.F(quality vs. time)
These strategies can then be used as the cache configuration in DaTo.
Script: sd1/scripts/search_ratio.py
This script searches for a length‑steps (default 50) pruning ratio vector. The ratios are encoded as discrete integers (e.g., 3–7) and then mapped to real-valued pruning rates (e.g., 0.3–0.7).
For each candidate pruning strategy, the script:
- Randomly samples prompts from a text file.
- Calls
txt2img_genwith:- deep cache enabled,
enable_dato=True,prune_ratioset to the candidate strategy.
- Evaluates a combination of quality (e.g., FID or ImageReward) and speed.
The objective typically includes:
- a term related to FID / quality, and
- a term related to inverse runtime.
Example usage (again, please adapt the paths inside the script):
CUDA_VISIBLE_DEVICES=0 python sd1/scripts/search_ratio.pyThe resulting res.X contains candidate pruning ratio schedules that can be plugged into DaTo for deployment.
If you find this repository useful, please consider citing:
@article{zhang2024token,
title={Token pruning for caching better: 9 times acceleration on stable diffusion for free},
author={Zhang, Evelyn and Xiao, Bang and Tang, Jiayi and Ma, Qianli and Zou, Chang and Ning, Xuefei and Hu, Xuming and Zhang, Linfeng},
journal={arXiv preprint arXiv:2501.00375},
year={2024}
}