Implementation of Flow++ in PyTorch. Based on the paper:
Flow++: Improving Flow-Based Generative Models with Variational Dequantization and Architecture Design
Jonathan Ho, Xi Chen, Aravind Srinivas, Yan Duan, Pieter Abbeel
OpenReview:Hyg74h05tX
Training script and hyperparameters designed to match the CIFAR-10 experiments described in the paper.
- Make sure you have Anaconda or Miniconda installed.
- Clone repo with
git clone https://github.com/chrischute/flowplusplus.git flowplusplus. - Go into the cloned repo:
cd flowplusplus. - Create the environment:
conda env create -f environment.yml. - Activate the environment:
source activate f++.
- Make sure you've created and activated the conda environment as described above.
- Run
python train.py -hto see options. - Run
python train.py [FLAGS]to train. E.g., runpython train.pyfor the default configuration, or runpython train.py --gpu_ids=0,1to run on 2 GPUs instead of the default of 1 GPU. This will also double the batch size. - At the end of each epoch, samples from the model will be saved to
samples/epoch_N.png, whereNis the epoch number.