Skip to content

Triton/Inductor Gather Prologue Fusion Issues #134535

@eellison

Description

@eellison

🐛 Describe the bug

I have a mostly ready PR implementing prologue fusion at #134532 (although not yet ready for review).

One of the use cases @Chillee pointed out was fusing a gather into a MM. For

torch.set_default_device("cuda)
x = torch.rand([2048, 2048], dtype=torch.float16)
y = torch.rand([2048, 2048], dtype=torch.float16)    
index = torch.randperm(2048, device="cuda")
def foo(x, y, index):
    return x[index] @ y

On an a100, When this is run on fp32, everything works succesfully. However, when run on fp16, I get the following error:

python: /home/eellison/triton/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp:852: virtual mlir::LogicalResult {anonymous}::AsyncCopyGlobalToLocalOpConversion::matchAndRewrite(mlir::triton::gpu::AsyncCopyGlobalToLocalOp, mlir::ConvertOpToLLVMPatternmlir::triton::gpu::AsyncCopyGlobalToLocalOp::OpAdaptor, mlir::ConversionPatternRewriter&) const: Assertion `byteWidth == 16 || byteWidth == 8 || byteWidth == 4' failed.

For this generated kernel.

The main loop of the mm kernel is:

    for k_idx in range(0, tl.cdiv(K, BLOCK_K)):

        a_k_idx_vals = offs_k[None, :] + (k_idx * BLOCK_K)
        idx_m = offs_a_m[:, None]
        idx_n = a_k_idx_vals
        xindex = idx_n + (2048*idx_m)
        tmp0 = tl.load(in_ptr1 + (tl.broadcast_to(idx_m, xindex.shape)), None, eviction_policy='evict_last')
        tmp1 = tl.full(xindex.shape, 2048, tl.int32)
        tmp2 = tmp0 + tmp1
        tmp3 = tmp0 < 0
        tmp4 = tl.where(tmp3, tmp2, tmp0)
        tl.device_assert((0 <= tl.broadcast_to(tmp4, xindex.shape)) & (tl.broadcast_to(tmp4, xindex.shape) < 2048), "index out of bounds: 0 <= tl.broadcast_to(tmp4, xindex.shape) < 2048")
        tmp6 = tl.load(in_ptr2 + (tl.broadcast_to(idx_n + (2048*tmp4), xindex.shape)), None, eviction_policy='evict_last')
        a = tmp6.to(tl.float16)

        B_ptr = B + ((offs_k[:, None] + (k_idx * BLOCK_K)) * stride_bk + offs_b_n[None, :] * stride_bn)

        b = tl.load(B_ptr)

        acc += tl.dot(a, b, allow_tf32=ALLOW_TF32)

In this case, the tmp0 load is loop invariant. When you update the kernel to hoist the load, it runs successfully. It also gives a 14% speedup over the max-autotune-no-cudagraphs without prologue fusion and 25% over eager.

So there are two issues here:

  1. the codegen failure

  2. Can triton improve its loop invariant code motion to hoist this load ? Note, we benchmark this prologue fusion, so in the worst case where we can't prove a hoist we would just not fuse instead of generating a terrible mm.

While it is doable for inductor to generate the code hoisted, this is code generated through triton templates. Ideally I would be to keep the single per buffer current prologue fusion api of {{load_input("A", "a", ("idx_m", "idx_n")}} instead of something more tedious that involved using apis for the initialization of indices.

Not sure I file here or triton.. doing here for now.

cc @ezyang @chauhang @penguinwu @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @bertmaher @int3 @davidberard98 @nmacchioni @embg @peterbell10 @aakhundov @htyu

Versions

master

Metadata

Metadata

Assignees

No one assigned

    Labels

    module: inductoroncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleupstream tritonUpstream Triton Issue

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions