Skip to content

enhance the sliding window devices #5345

@wyli

Description

@wyli

when inputs is on a cuda device:

inputs = data_type(torch.ones((3, 16, 15, 7))).to(device="cuda")

and the sliding window is defined with (device="cpu"):

output = sliding_window_inference(inputs, ..., sw_device=None, device="cpu")

the output should be on device="cpu". But it's currently implemented as:

final_output = convert_to_dst_type(final_output, inputs)[0] # type: ignore


to reproduce the issue:

>>> import torch
>>> from monai.inferers import sliding_window_inference
>>> from monai.data import MetaTensor
>>> a = MetaTensor(torch.zeros(1,1,5,5)).to("cuda")
>>> sliding_window_inference(a, (3,3), 1, lambda x: x, sw_device="cpu", device="cpu").device
device(type='cuda', index=0)

the final device should be 'cpu'

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions