Official implementations of MMD-B-Fair published in AISTATS 2023 by Namrata Deka and Danica J. Sutherland.
- Python 3.8.10
- PyTorch 1.12.1
- Torchvision 0.13.1
- Wandb 0.13.3
-
Create a .yml config file and store it in
config(see available examples, VERY IMPORTANT). The file should specify all data, model and trainer settings/hyperparameters. -
Then execute
python main.py -v <relative-path-to-yml-from-config> --seed <seed> -wExample:
python main.py -v eq/adult/lambda_1.yml --seed 42 -w
If you do not wish to sync to wandb while training add the option -m offline and sync anytime later with the wandb sync command.
If seed is not specified it will default to 0.
Trained models are saved in the location specified in experiment.output_location in a subfolder named as per the seed. In wandb, experiments are logged under <config file name>/<seed> in the mmd-b-fair workspace.
This repository heavily uses the factory design pattern for increased modularity. To add new datasets, models and/or trainers follow the steps below:
- Create new data/model/trainer class under the appropriate directories. All trainer classes must inherit
BaseTrainerand models must inheritBaseModel. - Create corresponding builders for new classes.
- Register all builder objects to the respective factories in
data/data.py,model/model.pyandtrainer/trainer.py. data.data_key,model.model_keyandmodel.trainer_keyin the config files must match the registered factory keys.- Specify class-specific arguments in the config file. Example, dataset arguements must go in
data.args.