-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Description
🐛 Describe the bug
Hello,
I'm on Macbook Air M2 and I would like to use MPS.
However, I’ve encountered an issue with torch.count_nonzero
import torch
device = "mps"
X = torch.as_tensor([1,2],device=device)
stacked_X = torch.stack([X/2, X/2], dim=-1) # wihtout this line there is no bug
empty_tensor = torch.empty(0, 2, 2, device=device)
mask = X > empty_tensor
print(mask)
print(torch.count_nonzero(mask))Result :
tensor([], device='mps:0', size=(0, 2, 2), dtype=torch.bool)
tensor(4575657222465388544, device='mps:0')
The first print statement shows that mask is empty, but torch.count_nonzero returns a large integer (something like a max int).
If I remove line 5 (the stacked_X line), the code works correctly.
Also, when the device is set to "cpu", the code works as expected:
tensor([], size=(0, 2, 2), dtype=torch.bool)
tensor(0)
torch.count_nonzero must be the root cause because when the code end by :
print(mask)
print(torch.count_nonzero(mask.cpu()))Results are correct :
tensor([], device='mps:0', size=(0, 2, 2), dtype=torch.bool)
tensor(0)
Thank you
Versions
PyTorch version: 2.5.1
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A
OS: macOS 15.1 (arm64)
GCC version: Could not collect
Clang version: 16.0.0 (clang-1600.0.26.4)
CMake version: Could not collect
Libc version: N/A
Python version: 3.11.10 | packaged by conda-forge | (main, Oct 16 2024, 01:26:25) [Clang 17.0.6 ] (64-bit runtime)
Python platform: macOS-15.1-arm64-arm-64bit
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
CPU:
Apple M2
Versions of relevant libraries:
[pip3] numpy==2.1.2
[pip3] torch==2.5.1
[pip3] torch_mas==0.1
[pip3] torchaudio==2.5.1
[pip3] torchvision==0.20.1
[conda] numpy 2.1.2 pypi_0 pypi
[conda] torch 2.5.1 pypi_0 pypi
[conda] torch-mas 0.1 dev_0
[conda] torchaudio 2.5.1 pypi_0 pypi
[conda] torchvision 0.20.1 pypi_0 pypi