Skip to content

Conversation

@blaine-rister
Copy link
Contributor

@blaine-rister blaine-rister commented Aug 7, 2024

Fixes #125077

Feature

This PR creates a new Inductor config, config.triton.prefer_nd_tiling, which is disabled by default. When enabled, this encourages the Triton code to use as many tiling dimensions as possible. This simplifies indexing expressions for discontiguous tensors, resulting in expressions like 5 * x + 8 * y as opposed to 5 * (x // 7) + 8 * (y % 9). This allows us to find more block pointers than we normally would. We should now see simplified indexing expressions as long as:

  1. All discontiguous reads/writes have the same shape.
  2. The number of discontiguous dimensions is less than config.triton.max_tiles.

Here's an example kernel (elementwise add of views) with ND tiling disabled:

@triton.jit
def triton_(in_ptr0, in_ptr1, out_ptr0, xnumel, XBLOCK : tl.constexpr):
   xnumel = 21
   xoffset = tl.program_id(0) * XBLOCK
   xindex = xoffset + tl.arange(0, XBLOCK)[:]
   xmask = xindex < xnumel
   x0 = xindex % 7
   x1 = (xindex // 7)
   x2 = xindex
   tmp0 = tl.load(in_ptr0 + (x0 + (9*x1)), xmask)
   tmp1 = tl.load(in_ptr1 + (x0 + (9*x1)), xmask)
   tmp2 = tmp0 + tmp1
   tl.store(tl.make_block_ptr(out_ptr0, shape=[21], strides=[1], block_shape=[XBLOCK], order=[0], offsets=[xoffset]), tl.broadcast_to(tmp2, [XBLOCK]).to(tl.float32), boundary_check=[0])
''', device_str='cuda')

And here's the version with it enabled:

@triton.jit
def triton_(in_ptr0, in_ptr1, out_ptr0, ynumel, xnumel, YBLOCK : tl.constexpr, XBLOCK : tl.constexpr):
   ynumel = 3
   xnumel = 7
   yoffset = tl.program_id(1) * YBLOCK
   yindex = yoffset + tl.arange(0, YBLOCK)[None, :]
   ymask = yindex < ynumel
   xoffset = tl.program_id(0) * XBLOCK
   xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
   xmask = xindex < xnumel
   x1 = xindex
   y0 = yindex
   tmp0 = tl.load(tl.make_block_ptr(in_ptr0, shape=[7, 3], strides=[1, 9], block_shape=[XBLOCK, YBLOCK], order=[1, 0], offsets=[xoffset, yoffset]), boundary_check=[0, 1], eviction_policy='evict_last')
   tmp1 = tl.load(tl.make_block_ptr(in_ptr1, shape=[7, 3], strides=[1, 9], block_shape=[XBLOCK, YBLOCK], order=[1, 0], offsets=[xoffset, yoffset]), boundary_check=[0, 1], eviction_policy='evict_last')
   tmp2 = tmp0 + tmp1
   tl.store(tl.make_block_ptr(out_ptr0, shape=[7, 3], strides=[1, 7], block_shape=[XBLOCK, YBLOCK], order=[1, 0], offsets=[xoffset, yoffset]), tl.broadcast_to(tmp2, [XBLOCK, YBLOCK]).to(tl.float32), boundary_check=[0, 1])
''', device_str='cuda')

With this feature enabled, we get a discontiguous strided block pointer. Previously, this would only have worked for specific shapes, like powers of 2 or multiples of the maximum block size. With this PR, we can support arbitrary shapes so long as we have enough tiles to cover all discontiguous dimensions.

Test plan

This PR adds some tests for pointwise ops with discontiguous tensors.

  • Test that we can generate block pointers for views with odd shapes like (5,7), (9,3,5), etc.
  • Test that we can generate block pointers for a single discontiguous dim in 3D and 4D tensors.
  • Test that we generate a 2D tiling for a 5D tensor with two discontiguous dims. This case doesn't generate a block pointer, but it checks that the output code is at least correct.

This PR also parametrizes some existing tests to run with and without triton.prefer_nd_tiling. That way, we ensure this feature doesn't break existing usage.

Since this setting isn't enabled on most tests, I also created #132935 to test what happens when triton.prefer_nd_tiling=True by default. None of the failures seem related to invalid tiling, so I think this feature is safe to merge.

Limitations and follow-ups

I can see two main improvements which would expand the usefulness of this feature:

  1. This feature currently only works for pointwise kernels, since reductions are never tiled. As a follow-up, we could enable tiled reductions to extend these benefits to reduction kernels.

  2. The usefulness of this feature depends on triton.config.max_tiles. This is currently restricted to 2 by default, although it can be increased to 3 in certain cases. To support more discontiguous dims, we might consider expanding support for 3D tiling, or even supporting ND tiling, by mapping an ND "virtual" launch grid onto Triton's 3D launch grid.

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @chauhang

@pytorch-bot
Copy link

pytorch-bot bot commented Aug 7, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/132937

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 8543226 with merge base b01402b (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@blaine-rister blaine-rister changed the title [Inductor] Create option to force higher-dimensional tiling [Inductor] Add config option to force higher-dimensional tiling Aug 7, 2024
@blaine-rister
Copy link
Contributor Author

@blaine-rister blaine-rister marked this pull request as ready for review August 8, 2024 01:53
@blaine-rister
Copy link
Contributor Author

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Aug 8, 2024
@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: This PR needs a release notes: label
If your changes are user facing and intended to be a part of release notes, please use a label starting with release notes:.

If not, please add the topic: not user facing label.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "topic: not user facing"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

Details for Dev Infra team Raised by workflow job

@blaine-rister
Copy link
Contributor Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: This PR needs a release notes: label
If your changes are user facing and intended to be a part of release notes, please use a label starting with release notes:.

If not, please add the topic: not user facing label.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "topic: not user facing"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

Details for Dev Infra team Raised by workflow job

@blaine-rister
Copy link
Contributor Author

@pytorchbot merge

@jfix71
Copy link
Contributor

jfix71 commented Aug 8, 2024

@blaine-rister NBD but wondering, numel/index/mask no longer are used, may be nice to rm them? Not to block landing the PR, more of a question.

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@blaine-rister
Copy link
Contributor Author

blaine-rister commented Aug 8, 2024

@blaine-rister NBD but wondering, numel/index/mask no longer are used, may be nice to rm them? Not to block landing the PR, more of a question.

I've wondered about this as well. The way the triton codegen is currently structured makes it a bit awkward to remove dead code. We generate these numels in the beginning, without knowing whether they are actually going to be used. The numels are treated like an artifact of codegen, and are not directly visible as ops in the IR. So we can't use standard dead code elimination to remove them.
https://github.com/pytorch/pytorch/blob/main/torch/_inductor/codegen/triton.py#L2701

The triton compiler eliminates dead code under the hood, so there shouldn't be a perf impact to leaving some dead code in the kernel. They do make the kernels a bit harder to read and debug, though. We might be able to post-process the kernels and determine whether numels were actually used, or update some data structure tracking their uses as we're generating code. But there's a question of whether the debugging benefit justifies additional complexity.

cc @eellison @shunting314 do you have any thoughts on this?

@eellison
Copy link
Contributor

eellison commented Aug 8, 2024

@blaine no strong opinion, but yea not especially harmful right now. maybe file an issue ? could be a ramp up task

@blaine-rister
Copy link
Contributor Author

blaine-rister commented Aug 9, 2024

@blaine no strong opinion, but yea not especially harmful right now. maybe file an issue ? could be a ramp up task

@nandesuka expressed an interest in this. I'll create an issue for it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Inductor] Generate triton block pointers for discontiguous strided tensors

6 participants