-
Notifications
You must be signed in to change notification settings - Fork 26.3k
[Inductor] Refactor "r" reduction prefix to {"r0_", "r1_"}. #142020
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
… brister/tiling_dict
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/142020
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 3e133ae with merge base d3d1a78 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
torch/_inductor/codegen/simd.py
Outdated
|
|
||
| def initialize_range_tree(self, pid_cache): | ||
| no_r_dim = not self.inside_reduction or self.numels[-1] == 1 | ||
| prefixes = OrderedSet(["z", "y", "x", "r0_", "r1_"]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Move this to global scope for slightly less espensive runtime.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changed 95704ec
torch/_inductor/codegen/triton.py
Outdated
| sympy_product(rn_numels[idx + 1 :]) for idx in range(len(rn_prefixes) - 1) | ||
| ] + [sympy.Integer(1)] | ||
|
|
||
| def _flatten_reduction_inds(self, multi_inds: List[sympy.Expr]) -> sympy.Expr: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
full words for function names
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changed 417bc9f
torch/_inductor/codegen/triton.py
Outdated
| coeffs = self._get_reduction_index_coeffs() | ||
| return sympy_dot(coeffs, multi_inds) | ||
|
|
||
| def codegen_reduction_inds(self, buffer) -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
full words for function names
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changed 417bc9f
|
|
||
| def prefix_to_size_hint(self, prefix: str) -> Optional[int]: | ||
| size_hint_idx = {"X": 0, "Y": 1, "Z": 2, "R": -1}[prefix] | ||
| size_hint_idx = {"X": 0, "Y": 1, "Z": 2, "R0_": -1, "R1_": -2}[prefix] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Isn't -2 wrong? Or is the order swapped?
size_hints = [x, r0, r1]
size_hints[-2] is r0
size_hints[-1] is r1
However:
size_hints = [x, r0]
size_hints[-1] is r0
So you actually need a different index based on the tiling now...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you're right about this. This does seem wrong. I'll add some unit tests to confirm.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
On second thought, this problem is exactly the reason why we made numels into a dict in #141751. Since size_hints is basically a rounded up version of numels, it seems like that should be a dict as well. That would make this code much simpler. I think I'll open a separate PR for that.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Created a separate PR for the size hints refactor #142249. I'll revisit this one once that lands.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The size_hints PR landed. Merged with this one in 9573db1.
Co-authored-by: Jason Ansel <[email protected]>
Co-authored-by: Jason Ansel <[email protected]>
Co-authored-by: Jason Ansel <[email protected]>
Co-authored-by: Jason Ansel <[email protected]>
|
@pytorchbot merge -i (Initiating merge automatically since Phabricator Diff has merged, merging with -i because oss signals were bypassed internally) |
15 similar comments
|
@pytorchbot merge -i (Initiating merge automatically since Phabricator Diff has merged, merging with -i because oss signals were bypassed internally) |
|
@pytorchbot merge -i (Initiating merge automatically since Phabricator Diff has merged, merging with -i because oss signals were bypassed internally) |
|
@pytorchbot merge -i (Initiating merge automatically since Phabricator Diff has merged, merging with -i because oss signals were bypassed internally) |
|
@pytorchbot merge -i (Initiating merge automatically since Phabricator Diff has merged, merging with -i because oss signals were bypassed internally) |
|
@pytorchbot merge -i (Initiating merge automatically since Phabricator Diff has merged, merging with -i because oss signals were bypassed internally) |
|
@pytorchbot merge -i (Initiating merge automatically since Phabricator Diff has merged, merging with -i because oss signals were bypassed internally) |
|
@pytorchbot merge -i (Initiating merge automatically since Phabricator Diff has merged, merging with -i because oss signals were bypassed internally) |
|
@pytorchbot merge -i (Initiating merge automatically since Phabricator Diff has merged, merging with -i because oss signals were bypassed internally) |
|
@pytorchbot merge -i (Initiating merge automatically since Phabricator Diff has merged, merging with -i because oss signals were bypassed internally) |
|
@pytorchbot merge -i (Initiating merge automatically since Phabricator Diff has merged, merging with -i because oss signals were bypassed internally) |
|
@pytorchbot merge -i (Initiating merge automatically since Phabricator Diff has merged, merging with -i because oss signals were bypassed internally) |
|
@pytorchbot merge -i (Initiating merge automatically since Phabricator Diff has merged, merging with -i because oss signals were bypassed internally) |
|
@pytorchbot merge -i (Initiating merge automatically since Phabricator Diff has merged, merging with -i because oss signals were bypassed internally) |
|
@pytorchbot merge -i (Initiating merge automatically since Phabricator Diff has merged, merging with -i because oss signals were bypassed internally) |
|
@pytorchbot merge -i (Initiating merge automatically since Phabricator Diff has merged, merging with -i because oss signals were bypassed internally) |
Merge startedYour change will be merged while ignoring the following 0 checks: Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
|
@pytorchbot merge -f "Merged internally and PR looks green to me" |
|
The merge job was canceled or timed out. This most often happen if two merge requests were issued for the same PR, or if merge job was waiting for more than 6 hours for tests to finish. In later case, please do not hesitate to reissue the merge command |
|
@clee2000 you are fast :P |
Merge startedYour change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
…143135) These tests were broken by #142020. This PR updates the fixed configs accordingly. Pull Request resolved: #143135 Approved by: https://github.com/jansel, https://github.com/huydhn
Fixes #134277 and #142317. Sub-PRs containing refactors from this one: - #141733 - #141738 - #141751 (based off the former) - #142249 - #142020 - #143135 These refactor PRs should land before the main one. # Feature *Note: to minimize risk, multi-dimensional reductions are gated by the flag `config.triton.tile_reductions`, which defaults to False.* Instead of having a single reduction dimension called `"r"`, we can now support 2D reductions with `"r0_"` and `"r1_"` dimensions. 2D reductions generate two nested loops, with different block pointer advancements in each loop body. Most of the implementation is generic to ND reductions, but for now the tiling algorithm sets a hard limit at 2D. Here's an example of a 2D persistent reduction kernel: ``` @triton.jit def triton_per_fused_sum_0(in_ptr0, out_ptr0, xnumel, r0_numel, r1_numel, XBLOCK : tl.constexpr): xnumel = 1 r0_numel = 15 R0_BLOCK: tl.constexpr = 16 r1_numel = 15 R1_BLOCK: tl.constexpr = 16 xoffset = tl.program_id(0) * XBLOCK xindex = xoffset + tl.arange(0, XBLOCK)[:, None, None] xmask = tl.full([XBLOCK, R0_BLOCK, R1_BLOCK], True, tl.int1) r0_index = tl.arange(0, R0_BLOCK)[None, :, None] r0_offset = 0 r0_mask = r0_index < r0_numel r1_index = tl.arange(0, R1_BLOCK)[None, None, :] r1_offset = 0 r1_mask = r1_index < r1_numel rnumel = r0_numel * r1_numel RBLOCK: tl.constexpr = R0_BLOCK*R1_BLOCK roffset = r1_offset + (r0_offset*r1_numel) rindex = r1_index + (r0_index*r1_numel) r0_0 = r0_index r1_1 = r1_index tmp0 = tl.load(tl.make_block_ptr(in_ptr0, shape=[15, 15], strides=[30, 1], block_shape=[R0_BLOCK, R1_BLOCK], order=[1, 0], offsets=[r0_offset, r1_offset]), boundary_check=[0, 1], padding_option='zero')[None, :, :] tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK, R1_BLOCK]) tmp3 = tl.where(r0_mask & r1_mask, tmp1, 0) tmp4 = tl.reshape(tmp3, [XBLOCK, RBLOCK]) tmp5 = tl.sum(tmp4, 1)[:, None, None] tl.store(out_ptr0 + (tl.full([XBLOCK, 1, 1], 0, tl.int32)), tmp5, None) ''', device_str='cuda') ``` There are a few main differences between this kernel and what Inductor would generate without this PR. - Instead of an `r`/`RBLOCK` dimension, we have two reduction dimensions: `r0_`/`R0_BLOCK` and `r1_`/`R1_BLOCK`. - There are special size and indexing variables for reductions, which don't directly correspond to any kernel dimension. (`rindex`, `rnumel`, `RBLOCK`, and `roffset`.) These collapse N-D reduction sizes and indices indices into 1D. This simplifies the codegen for reductions, which sometimes want to access linear indices instead of N-dimensional ones. Doing things this way allows us to generate N-D loads and stores, but access this data as if it were 1D, minimizing the blast radius of this PR. Although this makes the code more verbose, it shouldn't have a perf impact because the triton compiler eliminates dead code. - We generate the line `tmp4 = tl.reshape(tmp3, [XBLOCK, RBLOCK])` before performing the actual reduction. This reshapes N reduction dimensions into 1D. This allows us to reduce over all N dimensions at once, simplifying the codegen and allowing the Triton complier to decide the order of processing under the hood. Here's an example of a looped reduction: ``` @triton.jit def triton_red_fused_sum_0(in_ptr0, out_ptr0, xnumel, r0_numel, r1_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr, R1_BLOCK : tl.constexpr): xnumel = 3 r0_numel = 43 r1_numel = 129 xoffset = tl.program_id(0) * XBLOCK xindex = xoffset + tl.arange(0, XBLOCK)[:, None, None] xmask = xindex < xnumel r0_base = tl.arange(0, R0_BLOCK)[None, :, None] r1_base = tl.arange(0, R1_BLOCK)[None, None, :] rnumel = r0_numel * r1_numel RBLOCK: tl.constexpr = R0_BLOCK*R1_BLOCK rbase = r1_base + (r0_base*r1_numel) x0 = xindex block_ptr0 = tl.make_block_ptr(in_ptr0, shape=[3, 43, 129], strides=[11094, 258, 1], block_shape=[XBLOCK, R0_BLOCK, R1_BLOCK], order=[2, 1, 0], offsets=[xoffset, 0, 0]) _tmp2 = tl.full([XBLOCK, R0_BLOCK, R1_BLOCK], 0, tl.float32) for r0_offset in range(0, r0_numel, R0_BLOCK): r0_index = r0_offset + r0_base r0_mask = r0_index < r0_numel for r1_offset in range(0, r1_numel, R1_BLOCK): r1_index = r1_offset + r1_base r1_mask = r1_index < r1_numel roffset = r1_offset + (r0_offset*r1_numel) rindex = r1_index + (r0_index*r1_numel) r0_1 = r0_index r1_2 = r1_index tmp0 = tl.load(block_ptr0, boundary_check=[0, 1, 2], padding_option='zero', eviction_policy='evict_first') tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK, R1_BLOCK]) tmp3 = _tmp2 + tmp1 _tmp2 = tl.where(r0_mask & r1_mask & xmask, tmp3, _tmp2) block_ptr0 = tl.advance(block_ptr0, [0, 0, R1_BLOCK]) block_ptr0 = tl.advance(block_ptr0, [0, R0_BLOCK, (-1)*R1_BLOCK*((128 + R1_BLOCK) // R1_BLOCK)]) tmp4 = tl.reshape(_tmp2, [XBLOCK, RBLOCK]) tmp2 = tl.sum(tmp4, 1)[:, None, None] tl.store(tl.make_block_ptr(out_ptr0, shape=[3], strides=[1], block_shape=[XBLOCK], order=[0], offsets=[xoffset]), tl.reshape(tmp2, [XBLOCK]).to(tl.float32), boundary_check=[0]) ''', device_str='cuda') ``` In addition to the aforementioned changes to the persistent reduction, multidimensional looped reductions have a few more lines of code: - They calculate indices inside the loop using `r0_base` and `r1_base`. For compatibility with existing codegen, these are collapsed to the 1D variant `rbase`. - Block pointer advancements are more nuanced for multidimensional loops. At the end of each loop body, we emit a `tl.advance` line which not only increments the pointer in its own dimension, but also undoes the cumulative increments of the previous loop level. This is equivalent to the usual practice in nested loops of starting with a fresh iteration variable at each level. Implementing this required refactoring the way we generate pointer advancements into a new `self.pointer_advancements` field of the kernel, which categorizes advancements by dimension. The biggest difficulty in implementing this feature was that we represented tiling with a tuple like `(5,2)`. In the existing codebase, the compiler can infer that the reduction dimension of `(5,2)` is `2`, since reductions are always the last dimension. This became cumbersome now that we have to support multiple reduction dimensions, so I refactored tiling into a dict like `{"x": 5, "r0_": 2, "r1_": 4}`. This required quite a few code changes, but I don't think it makes the underlying logic much more complex. This will also make it easier to eventually support simultaneous pointwise and reduction tiling, like `{"x": 5, "y": 5, "r0_": 2, "r1_": 4}`. (This is not supported today, but we might want to do it eventually.) The existing tiling algorithm generalized naturally to support reductions. For pointwise kernels, we tile the pointwise dimensions (`"x"`, `"y"`) as is. For reduction kernels, we never tile the `"x"` dimension, and only tile the reduction dimensions (`"r0_"`, `"r1_"`). Thus we only ever tile pointwise OR reduction dimensions, but not both. In principle it seems possible to support both, but it would likely require changes to the kernel fusion and autotuning logic. I thought it best to keep this PR as minimal as possible since it already touched a lot of different files. Unfortunately, these changes weren't enough to get block pointers in some seemingly simple test cases. In some tests for `argmax` and `var_mean`, we already collapse reduction dimensions into 1D and generate modular indexing expressions, prior to tiling. So it's not trivial to figure out how to expand the collapsed reduction dimension back to a shape that would simplify the indexing. To address these cases, this PR adds a new feature to the `config.prefer_nd_tiling` option, which analyzes reads and writes in the kernel, using the same mod-div pattern matching logic that generates block pointers later on. By matching this pattern, we can solve for the tiling splits which *would* simplify the indexing expression, and use then use that tiling to eliminate the modular indexing and emit a block pointer. This tiling mode is still off by default, but it's important for certain applications where we need to get as many block pointers as possible. # Test plan This touches pretty much anything that uses the Triton and Halide backends, so the existing CI provides good coverage. However, 2D reductions are gated behind a few feature flags like `config.prefer_nd_tiling` and `config.tile_reductions`, so this really only checks that the PR doesn't break 1D reductions. In addition to existing CI tests, this PR also adds some new tests that specifically stress 2D reductions: - `test_2d_reduction_odd_shapes`: test 2D reductions with a variety of ops and sizes. This covers the typical persistent and looped reductions. - `test_2d_reduce_no_x_dim`: test 2D reductions with no x dimension. - `test_2d_welford_reduction`: test 2D welford reductions with block pointers. - `test_welford_non_block_pointer`: test a 2D welford reduction when block pointer analysis fails. - `test_reduction_multiple_discontiguous_dims`: test reducing over more than one discontiguous dimension. We won't get a block pointer for this case, since that would require 3D tiling, but we're currently limited to 2D. - `test_2d_reduction_multi_kernel`: test multi kernel autotuning on a 2D softmax kernel. - `test_enable_tiled_reductions`: test that `config.triton.tile_reductions` enables/disables this feature. Pull Request resolved: #137243 Approved by: https://github.com/jansel Co-authored-by: Yueming Hao <[email protected]> Co-authored-by: Jason Ansel <[email protected]>
Preparatory refactor for #137243.
Feature
This PR changes the
RINDEX/"r"symbol type to(R0_INDEX, R1_INDEX)and("r0_", "r1_"), respectively. This allows the relevant code to support 2D (often ND) reductions. Unlike the parent PR, this one does not change the tiling algorithm, so"r1_"is never used. However, it prepares other parts of the system to handle"r1_"once we start using it. This should significantly reduce the chances of hitting merge conflicts, making the parent PR much easier to land.The only change to the generated triton code is to rename
"rindex"->"r0_index","RBLOCK"->"R0_BLOCK", etc. To maintain compatibilty with existing codegen, this also generates aliases to the old reduction variables likerindex = r0_index. If we generated 2D reductions (which this PR will not do), the aliases would be more complicated and would collapse 2D multi-indices to linear indices. See some example kernels in the parent PR.These aliases can be eliminated by the Triton compiler, and should not impact the final machine code running on the GPU. See the perf testing in the parent PR which confirms the aliases do not impact perf.
Test plan
The existing CI provides good coverage. This PR modifies the expected code in a few places, renaming reduction variables from
r.*tor0_.*.cc @jgong5 @mingfeima @XiaobingSuper @sanchitintel @ashokei @jingxu10 @voznesenskym @penguinwu @EikanWang @Guobing-Chen @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @chauhang @aakhundov