Skip to content

SwinUNETR: Issue with Half Precision Training #4914

@BardiaKh

Description

@BardiaKh

Describe the bug
When you set the model and input to .half() the model will crash saying:

expected scalar type Float but found Half

This is because the created attn_mask is a torch.float type object and cannot be multiplied by the v, which is a torch.half type tensor.

To Reproduce
Steps to reproduce the behavior:

model = SwinUNETR(**params)
model = model.half()
for img in dataloader:
    img = img.half()
    model(img)

Environment

Ensuring you use the relevant python executable, please paste the output of:

Printing MONAI config...
================================
MONAI version: 0.9.1
Numpy version: 1.23.2
Pytorch version: 1.10.1
MONAI flags: HAS_EXT = False, USE_COMPILED = False, USE_META_DICT = False
MONAI rev id: 356d2d2f41b473f588899d705bbc682308cee52c
MONAI __file__: /home/m253231/anaconda3/envs/PyTorchDefault/lib/python3.9/site-packages/monai/__init__.py

Optional dependencies:
Pytorch Ignite version: NOT INSTALLED or UNKNOWN VERSION.
Nibabel version: 3.2.1
scikit-image version: 0.18.2
Pillow version: 8.3.1
Tensorboard version: 2.6.0
gdown version: NOT INSTALLED or UNKNOWN VERSION.
TorchVision version: 0.11.2
tqdm version: 4.62.2
lmdb version: NOT INSTALLED or UNKNOWN VERSION.
psutil version: 5.8.0
pandas version: 1.3.5
einops version: 0.3.2
transformers version: NOT INSTALLED or UNKNOWN VERSION.
mlflow version: NOT INSTALLED or UNKNOWN VERSION.
pynrrd version: NOT INSTALLED or UNKNOWN VERSION.

For details about installing the optional dependencies, please visit:
    https://docs.monai.io/en/latest/installation.html#installing-the-recommended-dependencies


================================
Printing system config...
================================
System: Linux
Linux version: CentOS Linux 7 (Core)
Platform: Linux-3.10.0-1160.66.1.el7.x86_64-x86_64-with-glibc2.17
Processor: x86_64
Machine: x86_64
Python version: 3.9.6
Process name: python
Command: ['python', '-c', 'import monai; monai.config.print_debug_info()']
Open files: []
Num physical CPUs: 48
Num logical CPUs: 96
Num usable CPUs: 96
CPU usage (%): [6.4, 13.8, 11.8, 8.0, 33.7, 6.6, 0.0, 0.3, 2.4, 0.7, 1.4, 2.4, 6.9, 4.5, 6.6, 76.9, 21.8, 2.1, 1.7, 14.5, 0.7, 12.4, 20.1,
 32.1, 35.2, 0.7, 25.6, 13.5, 6.6, 14.2, 0.3, 14.8, 0.3, 8.3, 9.4, 8.3, 24.8, 13.1, 81.1, 14.5, 70.0, 17.2, 14.5, 14.2, 4.8, 23.2, 18.8, 2
3.3, 14.8, 12.4, 14.8, 14.8, 13.8, 14.8, 14.2, 14.9, 0.7, 14.2, 28.6, 14.2, 13.8, 14.9, 14.5, 11.8, 12.1, 14.5, 14.5, 5.9, 14.8, 47.2, 13.
8, 12.1, 32.3, 14.5, 18.3, 3.8, 10.4, 1.0, 0.3, 0.3, 100.0, 80.0, 13.2, 11.0, 13.8, 11.4, 7.6, 10.3, 10.3, 11.7, 0.0, 1.0, 0.3, 14.5, 29.2
, 13.8]
CPU freq. (MHz): 2025
Load avg. in last 1, 5, 15 mins (%): [16.9, 17.0, 11.4]
Disk usage (%): 38.4
Avg. sensor temp. (Celsius): UNKNOWN for given OS
Total physical memory (GB): 1007.6
Available memory (GB): 465.2
Used memory (GB): 535.6

================================
Printing GPU config...
================================
Num GPUs: 4
Has CUDA: True
CUDA version: 11.3
cuDNN enabled: True
cuDNN version: 8200
Current device: 0
Library compiled for CUDA architectures: ['sm_37', 'sm_50', 'sm_60', 'sm_61', 'sm_70', 'sm_75', 'sm_80', 'sm_86', 'compute_37']
GPU 0 Name: NVIDIA A100-SXM4-80GB
GPU 0 Is integrated: False
GPU 0 Is multi GPU board: False
GPU 0 Multi processor count: 108
GPU 0 Total memory (GB): 79.2
GPU 0 CUDA capability (maj.min): 8.0
GPU 1 Name: NVIDIA A100-SXM4-80GB
GPU 1 Is integrated: False
GPU 1 Is multi GPU board: False
GPU 1 Multi processor count: 108
GPU 1 Total memory (GB): 79.2
GPU 1 CUDA capability (maj.min): 8.0
GPU 2 Name: NVIDIA A100-SXM4-80GB
GPU 2 Is integrated: False
GPU 2 Is multi GPU board: False
GPU 2 Multi processor count: 108
GPU 2 Total memory (GB): 79.2
GPU 2 CUDA capability (maj.min): 8.0
GPU 3 Name: NVIDIA A100-SXM4-80GB
GPU 3 Is integrated: False
GPU 3 Is multi GPU board: False
GPU 3 Multi processor count: 108
GPU 3 Total memory (GB): 79.2
GPU 3 CUDA capability (maj.min): 8.0

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions