-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Description
🚀 The feature, motivation and pitch
I ran the following program to test what triton code is generated from a discontiguous tensor:
import sys
import os
import logging
import torch
from torch._inductor import config as inductor_config
# Enable debug logging
os.environ["TORCH_COMPILE_DEBUG"] = "1"
torch._logging.set_logs(inductor=logging.DEBUG)
# Log to stdout
handler = logging.StreamHandler(sys.stdout)
for logger in torch._dynamo.logging.get_loggers():
logger.addHandler(handler)
inductor_config.triton.use_block_ptr = True
def foo(x, y):
return x + y
device = torch.device('cuda')
orig_size = (32, 32)
view_size = (32, 8)
orig = torch.randn(orig_size).to(device)
view = torch.as_strided(orig, view_size, orig.stride())
compiled_foo = torch.compile(foo, backend="inductor")
compiled_foo(view, view)
The generated kernel was:
@triton.jit
def triton_(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
xnumel = 256
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = xindex < xnumel
x0 = xindex % 8
x1 = (xindex // 8)
tmp0 = tl.load(in_ptr0 + (x0 + (32*x1)), xmask)
tmp1 = tmp0 + tmp0
tl.store(tl.make_block_ptr(out_ptr0, shape=[256], strides=[1], block_shape=[XBLOCK], order=[0], offsets=[xoffset]), tl.broadcast_to(tmp1, [XBLOCK]).to(tl.float32), boundary_check=[0])
It seems like Inductor generates a block pointer for the output, but reverts back to standard pointers for the input. Whereas if I don't call torch.as_strided on the input, I see block pointers for both.
I am wondering if it's possible for inductor to generate something like this instead:
@triton.jit
def triton_(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
xoffset = tl.program_id(0) * XBLOCK
tmp0 = tl.load(tl.make_block_ptr(in_ptr0, shape=[32,8], strides=[32,1], block_shape=[32,XBLOCK], order=[0], offsets=[0,xoffset])).to(tl.float32), boundary_check=[0])
tmp1 = tmp0 + tmp0
tl.store(tl.make_block_ptr(out_ptr0, shape=[32,8], strides=[32,1], block_shape=[32,XBLOCK], order=[0], offsets=[0,xoffset]), tl.broadcast_to(tmp1, [32,XBLOCK]).to(tl.float32), boundary_check=[0])
This would use the strides argument to tl.make_block_ptr to express that the input tensor is discontiguous. On GPUs, this could avoid the address calculation using division and modulo, which might yield some performance benefit. There is probably a much bigger win for accelerators like MTIA with simpler memory systems, where this code maps very naturally to DMA engines. Without this, simpler accelerators might have a tough time handling padding between the rows of a tensor.
Is this feature feasible? The main change I see is that here XBLOCK would refer the columns of the input matrix, as opposed to the linear index. It would also be possible to block on rows.
Alternatives
In principle, it's possible for the triton compiler to recognize this pattern under the hood. But it seems like that would require reading a whole number of rows, i.e. XBLOCK must be a multiple of the row length. Also, the analysis could get complex when division and modulo are involved. I'm wondering if makes more sense to handle this in Inductor.
Instead of block pointers, it's also possible to simplify the address calculation for standard pointers, such as
x0 = tl.broadcast_to(tl.expand_dims(tl.arange(xoffset, xoffset + XBLOCK), axis=0), [32,XBLOCK])
x1 = tl.broadcast_to(tl.expand_dims(tl.arange(32), axis=1), [32,XBLOCK])
tl.load(in_ptr0 + x0 + x1 * 32)
which could more easily be converted to a block representation inside the triton compiler.
Additional context
No response
cc @chauhang @penguinwu @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @chenyang78 @kadeng @muchulee8 @amjames @aakhundov @ezyang @msaroufim @bdhirsh @anijain2305 @peterbell10 @yf225 @ColinPeppler @desertfire
cc @shunting314 based on offline conversations. We were hoping for input from @jansel .