Skip to content

YaNgZhAnG-V5/EcoDiff

Folders and files

NameName
Last commit message
Last commit date

Latest commit

Β 

History

20 Commits
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 

Repository files navigation

Learnable Sparsity for Vision Generative Models

Authors: Yang Zhang, Er Jin, Wenzhong Liang, Yanfei Dong, Ashkan Khakzar, Philip Torr, Johannes Stegmaier, Kenji Kawaguchi

ICLR 2026 arXiv OpenReview Project Page Model Weights License

Official implementation of ICLR2026 "Learnable Sparsity for Vision Generative Models" - a novel approach for memory efficient diffusion model pruning.

TL;DR: A model-agnostic structural pruning framework that achieves up to 20% parameter reduction with minimal performance loss through differentiable mask learning and time step gradient checkpointing.

teaser

Table of Contents

Table of Contents
  1. Overview
  2. Installation
  3. Quick Start
  4. Advanced Usage
  5. Configuration Files
  6. Development
  7. Repository Structure
  8. Models
  9. Model Weights
  10. Citation
  11. License
  12. Acknowledgments

Overview

method

EcoDiff introduces a model-agnostic structural pruning framework that learns differentiable masks to sparsify diffusion models. Key innovations include:

  • ✨ Model-agnostic pruning for various diffusion architectures
  • πŸ§ͺ Differentiable mask learning allowing end-to-end optimization
  • 🧡 Time step gradient checkpointing for memory-efficient training
  • πŸ“‰ Up to 20% parameter reduction with minimal performance loss

βš™οΈ Installation

Requirements

  • Python 3.10+
  • Anaconda or Miniconda
  • CUDA-compatible GPU

Setup

# Create conda environment
conda create -n sdib python=3.10 -y
conda activate sdib

# Clone repository
git clone https://github.com/your-repo/ecodiff.git
cd ecodiff

# Install dependencies
pip install -e .[core,loggers,test]

Environment Configuration

Create a .env file:

PYTHON=/path/to/miniconda3/envs/sdib/bin/python
RESULTS_DIR=/path/to/ecodiff/results
CONFIG_DIR=/path/to/ecodiff/configs

πŸš€ Quick Start

1. Basic Pruning

# SDXL pruning
make visual cfg=sdxl

# FLUX pruning
make visual cfg=flux

2. Hyperparameter Tuning

# Generate configurations
python scripts/utils/hyperparameter_tuning.py --config configs/sdxl.yaml --task gen

# Run tuning
python scripts/utils/hyperparameter_tuning.py --task run --max_job 2

3. Evaluation

# Semantic evaluation
python scripts/evaluation/semantic_eval.py -sp <checkpoint_path> --task all

# Mask analysis
python scripts/evaluation/binary_mask_eval.py --ckpt <checkpoint_path> -lt 0.001

Advanced Usage

Pruning Training

# Direct training script
python scripts/train.py

# Development/debugging mode
make visual cfg=sdxl
make visual cfg=flux

Hyperparameter Tuning

# Generate configuration files
python scripts/utils/hyperparameter_tuning.py \
  --config configs/sdxl.yaml \
  --output_dir configs/param_sdxl_tuning \
  -lr 0.1 0.2 \
  -mask "hard_discrete" \
  -re ".*" \
  -lreg 1 0 \
  -lrec 1 2 \
  -b 0.1 0.01 \
  -d 2 \
  -pn sdxl_pruning \
  --task gen

# Run tuning jobs
python scripts/utils/hyperparameter_tuning.py \
  --output_dir configs/param_sdxl_tuning \
  --task run \
  --max_job 2

Evaluation

# Generate semantic evaluation
python scripts/evaluation/semantic_eval.py -sp <checkpoint_path> --task gen

# Run all semantic evaluations
python scripts/evaluation/semantic_eval.py -sp <checkpoint_path> --task all

# Binary mask evaluation with threshold
python scripts/evaluation/binary_mask_eval.py --ckpt <checkpoint_path> -lt 0.001

Fine-tuning After Pruning

# SDXL LoRA fine-tuning
bash scripts/retraining/train_text_to_image_lora_sdxl.sh 30 0

# FLUX LoRA fine-tuning
bash scripts/retraining/train_text_to_image_lora_flux.sh 30 0

Load Pruned Models

python scripts/load_pruned_model.py

Configuration Files

The framework uses YAML configuration files located in the configs/ directory:

configs/
β”œβ”€β”€ dit.yaml          # Diffusion Transformers configuration
β”œβ”€β”€ flux.yaml         # FLUX.1 Schnell model configuration
β”œβ”€β”€ flux_dev.yaml     # FLUX.1 Dev model configuration  
β”œβ”€β”€ sd2.yaml          # Stable Diffusion v2 configuration
β”œβ”€β”€ sd3.yaml          # Stable Diffusion 3 configuration
└── sdxl.yaml         # Stable Diffusion XL configuration

πŸ› οΈ Development

For developers contributing to the project:

# Install development dependencies
pip install pre-commit && pre-commit install

# Run tests
make test

# Format code
make format

# Clean generated files
make clean

Repository Structure

Models

Supported

  • SDXL: Stable Diffusion XL
  • FLUX.1: FLUX diffusion models

Experimental

These models are currently experimental implementations. They may require additional hyperparameter tuning for optimal performance.

  • DiT: Diffusion Transformers
  • SD2: Stable Diffusion v2
  • SD3: Stable Diffusion 3

πŸ€— Model Weights

Pre-trained pruned models and retrained weights are available on HuggingFace:

Model Type Link
SDXL Pruned EcoDiff-SDXL-Pruned
FLUX (Schnell & Dev) Pruned EcoDiff-FLUX-Pruned
SDXL Retrained (Full & LoRA) EcoDiff-SDXL-Retrain-Weights
FLUX Retrained (LoRA) EcoDiff-FLUX-Retrain-Weights

πŸ“ Citation

@inproceedings{zhang2026learnable,
  title={Learnable Sparsity for Vision Generative Models},
  author={Zhang, Yang and Jin, Er and Liang, Wenzhong and Dong, Yanfei and Khakzar, Ashkan and Torr, Philip and Stegmaier, Johannes and Kawaguchi, Kenji},
  booktitle={The Fourteenth International Conference on Learning Representations},
  year={2026},
  url={https://openreview.net/forum?id=9pNWZLVZ4r}
}

License

This project is licensed under the MIT License - see the LICENSE file for details.

Acknowledgments

About

[ICLR2026] Learnable Sparsity for Vision Generative Models

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages