Skip to content

Commit 0f17aa9

Browse files
authored
2231 Fixes tutorial 353 (#2954)
* fixes tutorial 353 Signed-off-by: Wenqi Li <[email protected]> * adding type tests Signed-off-by: Wenqi Li <[email protected]> * improves type checks Signed-off-by: Wenqi Li <[email protected]> * fixes flake8 Signed-off-by: Wenqi Li <[email protected]> * fixes as channel first Signed-off-by: Wenqi Li <[email protected]> * type test option Signed-off-by: Wenqi Li <[email protected]> * ndarray suuport Signed-off-by: Wenqi Li <[email protected]> * fixes unit tests Signed-off-by: Wenqi Li <[email protected]> update Signed-off-by: Wenqi Li <[email protected]> * bash option for windows test Signed-off-by: Wenqi Li <[email protected]> * fixes unit tests Signed-off-by: Wenqi Li <[email protected]> * enhance norm intensity tests Signed-off-by: Wenqi Li <[email protected]> * fixes merge tests Signed-off-by: Wenqi Li <[email protected]>
1 parent 2f4b582 commit 0f17aa9

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

48 files changed

+195
-233
lines changed

.github/workflows/pythonapp.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,7 @@ jobs:
152152
python -c 'import torch; print(torch.__version__); print(torch.rand(5,3))'
153153
python -c "import monai; monai.config.print_config()"
154154
./runtests.sh --min
155+
shell: bash
155156
env:
156157
QUICKTEST: True
157158

monai/transforms/spatial/array.py

Lines changed: 17 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -471,7 +471,7 @@ def __call__(
471471
padding_mode: Optional[Union[GridSamplePadMode, str]] = None,
472472
align_corners: Optional[bool] = None,
473473
dtype: Union[DtypeLike, torch.dtype] = None,
474-
) -> torch.Tensor:
474+
) -> NdarrayOrTensor:
475475
"""
476476
Args:
477477
img: channel first array, must have shape: [chns, H, W] or [chns, H, W, D].
@@ -526,13 +526,11 @@ def __call__(
526526
align_corners=self.align_corners if align_corners is None else align_corners,
527527
reverse_indexing=True,
528528
)
529-
output: torch.Tensor = xform(
530-
img_t.unsqueeze(0),
531-
transform_t,
532-
spatial_size=output_shape,
533-
)
529+
output: torch.Tensor = xform(img_t.unsqueeze(0), transform_t, spatial_size=output_shape).squeeze(0)
534530
self._rotation_matrix = transform
535-
return output.squeeze(0).detach().float()
531+
out: NdarrayOrTensor
532+
out, *_ = convert_to_dst_type(output, dst=img, dtype=output.dtype)
533+
return out
536534

537535
def get_rotation_matrix(self) -> Optional[np.ndarray]:
538536
"""
@@ -799,7 +797,7 @@ def __call__(
799797
padding_mode: Optional[Union[GridSamplePadMode, str]] = None,
800798
align_corners: Optional[bool] = None,
801799
dtype: Union[DtypeLike, torch.dtype] = None,
802-
) -> torch.Tensor:
800+
) -> NdarrayOrTensor:
803801
"""
804802
Args:
805803
img: channel first array, must have shape 2D: (nchannels, H, W), or 3D: (nchannels, H, W, D).
@@ -1290,7 +1288,7 @@ def __call__(
12901288
grid: Optional[NdarrayOrTensor] = None,
12911289
mode: Optional[Union[GridSampleMode, str]] = None,
12921290
padding_mode: Optional[Union[GridSamplePadMode, str]] = None,
1293-
) -> torch.Tensor:
1291+
) -> NdarrayOrTensor:
12941292
"""
12951293
Args:
12961294
img: shape must be (num_channels, H, W[, D]).
@@ -1344,8 +1342,9 @@ def __call__(
13441342
padding_mode=self.padding_mode.value if padding_mode is None else GridSamplePadMode(padding_mode).value,
13451343
align_corners=True,
13461344
)[0]
1347-
1348-
return out
1345+
out_val: NdarrayOrTensor
1346+
out_val, *_ = convert_to_dst_type(out, dst=img, dtype=out.dtype)
1347+
return out_val
13491348

13501349

13511350
class Affine(Transform):
@@ -1425,7 +1424,7 @@ def __call__(
14251424
spatial_size: Optional[Union[Sequence[int], int]] = None,
14261425
mode: Optional[Union[GridSampleMode, str]] = None,
14271426
padding_mode: Optional[Union[GridSamplePadMode, str]] = None,
1428-
) -> Union[torch.Tensor, Tuple[torch.Tensor, NdarrayOrTensor]]:
1427+
) -> Union[NdarrayOrTensor, Tuple[NdarrayOrTensor, NdarrayOrTensor]]:
14291428
"""
14301429
Args:
14311430
img: shape must be (num_channels, H, W[, D]),
@@ -1589,7 +1588,7 @@ def __call__(
15891588
spatial_size: Optional[Union[Sequence[int], int]] = None,
15901589
mode: Optional[Union[GridSampleMode, str]] = None,
15911590
padding_mode: Optional[Union[GridSamplePadMode, str]] = None,
1592-
) -> torch.Tensor:
1591+
) -> NdarrayOrTensor:
15931592
"""
15941593
Args:
15951594
img: shape must be (num_channels, H, W[, D]),
@@ -1615,7 +1614,7 @@ def __call__(
16151614
grid = self.get_identity_grid(sp_size)
16161615
if self._do_transform:
16171616
grid = self.rand_affine_grid(grid=grid)
1618-
out: torch.Tensor = self.resampler(
1617+
out: NdarrayOrTensor = self.resampler(
16191618
img=img, grid=grid, mode=mode or self.mode, padding_mode=padding_mode or self.padding_mode
16201619
)
16211620
return out
@@ -1727,7 +1726,7 @@ def __call__(
17271726
spatial_size: Optional[Union[Tuple[int, int], int]] = None,
17281727
mode: Optional[Union[GridSampleMode, str]] = None,
17291728
padding_mode: Optional[Union[GridSamplePadMode, str]] = None,
1730-
) -> torch.Tensor:
1729+
) -> NdarrayOrTensor:
17311730
"""
17321731
Args:
17331732
img: shape must be (num_channels, H, W),
@@ -1756,7 +1755,7 @@ def __call__(
17561755
grid = CenterSpatialCrop(roi_size=sp_size)(grid[0])
17571756
else:
17581757
grid = create_grid(spatial_size=sp_size)
1759-
out: torch.Tensor = self.resampler(
1758+
out: NdarrayOrTensor = self.resampler(
17601759
img, grid, mode=mode or self.mode, padding_mode=padding_mode or self.padding_mode
17611760
)
17621761
return out
@@ -1877,7 +1876,7 @@ def __call__(
18771876
spatial_size: Optional[Union[Tuple[int, int, int], int]] = None,
18781877
mode: Optional[Union[GridSampleMode, str]] = None,
18791878
padding_mode: Optional[Union[GridSamplePadMode, str]] = None,
1880-
) -> torch.Tensor:
1879+
) -> NdarrayOrTensor:
18811880
"""
18821881
Args:
18831882
img: shape must be (num_channels, H, W, D),
@@ -1902,7 +1901,7 @@ def __call__(
19021901
offset = torch.as_tensor(self.rand_offset, device=self.device).unsqueeze(0)
19031902
grid[:3] += gaussian(offset)[0] * self.magnitude
19041903
grid = self.rand_affine_grid(grid=grid)
1905-
out: torch.Tensor = self.resampler(
1904+
out: NdarrayOrTensor = self.resampler(
19061905
img, grid, mode=mode or self.mode, padding_mode=padding_mode or self.padding_mode
19071906
)
19081907
return out

monai/transforms/spatial/dictionary.py

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -820,10 +820,6 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, N
820820
if do_resampling:
821821
d[key] = self.rand_affine.resampler(d[key], grid, mode=mode, padding_mode=padding_mode)
822822

823-
# if not doing transform and spatial size is unchanged, only need to do convert to torch
824-
else:
825-
d[key], *_ = convert_data_type(d[key], torch.Tensor, dtype=torch.float32, device=device)
826-
827823
return d
828824

829825
def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]:
@@ -1442,10 +1438,7 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, N
14421438
self.randomize()
14431439
d = dict(data)
14441440
angle: Union[Sequence[float], float] = self.x if d[self.keys[0]].ndim == 3 else (self.x, self.y, self.z)
1445-
rotator = Rotate(
1446-
angle=angle,
1447-
keep_size=self.keep_size,
1448-
)
1441+
rotator = Rotate(angle=angle, keep_size=self.keep_size)
14491442
for key, mode, padding_mode, align_corners, dtype in self.key_iterator(
14501443
d, self.mode, self.padding_mode, self.align_corners, self.dtype
14511444
):
@@ -1460,7 +1453,6 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, N
14601453
)
14611454
rot_mat = rotator.get_rotation_matrix()
14621455
else:
1463-
d[key], *_ = convert_data_type(d[key], torch.Tensor)
14641456
rot_mat = np.eye(d[key].ndim)
14651457
self.push_transform(
14661458
d,

monai/utils/type_conversion.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -248,11 +248,14 @@ def convert_data_type(
248248
return data, orig_type, orig_device
249249

250250

251-
def convert_to_dst_type(src: Any, dst: NdarrayOrTensor) -> Tuple[NdarrayOrTensor, type, Optional[torch.device]]:
251+
def convert_to_dst_type(
252+
src: Any, dst: NdarrayOrTensor, dtype: Optional[Union[DtypeLike, torch.dtype]] = None
253+
) -> Tuple[NdarrayOrTensor, type, Optional[torch.device]]:
252254
"""
253-
If `dst` is `torch.Tensor` or its subclass, convert `src` to `torch.Tensor` with the same data type as `dst`,
254-
if `dst` is `numpy.ndarray` or its subclass, convert to `numpy.ndarray` with the same data type as `dst`,
255+
If `dst` is an instance of `torch.Tensor` or its subclass, convert `src` to `torch.Tensor` with the same data type as `dst`,
256+
if `dst` is an instance of `numpy.ndarray` or its subclass, convert to `numpy.ndarray` with the same data type as `dst`,
255257
otherwise, convert to the type of `dst` directly.
258+
`dtype` is an optional argument if the target `dtype` is different from the original `dst`'s data type.
256259
257260
See Also:
258261
:func:`convert_data_type`
@@ -261,11 +264,14 @@ def convert_to_dst_type(src: Any, dst: NdarrayOrTensor) -> Tuple[NdarrayOrTensor
261264
if isinstance(dst, torch.Tensor):
262265
device = dst.device
263266

267+
if dtype is None:
268+
dtype = dst.dtype
269+
264270
output_type: Any
265271
if isinstance(dst, torch.Tensor):
266272
output_type = torch.Tensor
267273
elif isinstance(dst, np.ndarray):
268274
output_type = np.ndarray
269275
else:
270276
output_type = type(dst)
271-
return convert_data_type(data=src, output_type=output_type, device=device, dtype=dst.dtype)
277+
return convert_data_type(data=src, output_type=output_type, device=device, dtype=dtype)

tests/test_affine_grid.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ def test_affine_grid(self, input_param, input_data, expected_val):
115115
result, _ = g(**input_data)
116116
if "device" in input_data:
117117
self.assertEqual(result.device, input_data[device])
118-
assert_allclose(result, expected_val, rtol=1e-4, atol=1e-4)
118+
assert_allclose(result, expected_val, type_test=False, rtol=1e-4, atol=1e-4)
119119

120120

121121
if __name__ == "__main__":

tests/test_as_channel_first.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def test_value(self, in_type, input_param, expected_shape):
3434
if isinstance(test_data, torch.Tensor):
3535
test_data = test_data.cpu().numpy()
3636
expected = np.moveaxis(test_data, input_param["channel_dim"], 0)
37-
assert_allclose(expected, result)
37+
assert_allclose(result, expected, type_test=False)
3838

3939

4040
if __name__ == "__main__":

tests/test_ensure_type.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def test_array_input(self):
2929
if dtype == "NUMPY":
3030
self.assertTrue(result.dtype == np.float32)
3131
self.assertTrue(isinstance(result, torch.Tensor if dtype == "tensor" else np.ndarray))
32-
assert_allclose(result, test_data)
32+
assert_allclose(result, test_data, type_test=False)
3333
self.assertTupleEqual(result.shape, (2, 2))
3434

3535
def test_single_input(self):
@@ -43,7 +43,7 @@ def test_single_input(self):
4343
if isinstance(test_data, bool):
4444
self.assertFalse(result)
4545
else:
46-
assert_allclose(result, test_data)
46+
assert_allclose(result, test_data, type_test=False)
4747
self.assertEqual(result.ndim, 0)
4848

4949
def test_string(self):

tests/test_ensure_typed.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def test_array_input(self):
3434
if dtype == "NUMPY":
3535
self.assertTrue(result.dtype == np.float32)
3636
self.assertTrue(isinstance(result, torch.Tensor if dtype == "tensor" else np.ndarray))
37-
assert_allclose(result, test_data)
37+
assert_allclose(result, test_data, type_test=False)
3838
self.assertTupleEqual(result.shape, (2, 2))
3939

4040
def test_single_input(self):
@@ -48,7 +48,7 @@ def test_single_input(self):
4848
if isinstance(test_data, bool):
4949
self.assertFalse(result)
5050
else:
51-
assert_allclose(result, test_data)
51+
assert_allclose(result, test_data, type_test=False)
5252
self.assertEqual(result.ndim, 0)
5353

5454
def test_string(self):

tests/test_flip.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,12 +34,10 @@ def test_correct_results(self, _, spatial_axis):
3434
for p in TEST_NDARRAYS:
3535
im = p(self.imt[0])
3636
flip = Flip(spatial_axis=spatial_axis)
37-
expected = []
38-
for channel in self.imt[0]:
39-
expected.append(np.flip(channel, spatial_axis))
37+
expected = [np.flip(channel, spatial_axis) for channel in self.imt[0]]
4038
expected = np.stack(expected)
4139
result = flip(im)
42-
assert_allclose(expected, result)
40+
assert_allclose(result, p(expected))
4341

4442

4543
if __name__ == "__main__":

tests/test_flipd.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,12 +33,10 @@ def test_invalid_cases(self, _, spatial_axis, raises):
3333
def test_correct_results(self, _, spatial_axis):
3434
for p in TEST_NDARRAYS:
3535
flip = Flipd(keys="img", spatial_axis=spatial_axis)
36-
expected = []
37-
for channel in self.imt[0]:
38-
expected.append(np.flip(channel, spatial_axis))
36+
expected = [np.flip(channel, spatial_axis) for channel in self.imt[0]]
3937
expected = np.stack(expected)
4038
result = flip({"img": p(self.imt[0])})["img"]
41-
assert_allclose(expected, result)
39+
assert_allclose(result, p(expected))
4240

4341

4442
if __name__ == "__main__":

0 commit comments

Comments
 (0)