|
| 1 | +# Owner(s): ["module: inductor"] |
| 2 | +import os |
| 3 | +import subprocess |
| 4 | +import sys |
| 5 | + |
| 6 | +import torch |
| 7 | +import torch._inductor.config |
| 8 | +from torch._inductor import metrics |
| 9 | +from torch._inductor.compiler_bisector import BisectionResult, CompilerBisector |
| 10 | +from torch._inductor.test_case import run_tests, TestCase |
| 11 | +from torch.testing._internal.common_utils import skipIfRocm |
| 12 | +from torch.testing._internal.triton_utils import requires_cuda_and_triton |
| 13 | + |
| 14 | + |
| 15 | +class TestTorchDeviceAssertTrigger(TestCase): |
| 16 | + def _run_assert_should_throw(self, device): |
| 17 | + def func(): |
| 18 | + a = torch.tensor([1.0, -2.0], device=device) |
| 19 | + result = torch.all(a > 0) |
| 20 | + assert result, "should throw" |
| 21 | + |
| 22 | + def test_fn(): |
| 23 | + torch._dynamo.reset() |
| 24 | + f_c = torch.compile(func) |
| 25 | + |
| 26 | + try: |
| 27 | + f_c() |
| 28 | + return False |
| 29 | + except Exception: |
| 30 | + return True |
| 31 | + |
| 32 | + bisect_result = CompilerBisector.do_bisect(test_fn) |
| 33 | + # do_bisect return None if all system is passed else return BisectionResult |
| 34 | + self.assertNotIsInstance(bisect_result, BisectionResult) |
| 35 | + |
| 36 | + def _run_assert_should_not_throw(self, device): |
| 37 | + def func(): |
| 38 | + a = torch.tensor([1.0, 2.0], device=device) |
| 39 | + result = torch.all(a > 0) |
| 40 | + assert result, "should throw" |
| 41 | + |
| 42 | + def test_fn(): |
| 43 | + torch._dynamo.reset() |
| 44 | + f_c = torch.compile(func) |
| 45 | + |
| 46 | + try: |
| 47 | + f_c() |
| 48 | + return True |
| 49 | + except Exception: |
| 50 | + return False |
| 51 | + |
| 52 | + bisect_result = CompilerBisector.do_bisect(test_fn) |
| 53 | + self.assertNotIsInstance(bisect_result, BisectionResult) |
| 54 | + |
| 55 | + def _run_assert_inline_expression_should_throw(self, device): |
| 56 | + def func(): |
| 57 | + a = torch.tensor([1.0, -2.0], device=device) |
| 58 | + assert torch.all(a > 0), "should throw" |
| 59 | + |
| 60 | + def test_fn(): |
| 61 | + torch._dynamo.reset() |
| 62 | + f_c = torch.compile(func) |
| 63 | + |
| 64 | + try: |
| 65 | + f_c() |
| 66 | + return False |
| 67 | + except Exception: |
| 68 | + return True |
| 69 | + |
| 70 | + bisect_result = CompilerBisector.do_bisect(test_fn) |
| 71 | + self.assertNotIsInstance(bisect_result, BisectionResult) |
| 72 | + |
| 73 | + def _run_assert_inline_expression_should_not_throw(self, device): |
| 74 | + def func(): |
| 75 | + a = torch.tensor([1.0, 2.0], device=device) |
| 76 | + assert torch.all(a > 0), "should throw" |
| 77 | + |
| 78 | + def test_fn(): |
| 79 | + torch._dynamo.reset() |
| 80 | + f_c = torch.compile(func) |
| 81 | + |
| 82 | + try: |
| 83 | + f_c() |
| 84 | + return True |
| 85 | + except Exception: |
| 86 | + return False |
| 87 | + |
| 88 | + bisect_result = CompilerBisector.do_bisect(test_fn) |
| 89 | + self.assertNotIsInstance(bisect_result, BisectionResult) |
| 90 | + |
| 91 | + @torch._inductor.config.patch(force_disable_caches=True) |
| 92 | + def test_assert_should_throw(self): |
| 93 | + device = "cpu" |
| 94 | + self._run_assert_should_throw(device) |
| 95 | + self._run_assert_inline_expression_should_throw(device) |
| 96 | + |
| 97 | + @torch._inductor.config.patch(force_disable_caches=True) |
| 98 | + def test_assert_should_not_throw(self): |
| 99 | + device = "cpu" |
| 100 | + self._run_assert_should_not_throw(device) |
| 101 | + self._run_assert_inline_expression_should_not_throw(device) |
| 102 | + |
| 103 | + @torch._inductor.config.patch(force_disable_caches=True, cpp_wrapper=True) |
| 104 | + def test_assert_should_throw_cpp_wrapper(self): |
| 105 | + device = "cpu" |
| 106 | + self._run_assert_should_throw(device) |
| 107 | + self._run_assert_inline_expression_should_throw(device) |
| 108 | + |
| 109 | + @torch._inductor.config.patch(force_disable_caches=True, cpp_wrapper=True) |
| 110 | + def test_assert_should_not_throw_cpp_wrapper(self): |
| 111 | + device = "cpu" |
| 112 | + self._run_assert_should_not_throw(device) |
| 113 | + self._run_assert_inline_expression_should_not_throw(device) |
| 114 | + |
| 115 | + @requires_cuda_and_triton |
| 116 | + @skipIfRocm |
| 117 | + @torch._inductor.config.patch(force_disable_caches=True) |
| 118 | + def test_assert_fusion(self): |
| 119 | + torch._logging.set_logs(inductor_metrics=True) |
| 120 | + |
| 121 | + def func(): |
| 122 | + a = torch.tensor([1.0, 2.0], device="cuda") |
| 123 | + result = torch.all(a > 0) |
| 124 | + assert result, "should throw" |
| 125 | + |
| 126 | + torch._dynamo.reset() |
| 127 | + f_c = torch.compile(func, backend="inductor") |
| 128 | + metrics.reset() |
| 129 | + self.assertEqual(metrics.generated_kernel_count, 0) |
| 130 | + f_c() |
| 131 | + self.assertEqual(metrics.generated_kernel_count, 1) |
| 132 | + torch._logging.set_logs() |
| 133 | + |
| 134 | + @requires_cuda_and_triton |
| 135 | + @skipIfRocm |
| 136 | + @torch._inductor.config.patch(force_disable_caches=True) |
| 137 | + def test_run_assert_triton(self): |
| 138 | + should_throw = """ |
| 139 | +import torch |
| 140 | +import torch._dynamo |
| 141 | +
|
| 142 | +def func_should_throw(): |
| 143 | + a = torch.tensor([1.0, -2.0], device='cuda') |
| 144 | + result = torch.all(a > 0) |
| 145 | + assert result, "should throw" |
| 146 | +
|
| 147 | +def test_fn(): |
| 148 | + torch._dynamo.reset() |
| 149 | + f_c = torch.compile(func_should_throw, backend="inductor") |
| 150 | +
|
| 151 | + try: |
| 152 | + f_c() |
| 153 | + torch.cuda.synchronize() |
| 154 | + return False |
| 155 | + except Exception as e: |
| 156 | + return True |
| 157 | +
|
| 158 | +result = test_fn() |
| 159 | +print(f"Test result: {result}") |
| 160 | +""" |
| 161 | + |
| 162 | + should_not_throw = """ |
| 163 | +import torch |
| 164 | +import torch._dynamo |
| 165 | +
|
| 166 | +def func_should_not_throw(): |
| 167 | + a = torch.tensor([1.0, 2.0], device='cuda') |
| 168 | + result = torch.all(a > 0) |
| 169 | + assert result, "should throw" |
| 170 | +
|
| 171 | +def test_fn(): |
| 172 | + torch._dynamo.reset() |
| 173 | + f_c = torch.compile(func_should_not_throw, backend="inductor") |
| 174 | +
|
| 175 | + try: |
| 176 | + f_c() |
| 177 | + torch.cuda.synchronize() |
| 178 | + return True |
| 179 | + except Exception as e: |
| 180 | + return False |
| 181 | +
|
| 182 | +result = test_fn() |
| 183 | +print(f"Test result: {result}") |
| 184 | +""" |
| 185 | + for script in [should_not_throw, should_throw]: |
| 186 | + p = subprocess.run( |
| 187 | + [sys.executable, "-c", script], |
| 188 | + cwd=os.path.dirname(os.path.realpath(__file__)), |
| 189 | + capture_output=True, |
| 190 | + text=True, |
| 191 | + ) |
| 192 | + |
| 193 | + output = p.stdout + "\n" + p.stderr |
| 194 | + |
| 195 | + self.assertIn("Test result: True", output) |
| 196 | + |
| 197 | + if p.returncode != 0: |
| 198 | + self.fail( |
| 199 | + f"Subprocess failed with return code {p.returncode}. Output: {output}" |
| 200 | + ) |
| 201 | + |
| 202 | + |
| 203 | +if __name__ == "__main__": |
| 204 | + run_tests() |
0 commit comments