Skip to content

Rand3DElasticd does not use correct device #5211

@razorx89

Description

@razorx89

I am trying to do a random elastic deformation on the GPU during the training step within a Pytorch Lightning distributed environment. On creation of the Rand3DElasticd instance I am leaving the device parameter as default (None), because at this point I am not aware of the device on which the training will be executed. For Rand3DElastic there is already a lookup of the device of the supplied tensor in __call__, however, in Rand3DElasticd the configured device of Rand3DElastic will be used, which is undefined. Thus, the grid instantiation and resampling will always happen on the CPU.

_device = self.rand_3d_elastic.device
grid = create_grid(spatial_size=sp_size, device=_device, backend="torch")

_device = img.device if isinstance(img, torch.Tensor) else self.device
grid = create_grid(spatial_size=sp_size, device=_device, backend="torch")

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions