|
1 | 1 | import math |
2 | 2 | from functools import partial |
3 | | -from typing import Any, Callable, Dict, List, Optional |
| 3 | +from typing import Any, Callable, List, Optional, Union |
4 | 4 |
|
5 | 5 | import torch |
6 | 6 | import torch.nn.functional as F |
@@ -143,18 +143,21 @@ def shifted_window_attention( |
143 | 143 | Tensor[N, H, W, C]: The output tensor after shifted window attention. |
144 | 144 | """ |
145 | 145 | B, H, W, C = input.shape |
| 146 | + |
| 147 | + # If window size is larger than feature size, there is no need to shift window |
| 148 | + if window_size[0] >= H: |
| 149 | + shift_size[0] = 0 |
| 150 | + window_size[0] = H |
| 151 | + if window_size[1] >= W: |
| 152 | + shift_size[1] = 0 |
| 153 | + window_size[1] = W |
| 154 | + |
146 | 155 | # pad feature maps to multiples of window size |
147 | 156 | pad_r = (window_size[1] - W % window_size[1]) % window_size[1] |
148 | 157 | pad_b = (window_size[0] - H % window_size[0]) % window_size[0] |
149 | 158 | x = F.pad(input, (0, 0, 0, pad_r, 0, pad_b)) |
150 | 159 | _, pad_H, pad_W, _ = x.shape |
151 | 160 |
|
152 | | - # If window size is larger than feature size, there is no need to shift window |
153 | | - if window_size[0] >= pad_H: |
154 | | - shift_size[0] = 0 |
155 | | - if window_size[1] >= pad_W: |
156 | | - shift_size[1] = 0 |
157 | | - |
158 | 161 | # cyclic shift |
159 | 162 | if sum(shift_size) > 0: |
160 | 163 | x = torch.roll(x, shifts=(-shift_size[0], -shift_size[1]), dims=(1, 2)) |
@@ -479,7 +482,7 @@ class SwinTransformer(nn.Module): |
479 | 482 | embed_dim (int): Patch embedding dimension. |
480 | 483 | depths (List(int)): Depth of each Swin Transformer layer. |
481 | 484 | num_heads (List(int)): Number of attention heads in different layers. |
482 | | - window_size (List[int]): Window size. |
| 485 | + window_size (int, List[int]): Window size. |
483 | 486 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.0. |
484 | 487 | stochastic_depth_prob (float): Stochastic depth rate. Default: 0.1. |
485 | 488 | num_classes (int): Number of classes for classification head. Default: 1000. |
|
0 commit comments