Pytorch implementation for DMPNet: Dynamic Message Propagation Network for RGB-D Salient Object Detection.
- Python 3.7
- Pytorch 1.8.0
- Torchvision 0.9.0
- Cuda 11.1
This is the Pytorch implementation of DMPNet. It has been trained and tested on Linux (Ubuntu20 + Cuda 11.1 + Python 3.7 + Pytorch 1.8), and it can also work on Windows.
-
Download the pre-trained ImageNet backbone, password: bd0z
(resnet101/resnet50, densenet161, vgg16 and vgg_conv1, whereas the latter already exists in the folder), and put it in the 'pretrained' folder. -
Download the training dataset, password: uw24
and modify the 'train_root' and 'train_list' in themain.py. -
Start to train with
python main.py --mode=train --arch=resnet --network=resnet101 --train_root=xx/dataset/RGBDcollection --train_list=xx/dataset/RGBDcollection/train.lst - Download the testing dataset, password: uw24
and have it in the 'dataset/test/' folder. - Download the already-trained DMPNet pytorch model, password: kas7
and modify the 'model' to its saving path in themain.py. - Start to test with
python main.py --mode=test --arch=resnet --network=resnet101 --model=xx/JLDCF_resnet101.pth --sal_mode=LFSD --test_folder=test/LFSD