-
Notifications
You must be signed in to change notification settings - Fork 26.3k
[MPS] Add native im2col #135706
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[MPS] Add native im2col #135706
Conversation
It's called from `torch.unfold` and one of the few remaining vestiges in MPSFallback
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/135706
Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (2 Unrelated Failures)As of commit 478c8e9 with merge base fc88ba2 ( FLAKY - The following jobs failed but were likely due to flakiness present on trunk:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
Attention! native_functions.yaml was changedIf you are adding a new function or defaulted argument to native_functions.yaml, you cannot use it from pre-existing Python frontend code until our FC window passes (two weeks). Split your PR into two PRs, one which adds the new C++ functionality, and one that makes use of it from Python, and land them two weeks apart. See https://github.com/pytorch/pytorch/wiki/PyTorch's-Python-Frontend-Backward-and-Forward-Compatibility-Policy#forwards-compatibility-fc for more info. Caused by: |
This reverts commit 5cf8956.
| } | ||
| }); | ||
| if (!batched_input) { | ||
| output.resize_({n_output_plane, output_length}); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| output.resize_({n_output_plane, output_length}); | |
| output.squeeze_(0); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same on the cuda code + moving it out of the AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND3 block
| INSTANTIATE_IM2COL(float2); | ||
| INSTANTIATE_IM2COL(half); | ||
| INSTANTIATE_IM2COL(half2); | ||
| #if __METAL_VERSION__ >= 310 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What's going to happen if I try to call this with bfloat on older metal? Do I get a nice error?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, see
pytorch/aten/src/ATen/native/mps/OperationUtils.mm
Lines 59 to 60 in 835e7bb
| TORCH_CHECK_TYPE(is_macos_13_or_newer(MacOSVersion::MACOS_VER_14_0_PLUS), | |
| "MPS bfloat16 type is supported on MacOS 14.0 or newer."); |
Which will be triggered from here
pytorch/aten/src/ATen/native/mps/OperationUtils.mm
Lines 215 to 217 in 835e7bb
| case ScalarType::BFloat16: | |
| checkSupportsBFloat16(); | |
| return "bfloat"; |
albanD
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good !
|
@pytorchbot merge -f "Lint and MPS are green" |
Merge startedYour change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
After #135706 `getGatherScatterScalarType` returns exactly the same results as `scalarToMetalTypeString` , so delete the function and call `scalarToMetalTypeString`
After #135706 `getGatherScatterScalarType` returns exactly the same results as `scalarToMetalTypeString` , so delete the function and call `scalarToMetalTypeString` Pull Request resolved: #136295 Approved by: https://github.com/kit1980
It's called from `torch.unfold` and one of the few remaining vestiges in `MPSFallback.mm` Strongly inspired by CUDA implementation from https://github.com/pytorch/pytorch/blob/09519eb1959978af7f2f909acf768485772596ab/aten/src/ATen/native/cuda/im2col.cuh#L40-L61 Pull Request resolved: pytorch#135706 Approved by: https://github.com/albanD
After pytorch#135706 `getGatherScatterScalarType` returns exactly the same results as `scalarToMetalTypeString` , so delete the function and call `scalarToMetalTypeString` Pull Request resolved: pytorch#136295 Approved by: https://github.com/kit1980
It's called from
torch.unfoldand one of the few remaining vestiges inMPSFallback.mmStrongly inspired by CUDA implementation from
pytorch/aten/src/ATen/native/cuda/im2col.cuh
Lines 40 to 61 in 09519eb