-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Description
🐛 Describe the bug
The ^= (XOR in-place) operation produces incorrect results on the MPS backend. The behavior is inconsistent with other backends, such as CPU. Specifically, the operation appears to modify unintended values in the tensor.
import torch
# On CPU
zeros = torch.zeros((10, 2), dtype=torch.int16, device="cpu")
zeros[:, 0] ^= 1
print(zeros) # Expected and correct output:
# tensor([[1, 0],
# [1, 0],
# [1, 0],
# [1, 0],
# [1, 0],
# [1, 0],
# [1, 0],
# [1, 0],
# [1, 0],
# [1, 0]], dtype=torch.int16)
# On MPS
zeros = torch.zeros((10, 2), dtype=torch.int16, device="mps")
zeros[:, 0] ^= 1
print(zeros) # Incorrect output:
# tensor([[1, 1],
# [1, 1],
# [1, 1],
# [1, 1],
# [1, 1],
# [0, 0],
# [0, 0],
# [0, 0],
# [0, 0],
# [0, 0]], device='mps:0', dtype=torch.int16)
# Non-in-place workaround
zeros = torch.zeros((10, 2), dtype=torch.int16, device="mps")
zeros[:, 0] = zeros[:, 0] ^ 1
print(zeros) # Correct output:
# tensor([[1, 0],
# [1, 0],
# [1, 0],
# [1, 0],
# [1, 0],
# [1, 0],
# [1, 0],
# [1, 0],
# [1, 0],
# [1, 0]], device='mps:0', dtype=torch.int16)
Versions
PyTorch version: 2.5.1
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A
OS: macOS 15.2 (arm64)
GCC version: Could not collect
Clang version: 16.0.0 (clang-1600.0.26.4)
CMake version: Could not collect
Libc version: N/A
Python version: 3.12.3 | packaged by conda-forge | (main, Apr 15 2024, 18:35:20) [Clang 16.0.6 ] (64-bit runtime)
Python platform: macOS-15.2-arm64-arm-64bit
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
CPU:
Apple M1 Max
Versions of relevant libraries:
[pip3] numpy==2.1.2
[pip3] onnx==1.17.0
[pip3] torch==2.5.1
[pip3] torchaudio==2.5.1
[pip3] torchvision==0.20.1
[conda] numpy 2.1.2 py312h801f5e3_0 conda-forge
[conda] pytorch 2.5.1 py3.12_0 pytorch
[conda] torchaudio 2.5.1 py312_cpu pytorch
[conda] torchvision 0.20.1 py312_cpu pytorch
cc @ezyang @gchanan @zou3519 @kadeng @msaroufim @kulinseth @albanD @malfet @DenisVieriu97 @jhavukainen