Skip to content

[mps] [PyTorch 2.0] LayerNorm crashes when input is in float16 #96113

@pcuenca

Description

@pcuenca

🐛 Describe the bug

As stated in the title, the following crashes when using the mps device:

ln = nn.LayerNorm((768,), elementwise_affine=True).to("mps")
ln(torch.randn(1, 77, 768).to("mps", dtype=torch.float16))

Versions

Collecting environment information...
PyTorch version: 2.1.0.dev20230306
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A

OS: macOS 13.2 (arm64)
GCC version: Could not collect
Clang version: 14.0.0 (clang-1400.0.29.202)
CMake version: Could not collect
Libc version: N/A

Python version: 3.9.13 | packaged by conda-forge | (main, May 27 2022, 17:01:00) [Clang 13.0.1 ] (64-bit runtime)
Python platform: macOS-13.2-arm64-arm-64bit
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Apple M1 Max

Versions of relevant libraries:
[pip3] functorch==0.2.0a0+b2c9d60
[pip3] mypy-extensions==0.4.3
[pip3] numpy==1.23.2
[pip3] torch==2.1.0.dev20230306
[pip3] torchinfo==1.7.0
[conda] numpy 1.23.2 pypi_0 pypi
[conda] torch 2.1.0.dev20230306 pypi_0 pypi
[conda] torchinfo 1.7.0 pypi_0 pypi

cc @ezyang @gchanan @zou3519 @kulinseth @albanD @malfet @DenisVieriu97 @razarmehr @abhudev

Metadata

Metadata

Assignees

Labels

module: mpsRelated to Apple Metal Performance Shaders frameworktriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

No projects

Milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions