-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Description
🐛 Describe the bug
It looks like #129207 addressed an issue with the MPS implementation of nn.Conv1d. Specifically: The implementation would silently return incorrect results when running a convolution with more than 65536 channels. In the issue, the fix was to temporarily have nn.Conv1d throw NotImplementedError in this situation and to suggest using PYTORCH_ENABLE_MPS_FALLBACK = 1 so that the operation could fall back to the CPU.
However, it seems like this fallback does not work. Trying to run the linked issue's minimal example (slightly trimmed down) causes an error regardless if PYTORCH_ENABLE_MPS_FALLBACK is set or not:
# in main.py
import torch
import torch.nn as nn
import os
print(os.environ["PYTORCH_ENABLE_MPS_FALLBACK"]) # Prints 1, assuming the variable is set
torch.manual_seed(0)
conv = nn.Conv1d(1, 65537, 3, padding=1)
x = torch.ones([1, 1, 3])
y_mps = conv.to("mps")(x.to("mps")) # Fails with NotImplementedErrorRun with something like PYTORCH_ENABLE_MPS_FALLBACK=1 python3.11 main.py
Attempting to run this code always fails with the following, even when PYTORCH_ENABLE_MPS_FALLBACK is set to 1:
Traceback (most recent call last):
File "[path to main.py]", line 13, in <module>
y_mps = conv.to("mps")(x.to("mps")) # Fails with NotImplementedError
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/homebrew/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/homebrew/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/homebrew/lib/python3.11/site-packages/torch/nn/modules/conv.py", line 373, in forward
return self._conv_forward(input, self.weight, self.bias)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/homebrew/lib/python3.11/site-packages/torch/nn/modules/conv.py", line 368, in _conv_forward
return F.conv1d(
^^^^^^^^^
NotImplementedError: Output channels > 65536 not supported at the MPS device. As a temporary fix, you can set the environment variable `PYTORCH_ENABLE_MPS_FALLBACK=1` to use the CPU as a fallback for this op. WARNING: this will be slower than running natively on MPS.
Versions
Collecting environment information...
PyTorch version: 2.5.0.dev20240824
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A
OS: macOS 14.5 (arm64)
GCC version: Could not collect
Clang version: 15.0.0 (clang-1500.3.9.4)
CMake version: version 3.29.0
Libc version: N/A
Python version: 3.11.5 (main, Aug 24 2023, 15:09:32) [Clang 14.0.0 (clang-1400.0.29.202)] (64-bit runtime)
Python platform: macOS-14.5-arm64-arm-64bit
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
CPU:
Apple M1 Pro
Versions of relevant libraries:
[pip3] numpy==2.1.0
[pip3] torch==2.5.0.dev20240824
[pip3] torchaudio==2.4.0.dev20240824
[pip3] torchsde==0.2.6
[pip3] torchvision==0.20.0.dev20240824
[conda] Could not collect