-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Description
🚀 The feature, motivation and pitch
There is an ongoing effort to expand Inductor’s block pointer support to cover discontiguous tensors and views (umbrella issue: #125077). Tiling is an important part of this, as it simplifies indexing expressions, creating more opportunities to use block pointers. See #132937 for an example.
Currently, Inductor supports tiling for pointwise kernels, but not for reductions https://github.com/pytorch/pytorch/blob/main/torch/_inductor/codegen/simd.py#L1728. This causes us to miss some easy block pointers.
For example, consider the test program:
import torch
import torch._inductor.config as config
config.triton.use_block_ptr = True
config.triton.prefer_nd_tiling = True
device = torch.device("cuda")
full = torch.randn((16, 32)).to(device)
view = torch.as_strided(full, (16, 17), full.stride())
compiled = torch.compile(torch.sum)
result = compiled(view)
which generates the following Triton kernel:
@triton.jit
def triton_(in_ptr0, out_ptr0, xnumel, rnumel):
xnumel = 1
XBLOCK: tl.constexpr = 1
rnumel = 272
RBLOCK: tl.constexpr = 512
xoffset = tl.program_id(0) * XBLOCK
xindex = tl.full([1], xoffset, tl.int32)
xmask = tl.full([RBLOCK], True, tl.int1)
rindex = tl.arange(0, RBLOCK)[:]
roffset = 0
rmask = rindex < rnumel
r0 = rindex % 17
r1 = (rindex // 17)
tmp0 = tl.load(in_ptr0 + (r0 + (32*r1)), rmask, other=0.0)
tmp1 = tl.broadcast_to(tmp0, [RBLOCK])
tmp3 = tl.where(rmask, tmp1, 0)
tmp4 = triton_helpers.promote_to_tensor(tl.sum(tmp3, 0))
tl.store(out_ptr0 + (tl.full([1], 0, tl.int32)), tmp4, None)
''', device_str='cuda')
We miss the opportunity for a block pointer because of the 1D modular indexing expression. (We can convert this pattern to a block pointer in certain cases, but not for the shape (16,17).) If we supported 2D tiled reductions, we would get a pair of simpler 2D indices, which would map to a block pointer.
Besides expanding block pointer support, tiled reductions might have performance benefits in certain scenarios. I’m not sure exactly when it’s profitable, but since we already tile certain pointwise kernels, it seems like we should see similar benefits for reductions.
Would it be feasible to add this feature? In principle it doesn’t seem much different from the pointwise case, but I’m guessing there are a lot of places in our code which assume reductions are (X,R). So we probably just need to trudge through them one at a time.
Alternatives
Right now, I can't think of any other way to use a block pointer in this example program.
Additional context
No response
cc @ezyang @chauhang @penguinwu @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire