Skip to content

Conversation

@blaine-rister
Copy link
Contributor

@blaine-rister blaine-rister commented May 28, 2024

Summary

Inductor currently uses modulo and division to compute indices into certain multi-dimensional tensors, such as those arising from row padding. This PR matches on that indexing pattern, replacing it with an N-D block pointer. This should be more efficient than computing indices with division and modulo, and it can easily map to DMAs on non-GPU hardware targets.

Because the 1D block size needs to map to an integer block shape in ND, we need to know that the ND block size evenly divides the size of the iteration range. This PR only generates ND block pointers when it can guarantee that the iteration order and number of elements loaded are unchanged. This means that the number of elements in a slice of the iteration range must either be:

  • Powers of 2. Since Triton block sizes are powers of 2, any integer power of 2 either divides the block size, or is greater than the block size. In the latter case, CielDiv(x, y) rounds up to 1.
  • Multiples of the maximum block size. Since block sizes are powers of 2, the maximum block size is a multiple of every possible block size.

Note that a slice of the iteration range does not include the leading dimension. Thus we can support arbitrary leading dimensions like (5,8).

Feature proposal and discussion: #125077

Example kernel:

triton.jit
def triton_(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 4096
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = xindex < xnumel
    tmp0 = tl.reshape(tl.load(tl.make_block_ptr(in_ptr0, shape=[32, 16, 8], strides=[1024, 32, 1], block_shape=[32 * (32 <= ((127 + XBLOCK) // 128)) + ((127 + XBLOCK) // 128) * (((127 + XBLOCK) // 128) < 32), 16 * (16 <= ((7 + XBLOCK) // 8)) + ((7 + XBLOCK) // 8) * (((7 + XBLOCK) // 8) < 16), 8 * (8 <= XBLOCK) + XBLOCK * (XBLOCK < 8)], order=[0, 1, 2], offsets=[(xoffset // 128), (xoffset // 8) % 16, xoffset % 8]), boundary_check=[0, 1, 2]), [XBLOCK])
    tmp1 = tmp0 + tmp0
    tl.store(tl.make_block_ptr(out_ptr0, shape=[4096], strides=[1], block_shape=[XBLOCK], order=[0], offsets=[xoffset]), tl.broadcast_to(tmp1, [XBLOCK]).to(tl.float32))
''', device_str='cuda')

Test Plan

This PR adds a new CI test script to cover this feature. The tests can be grouped into a few main categories:

  • Can we generate strided block pointers for the appropriate shapes?
    • Powers of 2
    • Non-power of 2, but multiple of the maximum block size
    • Arbitrary leading dimensions, with power of 2 inner dimensions
    • Weird strides and offsets
    • Reductions
    • Symbolic shapes that are multiples of the maximum block size (wasn't able to trace this through dynamo)
    • Broadcasts (some variables are missing from the indexing expression)
  • Do we still compile other cases correctly, even if we don't expect to be able to generate block pointers?
    • Unsupported static shapes
    • Unsupported symbolic shapes
  • Mixing and matching these cases:
    • Pointwise and reduction in the same kernel
  • Sanity check the test harness
    • Do we raise an exception if the expected number of block pointers and the actual number are different?

Follow-ups

There are a few important cases which this PR can't handle. I'm hoping these can be deferred to follow-up PRs:

  • Handle non-divisible shapes
    • Change the tiling algorithm to generate a 2D (X,Y) blocking, if doing so enables block pointers to be emitted.
    • Pad unsupported loads up to the nearest divisible size, then mask/slice out the extra elements? This is probably the best solution, but I'm not yet sure how to go about it in triton.
  • Take advantage of this analysis when triton.use_block_ptr=False. I'm guessing we can still avoid % and / without requiring block pointers. Maybe we could compute block indices with arange and broadcast instead?

Differential Revision: D56739375

cc @jgong5 @mingfeima @XiaobingSuper @sanchitintel @ashokei @jingxu10 @voznesenskym @penguinwu @EikanWang @Guobing-Chen @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @chauhang

@pytorch-bot
Copy link

pytorch-bot bot commented May 28, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/127342

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (4 Unrelated Failures)

As of commit 2b5a2f1 with merge base 732b4e9 (image):

BROKEN TRUNK - The following jobs failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D56739375

facebook-github-bot pushed a commit that referenced this pull request May 28, 2024
…iv (#127342)

Summary:

Inductor currently uses modulo and division to compute indices into certain multi-dimensional tensors, such as those arising from row padding. This PR matches on that indexing pattern, replacing it with an N-D block pointer. This should be more efficient than computing indices with division and modulo, and it can easily map to DMAs on non-GPU hardware targets.

Because the 1D block size needs to map to an integer block shape in ND, and triton block shapes must be powers of 2, this only works if the iteration range's dims are all powers of 2. However, this feature is still worthwhile since powers of 2 dims are commonly seen in practice. 

Feature proposal and discussion: #125077

Example kernel:
```
triton.jit
def triton_(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 4096
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = xindex < xnumel
    tmp0 = tl.reshape(tl.load(tl.make_block_ptr(in_ptr0, shape=[32, 16, 8], strides=[1024, 32, 1], block_shape=[32 * (32 <= ((127 + XBLOCK) // 128)) + ((127 + XBLOCK) // 128) * (((127 + XBLOCK) // 128) < 32), 16 * (16 <= ((7 + XBLOCK) // 8)) + ((7 + XBLOCK) // 8) * (((7 + XBLOCK) // 8) < 16), 8 * (8 <= XBLOCK) + XBLOCK * (XBLOCK < 8)], order=[0, 1, 2], offsets=[(xoffset // 128), (xoffset // 8) % 16, xoffset % 8]), boundary_check=[0, 1, 2]), [XBLOCK])
    tmp1 = tmp0 + tmp0
    tl.store(tl.make_block_ptr(out_ptr0, shape=[4096], strides=[1], block_shape=[XBLOCK], order=[0], offsets=[xoffset]), tl.broadcast_to(tmp1, [XBLOCK]).to(tl.float32))
''', device_str='cuda')
```

Test Plan:
Added some new CI tests to cover this feature. The tests check that block pointers are generated for strided loads of the appropriate shapes.


TODO add some more complex tests, like 2 different strided reads of different sizes. (Maybe read a small matrix, tile it up to size of larger matrix, then add together?)

Differential Revision: D56739375
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D56739375

…iv (#127342)

Summary:

Inductor currently uses modulo and division to compute indices into certain multi-dimensional tensors, such as those arising from row padding. This PR matches on that indexing pattern, replacing it with an N-D block pointer. This should be more efficient than computing indices with division and modulo, and it can easily map to DMAs on non-GPU hardware targets.

Because the 1D block size needs to map to an integer block shape in ND, and triton block shapes must be powers of 2, this only works if the iteration range's dims are all powers of 2. However, this feature is still worthwhile since powers of 2 dims are commonly seen in practice. 

Feature proposal and discussion: #125077

Example kernel:
```
triton.jit
def triton_(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 4096
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = xindex < xnumel
    tmp0 = tl.reshape(tl.load(tl.make_block_ptr(in_ptr0, shape=[32, 16, 8], strides=[1024, 32, 1], block_shape=[32 * (32 <= ((127 + XBLOCK) // 128)) + ((127 + XBLOCK) // 128) * (((127 + XBLOCK) // 128) < 32), 16 * (16 <= ((7 + XBLOCK) // 8)) + ((7 + XBLOCK) // 8) * (((7 + XBLOCK) // 8) < 16), 8 * (8 <= XBLOCK) + XBLOCK * (XBLOCK < 8)], order=[0, 1, 2], offsets=[(xoffset // 128), (xoffset // 8) % 16, xoffset % 8]), boundary_check=[0, 1, 2]), [XBLOCK])
    tmp1 = tmp0 + tmp0
    tl.store(tl.make_block_ptr(out_ptr0, shape=[4096], strides=[1], block_shape=[XBLOCK], order=[0], offsets=[xoffset]), tl.broadcast_to(tmp1, [XBLOCK]).to(tl.float32))
''', device_str='cuda')
```

Test Plan:
Added some new CI tests to cover this feature.
  - Check that block pointers are generated for `x + y` where `x` and `y` are views.
  - Check that `x + y` still works for odd sizes where we don't generate block pointers.
  - Check some cases with view args of different sizes:
    - If sizes are all power of two, check that we generate block pointers for both. 
    - If one size is a power of two and the other isn't check that we generate one block pointer.
  - Check that we can handle `torch.sum(x)` where `x` is a view of shape `(3 * TRITON_MAX_BLOCK["Y"], 2)`. Besides powers of two, we should also be able to handle multiples of the max block size.

Differential Revision: D56739375
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D56739375

@facebook-github-bot
Copy link
Contributor

@blaine-rister has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@linux-foundation-easycla
Copy link

linux-foundation-easycla bot commented Jun 1, 2024

@blaine-rister
Copy link
Contributor Author

@jansel @shunting314 I think I've addressed your comments from last time. Could you please take another look at this?

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Jun 14, 2024
@blaine-rister
Copy link
Contributor Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@bghira
Copy link

bghira commented Nov 10, 2024

it seems that in 2.4.1 this causes a problem because is_power_of_2 is not available there.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/inductor ciflow/trunk Trigger trunk jobs on your pull request fb-exported Merged module: cpu CPU specific problem (e.g., perf, algorithm) module: inductor release notes: inductor

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants