This directory contains implementation of FixMatch for ImageNet with RandAugment and CT Augment.
-
Follow these instructions to prepare splits of semi-supervised data for ImageNet training.
-
Make sure to have python 3, pip and virual env installed.
sudo apt install python3-dev python3-virtualenv -
Install all software dependecies needed for Tensorflow 2.1.
- If you plan to run code on GPUs you have to install NVIDIA driver, CUDA and cuDNNs, see Tensorflow documentation. Keep in mind which versions of these libraries works with Tensorflow 2.1.
-
Create python virtual environment and install all necessary dependencies:
virtualenv -p python3 ~/.venv3/fixmatch_imagenet source ~/.venv3/fixmatch_imagenet/bin/activate pip install tensorflow==2.1 pip install tensorflow-addons pip install absl-py pip install easydict- It is recommended to avoid
--system-site-packages, especially if you already have Tensorflow installed on the system. Otherwise you may get binary incopatibility between Tensorflow and Tensorflow Addons, see tensorflow/addons#676 (comment) for context.
- It is recommended to avoid
-
Make sure that
${SSL_IMAGENET_DIR}points to directory with semi-supervised ImageNet training data.
Provided codebase supports training on GPU and Cloud TPU.
Train fully supervised model with CT Augmentation:
mkdir -p ${HOME}/models/supervised
python -B supervised.py \
--imagenet_data="${SSL_IMAGENET_DIR}" \
--model_dir="${HOME}/models/supervised" \
--steps_per_run=100 \
--dataset="imagenet" \
--hparams="\"bfloat16\": false, \"num_epochs\": 200, \"augment\": { \"type\": \"cta\" }"
Train fixmatch model on 10% ImageNet data with random augmentation (random magnitude):
mkdir -p ${HOME}/models/fixmatch
python -B fixmatch.py \
--imagenet_data="${SSL_IMAGENET_DIR}" \
--model_dir="${HOME}/models/fixmatch" \
--steps_per_run=100 \
--per_worker_batch_size=8 \
--dataset="imagenet128116.1" \
--hparams="\"bfloat16\": false, \"num_epochs\": 3000, \"learning_rate\": { \"warmup_epochs\": 50, \"decay_epochs\": 500 }, \"augment\": { \"type\": \"randaugment\" }"
Running TPU training is very similar to running GPU training with following exceptions:
- You have to create Cloud TPU instance and provide address to this instance to training code via
--tpuargument. See aslo Cloud TPU documentation. - Both dataset and model have to be stored on Google Cloud Storage. See Cloud TPU resnet example for possible setup.
- You can use
bfloat16data type when training on TPU by specifying it in the hyperparameters list. This usually results in faster training with similar accuracy.
Example of running fixmatch model on TPU:
python -B fixmatch.py \
--imagenet_data="${SSL_IMAGENET_DIR}" \
--tpu="${CLOUD_TPU_INSTANCE}" \
--model_dir="${MODEL_DIR}" \
--steps_per_run=1000 \
--per_worker_batch_size=32 \
--dataset="imagenet128116.1" \
--hparams="\"bfloat16\": true, \"num_epochs\": 3000, \"learning_rate\": { \"warmup_epochs\": 50, \"decay_epochs\": 500 }, \"augment\": { \"type\": \"randaugment\" }"
Training is controlled by following command line arguments:
--steps_per_runcontrols how many training steps are done between evaluations. It also controls how often CT augmentation parameters are updates, when CTA is used. 1000 steps is recommended for TPU trainings.--per_worker_batch_sizecontrols per worker (GPU or TPU core) supervised batch size, with default 128. Note that unsupervised batch size is computed by multiplyingper_worker_batch_sizebyuratiohyperparameter.--hparamsis a list of hyperparameters in JSON format.DEFAULT_COMMON_HPARAMSintraining.pyandDEFAULT_FIXMATCH_HPARAMSinfixmatch.pycontains default values of all hyperparameters.- Note that number of epochs (
num_epochshyperparameter) is measures in epochs of supervised examples. So each epoch on 10% of ImageNet data is 10 times shorter compared to full ImageNet data.
- Note that number of epochs (
--datasetcontrols dataset used for training. It has to beimagenetfor supervised training. For semi-supervised training on 10% of ImageNet data it should beimagenet128116.${split}where${split}is a split of semisupervised data, which could be one of1,2,3,4or5.