Diffuse and Disperse: Image Generation with Representation Regularization
Runqian Wang, Kaiming He
MIT
We propose Dispersive Loss, a simple plug-and-play regularizer that effectively improves diffusion-based generative models. Our loss function encourages internal representations to disperse in the hidden space, analogous to contrastive self-supervised learning, with the key distinction that it requires no positive sample pairs and therefore does not interfere with the sampling process used for regression.
We implement our Dispersive Loss on top of SiT codebase. The core implementation of Dispersive Loss is highlighted below:
def disp_loss(self, z): # Dispersive Loss implementation (InfoNCE-L2 variant)
z = z.reshape((z.shape[0],-1)) # flatten
diff = th.nn.functional.pdist(z).pow(2)/z.shape[1] # pairwise distance
diff = th.concat((diff, diff, th.zeros(z.shape[0]).cuda())) # match JAX implementation of full BxB matrix
return th.log(th.exp(-diff).mean()) # calculate lossRun the following script to setup environment.
git clone https://github.com/raywang4/DispLoss.git
cd DispLoss
conda env create -f environment.yml
conda activate SiTTo train with Dispersive Loss, simply add the --disp argument to the training script:
torchrun --nnodes=1 --nproc_per_node=N train.py --model SiT-XL/2 --data-path /path/to/imagenet/train --dispLogging. To enable wandb, firstly set WANDB_KEY, ENTITY, and PROJECT as environment variables:
export WANDB_KEY="key"
export ENTITY="entity name"
export PROJECT="project name"Then in training command add the --wandb flag:
torchrun --nnodes=1 --nproc_per_node=N train.py --model SiT-XL/2 --data-path /path/to/imagenet/train --disp --wandbResume training. To resume training from custom checkpoint:
torchrun --nnodes=1 --nproc_per_node=N train.py --model SiT-L/2 --data-path /path/to/imagenet/train --disp --ckpt /path/to/model.ptPre-trained checkpoints. We provide a SiT-B/2 checkpoint and a SiT-XL/2 checkpoint both trained with Dispersive Loss for 80 epochs on ImageNet 256x256.
Sampling from checkpoint. To sample from the EMA weights of a 256x256 SiT-XL/2 model checkpoint with ODE sampler, run:
python sample.py ODE --model SiT-XL/2 --image-size 256 --ckpt /path/to/model.ptMore sampling options. For more sampling options such as SDE sampling, please refer to train_utils.py.
The sample_ddp.py script samples a large number of images from a pre-trained model in parallel. This script
generates a folder of samples as well as a .npz file which can be directly used with ADM's TensorFlow
evaluation suite to compute FID, Inception Score and
other metrics. To sample 50K images from a pre-trained SiT-XL/2 model over N GPUs under default ODE sampler settings, run:
torchrun --nnodes=1 --nproc_per_node=N sample_ddp.py ODE --model SiT-XL/2 --num-fid-samples 50000 --ckpt /path/to/model.ptOur original implementation is in JAX, and this repo contains our re-implementation in PyTorch. Therefore, results from running this repo may have minor numerical differences with those reported in our paper. In our JAX experiments, we used 16 devices with local batch size 16, whereas in PyTorch experiments we used 8 devices with local batch size 32. We have adjusted the hyperparameter choices slightly for best performance. We report our reproduction results below.
| implementation | config | local bz | B/2 80 ep | XL/2 80 ep (cfg=1.5) |
|---|---|---|---|---|
| baseline | - | 16 | 36.49 | 6.02 |
| JAX |
|
16 | 32.35 | 5.09 |
| PyTorch |
|
32 | 32.64 | 4.74 |
This project is under the MIT license. See LICENSE for details.
