Skip to content

[Inductor] Support tiled reductions #134277

@blaine-rister

Description

@blaine-rister

🚀 The feature, motivation and pitch

There is an ongoing effort to expand Inductor’s block pointer support to cover discontiguous tensors and views (umbrella issue: #125077). Tiling is an important part of this, as it simplifies indexing expressions, creating more opportunities to use block pointers. See #132937 for an example.

Currently, Inductor supports tiling for pointwise kernels, but not for reductions https://github.com/pytorch/pytorch/blob/main/torch/_inductor/codegen/simd.py#L1728. This causes us to miss some easy block pointers.

For example, consider the test program:

import torch
import torch._inductor.config as config

config.triton.use_block_ptr = True
config.triton.prefer_nd_tiling = True

device = torch.device("cuda")
full = torch.randn((16, 32)).to(device)
view = torch.as_strided(full, (16, 17), full.stride())

compiled = torch.compile(torch.sum)
result = compiled(view)

which generates the following Triton kernel:

@triton.jit
def triton_(in_ptr0, out_ptr0, xnumel, rnumel):
    xnumel = 1
    XBLOCK: tl.constexpr = 1
    rnumel = 272
    RBLOCK: tl.constexpr = 512
    xoffset = tl.program_id(0) * XBLOCK
    xindex = tl.full([1], xoffset, tl.int32)
    xmask = tl.full([RBLOCK], True, tl.int1)
    rindex = tl.arange(0, RBLOCK)[:]
    roffset = 0
    rmask = rindex < rnumel
    r0 = rindex % 17
    r1 = (rindex // 17)
    tmp0 = tl.load(in_ptr0 + (r0 + (32*r1)), rmask, other=0.0)
    tmp1 = tl.broadcast_to(tmp0, [RBLOCK])
    tmp3 = tl.where(rmask, tmp1, 0)
    tmp4 = triton_helpers.promote_to_tensor(tl.sum(tmp3, 0))
    tl.store(out_ptr0 + (tl.full([1], 0, tl.int32)), tmp4, None)
''', device_str='cuda')

We miss the opportunity for a block pointer because of the 1D modular indexing expression. (We can convert this pattern to a block pointer in certain cases, but not for the shape (16,17).) If we supported 2D tiled reductions, we would get a pair of simpler 2D indices, which would map to a block pointer.

Besides expanding block pointer support, tiled reductions might have performance benefits in certain scenarios. I’m not sure exactly when it’s profitable, but since we already tile certain pointwise kernels, it seems like we should see similar benefits for reductions.

Would it be feasible to add this feature? In principle it doesn’t seem much different from the pointwise case, but I’m guessing there are a lot of places in our code which assume reductions are (X,R). So we probably just need to trudge through them one at a time.

Alternatives

Right now, I can't think of any other way to use a block pointer in this example program.

Additional context

No response

cc @ezyang @chauhang @penguinwu @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire

Metadata

Metadata

Assignees

Labels

module: inductoroncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions