This repository is the official implementation of "Towards Scaling Difference Target Propagation with Backprop Targets", currently under review at ICML 2022. The following code runs on Python > 3.7 with Pytorch >= 1.7.0.
pip install -e .(Optional): We suggest you use a conda environment. The specs of our environment are stored in conda_env_specs.txt.
| Name in paper | Name in codebase |
|---|---|
| L-DRL | DTP |
| Backpropagation | BaselineModel |
| DRL | meulemans_dtp (Based on the original authors' repo) |
| Target Propagation | TargetProp |
| Difference Target Propagation | VanillaDTP |
| "Parallel" L-DRL (not in the paper) | ParallelDTP |
The main logic of our method is in target_prop/models/dtp.py
An initial PyTorch implementation of our DTP model can be found under target_prop/legacy. This model was then re-implemented using PyTorch-Lightning.
Here is how the codebase is roughly structured:
├── main_pl.py # Training script used in the paper
├── main.py # training script (legacy)
├── figure_4_3.py # Script for Figure 4.3
├── data # Data for figure 4.3
├── final_figures # Resulting figures 4.3
├── meulemans_dtp # Codebase for DRL (Meulemans repo)
├── numerical_experiments # Initial scripts for creating the figures (used for fig. 4.2)
└── target_prop
├── datasets # Datasets
├── legacy # initial implementation
├── models # Code for all the models except DRL
└── networks # Networks (SimpleVGG, LetNet, ResNet)
-
Recreating figure 4.2:
$ python -m numerical_experiments figure_4_2The figure save location will then be displayed on the console.
-
Recreating figure 4.3:
$ pytest target_prop/networks/lenet_test.py $ python plot.py
To run the pytorch-lightning re-implementation of DTP on CIFAR-10, using a VGG-like architecture, use the following command:
python main_pl.py run dtp simple_vggTo see a list of available command-line options, use the "--help" command.
python main_pl.py --help
python main_pl.py run --helpTo use the modified version of the above DTP model, with "parallel" feedback weight training on CIFAR-10, use the following command:
python main_pl.py run parallel_dtp simple_vggTo train with DTP on downsampled ImageNet 32x32 dataset, do:
python main_pl.py run dtp <architecture> --dataset imagenet32
To check training on CIFAR-10, use the following command:
python main.py --batch-size 128 \
--C 128 128 256 256 512 \
--iter 20 30 35 55 20 \
--epochs 90 \
--lr_b 1e-4 3.5e-4 8e-3 8e-3 0.18 \
--noise 0.4 0.4 0.2 0.2 0.08 \
--lr_f 0.08 \
--beta 0.7 \
--path CIFAR-10 \
--scheduler --wdecay 1e-4