Skip to content

Problem relevant to PatchMerging in 3D SWIN-UNETR #4757

@wyli

Description

@wyli

Discussed in #4753

Originally posted by GYDDHPY July 23, 2022
In the 3D forwrd function of PatchMerging, it seems x3 equals to x6, which is a little bit beird.
Is this a special design for 3D image or just a code error?

class PatchMerging(nn.Module):
"""
Patch merging layer based on: "Liu et al.,
Swin Transformer: Hierarchical Vision Transformer using Shifted Windows
<https://arxiv.org/abs/2103.14030>"
https://github.com/microsoft/Swin-Transformer
"""
def __init__(self, dim: int, norm_layer: Type[LayerNorm] = nn.LayerNorm, spatial_dims: int = 3) -> None:
"""
Args:
dim: number of feature channels.
norm_layer: normalization layer.
spatial_dims: number of spatial dims.
"""
super().__init__()
self.dim = dim
if spatial_dims == 3:
self.reduction = nn.Linear(8 * dim, 2 * dim, bias=False)
self.norm = norm_layer(8 * dim)
elif spatial_dims == 2:
self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
self.norm = norm_layer(4 * dim)
def forward(self, x):
x_shape = x.size()
if len(x_shape) == 5:
b, d, h, w, c = x_shape
pad_input = (h % 2 == 1) or (w % 2 == 1) or (d % 2 == 1)
if pad_input:
x = F.pad(x, (0, 0, 0, w % 2, 0, h % 2, 0, d % 2))
x0 = x[:, 0::2, 0::2, 0::2, :]
x1 = x[:, 1::2, 0::2, 0::2, :]
x2 = x[:, 0::2, 1::2, 0::2, :]
x3 = x[:, 0::2, 0::2, 1::2, :]
x4 = x[:, 1::2, 0::2, 1::2, :]
x5 = x[:, 0::2, 1::2, 0::2, :]
x6 = x[:, 0::2, 0::2, 1::2, :]
x7 = x[:, 1::2, 1::2, 1::2, :]
x = torch.cat([x0, x1, x2, x3, x4, x5, x6, x7], -1)

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions