-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Description
🐛 Describe the bug
I have a mostly ready PR implementing prologue fusion at #134532 (although not yet ready for review).
One of the use cases @Chillee pointed out was fusing a gather into a MM. For
torch.set_default_device("cuda)
x = torch.rand([2048, 2048], dtype=torch.float16)
y = torch.rand([2048, 2048], dtype=torch.float16)
index = torch.randperm(2048, device="cuda")
def foo(x, y, index):
return x[index] @ yOn an a100, When this is run on fp32, everything works succesfully. However, when run on fp16, I get the following error:
python: /home/eellison/triton/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp:852: virtual mlir::LogicalResult {anonymous}::AsyncCopyGlobalToLocalOpConversion::matchAndRewrite(mlir::triton::gpu::AsyncCopyGlobalToLocalOp, mlir::ConvertOpToLLVMPatternmlir::triton::gpu::AsyncCopyGlobalToLocalOp::OpAdaptor, mlir::ConversionPatternRewriter&) const: Assertion `byteWidth == 16 || byteWidth == 8 || byteWidth == 4' failed.
For this generated kernel.
The main loop of the mm kernel is:
for k_idx in range(0, tl.cdiv(K, BLOCK_K)):
a_k_idx_vals = offs_k[None, :] + (k_idx * BLOCK_K)
idx_m = offs_a_m[:, None]
idx_n = a_k_idx_vals
xindex = idx_n + (2048*idx_m)
tmp0 = tl.load(in_ptr1 + (tl.broadcast_to(idx_m, xindex.shape)), None, eviction_policy='evict_last')
tmp1 = tl.full(xindex.shape, 2048, tl.int32)
tmp2 = tmp0 + tmp1
tmp3 = tmp0 < 0
tmp4 = tl.where(tmp3, tmp2, tmp0)
tl.device_assert((0 <= tl.broadcast_to(tmp4, xindex.shape)) & (tl.broadcast_to(tmp4, xindex.shape) < 2048), "index out of bounds: 0 <= tl.broadcast_to(tmp4, xindex.shape) < 2048")
tmp6 = tl.load(in_ptr2 + (tl.broadcast_to(idx_n + (2048*tmp4), xindex.shape)), None, eviction_policy='evict_last')
a = tmp6.to(tl.float16)
B_ptr = B + ((offs_k[:, None] + (k_idx * BLOCK_K)) * stride_bk + offs_b_n[None, :] * stride_bn)
b = tl.load(B_ptr)
acc += tl.dot(a, b, allow_tf32=ALLOW_TF32)In this case, the tmp0 load is loop invariant. When you update the kernel to hoist the load, it runs successfully. It also gives a 14% speedup over the max-autotune-no-cudagraphs without prologue fusion and 25% over eager.
So there are two issues here:
-
the codegen failure
-
Can triton improve its loop invariant code motion to hoist this load ? Note, we benchmark this prologue fusion, so in the worst case where we can't prove a hoist we would just not fuse instead of generating a terrible mm.
While it is doable for inductor to generate the code hoisted, this is code generated through triton templates. Ideally I would be to keep the single per buffer current prologue fusion api of {{load_input("A", "a", ("idx_m", "idx_n")}} instead of something more tedious that involved using apis for the initialization of indices.
Not sure I file here or triton.. doing here for now.
cc @ezyang @chauhang @penguinwu @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @bertmaher @int3 @davidberard98 @nmacchioni @embg @peterbell10 @aakhundov @htyu
Versions
master