This is the code repository for the paper Pre-trained Universal Medical Image Transformer (arxiv)
This repository keeps updating for further work. You may want to switch to the submit branch to reproduce the results from the paper, note that submodules need checkout as well. Pre-trained model weights can be found in the release page.
Mamba is recommended to manage virtual environments.
mamba env create -n pumit -f environment.yaml
mamba activate pumit
echo "export PYTHONPATH=$PWD:\$PYTHONPATH" > $CONDA_PREFIX/etc/conda/activate.d/env_vars.sh
BUILD_MONAI=1 pip install --no-build-isolation -e third-party/MONAIDownload datasets and put them under the ./datasets folder, e.g.:
datasets
├── AbdomenCT-1K
├── ACDC
├── AMOS22
├── BCV
├── BrainPTM-2021
├── BraTS2023
...Then run pre-proecssing script: python scripts/data/process.py <dataset name>. More details can be found in the script.
python scripts/tokenizer/main.py -c conf/tokenizer/simple/main.yaml --data.dl_conf.train_batch_size 8 --data.dl_conf.num_workers 10 --training.benchmark true --model.quantize.mode soft --model.quantize.num_embeddings 1024 --loss.entropy_weight 1 --loss.quant_weight 0.03Note that you may need to adjust the batch size according to the number of GPUs.
scripts/model/mim-b.zsh --data.dl_conf.train_batch_size 14 --data.dl_conf.num_workers 10 --model.tokenizer.path <tokenizer checkpoit path>Assume that the pre-trained checkpoint is placed at ./pre-trained/pumit.ckpt.
Execute scripts under scripts/downstream/medmnistv2 for training and evaluation for each model.
Download the BTCV data from the official challenge, and download the train/validation split file from SMIT's repository, organize the files as following:
data
├── BTCV
│ ├── smit.json
│ ├── Testing
│ └── Training
Then run fine-tuning and inference:
scripts/downstream/btcv/pumit-b.zsh --data.num_workers 10 --data.ratio 1 --trainer.logger.name pumit-b --data.train_batch_size 4
scripts/downstream/btcv/test-b.zsh --data.num_workers 10 --ckpt_path <output checkpoint path> --trainer.logger.name pumit-bFirst, run the pre-processing script to convert the DICOM series into NIFTI format: python scripts/downstream/chaos/preprocess.py
Then run fine-tuning and inference:
scripts/downstream/chaos/pumit-b.zsh --data.num_workers 10 --data.ratio 1 --trainer.logger.name pumit-b --data.train_batch_size 8
scripts/downstream/chaos/predict-b.zsh --data.num_workers 10 --ckpt_path <output checkpoint path> --trainer.logger.name pumit-b