Skip to content

Conversation

@delock
Copy link
Collaborator

@delock delock commented Sep 28, 2025

This PR fix a bug that in some place get_accelerator().current_device() are used instead of get_accelerator().current_device_name(). This would be mostly fine but on CPU this won't work

torch.empty(3, device=get_accelerator().current_device() <-- won't work other than CUDA device
torch.empty(3, device=torch.device(get_accelerator().current_device())) <-- works for GPU device, but won't work for CPU
torch.empty(3, device=torch.device(get_accelerator().current_device_name())) <-- works for both GPU device and CPU
torch.empty(3, device=get_accelerator().current_device_name()) <-- this also works, but not as formal as the last one.

This bug is exposed when I tried to run AutoTP training on Xeon server for debug purpose.

@delock delock force-pushed the gma/fix_device_name branch from dcd8b56 to 462b2c0 Compare September 28, 2025 10:17
@tohtana tohtana merged commit 66c7031 into master Sep 28, 2025
12 checks passed
@tohtana tohtana deleted the gma/fix_device_name branch September 28, 2025 17:19
mauryaavinash95 pushed a commit to DataStates/DeepSpeed that referenced this pull request Oct 4, 2025
This PR fix a bug that in some place get_accelerator().current_device()
are used instead of get_accelerator().current_device_name(). This would
be mostly fine but on CPU this won't work

`torch.empty(3, device=get_accelerator().current_device()` <-- won't
work other than CUDA device
`torch.empty(3,
device=torch.device(get_accelerator().current_device()))` <-- works for
GPU device, but won't work for CPU
`torch.empty(3,
device=torch.device(get_accelerator().current_device_name()))` <-- works
for both GPU device and CPU
`torch.empty(3, device=get_accelerator().current_device_name())` <--
this also works, but not as formal as the last one.

This bug is exposed when I tried to run AutoTP training on Xeon server
for debug purpose.

---------

Signed-off-by: Guokai Ma <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants