-
Notifications
You must be signed in to change notification settings - Fork 1.5k
Rand3DElasticd does not use correct device #5211
Description
I am trying to do a random elastic deformation on the GPU during the training step within a Pytorch Lightning distributed environment. On creation of the Rand3DElasticd instance I am leaving the device parameter as default (None), because at this point I am not aware of the device on which the training will be executed. For Rand3DElastic there is already a lookup of the device of the supplied tensor in __call__, however, in Rand3DElasticd the configured device of Rand3DElastic will be used, which is undefined. Thus, the grid instantiation and resampling will always happen on the CPU.
MONAI/monai/transforms/spatial/dictionary.py
Lines 1137 to 1138 in a3f3504
| _device = self.rand_3d_elastic.device | |
| grid = create_grid(spatial_size=sp_size, device=_device, backend="torch") |
MONAI/monai/transforms/spatial/array.py
Lines 2856 to 2857 in a3f3504
| _device = img.device if isinstance(img, torch.Tensor) else self.device | |
| grid = create_grid(spatial_size=sp_size, device=_device, backend="torch") |