-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Description
🐛 Describe the bug
Using Conv1d with out_channels >65536 in MPS gives different outputs than in CPU. Here is a seeded example code:
import torch
import torch.nn as nn
torch.manual_seed(0)
conv = nn.Conv1d(1, 65537, 3, padding=1)
x = torch.ones([1, 1, 3])
y_cpu = conv.to("cpu")(x.to("cpu"))
y_mps = conv.to("mps")(x.to("mps"))
print(y_cpu)
print(y_mps)
print("Equal:", torch.equal(y_cpu, y_mps.to("cpu")))The output:
tensor([[[ 0.373, 0.369, 0.844],
[-0.426, -0.851, -1.006],
[ 0.309, 0.298, 0.349],
...,
[ 0.040, 0.351, 0.010],
[-0.451, -0.558, -0.234],
[-0.573, -0.567, -0.568]]], grad_fn=<ConvolutionBackward0>)
tensor([[[ 0.163, 0.169, 0.169],
[-0.426, -0.851, -1.006],
[ 0.309, 0.298, 0.349],
...,
[ 0.040, 0.351, 0.010],
[-0.451, -0.558, -0.234],
[ 0.000, 0.000, 0.000]]], device='mps:0', grad_fn=<ConvolutionBackward0>)
Equal: False
The example output shows the calculation difference between mps and cpu. And, if you change the out_channel to <=65536, things work fine.
--
Edit: I tested Conv2d too. It has the same problem as well.
--
Edit2: Here is a more controlled test script:
import torch
import torch.nn.functional as F
torch.manual_seed(0)
out_channels = 65537
weight = torch.randn(out_channels, 1, 1)
x = torch.ones([1, 1, 1])
print(F.conv1d(x.to('cpu'), weight.to('cpu'))) # tensor([[[-1.126], [-1.152], [-0.251], ..., [ 0.275], [ 0.159], [-0.037]]])
print(F.conv1d(x.to('mps'), weight.to('mps'))) # tensor([[[-0.037], [-1.152], [-0.251], ..., [ 0.275], [ 0.159], [-0.564]]], device='mps:0')Versions
PyTorch version: 2.3.1
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A
OS: macOS 14.4.1 (arm64)
GCC version: Could not collect
Clang version: 15.0.0 (clang-1500.3.9.4)
CMake version: version 3.29.3
Libc version: N/A
Python version: 3.11.9 (main, Apr 3 2024, 20:18:58) [Clang 15.0.0 (clang-1500.3.9.4)] (64-bit runtime)
Python platform: macOS-14.4.1-arm64-arm-64bit
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
CPU:
Apple M3 Max
Versions of relevant libraries:
[pip3] numpy==1.26.4
[pip3] onnx==1.16.0
[pip3] onnx2torch==1.5.14
[pip3] onnxruntime==1.18.0
[pip3] open-clip-torch==2.24.0
[pip3] optree==0.11.0
[pip3] pytorch-lightning==2.2.3
[pip3] pytorch-metric-learning==2.5.0
[pip3] rotary-embedding-torch==0.6.2
[pip3] torch==2.3.1
[pip3] torch-audiomentations==0.11.1
[pip3] torch-pitch-shift==1.2.4
[pip3] torch-snake==0.1.0
[pip3] torch-tb-profiler==0.4.3
[pip3] torchaudio==2.3.1
[pip3] torchdiffeq==0.2.3
[pip3] torchinfo==1.8.0
[pip3] torchmetrics==1.3.2
[pip3] torchsde==0.2.6
[pip3] torchtext==0.18.0
[pip3] torchvision==0.18.1
[conda] Could not collect
cc @ezyang @gchanan @zou3519 @kadeng @msaroufim @kulinseth @albanD @malfet @DenisVieriu97 @jhavukainen