Skip to content

Commit 2267787

Browse files
YUNQIUGUOfacebook-github-bot
authored andcommitted
[AOTI] Fix a special case compile time data type codegen for sym int variables (#138106)
Summary: This change unblocks the CFR AOTI lowering runtime error. TL;DR: In this model, one triton kernel expects a scalar input dtype as i64, but getting an i32. The reason is "auto" can infer a smaller data type if the variable get passed in e.g. is i32. thus cause CUDA IMA. Original problematic kernel: `triton_poi_fused_add_ge_logical_and_logical_or_lt_46_grid_100`. and third input `auto var_402 = u0`. This diff explicitly specifies it to i64 for all symbolic arguments in compile time for i64 triton kernel inputs, instead of use `auto var_x = {arg}` in cpp wrapper code. Test Plan: Verified in FLB locally: ``` PYTORCH_NO_CUDA_MEMORY_CACHING=1 AOT_INDUCTOR_DEBUG_INTERMEDIATE_VALUE_PRINTER=3 TORCH_LOGS="output_code" TORCHINDUCTOR_MAX_AUTOTUNE=1 TORCH_SHOW_CPP_STACKTRACES=1 CUDA_LAUNCH_BLOCKING=1 ~/fbsource/buck-out/v2/gen/fbcode/98e643f8bb44fe9d/hpc/new/models/feed/benchmark/__feed_lower_benchmark__/feed_lower_benchmark.par --skip-eager --skip-flop-estimation --lower-backend="AOT_INDUCTOR" --sync-mode=0 --precision bf16 --output-precision bf16 --lower-presets="ifr_cint;disable_new_lowering_weights;disable_dper_passes:passes=fuse_parallel_linear_no_weight_change" --remove-unexpected-type-cast=False --load="manifold://ads_storage_fblearner/tree/user/facebook/fblearner/predictor/924293663/0/gpu_lowering/input.merge"``` Differential Revision: D64490039
1 parent 620039c commit 2267787

File tree

2 files changed

+87
-6
lines changed

2 files changed

+87
-6
lines changed

test/inductor/test_aot_inductor.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import torch._inductor
1515
import torch._inductor.config
1616
import torch.nn as nn
17+
from torch._dynamo import config as dynamo_config
1718
from torch._dynamo.testing import rand_strided, same
1819
from torch._dynamo.utils import counters
1920
from torch._inductor import config
@@ -3608,6 +3609,54 @@ def forward(self, x):
36083609
example_inputs = (torch.randn(8, device=self.device),)
36093610
self.check_model(Model(), example_inputs)
36103611

3612+
@dynamo_config.patch({"capture_scalar_outputs": True})
3613+
def test_sym_i64_input_codegen(self):
3614+
if self.device != "cuda":
3615+
raise unittest.SkipTest("requires CUDA")
3616+
3617+
from torch.testing._internal.triton_utils import add_kernel
3618+
3619+
class Model(torch.nn.Module):
3620+
def __init__(self) -> None:
3621+
super().__init__()
3622+
3623+
def forward(self, x):
3624+
x_symint = x.item()
3625+
a = torch.ones(x_symint, device="cuda")
3626+
b = torch.ones(x_symint, device="cuda")
3627+
out = torch.zeros_like(a)
3628+
# unbacked symint in grid
3629+
add_kernel[(1, 1, x_symint)](a, b, out, x_symint, 32)
3630+
return out
3631+
3632+
example_inputs = (
3633+
torch.randint(high=1024, size=(1,), device=self.device, dtype=torch.int32),
3634+
)
3635+
# This simple unit test case model generates two triton kernels:
3636+
# 1. triton_poi_fused_ones_1:
3637+
# triton_meta={'signature': {'out_ptr0': '*fp32', 'xnumel': 'i64'}
3638+
# 2. add_kernel:
3639+
# triton_meta={'signature': {'in_ptr0': '*fp32', 'in_ptr1': '*fp32', 'out_ptr': '*fp32', 'n_elements': 'i64'}
3640+
# input u0 was defined as int32_t initially, verify for every kernel var args downstream,
3641+
# it gets explicitly declared using its data types in the cpp wrapper codegen code.
3642+
expected_scalar_args = [
3643+
"int64_t var_1 = u0;",
3644+
"int64_t var_3 = u0;",
3645+
"int64_t var_5 = u0;",
3646+
"int64_t var_9 = u0;",
3647+
]
3648+
# check the new behavior of codegen is expected
3649+
result, code = run_and_get_cpp_code(
3650+
AOTIRunnerUtil.compile, Model(), example_inputs
3651+
)
3652+
for scalar_line in expected_scalar_args:
3653+
FileCheck().check_count(
3654+
scalar_line,
3655+
1,
3656+
).run(code)
3657+
3658+
self.check_model(Model(), example_inputs)
3659+
36113660

36123661
common_utils.instantiate_parametrized_tests(AOTInductorTestsTemplate)
36133662

torch/_inductor/codegen/cpp_wrapper_gpu.py

Lines changed: 38 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# mypy: allow-untyped-defs
22
import functools
33
import os
4-
from itertools import chain, count
4+
from itertools import chain, count, zip_longest
55
from typing import Any, Callable, List, Optional, Tuple, TYPE_CHECKING, Union
66

77
import sympy
@@ -286,9 +286,17 @@ def generate_load_kernel_once(
286286
self.writeline("}")
287287
return kernel_var_name
288288

289-
def generate_args_decl(self, call_args, arg_types):
289+
def generate_args_decl(self, call_args, arg_types, arg_signatures):
290290
new_args = []
291-
for arg, arg_type in zip(call_args, arg_types):
291+
292+
# Add more cases for other types as needed
293+
signature2dtype = {
294+
"i32": "int32_t",
295+
"i64": "int64_t",
296+
"fp32": "float",
297+
}
298+
299+
def process_args(arg, arg_type, arg_signature=None):
292300
var_name = f"var_{next(self.arg_var_id)}"
293301
if isinstance(arg_type, torch_dtype):
294302
if arg.endswith(".item()"):
@@ -312,10 +320,26 @@ def generate_args_decl(self, call_args, arg_types):
312320
self.writeline(f"int {var_name} = {self.expr_printer(arg)};")
313321
elif arg_type in (sympy.Float, float):
314322
self.writeline(f"float {var_name} = {self.expr_printer(arg)};")
323+
# For symbolic call arguments, examine the arg signatures from triton meta
324+
# to explicitly cast to the right type
325+
# Reason: `auto` can infer unexpected type against kernel input signature.
326+
elif (
327+
isinstance(arg_type, type(SymbolicCallArg))
328+
and arg_signature is not None
329+
and arg_signature in signature2dtype.keys()
330+
):
331+
self.writeline(
332+
f"{signature2dtype[arg_signature]} {var_name} = {self.expr_printer(arg)};"
333+
)
315334
else:
316335
self.writeline(f"auto {var_name} = {self.expr_printer(arg)};")
317336
new_args.append(f"&{var_name}")
318337

338+
for arg, arg_type, arg_signature in zip_longest(
339+
call_args, arg_types, arg_signatures
340+
):
341+
process_args(arg, arg_type, arg_signature)
342+
319343
return ", ".join(new_args)
320344

321345
def generate_default_grid(
@@ -392,18 +416,26 @@ def generate_kernel_call(
392416
# args with value 1 are added into equal_to_1 and constants
393417
# in triton_meta (in the Python codegen) which makes them
394418
# inlined in the PTX and compiled CUBIN
419+
arg_signatures = []
395420
if (
396421
triton_meta is not None
397-
and "configs" in triton_meta
398-
and triton_meta["configs"]
422+
and triton_meta.get("configs")
423+
and triton_meta.get("signature")
399424
):
400425
equal_to_1 = triton_meta["configs"][0].equal_to_1
401426
call_args = [
402427
arg for i, arg in enumerate(call_args) if i not in equal_to_1
403428
]
404429
arg_types = [t for i, t in enumerate(arg_types) if i not in equal_to_1]
430+
# extract the arg signatures from triton_meta
431+
arg_signatures = triton_meta["signature"].values()
432+
arg_signatures = [
433+
v for i, v in enumerate(arg_signatures) if i not in equal_to_1
434+
]
405435

406-
call_args_str = self.generate_args_decl(call_args, arg_types)
436+
call_args_str = self.generate_args_decl(
437+
call_args, arg_types, arg_signatures
438+
)
407439
kernel_args_var = f"kernel_args_var_{next(self.kernel_callsite_id)}"
408440
self.writeline(f"void* {kernel_args_var}[] = {{{call_args_str}}};")
409441

0 commit comments

Comments
 (0)