Skip to content

[MPS] Empty float32 input to nn.BatchNorm2d mistaken as float64 #134423

@hvaara

Description

@hvaara

🐛 Describe the bug

This was discovered while annotating failures in #134184. For a similar issue from when test_module.py was included in the test suite see #100914.

minimal repro:

import torch

dtype = torch.float32

mod_cpu = torch.nn.BatchNorm2d(3, device='cpu')
mod_mps = torch.nn.BatchNorm2d(3, device='mps')

inp_cpu = torch.randn(0, 3, 2, 2, device='cpu', dtype=dtype)
inp_mps = inp_cpu.detach().clone().to('mps')

res_cpu = mod_cpu(inp_cpu)  # passes
res_mps = mod_mps(inp_mps)  # TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64. Please use float32 instead.

Versions

PyTorch version: 2.5.0a0+gitc19005d
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A

OS: macOS 14.6.1 (arm64)
GCC version: Could not collect
Clang version: 15.0.0 (clang-1500.3.9.4)
CMake version: version 3.30.1
Libc version: N/A

Python version: 3.12.4 | packaged by conda-forge | (main, Jun 17 2024, 10:13:44) [Clang 16.0.6 ] (64-bit runtime)
Python platform: macOS-14.6.1-arm64-arm-64bit
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Apple M3 Max

Versions of relevant libraries:
[pip3] flake8==6.1.0
[pip3] flake8-bugbear==23.3.23
[pip3] flake8-comprehensions==3.15.0
[pip3] flake8-executable==2.1.3
[pip3] flake8-logging-format==0.9.0
[pip3] flake8-pyi==23.3.1
[pip3] flake8-simplify==0.19.3
[pip3] mypy==1.10.0
[pip3] mypy-extensions==1.0.0
[pip3] numpy==2.0.1
[pip3] optree==0.12.1
[pip3] torch==2.5.0a0+gitc19005d
[pip3] torch-tb-profiler==0.4.3
[pip3] torchvision==0.20.0a0+0d80848
[pip3] triton==3.0.0
[conda] numpy 2.0.1 pypi_0 pypi
[conda] optree 0.12.1 pypi_0 pypi
[conda] torch 2.5.0a0+gitc19005d dev_0
[conda] torch-tb-profiler 0.4.3 pypi_0 pypi
[conda] torchfix 0.4.0 pypi_0 pypi
[conda] torchvision 0.20.0a0+0d80848 dev_0
[conda] triton 3.0.0 pypi_0 pypi

cc @malfet @kulinseth @albanD @DenisVieriu97 @jhavukainen

Metadata

Metadata

Assignees

Labels

module: empty tensormodule: error checkingBugs related to incorrect/lacking error checkingmodule: mpsRelated to Apple Metal Performance Shaders frameworktriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions