Skip to content

PYTORCH_ENABLE_MPS_FALLBACK does not appear to work for nn.Conv1d #134416

@a2aaron

Description

@a2aaron

🐛 Describe the bug

It looks like #129207 addressed an issue with the MPS implementation of nn.Conv1d. Specifically: The implementation would silently return incorrect results when running a convolution with more than 65536 channels. In the issue, the fix was to temporarily have nn.Conv1d throw NotImplementedError in this situation and to suggest using PYTORCH_ENABLE_MPS_FALLBACK = 1 so that the operation could fall back to the CPU.

However, it seems like this fallback does not work. Trying to run the linked issue's minimal example (slightly trimmed down) causes an error regardless if PYTORCH_ENABLE_MPS_FALLBACK is set or not:

# in main.py
import torch
import torch.nn as nn
import os

print(os.environ["PYTORCH_ENABLE_MPS_FALLBACK"]) # Prints 1, assuming the variable is set

torch.manual_seed(0)
conv = nn.Conv1d(1, 65537, 3, padding=1)

x = torch.ones([1, 1, 3])
y_mps = conv.to("mps")(x.to("mps")) # Fails with NotImplementedError

Run with something like PYTORCH_ENABLE_MPS_FALLBACK=1 python3.11 main.py

Attempting to run this code always fails with the following, even when PYTORCH_ENABLE_MPS_FALLBACK is set to 1:

Traceback (most recent call last):
  File "[path to main.py]", line 13, in <module>
    y_mps = conv.to("mps")(x.to("mps"))  # Fails with NotImplementedError
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/homebrew/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/homebrew/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/homebrew/lib/python3.11/site-packages/torch/nn/modules/conv.py", line 373, in forward
    return self._conv_forward(input, self.weight, self.bias)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/homebrew/lib/python3.11/site-packages/torch/nn/modules/conv.py", line 368, in _conv_forward
    return F.conv1d(
           ^^^^^^^^^
NotImplementedError: Output channels > 65536 not supported at the MPS device. As a temporary fix, you can set the environment variable `PYTORCH_ENABLE_MPS_FALLBACK=1` to use the CPU as a fallback for this op. WARNING: this will be slower than running natively on MPS.

Versions

Collecting environment information...
PyTorch version: 2.5.0.dev20240824
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A

OS: macOS 14.5 (arm64)
GCC version: Could not collect
Clang version: 15.0.0 (clang-1500.3.9.4)
CMake version: version 3.29.0
Libc version: N/A

Python version: 3.11.5 (main, Aug 24 2023, 15:09:32) [Clang 14.0.0 (clang-1400.0.29.202)] (64-bit runtime)
Python platform: macOS-14.5-arm64-arm-64bit
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Apple M1 Pro

Versions of relevant libraries:
[pip3] numpy==2.1.0
[pip3] torch==2.5.0.dev20240824
[pip3] torchaudio==2.4.0.dev20240824
[pip3] torchsde==0.2.6
[pip3] torchvision==0.20.0.dev20240824
[conda] Could not collect

cc @kulinseth @albanD @malfet @DenisVieriu97 @jhavukainen

Metadata

Metadata

Assignees

No one assigned

    Labels

    module: mpsRelated to Apple Metal Performance Shaders frameworkmodule: regressionIt used to work, and now it doesn'ttriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions