Skip to content

Commit fb65555

Browse files
committed
Forced Fourier class to output contiguous() tensors.
1 parent f1ef3e8 commit fb65555

File tree

1 file changed

+6
-6
lines changed

1 file changed

+6
-6
lines changed

monai/transforms/utils.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1649,12 +1649,12 @@ def shift_fourier(x: NdarrayOrTensor, spatial_dims: int) -> NdarrayOrTensor:
16491649
k: NdarrayOrTensor
16501650
if isinstance(x, torch.Tensor):
16511651
if hasattr(torch.fft, "fftshift"): # `fftshift` is new in torch 1.8.0
1652-
k = torch.fft.fftshift(torch.fft.fftn(x, dim=dims), dim=dims)
1652+
k = torch.fft.fftshift(torch.fft.fftn(x, dim=dims), dim=dims).contiguous()
16531653
else:
16541654
# if using old PyTorch, will convert to numpy array and return
1655-
k = np.fft.fftshift(np.fft.fftn(x.cpu().numpy(), axes=dims), axes=dims)
1655+
k = np.ascontiguousarray(np.fft.fftshift(np.fft.fftn(x.cpu().numpy(), axes=dims), axes=dims))
16561656
else:
1657-
k = np.fft.fftshift(np.fft.fftn(x, axes=dims), axes=dims)
1657+
k = np.ascontiguousarray(np.fft.fftshift(np.fft.fftn(x, axes=dims), axes=dims))
16581658
return k
16591659

16601660
@staticmethod
@@ -1674,12 +1674,12 @@ def inv_shift_fourier(k: NdarrayOrTensor, spatial_dims: int, n_dims: int | None
16741674
out: NdarrayOrTensor
16751675
if isinstance(k, torch.Tensor):
16761676
if hasattr(torch.fft, "ifftshift"): # `ifftshift` is new in torch 1.8.0
1677-
out = torch.fft.ifftn(torch.fft.ifftshift(k, dim=dims), dim=dims, norm="backward").real
1677+
out = torch.fft.ifftn(torch.fft.ifftshift(k, dim=dims), dim=dims, norm="backward").real.contiguous()
16781678
else:
16791679
# if using old PyTorch, will convert to numpy array and return
1680-
out = np.fft.ifftn(np.fft.ifftshift(k.cpu().numpy(), axes=dims), axes=dims).real
1680+
out = np.ascontiguousarray(np.fft.ifftn(np.fft.ifftshift(k.cpu().numpy(), axes=dims), axes=dims).real)
16811681
else:
1682-
out = np.fft.ifftn(np.fft.ifftshift(k, axes=dims), axes=dims).real
1682+
out = np.ascontiguousarray(np.fft.ifftn(np.fft.ifftshift(k, axes=dims), axes=dims).real)
16831683
return out
16841684

16851685

0 commit comments

Comments
 (0)