Skip to content

MPS 16Bit Not Working correctly #78168

@justusschock

Description

@justusschock

🐛 Describe the bug

When i try to use half-precision together with the new mps backend, I get the following:

>>> import torch
>>> a = torch.rand(1, device='mps')
>>> a
tensor([0.4496], device='mps:0')
>>> a.item()
0.4495652914047241
>>> a.half()
tensor([4.4957e-01], device='mps:0', dtype=torch.float16)
>>> a.half().item()
0.084716796875

whereas with double precision it correctly fails with

"TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64. Please use float32 instead."

I would expect half-precision to either work correctly or to fail with a similar error message.

Versions

Collecting environment information...
PyTorch version: 1.13.0.dev20220523
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A

OS: macOS 12.3.1 (arm64)
GCC version: Could not collect
Clang version: 13.1.6 (clang-1316.0.21.2.3)
CMake version: version 3.23.1
Libc version: N/A

Python version: 3.9.12 (main, Apr 5 2022, 01:52:34) [Clang 12.0.0 ] (64-bit runtime)
Python platform: macOS-12.3.1-arm64-arm-64bit
Is CUDA available: False
CUDA runtime version: No CUDA
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

Versions of relevant libraries:
[pip3] mypy==0.950
[pip3] mypy-extensions==0.4.3
[pip3] numpy==1.22.4
[pip3] pytorch-lightning==1.7.0.dev0
[pip3] torch==1.13.0.dev20220523
[pip3] torchmetrics==0.8.2
[pip3] torchtext==0.12.0
[pip3] torchvision==0.12.0
[conda] numpy 1.22.4 pypi_0 pypi
[conda] pytorch-lightning 1.7.0.dev0 dev_0
[conda] torch 1.13.0.dev20220523 pypi_0 pypi
[conda] torchmetrics 0.8.2 pypi_0 pypi
[conda] torchtext 0.12.0 pypi_0 pypi
[conda] torchvision 0.12.0 pypi_0 pypi

Metadata

Metadata

Assignees

Labels

module: mpsRelated to Apple Metal Performance Shaders frameworktriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions