Conversation
By up/down casting weights to input types Extend unittests to support float16 input Fixes #96113
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/96208
Note: Links to docs will display an error until the docs builds have been completed. ❌ 1 FailuresAs of commit 2cc290c: NEW FAILURES - The following jobs have failed:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
| secondaryTensor:momentumTensor | ||
| name:nil]; | ||
| MPSGraphTensor* scaledRunningMean = [mpsGraph multiplicationWithPrimaryTensor:runningMeanTensor | ||
| MPSGraphTensor* scaledRunningMean = [mpsGraph multiplicationWithPrimaryTensor:castMPSTensor(mpsGraph, runningMeanTensor, input_mps_dtype) |
There was a problem hiding this comment.
All these casts shouldn't be necessary. If the initial Ranked placeholder is created with correct data type. We shouldn't be adding any casts here in the code.
| name:nil]; | ||
| const auto inputTensorType = [inputTensor dataType]; | ||
| MPSGraphTensor* outputTensor = [mpsGraph normalizationWithTensor: inputTensor | ||
| meanTensor: castMPSTensor(mpsGraph, saveMeanTensor, inputTensorType) |
There was a problem hiding this comment.
The normalization shouldn't need casts here. The computation graph for normalization should be performed in the type which is requested by the user. Even if it's a pass through, it can add spurious casts which should have been fixed earlier in the Graph to add casts at proper places. My main concern is that it may lead to actual casts in future and will be hidden from the next person modifying the code. I am curious do you need the casts here for this crash ?
…atch norm error
|
As we've discussed, let's split it into a smaller change, and then land a bigger one |
Only for forward pass Subset of #96208 Create constant with scalar using `input_mps_dtype` and use `reciprocalWithTensor` instead of `divisionWithPrimaryTensor:1.0 secondaryTensor:` Fixes #96113 Pull Request resolved: #96430 Approved by: https://github.com/kulinseth
Only for forward pass Subset of pytorch/pytorch#96208 Create constant with scalar using `input_mps_dtype` and use `reciprocalWithTensor` instead of `divisionWithPrimaryTensor:1.0 secondaryTensor:` Fixes pytorch/pytorch#96113 Pull Request resolved: pytorch/pytorch#96430 Approved by: https://github.com/kulinseth
Only for forward pass Subset of pytorch#96208 Create constant with scalar using `input_mps_dtype` and use `reciprocalWithTensor` instead of `divisionWithPrimaryTensor:1.0 secondaryTensor:` Fixes pytorch#96113 Pull Request resolved: pytorch#96430 Approved by: https://github.com/kulinseth
Only for forward pass Subset of pytorch#96208 Create constant with scalar using `input_mps_dtype` and use `reciprocalWithTensor` instead of `divisionWithPrimaryTensor:1.0 secondaryTensor:` Fixes pytorch#96113 Pull Request resolved: pytorch#96430 Approved by: https://github.com/kulinseth
Only for forward pass Subset of #96208 Create constant with scalar using `input_mps_dtype` and use `reciprocalWithTensor` instead of `divisionWithPrimaryTensor:1.0 secondaryTensor:` Fixes #96113 Pull Request resolved: #96430 Approved by: https://github.com/kulinseth (cherry picked from commit 075a494)
|
Looks like this PR hasn't been updated in a while so we're going to go ahead and mark this as |
By up/down casting weights to input types
Extend unittests to support float16 input
Fixes #96113