Skip to content

[MPS] Fix batchnorm for mixed types#96208

Closed
malfet wants to merge 1 commit intomainfrom
malfet/mps-fix-batchnorm-mixed-types
Closed

[MPS] Fix batchnorm for mixed types#96208
malfet wants to merge 1 commit intomainfrom
malfet/mps-fix-batchnorm-mixed-types

Conversation

@malfet
Copy link
Contributor

@malfet malfet commented Mar 7, 2023

By up/down casting weights to input types

Extend unittests to support float16 input

Fixes #96113

By up/down casting weights to input types

Extend unittests to support float16 input

Fixes #96113
@malfet malfet requested a review from kulinseth as a code owner March 7, 2023 18:31
@pytorch-bot
Copy link

pytorch-bot bot commented Mar 7, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/96208

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 Failures

As of commit 2cc290c:

NEW FAILURES - The following jobs have failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added ciflow/mps Run MPS tests (subset of trunk) release notes: mps Release notes category labels Mar 7, 2023
@malfet malfet added the topic: bug fixes topic category label Mar 7, 2023
secondaryTensor:momentumTensor
name:nil];
MPSGraphTensor* scaledRunningMean = [mpsGraph multiplicationWithPrimaryTensor:runningMeanTensor
MPSGraphTensor* scaledRunningMean = [mpsGraph multiplicationWithPrimaryTensor:castMPSTensor(mpsGraph, runningMeanTensor, input_mps_dtype)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All these casts shouldn't be necessary. If the initial Ranked placeholder is created with correct data type. We shouldn't be adding any casts here in the code.

name:nil];
const auto inputTensorType = [inputTensor dataType];
MPSGraphTensor* outputTensor = [mpsGraph normalizationWithTensor: inputTensor
meanTensor: castMPSTensor(mpsGraph, saveMeanTensor, inputTensorType)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The normalization shouldn't need casts here. The computation graph for normalization should be performed in the type which is requested by the user. Even if it's a pass through, it can add spurious casts which should have been fixed earlier in the Graph to add casts at proper places. My main concern is that it may lead to actual casts in future and will be hidden from the next person modifying the code. I am curious do you need the casts here for this crash ?

ethanliu1206 pushed a commit to kulinseth/pytorch that referenced this pull request Mar 9, 2023
@malfet
Copy link
Contributor Author

malfet commented Mar 9, 2023

As we've discussed, let's split it into a smaller change, and then land a bigger one

malfet added a commit that referenced this pull request Mar 9, 2023
Only for forward pass

Subset of #96208

Create constant with scalar using `input_mps_dtype` and use
`reciprocalWithTensor` instead of `divisionWithPrimaryTensor:1.0
secondaryTensor:`

Fixes #96113
pytorchmergebot pushed a commit that referenced this pull request Mar 9, 2023
Only for forward pass

Subset of #96208

Create constant with scalar using `input_mps_dtype` and use
`reciprocalWithTensor` instead of `divisionWithPrimaryTensor:1.0
secondaryTensor:`

Fixes #96113

Pull Request resolved: #96430
Approved by: https://github.com/kulinseth
cyyever pushed a commit to cyyever/pytorch_private that referenced this pull request Mar 12, 2023
Only for forward pass

Subset of pytorch/pytorch#96208

Create constant with scalar using `input_mps_dtype` and use
`reciprocalWithTensor` instead of `divisionWithPrimaryTensor:1.0
secondaryTensor:`

Fixes pytorch/pytorch#96113

Pull Request resolved: pytorch/pytorch#96430
Approved by: https://github.com/kulinseth
ydwu4 added a commit to ydwu4/pytorch that referenced this pull request Mar 13, 2023
Only for forward pass

Subset of pytorch#96208

Create constant with scalar using `input_mps_dtype` and use
`reciprocalWithTensor` instead of `divisionWithPrimaryTensor:1.0
secondaryTensor:`

Fixes pytorch#96113

Pull Request resolved: pytorch#96430
Approved by: https://github.com/kulinseth
brkirch pushed a commit to brkirch/pytorch that referenced this pull request Apr 15, 2023
Only for forward pass

Subset of pytorch#96208

Create constant with scalar using `input_mps_dtype` and use
`reciprocalWithTensor` instead of `divisionWithPrimaryTensor:1.0
secondaryTensor:`

Fixes pytorch#96113

Pull Request resolved: pytorch#96430
Approved by: https://github.com/kulinseth
malfet added a commit that referenced this pull request Apr 18, 2023
Only for forward pass

Subset of #96208

Create constant with scalar using `input_mps_dtype` and use
`reciprocalWithTensor` instead of `divisionWithPrimaryTensor:1.0
secondaryTensor:`

Fixes #96113

Pull Request resolved: #96430
Approved by: https://github.com/kulinseth

(cherry picked from commit 075a494)
@github-actions
Copy link
Contributor

Looks like this PR hasn't been updated in a while so we're going to go ahead and mark this as Stale.
Feel free to remove the Stale label if you feel this was a mistake.
If you are unable to remove the Stale label please contact a maintainer in order to do so.
If you want the bot to never mark this PR stale again, add the no-stale label.
Stale pull requests will automatically be closed after 30 days of inactivity.

@github-actions github-actions bot added the Stale label Jun 16, 2023
@github-actions github-actions bot closed this Jul 16, 2023
@github-actions github-actions bot deleted the malfet/mps-fix-batchnorm-mixed-types branch September 6, 2024 02:01
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/mps Run MPS tests (subset of trunk) release notes: mps Release notes category Stale topic: bug fixes topic category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[mps] [PyTorch 2.0] LayerNorm crashes when input is in float16

2 participants