This repository implements various generative data augmentation strategies using stable diffusion to create synthetic datasets, aimed at enhancing classification tasks.
The key packages and their versions are listed below. The code is tested on a single node with 4 NVIDIA RTX3090 GPUs.
torch==2.0.1+cu118
diffusers==0.25.1
transformers==4.36.2
datasets==2.16.1
accelerate==0.26.1
numpy==1.24.4
For convenience, well-structured datasets in Hugging Face can be utilized. The fine-grained datasets CUB and Aircraft we experimented with can be downloaded from Multimodal-Fatima/CUB_train and Multimodal-Fatima/FGVC_Aircraft_train, respectively. In case of encountering network connection problem during training, please pre-download the data from the website, and the saved local path HUG_LOCAL_IMAGE_TRAIN_DIR should be specified in the dataset/instance/cub.py.
We provide the lora weights fine-tuned on the full dataset in case for fast reproducation on given datasets. One can download using the following link, and unzip the file into dir ckpts and the file structure look like:
ckpts
├── cub -packages/torch/nn/modules/module.py", line 1501, in _call_impl
│ └── shot-1-lora-rank10
│ ├── learned_embeds-steps-last.bin -packages/diffusers/models/attention_processor.py", line 527, in forward
│ └── pytorch_lora_weights.safetensors
└── put_finetuned_ckpts_here.txt
| Dataset | data | ckpts (fullshot) |
|---|---|---|
| CUB | huggingface (train/test) | google drive |
| Flower | official website | google drive |
| Aircraft | huggingface (train/test) | google drive |
The scripts/finetune.sh script allows users to perform fine-tuning on their own datasets. By default, it implements a fine-tuning strategy combining DreamBooth and Textual Inversion. Users can customize the examples_per_class argument to fine-tune the model on a dataset with {examples_per_class} shots. The tuning process costs around 4 hours on 4 RTX3090 GPUs for full-shot cub dataset.
MODEL_NAME="runwayml/stable-diffusion-v1-5"
DATASET='cub'
SHOT=-1 # set -1 for full shot
OUTPUT_DIR="ckpts/${DATASET}/shot${SHOT}_lora_rank10"
accelerate launch --mixed_precision='fp16' --main_process_port 29507 \
train_lora.py \
--pretrained_model_name_or_path=$MODEL_NAME \
--dataset_name=$DATASET \
--resolution=224 \
--random_flip \
--max_train_steps=35000 \
--num_train_epochs=10 \
--checkpointing_steps=5000 \
--learning_rate=5e-05 \
--lr_scheduler='constant' \
--lr_warmup_steps=0 \
--seed=42 \
--rank=10 \
--local_files_only \
--examples_per_class $SHOT \
--train_batch_size 2 \
--output_dir=$OUTPUT_DIR \
--report_to='tensorboard'"
scripts/sample.sh provides script to synthesize augmented images in a multi-processing way. Each item in GPU_IDS denotes the process running on the indexed GPU. The simplified command for sampling a diff-mix) with strength
DATASET='cub'
# set -1 for full shot
SHOT=-1
FINETUNED_CKPT="ckpts/cub/shot${SHOT}-lora-rank10"
# ['diff-mix', 'diff-aug', 'diff-gen', 'real-mix', 'real-aug', 'real-gen', 'ti_mix', 'ti_aug']
SAMPLE_STRATEGY='diff-mix'
STRENGTH=0.8
# ['fixed', 'uniform']. 'fixed': use fixed $STRENGTH, 'uniform': sample from [0.3, 0.5, 0.7, 0.9]
STRENGTH_STRATEGY='fixed'
# expand the dataset by 5 times
MULTIPLIER=5
# spwan 4 processes
GPU_IDS=(0 1 2 3)
python scripts/sample_mp.py \
--model-path='runwayml/stable-diffusion-v1-5' \
--output_root='outputs/aug_samples' \
--dataset=$DATASET \
--finetuned_ckpt=$FINETUNED_CKPT \
--syn_dataset_mulitiplier=$MULTIPLIER \
--strength_strategy=$STRENGTH_STRATEGY \
--sample_strategy=$SAMPLE_STRATEGY \
--examples_per_class=$SHOT \
--resolution=512 \
--batch_size=1 \
--aug_strength=0.8 \
--gpu-ids=${GPU_IDS[@]}The output synthetic dir will be located at aug_samples/cub/diff-mix_-1_fixed_0.7. To create a 5-shot setting, set the examples_per_class argument to 5 and the output dir will be at aug_samples/cub/diff-mix_5_fixed_0.7. Please ensure that the finetuned_ckpt is also fine-tuned under the same 5-shot setting.
After completing the sampling process, you can integrate the synthetic data into downstream classification and initiate training using the script scripts/classification.sh:
GPU=1
DATASET="cub"
SHOT=-1
# "shot{args.examples_per_class}_{args.sample_strategy}_{args.strength_strategy}_{args.aug_strength}"
SYNDATA_DIR="aug_samples/cub/shot${SHOT}_diff-mix_fixed_0.7" # shot-1 denotes full shot
SYNDATA_P=0.1
GAMMA=0.8
python downstream_tasks/train_hub.py \
--dataset $DATASET \
--syndata_dir $SYNDATA_DIR \
--syndata_p $SYNDATA_P \
--model "resnet50" \
--gamma $GAMMA \
--examples_per_class $SHOT \
--gpu $GPU \
--amp 2 \
--note $(date +%m%d%H%M) \
--group_note "fullshot" \
--nepoch 120 \
--res_mode 224 \
--lr 0.05 \
--seed 0 \
--weight_decay 0.0005
We also provides the scripts for robustness test and long-tail classification in scripts/classification_waterbird.sh and scripts/classification_imb.sh, respectively.
This project is built upon the repository Da-fusion and diffusers. Special thanks to the contributors.
