-
Notifications
You must be signed in to change notification settings - Fork 26.3k
[Inductor] Generalize tiling algorithm to handle fused reductions #144041
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/144041
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 6cfdda3 with merge base a174ee2 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
b05cf37 to
cb1179f
Compare
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
|
The merge job was canceled or timed out. This most often happen if two merge requests were issued for the same PR, or if merge job was waiting for more than 6 hours for tests to finish. In later case, please do not hesitate to reissue the merge command |
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Issue
This PR cleans up an edge case that wasn't handled by #137243. The existing tiling code assumes that
node.get_ranges()is a reliable source of pointwise and reduction numels. This is true for pointwise kernels, but the situation is more complicated with reductions. Since reductions change the number of elements in a tensor, not all ops within a reduction kernel will have the same number of iterations. For example,var_meanfuses pointwise division with the output of reduction sum, and the division lacks the corresponding reduction ranges.Fix
Instead of getting numels from
node.get_ranges(), explicitly pass the global pointwise and reduction numels to the relevant tiling functions. InSIMDKernel.complete_partial_tiling, we solve for the missing numel by diving the global numel by the partial tiling's numel. This ensures all tilings have the correct global numel.Also, in
SIMDKernel.is_compatible, add the global reduction numel to node ranges that are missing it. For example,{"x": 8, "r0_": 8}is compatible with a node of ranges([8], [])when we havereduction_numel=8.Finally, this PR generalizes some of the existing codegen to handle multiple reduction dims. We already had code to ignore reduction splits for pointwise kernels, but it only worked for 1D reductions. Now it can handle ND.
Test plan
This PR parametrizes the existing CI test for
var_meanto also run with tiled reductions. It also adds a new test checking thatvar_meangenerates 2D tilings (with tiled reduction enabled). These new tests would fail on the current main branch.cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @chauhang @aakhundov