@@ -471,7 +471,7 @@ def __call__(
471471 padding_mode : Optional [Union [GridSamplePadMode , str ]] = None ,
472472 align_corners : Optional [bool ] = None ,
473473 dtype : Union [DtypeLike , torch .dtype ] = None ,
474- ) -> torch . Tensor :
474+ ) -> NdarrayOrTensor :
475475 """
476476 Args:
477477 img: channel first array, must have shape: [chns, H, W] or [chns, H, W, D].
@@ -526,13 +526,11 @@ def __call__(
526526 align_corners = self .align_corners if align_corners is None else align_corners ,
527527 reverse_indexing = True ,
528528 )
529- output : torch .Tensor = xform (
530- img_t .unsqueeze (0 ),
531- transform_t ,
532- spatial_size = output_shape ,
533- )
529+ output : torch .Tensor = xform (img_t .unsqueeze (0 ), transform_t , spatial_size = output_shape ).squeeze (0 )
534530 self ._rotation_matrix = transform
535- return output .squeeze (0 ).detach ().float ()
531+ out : NdarrayOrTensor
532+ out , * _ = convert_to_dst_type (output , dst = img , dtype = output .dtype )
533+ return out
536534
537535 def get_rotation_matrix (self ) -> Optional [np .ndarray ]:
538536 """
@@ -799,7 +797,7 @@ def __call__(
799797 padding_mode : Optional [Union [GridSamplePadMode , str ]] = None ,
800798 align_corners : Optional [bool ] = None ,
801799 dtype : Union [DtypeLike , torch .dtype ] = None ,
802- ) -> torch . Tensor :
800+ ) -> NdarrayOrTensor :
803801 """
804802 Args:
805803 img: channel first array, must have shape 2D: (nchannels, H, W), or 3D: (nchannels, H, W, D).
@@ -1290,7 +1288,7 @@ def __call__(
12901288 grid : Optional [NdarrayOrTensor ] = None ,
12911289 mode : Optional [Union [GridSampleMode , str ]] = None ,
12921290 padding_mode : Optional [Union [GridSamplePadMode , str ]] = None ,
1293- ) -> torch . Tensor :
1291+ ) -> NdarrayOrTensor :
12941292 """
12951293 Args:
12961294 img: shape must be (num_channels, H, W[, D]).
@@ -1344,8 +1342,9 @@ def __call__(
13441342 padding_mode = self .padding_mode .value if padding_mode is None else GridSamplePadMode (padding_mode ).value ,
13451343 align_corners = True ,
13461344 )[0 ]
1347-
1348- return out
1345+ out_val : NdarrayOrTensor
1346+ out_val , * _ = convert_to_dst_type (out , dst = img , dtype = out .dtype )
1347+ return out_val
13491348
13501349
13511350class Affine (Transform ):
@@ -1425,7 +1424,7 @@ def __call__(
14251424 spatial_size : Optional [Union [Sequence [int ], int ]] = None ,
14261425 mode : Optional [Union [GridSampleMode , str ]] = None ,
14271426 padding_mode : Optional [Union [GridSamplePadMode , str ]] = None ,
1428- ) -> Union [torch . Tensor , Tuple [torch . Tensor , NdarrayOrTensor ]]:
1427+ ) -> Union [NdarrayOrTensor , Tuple [NdarrayOrTensor , NdarrayOrTensor ]]:
14291428 """
14301429 Args:
14311430 img: shape must be (num_channels, H, W[, D]),
@@ -1589,7 +1588,7 @@ def __call__(
15891588 spatial_size : Optional [Union [Sequence [int ], int ]] = None ,
15901589 mode : Optional [Union [GridSampleMode , str ]] = None ,
15911590 padding_mode : Optional [Union [GridSamplePadMode , str ]] = None ,
1592- ) -> torch . Tensor :
1591+ ) -> NdarrayOrTensor :
15931592 """
15941593 Args:
15951594 img: shape must be (num_channels, H, W[, D]),
@@ -1615,7 +1614,7 @@ def __call__(
16151614 grid = self .get_identity_grid (sp_size )
16161615 if self ._do_transform :
16171616 grid = self .rand_affine_grid (grid = grid )
1618- out : torch . Tensor = self .resampler (
1617+ out : NdarrayOrTensor = self .resampler (
16191618 img = img , grid = grid , mode = mode or self .mode , padding_mode = padding_mode or self .padding_mode
16201619 )
16211620 return out
@@ -1727,7 +1726,7 @@ def __call__(
17271726 spatial_size : Optional [Union [Tuple [int , int ], int ]] = None ,
17281727 mode : Optional [Union [GridSampleMode , str ]] = None ,
17291728 padding_mode : Optional [Union [GridSamplePadMode , str ]] = None ,
1730- ) -> torch . Tensor :
1729+ ) -> NdarrayOrTensor :
17311730 """
17321731 Args:
17331732 img: shape must be (num_channels, H, W),
@@ -1756,7 +1755,7 @@ def __call__(
17561755 grid = CenterSpatialCrop (roi_size = sp_size )(grid [0 ])
17571756 else :
17581757 grid = create_grid (spatial_size = sp_size )
1759- out : torch . Tensor = self .resampler (
1758+ out : NdarrayOrTensor = self .resampler (
17601759 img , grid , mode = mode or self .mode , padding_mode = padding_mode or self .padding_mode
17611760 )
17621761 return out
@@ -1877,7 +1876,7 @@ def __call__(
18771876 spatial_size : Optional [Union [Tuple [int , int , int ], int ]] = None ,
18781877 mode : Optional [Union [GridSampleMode , str ]] = None ,
18791878 padding_mode : Optional [Union [GridSamplePadMode , str ]] = None ,
1880- ) -> torch . Tensor :
1879+ ) -> NdarrayOrTensor :
18811880 """
18821881 Args:
18831882 img: shape must be (num_channels, H, W, D),
@@ -1902,7 +1901,7 @@ def __call__(
19021901 offset = torch .as_tensor (self .rand_offset , device = self .device ).unsqueeze (0 )
19031902 grid [:3 ] += gaussian (offset )[0 ] * self .magnitude
19041903 grid = self .rand_affine_grid (grid = grid )
1905- out : torch . Tensor = self .resampler (
1904+ out : NdarrayOrTensor = self .resampler (
19061905 img , grid , mode = mode or self .mode , padding_mode = padding_mode or self .padding_mode
19071906 )
19081907 return out
0 commit comments