Skip to content

Commit 378edb0

Browse files
karthickaipytorchmergebot
authored andcommitted
[Inductor] Add DeviceAssert op to enable device-side assertion in torch.compile (#160677)
This PR introduces a device_assert op to trigger device-side assertions within torch.compile. This implementation is based on the suggestion in [this comment](#147282 (comment)). Changes Included - Implemented device_assert op and overrides has_side_effect to return True to avoid removal by dead code elimination. - Commented out the assert_async_msg_decomp and functional_assert_async_msg_decomp decompositions to disable the default assert decomposition inside Inductor. - Added lowering for torch.ops.aten._assert_async.msg to convert assert calls into the ops_handler. - Implemented the codegen method for the device_assert op. This supports generating C++ and Triton code. - Added test cases to verify both "should throw" and "should not throw" scenarios. Fixes #147282 Pull Request resolved: #160677 Approved by: https://github.com/mlazos
1 parent d2db6c8 commit 378edb0

File tree

11 files changed

+298
-16
lines changed

11 files changed

+298
-16
lines changed
Lines changed: 204 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,204 @@
1+
# Owner(s): ["module: inductor"]
2+
import os
3+
import subprocess
4+
import sys
5+
6+
import torch
7+
import torch._inductor.config
8+
from torch._inductor import metrics
9+
from torch._inductor.compiler_bisector import BisectionResult, CompilerBisector
10+
from torch._inductor.test_case import run_tests, TestCase
11+
from torch.testing._internal.common_utils import skipIfRocm
12+
from torch.testing._internal.triton_utils import requires_cuda_and_triton
13+
14+
15+
class TestTorchDeviceAssertTrigger(TestCase):
16+
def _run_assert_should_throw(self, device):
17+
def func():
18+
a = torch.tensor([1.0, -2.0], device=device)
19+
result = torch.all(a > 0)
20+
assert result, "should throw"
21+
22+
def test_fn():
23+
torch._dynamo.reset()
24+
f_c = torch.compile(func)
25+
26+
try:
27+
f_c()
28+
return False
29+
except Exception:
30+
return True
31+
32+
bisect_result = CompilerBisector.do_bisect(test_fn)
33+
# do_bisect return None if all system is passed else return BisectionResult
34+
self.assertNotIsInstance(bisect_result, BisectionResult)
35+
36+
def _run_assert_should_not_throw(self, device):
37+
def func():
38+
a = torch.tensor([1.0, 2.0], device=device)
39+
result = torch.all(a > 0)
40+
assert result, "should throw"
41+
42+
def test_fn():
43+
torch._dynamo.reset()
44+
f_c = torch.compile(func)
45+
46+
try:
47+
f_c()
48+
return True
49+
except Exception:
50+
return False
51+
52+
bisect_result = CompilerBisector.do_bisect(test_fn)
53+
self.assertNotIsInstance(bisect_result, BisectionResult)
54+
55+
def _run_assert_inline_expression_should_throw(self, device):
56+
def func():
57+
a = torch.tensor([1.0, -2.0], device=device)
58+
assert torch.all(a > 0), "should throw"
59+
60+
def test_fn():
61+
torch._dynamo.reset()
62+
f_c = torch.compile(func)
63+
64+
try:
65+
f_c()
66+
return False
67+
except Exception:
68+
return True
69+
70+
bisect_result = CompilerBisector.do_bisect(test_fn)
71+
self.assertNotIsInstance(bisect_result, BisectionResult)
72+
73+
def _run_assert_inline_expression_should_not_throw(self, device):
74+
def func():
75+
a = torch.tensor([1.0, 2.0], device=device)
76+
assert torch.all(a > 0), "should throw"
77+
78+
def test_fn():
79+
torch._dynamo.reset()
80+
f_c = torch.compile(func)
81+
82+
try:
83+
f_c()
84+
return True
85+
except Exception:
86+
return False
87+
88+
bisect_result = CompilerBisector.do_bisect(test_fn)
89+
self.assertNotIsInstance(bisect_result, BisectionResult)
90+
91+
@torch._inductor.config.patch(force_disable_caches=True)
92+
def test_assert_should_throw(self):
93+
device = "cpu"
94+
self._run_assert_should_throw(device)
95+
self._run_assert_inline_expression_should_throw(device)
96+
97+
@torch._inductor.config.patch(force_disable_caches=True)
98+
def test_assert_should_not_throw(self):
99+
device = "cpu"
100+
self._run_assert_should_not_throw(device)
101+
self._run_assert_inline_expression_should_not_throw(device)
102+
103+
@torch._inductor.config.patch(force_disable_caches=True, cpp_wrapper=True)
104+
def test_assert_should_throw_cpp_wrapper(self):
105+
device = "cpu"
106+
self._run_assert_should_throw(device)
107+
self._run_assert_inline_expression_should_throw(device)
108+
109+
@torch._inductor.config.patch(force_disable_caches=True, cpp_wrapper=True)
110+
def test_assert_should_not_throw_cpp_wrapper(self):
111+
device = "cpu"
112+
self._run_assert_should_not_throw(device)
113+
self._run_assert_inline_expression_should_not_throw(device)
114+
115+
@requires_cuda_and_triton
116+
@skipIfRocm
117+
@torch._inductor.config.patch(force_disable_caches=True)
118+
def test_assert_fusion(self):
119+
torch._logging.set_logs(inductor_metrics=True)
120+
121+
def func():
122+
a = torch.tensor([1.0, 2.0], device="cuda")
123+
result = torch.all(a > 0)
124+
assert result, "should throw"
125+
126+
torch._dynamo.reset()
127+
f_c = torch.compile(func, backend="inductor")
128+
metrics.reset()
129+
self.assertEqual(metrics.generated_kernel_count, 0)
130+
f_c()
131+
self.assertEqual(metrics.generated_kernel_count, 1)
132+
torch._logging.set_logs()
133+
134+
@requires_cuda_and_triton
135+
@skipIfRocm
136+
@torch._inductor.config.patch(force_disable_caches=True)
137+
def test_run_assert_triton(self):
138+
should_throw = """
139+
import torch
140+
import torch._dynamo
141+
142+
def func_should_throw():
143+
a = torch.tensor([1.0, -2.0], device='cuda')
144+
result = torch.all(a > 0)
145+
assert result, "should throw"
146+
147+
def test_fn():
148+
torch._dynamo.reset()
149+
f_c = torch.compile(func_should_throw, backend="inductor")
150+
151+
try:
152+
f_c()
153+
torch.cuda.synchronize()
154+
return False
155+
except Exception as e:
156+
return True
157+
158+
result = test_fn()
159+
print(f"Test result: {result}")
160+
"""
161+
162+
should_not_throw = """
163+
import torch
164+
import torch._dynamo
165+
166+
def func_should_not_throw():
167+
a = torch.tensor([1.0, 2.0], device='cuda')
168+
result = torch.all(a > 0)
169+
assert result, "should throw"
170+
171+
def test_fn():
172+
torch._dynamo.reset()
173+
f_c = torch.compile(func_should_not_throw, backend="inductor")
174+
175+
try:
176+
f_c()
177+
torch.cuda.synchronize()
178+
return True
179+
except Exception as e:
180+
return False
181+
182+
result = test_fn()
183+
print(f"Test result: {result}")
184+
"""
185+
for script in [should_not_throw, should_throw]:
186+
p = subprocess.run(
187+
[sys.executable, "-c", script],
188+
cwd=os.path.dirname(os.path.realpath(__file__)),
189+
capture_output=True,
190+
text=True,
191+
)
192+
193+
output = p.stdout + "\n" + p.stderr
194+
195+
self.assertIn("Test result: True", output)
196+
197+
if p.returncode != 0:
198+
self.fail(
199+
f"Subprocess failed with return code {p.returncode}. Output: {output}"
200+
)
201+
202+
203+
if __name__ == "__main__":
204+
run_tests()

torch/_inductor/codegen/cpp.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1119,6 +1119,10 @@ def sign(x):
11191119
code.writeline("()")
11201120
return code
11211121

1122+
@staticmethod
1123+
def device_assert_async(cond, msg):
1124+
return f'({cond} ? 0 : (throw std::runtime_error("{msg}"), 0))'
1125+
11221126

11231127
CppOverrides._initialize_pointwise_overrides("cpp")
11241128

torch/_inductor/codegen/halide.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -566,6 +566,10 @@ def masked(mask, body, other):
566566
def frexp(x):
567567
raise NotImplementedError("frexp")
568568

569+
@staticmethod
570+
def device_assert_async(cond, msg):
571+
raise NotImplementedError("device_assert_async")
572+
569573

570574
HalideOverrides._initialize_pointwise_overrides("halide")
571575

torch/_inductor/codegen/triton.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1592,6 +1592,10 @@ def frexp(x):
15921592
V.kernel.cse.put(cache_key, (mantissa, exponent))
15931593
return (mantissa, exponent)
15941594

1595+
@staticmethod
1596+
def device_assert_async(cond, msg):
1597+
return f"tl.device_assert({cond}, {repr(msg)})"
1598+
15951599

15961600
class HelperFunctions:
15971601
"""An ordered set of helper functions."""

torch/_inductor/decomposition.py

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -158,19 +158,6 @@ def _embedding_dense_backward(
158158
)
159159

160160

161-
# TODO: for now, inductor doesn't handle asserts
162-
# because the condition is symbol -> tensor in the graph.
163-
@register_decomposition([aten._assert_async.msg])
164-
def assert_async_msg_decomp(tensor: torch.Tensor, msg: str) -> None:
165-
return
166-
167-
168-
# Following `assert_async_msg_decomp` and implement as non-op.
169-
@register_decomposition([aten._functional_assert_async.msg])
170-
def functional_assert_async_msg_decomp(tensor: torch.Tensor, msg: str) -> None:
171-
return
172-
173-
174161
@register_decomposition([aten.sym_constrain_range_for_size.default])
175162
def sym_constrain_range_for_size(
176163
symbol: torch.SymInt,

torch/_inductor/dtype_propagation.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -373,6 +373,10 @@ def placeholder(self, index: int) -> torch.dtype:
373373
f"{type(self).__name__}: ops.placeholder should not appear here"
374374
)
375375

376+
@staticmethod
377+
def device_assert_async(cond, msg: str) -> torch.dtype:
378+
return torch.bool
379+
376380

377381
if TYPE_CHECKING:
378382

torch/_inductor/ir.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1094,7 +1094,10 @@ def constant_to_device(self, device: torch.device) -> IRNode:
10941094
loader = self.make_loader()
10951095
loader = patch.object(ConstantBuffer, "override_device", device)(loader)
10961096
return Pointwise(
1097-
device=device, dtype=self.dtype, inner_fn=loader, ranges=self.ranges
1097+
device=device,
1098+
dtype=self.dtype,
1099+
inner_fn=loader,
1100+
ranges=self.ranges,
10981101
)
10991102

11001103

@@ -4423,6 +4426,17 @@ class ComputedBuffer(OperationBuffer):
44234426
"""
44244427

44254428
data: Loops
4429+
_force_realize: ClassVar[bool] = False
4430+
4431+
@staticmethod
4432+
@contextlib.contextmanager
4433+
def force_realize() -> Iterator[None]:
4434+
old_value = ComputedBuffer._force_realize
4435+
try:
4436+
ComputedBuffer._force_realize = True
4437+
yield
4438+
finally:
4439+
ComputedBuffer._force_realize = old_value
44264440

44274441
def get_computed_buffer_name(self) -> Optional[str]:
44284442
"""
@@ -4497,6 +4511,7 @@ def make_loader(self) -> Callable[[Sequence[Expr]], OpsValue]:
44974511
not self.get_reduction_type()
44984512
and self.name not in V.graph.mutated_buffers
44994513
and self.num_reads() == 0
4514+
and not self._force_realize
45004515
):
45014516
# inline this op rather than generating ops.load()
45024517
return self.data.make_loader()

torch/_inductor/lowering.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1329,6 +1329,39 @@ def inner_fn(idx):
13291329
)
13301330

13311331

1332+
def _assert_async(cond, msg):
1333+
cond.realize()
1334+
cond = to_dtype(cond, torch.bool)
1335+
1336+
def inner_fn(index):
1337+
if hasattr(cond.data, "data") and hasattr(cond.data.data, "force_realize"):
1338+
with cond.data.data.force_realize():
1339+
cond_loader = cond.make_loader()
1340+
return ops.device_assert_async(cond_loader(index), msg)
1341+
else:
1342+
cond_loader = cond.make_loader()
1343+
return ops.device_assert_async(cond_loader(index), msg)
1344+
1345+
assertion_op = Pointwise.create(
1346+
device=cond.get_device(),
1347+
dtype=cond.get_dtype(),
1348+
inner_fn=inner_fn,
1349+
ranges=list(cond.get_size()),
1350+
)
1351+
assertion_op.realize()
1352+
return assertion_op
1353+
1354+
1355+
@register_lowering(aten._assert_async.msg)
1356+
def lower_assert_async(cond, msg):
1357+
return _assert_async(cond, msg)
1358+
1359+
1360+
@register_lowering(aten._functional_assert_async.msg)
1361+
def lower_assert_functional_async(cond, msg):
1362+
return _assert_async(cond, msg)
1363+
1364+
13321365
@register_lowering(
13331366
quantized_decomposed.dequantize_per_channel, type_promotion_kind=None
13341367
)

torch/_inductor/ops_handler.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -706,6 +706,9 @@ def placeholder(self, index: int) -> T:
706706
"""This is a fake op used in analysis but not codegen"""
707707
raise NotImplementedError
708708

709+
def device_assert_async(self, cond: T, msg: str) -> T:
710+
raise NotImplementedError
711+
709712

710713
_ignore_op_re = re.compile(r"_.*|paren").fullmatch
711714

@@ -788,6 +791,9 @@ def {target}(self, {", ".join(args)}):
788791
if target in OP_NAMES:
789792
setattr(cls, target, impl)
790793

794+
def device_assert_async(self, cond, msg):
795+
return None
796+
791797

792798
DefaultHandler._init_cls()
793799

@@ -933,6 +939,9 @@ def sort(dtypes, values, stable, descending):
933939
def indirect_indexing(index_var, size, check=True, wrap_neg=True) -> sympy.Symbol:
934940
return sympy_index_symbol(str(index_var))
935941

942+
def device_assert_async(self, cond, msg):
943+
return None
944+
936945

937946
class KernelFormatterHandler(DefaultHandler):
938947
def __init__(self, parent_handler: OpsHandler[Any]):
@@ -999,6 +1008,9 @@ def getvalue(self, result):
9991008
self._output.writeline(f"return {result}")
10001009
return self._output.getvalue()
10011010

1011+
def device_assert_async(self, cond, msg: str):
1012+
return f"ops.device_assert_async({cond}, {msg})"
1013+
10021014

10031015
class WrapperHandler(DefaultHandler):
10041016
def __init__(self, inner: OpsHandler[Any]):

0 commit comments

Comments
 (0)