Skip to content

SwinUNETR: Shape error in sliding window attention #4693

@razorx89

Description

@razorx89

Describe the bug
The SwinUNETR architecture has a maximum reduction of 32 on each spatial dimension. However, not all image sizes currently work due to reshape errors in the sliding window attention. It seems like the error only happens for 2D images.

To Reproduce

from monai.networks.nets import SwinUNETR
import torch

def compute(img_size):
    x = torch.zeros((1, 3) + img_size, dtype=torch.float32).to("cuda")
    model = SwinUNETR(
        in_channels=3,
        out_channels=1,
        feature_size=24,
        img_size=img_size,
        spatial_dims=len(img_size),
    ).to("cuda")
    try:
        model(x)
        print(img_size, "success")
    except:
        print(img_size, "error")

for img_size in [
    (480, 480),
    (640, 640),
    (640, 480),
    (480, 640),
    (96, 96, 96),
    (96, 96, 32),
    (192, 192, 32),
    (192, 96, 32),
    (192, 96, 96),
    (480, 640, 32),
    (640, 480, 32),
    (192, 192),
    (192, 96),
    (192, 32),
    (96, 32),
    (96, 96),
]:
    compute(img_size)

Output:

(480, 480) success
(640, 640) success
(640, 480) error
(480, 640) error
(96, 96, 96) success
(96, 96, 32) success
(192, 192, 32) success
(192, 96, 32) success
(192, 96, 96) success
(480, 640, 32) success
(640, 480, 32) success
(192, 192) success
(192, 96) error
(192, 32) error
(96, 32) error
(96, 96) success

Expected behavior
Images with dimensions equal to k*32 should work.

Environment

monai.config.print_debug_info()
================================
Printing MONAI config...
================================
MONAI version: 0.9.0
Numpy version: 1.22.0
Pytorch version: 1.12.0a0+2c916ef
MONAI flags: HAS_EXT = False, USE_COMPILED = False
MONAI rev id: af0e0e9f757558d144b655c63afcea3a4e0a06f5
MONAI __file__: /opt/conda/lib/python3.8/site-packages/monai/__init__.py

Optional dependencies:
Pytorch Ignite version: NOT INSTALLED or UNKNOWN VERSION.
Nibabel version: 3.2.2
scikit-image version: NOT INSTALLED or UNKNOWN VERSION.
Pillow version: 9.2.0
Tensorboard version: 2.9.1
gdown version: NOT INSTALLED or UNKNOWN VERSION.
TorchVision version: 0.13.0a0
tqdm version: 4.64.0
lmdb version: 1.3.0
psutil version: 5.9.0
pandas version: 1.4.3
einops version: 0.4.1
transformers version: NOT INSTALLED or UNKNOWN VERSION.
mlflow version: NOT INSTALLED or UNKNOWN VERSION.
pynrrd version: NOT INSTALLED or UNKNOWN VERSION.

For details about installing the optional dependencies, please visit:
    https://docs.monai.io/en/latest/installation.html#installing-the-recommended-dependencies


================================
Printing system config...
================================
System: Linux
Linux version: Ubuntu 20.04.4 LTS
Platform: Linux-5.15.0-40-generic-x86_64-with-glibc2.10
Processor: x86_64
Machine: x86_64
Python version: 3.8.12
Process name: python
Command: ['python', '-c', 'import monai; monai.config.print_debug_info()']
Open files: []
Num physical CPUs: 48
Num logical CPUs: 48
Num usable CPUs: 48
CPU usage (%): [5.4, 5.9, 4.9, 4.9, 5.4, 5.4, 4.9, 4.9, 4.4, 4.9, 4.9, 4.9, 5.4, 5.4, 5.4, 4.9, 4.9, 5.4, 4.9, 5.4, 5.4, 4.9, 5.9, 5.4, 5.4, 4.9, 5.4, 4.9, 4.9, 5.4, 4.9, 4.9, 5.4, 5.4, 5.4, 4.9, 5.4, 5.4, 4.9, 5.4, 4.9, 5.4, 5.4, 5.4, 4.9, 4.9, 4.9, 100.0]
CPU freq. (MHz): 2
Load avg. in last 1, 5, 15 mins (%): [1.7, 4.7, 3.4]
Disk usage (%): 19.3
Avg. sensor temp. (Celsius): UNKNOWN for given OS
Total physical memory (GB): 1007.8
Available memory (GB): 996.0
Used memory (GB): 5.5

================================
Printing GPU config...
================================
Num GPUs: 6
Has CUDA: True
CUDA version: 11.6
cuDNN enabled: True
cuDNN version: 8303
Current device: 0
Library compiled for CUDA architectures: ['sm_52', 'sm_60', 'sm_61', 'sm_70', 'sm_75', 'sm_80', 'sm_86', 'compute_86']
GPU 0 Name: NVIDIA RTX A6000
GPU 0 Is integrated: False
GPU 0 Is multi GPU board: False
GPU 0 Multi processor count: 84
GPU 0 Total memory (GB): 47.5
GPU 0 CUDA capability (maj.min): 8.6
GPU 1 Name: NVIDIA RTX A6000
GPU 1 Is integrated: False
GPU 1 Is multi GPU board: False
GPU 1 Multi processor count: 84
GPU 1 Total memory (GB): 47.5
GPU 1 CUDA capability (maj.min): 8.6
GPU 2 Name: NVIDIA RTX A6000
GPU 2 Is integrated: False
GPU 2 Is multi GPU board: False
GPU 2 Multi processor count: 84
GPU 2 Total memory (GB): 47.5
GPU 2 CUDA capability (maj.min): 8.6
GPU 3 Name: NVIDIA RTX A6000
GPU 3 Is integrated: False
GPU 3 Is multi GPU board: False
GPU 3 Multi processor count: 84
GPU 3 Total memory (GB): 47.5
GPU 3 CUDA capability (maj.min): 8.6
GPU 4 Name: NVIDIA RTX A6000
GPU 4 Is integrated: False
GPU 4 Is multi GPU board: False
GPU 4 Multi processor count: 84
GPU 4 Total memory (GB): 47.5
GPU 4 CUDA capability (maj.min): 8.6
GPU 5 Name: NVIDIA RTX A6000
GPU 5 Is integrated: False
GPU 5 Is multi GPU board: False
GPU 5 Multi processor count: 84
GPU 5 Total memory (GB): 47.5
GPU 5 CUDA capability (maj.min): 8.6

Additional context

The error happens always for the sliding window attention. Traceback for (640, 480):

Traceback (most recent call last):
  File "/home/saros-dataset-training/experiments/foo.py", line 15, in compute
    model(x)
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1111, in _call_impl
    return forward_call(*input, **kwargs)
  File "/opt/conda/lib/python3.8/site-packages/monai/networks/nets/swin_unetr.py", line 280, in forward
    hidden_states_out = self.swinViT(x_in, self.normalize)
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1111, in _call_impl
    return forward_call(*input, **kwargs)
  File "/opt/conda/lib/python3.8/site-packages/monai/networks/nets/swin_unetr.py", line 971, in forward
    x1 = self.layers1[0](x0.contiguous())
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1111, in _call_impl
    return forward_call(*input, **kwargs)
  File "/opt/conda/lib/python3.8/site-packages/monai/networks/nets/swin_unetr.py", line 855, in forward
    x = blk(x, attn_mask)
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1111, in _call_impl
    return forward_call(*input, **kwargs)
  File "/opt/conda/lib/python3.8/site-packages/monai/networks/nets/swin_unetr.py", line 653, in forward
    x = self.forward_part1(x, mask_matrix)
  File "/opt/conda/lib/python3.8/site-packages/monai/networks/nets/swin_unetr.py", line 590, in forward_part1
    x_windows = window_partition(shifted_x, window_size)
  File "/opt/conda/lib/python3.8/site-packages/monai/networks/nets/swin_unetr.py", line 323, in window_partition
    x = x.view(b, h // window_size[0], window_size[0], w // window_size[1], window_size[1], c)
RuntimeError: shape '[1, 46, 7, 34, 7, 24]' is invalid for input of size 1887600

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions