Skip to content

[MPS] Allow float16 input to float32 LayerNorm#96430

Closed
malfet wants to merge 1 commit intomasterfrom
malfet/mps-smaller-batch-norm-fix
Closed

[MPS] Allow float16 input to float32 LayerNorm#96430
malfet wants to merge 1 commit intomasterfrom
malfet/mps-smaller-batch-norm-fix

Conversation

@malfet
Copy link
Contributor

@malfet malfet commented Mar 9, 2023

Only for forward pass

Subset of #96208

Create constant with scalar using input_mps_dtype and use
reciprocalWithTensor instead of divisionWithPrimaryTensor:1.0 secondaryTensor:

Fixes #96113

Only for forward pass

Subset of #96208

Create constant with scalar using `input_mps_dtype` and use
`reciprocalWithTensor` instead of `divisionWithPrimaryTensor:1.0
secondaryTensor:`

Fixes #96113
@malfet malfet requested a review from kulinseth as a code owner March 9, 2023 18:16
@pytorch-bot
Copy link

pytorch-bot bot commented Mar 9, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/96430

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 Failures

As of commit 7ebe509:

NEW FAILURES - The following jobs have failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added ciflow/mps Run MPS tests (subset of trunk) release notes: mps Release notes category labels Mar 9, 2023
@malfet
Copy link
Contributor Author

malfet commented Mar 9, 2023

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Mar 9, 2023
@malfet
Copy link
Contributor Author

malfet commented Mar 9, 2023

@pytorchbot merge -f "MPS tests are green"

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@malfet malfet deleted the malfet/mps-smaller-batch-norm-fix branch March 9, 2023 22:16
cyyever pushed a commit to cyyever/pytorch_private that referenced this pull request Mar 12, 2023
Only for forward pass

Subset of pytorch/pytorch#96208

Create constant with scalar using `input_mps_dtype` and use
`reciprocalWithTensor` instead of `divisionWithPrimaryTensor:1.0
secondaryTensor:`

Fixes pytorch/pytorch#96113

Pull Request resolved: pytorch/pytorch#96430
Approved by: https://github.com/kulinseth
ydwu4 added a commit to ydwu4/pytorch that referenced this pull request Mar 13, 2023
Only for forward pass

Subset of pytorch#96208

Create constant with scalar using `input_mps_dtype` and use
`reciprocalWithTensor` instead of `divisionWithPrimaryTensor:1.0
secondaryTensor:`

Fixes pytorch#96113

Pull Request resolved: pytorch#96430
Approved by: https://github.com/kulinseth
@malfet malfet added this to the 2.0.1 milestone Apr 7, 2023
brkirch pushed a commit to brkirch/pytorch that referenced this pull request Apr 15, 2023
Only for forward pass

Subset of pytorch#96208

Create constant with scalar using `input_mps_dtype` and use
`reciprocalWithTensor` instead of `divisionWithPrimaryTensor:1.0
secondaryTensor:`

Fixes pytorch#96113

Pull Request resolved: pytorch#96430
Approved by: https://github.com/kulinseth
malfet added a commit that referenced this pull request Apr 18, 2023
Only for forward pass

Subset of #96208

Create constant with scalar using `input_mps_dtype` and use
`reciprocalWithTensor` instead of `divisionWithPrimaryTensor:1.0
secondaryTensor:`

Fixes #96113

Pull Request resolved: #96430
Approved by: https://github.com/kulinseth

(cherry picked from commit 075a494)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/mps Run MPS tests (subset of trunk) ciflow/trunk Trigger trunk jobs on your pull request Merged release notes: mps Release notes category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[mps] [PyTorch 2.0] LayerNorm crashes when input is in float16

3 participants