Skip to content

MPS backend produces bad training results in comparison to other backends #92615

@paulkre

Description

@paulkre

🐛 Describe the bug

Using the MPS backend to train a model produces much worse results than using other backends (e.g. CPU or CUDA). To be clear, I am not talking about the speed of the training, but rather about the metrics for the quality (loss, perplexity) of the model after it has been trained. For example, if you run the training of the word_language_model in the pytorch/examples repository with either the CPU or the CUDA backend, your values for loss and ppl will be similar to this:

$ python main.py --cuda --epochs 1

| epoch   1 |   200/ 2983 batches | lr 20.00 | ms/batch 18.51 | loss  7.63 | ppl  2063.29
| epoch   1 |   400/ 2983 batches | lr 20.00 | ms/batch 17.47 | loss  6.86 | ppl   950.96
| epoch   1 |   600/ 2983 batches | lr 20.00 | ms/batch 17.47 | loss  6.48 | ppl   653.41
| epoch   1 |   800/ 2983 batches | lr 20.00 | ms/batch 17.46 | loss  6.29 | ppl   539.53
| epoch   1 |  1000/ 2983 batches | lr 20.00 | ms/batch 17.50 | loss  6.14 | ppl   465.56
| epoch   1 |  1200/ 2983 batches | lr 20.00 | ms/batch 17.54 | loss  6.07 | ppl   430.74
| epoch   1 |  1400/ 2983 batches | lr 20.00 | ms/batch 17.52 | loss  5.95 | ppl   384.86
| epoch   1 |  1600/ 2983 batches | lr 20.00 | ms/batch 17.55 | loss  5.96 | ppl   387.05
| epoch   1 |  1800/ 2983 batches | lr 20.00 | ms/batch 17.54 | loss  5.82 | ppl   337.38
| epoch   1 |  2000/ 2983 batches | lr 20.00 | ms/batch 17.52 | loss  5.80 | ppl   329.33
| epoch   1 |  2200/ 2983 batches | lr 20.00 | ms/batch 17.54 | loss  5.67 | ppl   289.19
| epoch   1 |  2400/ 2983 batches | lr 20.00 | ms/batch 17.53 | loss  5.67 | ppl   290.80
| epoch   1 |  2600/ 2983 batches | lr 20.00 | ms/batch 17.53 | loss  5.66 | ppl   285.86
| epoch   1 |  2800/ 2983 batches | lr 20.00 | ms/batch 17.54 | loss  5.55 | ppl   256.72
-----------------------------------------------------------------------------------------
| end of epoch   1 | time: 54.77s | valid loss  5.53 | valid ppl   252.11
-----------------------------------------------------------------------------------------
=========================================================================================
| End of training | test loss  5.44 | test ppl   230.14
=========================================================================================

Running the same training with MPS enabled consistently results in significantly worse values for loss and ppl (tested with pytorch v1.13.1 and v2.0.0.dev20230119):

$ python main.py --mps --epochs 1

| epoch   1 |   200/ 2983 batches | lr 20.00 | ms/batch 115.26 | loss  7.99 | ppl  2959.03
| epoch   1 |   400/ 2983 batches | lr 20.00 | ms/batch 114.31 | loss  7.52 | ppl  1849.67
| epoch   1 |   600/ 2983 batches | lr 20.00 | ms/batch 114.39 | loss  7.38 | ppl  1603.87
| epoch   1 |   800/ 2983 batches | lr 20.00 | ms/batch 113.73 | loss  7.30 | ppl  1475.25
| epoch   1 |  1000/ 2983 batches | lr 20.00 | ms/batch 113.39 | loss  7.26 | ppl  1421.42
| epoch   1 |  1200/ 2983 batches | lr 20.00 | ms/batch 113.48 | loss  7.25 | ppl  1406.03
| epoch   1 |  1400/ 2983 batches | lr 20.00 | ms/batch 113.49 | loss  7.18 | ppl  1317.59
| epoch   1 |  1600/ 2983 batches | lr 20.00 | ms/batch 113.44 | loss  7.19 | ppl  1330.15
| epoch   1 |  1800/ 2983 batches | lr 20.00 | ms/batch 114.70 | loss  7.16 | ppl  1280.67
| epoch   1 |  2000/ 2983 batches | lr 20.00 | ms/batch 113.67 | loss  7.16 | ppl  1285.51
| epoch   1 |  2200/ 2983 batches | lr 20.00 | ms/batch 113.82 | loss  7.16 | ppl  1288.31
| epoch   1 |  2400/ 2983 batches | lr 20.00 | ms/batch 113.25 | loss  7.13 | ppl  1247.44
| epoch   1 |  2600/ 2983 batches | lr 20.00 | ms/batch 113.25 | loss  7.18 | ppl  1313.82
| epoch   1 |  2800/ 2983 batches | lr 20.00 | ms/batch 113.55 | loss  7.17 | ppl  1301.76
-----------------------------------------------------------------------------------------
| end of epoch   1 | time: 362.77s | valid loss  7.01 | valid ppl  1112.40
-----------------------------------------------------------------------------------------
=========================================================================================
| End of training | test loss  6.93 | test ppl  1019.79
=========================================================================================

Increasing the number of epochs will only increase the difference between the results of the two training configurations.

Versions

Latest:

PyTorch version: 1.13.1
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A

OS: macOS 13.1 (arm64)
GCC version: Could not collect
Clang version: 14.0.0 (clang-1400.0.29.202)
CMake version: version 3.23.2
Libc version: N/A

Python version: 3.10.9 (main, Jan 11 2023, 09:18:18) [Clang 14.0.6 ] (64-bit runtime)
Python platform: macOS-13.1-arm64-arm-64bit
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

Versions of relevant libraries:
[pip3] numpy==1.23.5
[pip3] torch==1.13.1
[conda] numpy                     1.23.5          py310hb93e574_0
[conda] numpy-base                1.23.5          py310haf87e8b_0
[conda] pytorch                   1.13.1                 py3.10_0    pytorch

Nightly:

PyTorch version: 2.0.0.dev20230119
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A

OS: macOS 13.1 (arm64)
GCC version: Could not collect
Clang version: 14.0.0 (clang-1400.0.29.202)
CMake version: version 3.23.2
Libc version: N/A

Python version: 3.10.9 (main, Jan 11 2023, 09:18:18) [Clang 14.0.6 ] (64-bit runtime)
Python platform: macOS-13.1-arm64-arm-64bit
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

Versions of relevant libraries:
[pip3] numpy==1.23.5
[pip3] torch==2.0.0.dev20230119
[conda] numpy                     1.23.5          py310hb93e574_0
[conda] numpy-base                1.23.5          py310haf87e8b_0
[conda] pytorch                   2.0.0.dev20230119        py3.10_0    pytorch-nightly

cc @kulinseth @albanD @malfet @DenisVieriu97 @razarmehr @abhudev

Metadata

Metadata

Assignees

No one assigned

    Labels

    module: correctness (silent)issue that returns an incorrect result silentlymodule: mpsRelated to Apple Metal Performance Shaders frameworktriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions