-
Notifications
You must be signed in to change notification settings - Fork 1.5k
enhance the sliding window devices #5345
Copy link
Copy link
Closed
Description
when inputs is on a cuda device:
inputs = data_type(torch.ones((3, 16, 15, 7))).to(device="cuda")and the sliding window is defined with (device="cpu"):
output = sliding_window_inference(inputs, ..., sw_device=None, device="cpu")the output should be on device="cpu". But it's currently implemented as:
Line 282 in e37de69
| final_output = convert_to_dst_type(final_output, inputs)[0] # type: ignore |
to reproduce the issue:
>>> import torch
>>> from monai.inferers import sliding_window_inference
>>> from monai.data import MetaTensor
>>> a = MetaTensor(torch.zeros(1,1,5,5)).to("cuda")
>>> sliding_window_inference(a, (3,3), 1, lambda x: x, sw_device="cpu", device="cpu").device
device(type='cuda', index=0)the final device should be 'cpu'
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels