Skip to content

Commit 44186a0

Browse files
isurufpytorchmergebot
authored andcommitted
1 parent 29ca448 commit 44186a0

20 files changed

+589
-579
lines changed

test/distributed/test_inductor_collectives.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -580,8 +580,8 @@ def example(inp, *, tag, ranks, group_size):
580580
.check_regex(
581581
"torch.ops._c10d_functional.all_to_all_single.default\\("
582582
"arg\\d+_\\d+, "
583-
"\\[\\(s\\d+ // \\d\\), \\(s\\d+ // \\d\\)\\], "
584-
"\\[\\(s\\d+ // \\d\\), \\(s\\d+ // \\d\\)\\]"
583+
"\\[s\\d+ // \\d, s\\d+ // \\d\\], "
584+
"\\[s\\d+ // \\d, s\\d+ // \\d\\]"
585585
)
586586
.run(code)
587587
)

test/dynamo/test_export.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3570,7 +3570,7 @@ def forward(self, pred, x):
35703570
"cast_symbool_to_symint_guardless(L['pred']) == 1",
35713571
]
35723572
false_guard_code = [
3573-
"Ne(cast_symbool_to_symint_guardless(L['pred']), 1)",
3573+
"cast_symbool_to_symint_guardless(L['pred']) != 1",
35743574
]
35753575
test_symbool_guards(
35763576
f,

test/dynamo/test_logging.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -668,7 +668,7 @@ def f(x, y, z):
668668
"""\
669669
+- LAMBDA_GUARD: L['x'].size()[0] == 2*L['z'].size()[0] # return x + torch.cat([y, z]) # #:# in # #:# in #
670670
+- LAMBDA_GUARD: L['y'].size()[0] == L['z'].size()[0] # duck sizing added this equality because these variables had the same size 3 (to avoid this specialization, set torch.fx.experimental._config.use_duck_shape = False)
671-
+- LAMBDA_GUARD: Eq(Mod(2*L['z'].size()[0], 3), 0) # if x.size(0) % 3 == 0: # #:# in # #:# in #
671+
+- LAMBDA_GUARD: ((2*L['z'].size()[0]) % 3) == 0 # if x.size(0) % 3 == 0: # #:# in # #:# in #
672672
+- LAMBDA_GUARD: 2 <= L['z'].size()[0] # return x + torch.cat([y, z]) # #:# in # (user code shown is first use of this value--the guard itself is not due user code but due to 0/1 specialization in the framework; to avoid specialization try torch._dynamo.mark_unbacked(tensor, dim))""", # noqa: B950
673673
)
674674

test/dynamo/test_misc.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10457,7 +10457,7 @@ def test_shape_env_equal_evaluate_expr_divisible(self):
1045710457
ShapeEnv not equal: field values don't match:
1045810458
1045910459
==> axioms: values don't match.
10460-
> Left: {0 < Mod(s0, 3): False, 0 <= Mod(s0, 3): True, Eq(0, Mod(s0, 3)): True, Eq(Mod(s0, 3), 0): True, Mod(s0, 3) < 0: False, Mod(s0, 3) <= 0: True, Ne(0, Mod(s0, 3)): False, Ne(Mod(s0, 3), 0): False}
10460+
> Left: {(Mod(s0, 3)) < 0: False, (Mod(s0, 3)) <= 0: True, 0 < (Mod(s0, 3)): False, 0 <= (Mod(s0, 3)): True, Eq(0, Mod(s0, 3)): True, Eq(Mod(s0, 3), 0): True, Ne(0, Mod(s0, 3)): False, Ne(Mod(s0, 3), 0): False}
1046110461
> Right: {}
1046210462
==> divisible: values don't match.
1046310463
> Left: {Mod(s0, 3)}
@@ -10576,7 +10576,7 @@ def test_shape_env_equal_runtime_assert(self):
1057610576
ShapeEnv not equal: field values don't match:
1057710577
1057810578
==> axioms: values don't match.
10579-
> Left: {0 < PythonMod(u0, 3): False, 0 <= PythonMod(u0, 3): True, Eq(0, PythonMod(u0, 3)): True, Eq(PythonMod(u0, 3), 0): True, Ne(0, PythonMod(u0, 3)): False, Ne(PythonMod(u0, 3), 0): False, PythonMod(u0, 3) < 0: False, PythonMod(u0, 3) <= 0: True}
10579+
> Left: {(PythonMod(u0, 3)) < 0: False, (PythonMod(u0, 3)) <= 0: True, 0 < (PythonMod(u0, 3)): False, 0 <= (PythonMod(u0, 3)): True, Eq(0, PythonMod(u0, 3)): True, Eq(PythonMod(u0, 3), 0): True, Ne(0, PythonMod(u0, 3)): False, Ne(PythonMod(u0, 3), 0): False}
1058010580
> Right: {}
1058110581
==> deferred_runtime_asserts: values don't match.
1058210582
> Left: {u0: [Eq(PythonMod(u0, 3), 0)]}

test/export/test_export.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3259,9 +3259,9 @@ def forward(self, x, fixes):
32593259
(torch.tensor(20),),
32603260
fixes=[
32613261
# Could not guard on data-dependent expression Eq((u0//2), 0)
3262-
"torch._check(((i//2)) != 0)",
3262+
"torch._check((i // 2) != 0)",
32633263
# Could not guard on data-dependent expression Eq((u0//2), 1)
3264-
"torch._check(((i//2)) != 1)",
3264+
"torch._check((i // 2) != 1)",
32653265
],
32663266
)
32673267

test/inductor/test_cuda_repro.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1426,12 +1426,12 @@ def triton_poi_fused_add_reflection_pad2d_0(in_ptr0, in_ptr1, out_ptr0, xnumel,
14261426
xoffset = tl.program_id(0) * XBLOCK
14271427
xindex = xoffset + tl.arange(0, XBLOCK)[:]
14281428
xmask = xindex < xnumel
1429-
x0 = xindex % 20
1430-
x1 = (xindex // 20) % 20
1431-
x2 = (xindex // 400)
1429+
x0 = (xindex % 20)
1430+
x1 = ((xindex // 20) % 20)
1431+
x2 = xindex // 400
14321432
x3 = xindex
1433-
tmp0 = tl.load(in_ptr0 + (99 + ((-1)*(tl_math.abs((-9) + (tl_math.abs((-5) + x0))))) + ((-10)*(tl_math.abs((-9) + (tl_math.abs((-5) + x1))))) + (100*x2)), xmask, eviction_policy='evict_last')
1434-
tmp1 = tl.load(in_ptr1 + (99 + ((-1)*(tl_math.abs((-9) + (tl_math.abs((-5) + x0))))) + ((-10)*(tl_math.abs((-9) + (tl_math.abs((-5) + x1))))) + (100*x2)), xmask, eviction_policy='evict_last')
1433+
tmp0 = tl.load(in_ptr0 + (99 + ((-1)*tl_math.abs((-9) + tl_math.abs((-5) + x0))) + ((-10)*tl_math.abs((-9) + tl_math.abs((-5) + x1))) + 100*x2), xmask, eviction_policy='evict_last')
1434+
tmp1 = tl.load(in_ptr1 + (99 + ((-1)*tl_math.abs((-9) + tl_math.abs((-5) + x0))) + ((-10)*tl_math.abs((-9) + tl_math.abs((-5) + x1))) + 100*x2), xmask, eviction_policy='evict_last')
14351435
tmp2 = tmp0 + tmp1
14361436
tl.store(out_ptr0 + (x3), tmp2, xmask)""", # noqa: B950
14371437
)

test/inductor/test_indexing.py

Lines changed: 37 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,9 @@
2020
from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_CPU, HAS_GPU
2121
from torch.utils._sympy.functions import (
2222
FloorDiv,
23+
Mod,
2324
ModularIndexing,
25+
PythonMod,
2426
RoundDecimal,
2527
RoundToInt,
2628
)
@@ -236,7 +238,7 @@ def f(x):
236238
triton_code = run_and_get_triton_code(f, x)
237239
# Make sure the 2 load uses simpified indexing rather than something like
238240
# tl.load(in_ptr0 + ((5504*x1) + (x0 // 2)),
239-
self.assertEqual(2, triton_code.count("tl.load(in_ptr0 + ((x2 // 2)),"))
241+
self.assertEqual(2, triton_code.count("tl.load(in_ptr0 + (x2 // 2),"))
240242
if DO_PERF_TEST:
241243
ms = benchmarker.benchmark_gpu(lambda: f(x))
242244
print(f"{ms=:.03f}")
@@ -313,6 +315,39 @@ def test_print_round(self):
313315
self.assertExpectedInline(cexpr(expr), """std::lrint((1.0/2.0)*x)""")
314316
self.assertExpectedInline(texpr(expr), """libdevice.llrint((1/2)*x)""")
315317

318+
def test_print_mod(self):
319+
x = sympy.Symbol("x", integer=True)
320+
expr = Mod(x - 1, 2)
321+
self.assertExpectedInline(pexpr(expr), """((-1) + x) % 2""")
322+
self.assertExpectedInline(cexpr(expr), """((-1L) + x) % 2L""")
323+
self.assertExpectedInline(texpr(expr), """((-1) + x) % 2""")
324+
325+
expr = (x - 10) % x
326+
self.assertExpectedInline(pexpr(expr), """(-10) % x""")
327+
self.assertExpectedInline(cexpr(expr), """(-10L) % x""")
328+
self.assertExpectedInline(texpr(expr), """(-10) % x""")
329+
330+
def test_print_mod_index(self):
331+
x = sympy.Symbol("x", integer=True)
332+
ks = sympy.Symbol("ks", integer=True)
333+
expr = ModularIndexing(x - 10, ks, ks)
334+
self.assertExpectedInline(pexpr(expr), """((((-10) + x) // ks) % ks)""")
335+
self.assertExpectedInline(
336+
cexpr(expr),
337+
"""(static_cast<int64_t>(c10::div_floor_integer("""
338+
"""static_cast<int64_t>((-10L) + x), static_cast<int64_t>(ks))) % static_cast<int64_t>(ks))""",
339+
)
340+
self.assertExpectedInline(texpr(expr), """((((-10) + x) // ks) % ks)""")
341+
342+
def test_print_python_mod(self):
343+
x = sympy.Symbol("x", integer=True)
344+
expr = PythonMod(x - 10, x)
345+
self.assertExpectedInline(pexpr(expr), """((-10) + x) % x""")
346+
self.assertExpectedInline(cexpr(expr), """((-10L) + x) % x""")
347+
self.assertExpectedInline(
348+
texpr(expr), """triton_helpers.remainder_integer((-10) + x, x)"""
349+
)
350+
316351
@parametrize("ndigits", [-1, 0, 1])
317352
def test_print_round_decimal(self, ndigits):
318353
expr = RoundDecimal(sympy.Symbol("x", integer=True) / 2, ndigits)
@@ -330,7 +365,7 @@ def test_print_floor_div(self):
330365
s1 = sympy.Symbol("s1", integer=True)
331366
s2 = sympy.Symbol("s2", integer=True)
332367
expr = FloorDiv(s1, s2)
333-
self.assertEqual(pexpr(expr), "(s1 // s2)")
368+
self.assertEqual(pexpr(expr), "s1 // s2")
334369
self.assertEqual(
335370
cexpr(expr),
336371
"c10::div_floor_integer(static_cast<int64_t>(s1), static_cast<int64_t>(s2))",

test/inductor/test_memory_planning.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -58,13 +58,11 @@ def test_python_wrapper(self):
5858
result, code = run_and_get_cpp_code(compiled, *args)
5959

6060
FileCheck().check(
61-
"pool1 = empty_strided_"
62-
+ GPU_TYPE
63-
+ "(((4*s0*s1) + (align(4*(s0*s0))), ), (1, )"
61+
"pool1 = empty_strided_" + GPU_TYPE + "((4*s0*s1 + align(4*s0*s0), ), (1, )"
6462
).check_next(
6563
"buf0 = alloc_from_pool(pool1, 0, torch.float32, (s0, s0), (s0, 1))"
6664
).check(
67-
"buf1 = alloc_from_pool(pool1, align(4*(s0*s0)),"
65+
"buf1 = alloc_from_pool(pool1, align(4*s0*s0),"
6866
).run(
6967
code
7068
)
@@ -103,7 +101,7 @@ def test_aoti(self):
103101
)
104102

105103
FileCheck().check(
106-
"int64_t int_array_2[] = {24L + (align(12L*s0)), };"
104+
"int64_t int_array_2[] = {24L + align(12L*s0), };"
107105
).check_next("int64_t int_array_3[] = {1L, };").check_next(
108106
"AtenTensorHandle pool1_handle;"
109107
).check_next(

test/inductor/test_padding.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -487,7 +487,7 @@ def test_LinearAndSoftmax_codegen(self, bias=True):
487487

488488
# make sure the load for softmax is aligned
489489
self.assertTrue(
490-
"tl.load(in_ptr0 + (r1 + (30528*x0))" in forward_wrapper,
490+
"tl.load(in_ptr0 + (r1 + 30528*x0)" in forward_wrapper,
491491
f"forward_wrapper: {forward_wrapper}",
492492
)
493493

test/inductor/test_torchinductor.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12505,8 +12505,8 @@ def f(a, b):
1250512505
self.assertExpectedInline(
1250612506
"\n".join(lines),
1250712507
"""\
12508-
tmp0 = tl.load(in_ptr0 + (x1 + (512*x0) + (262144*r2)), rmask, eviction_policy='evict_last', other=0.0)
12509-
tmp1 = tl.load(in_ptr1 + (x3 + (262144*r2)), rmask, eviction_policy='evict_first', other=0.0)""",
12508+
tmp0 = tl.load(in_ptr0 + (x1 + 512*x0 + 262144*r2), rmask, eviction_policy='evict_last', other=0.0)
12509+
tmp1 = tl.load(in_ptr1 + (x3 + 262144*r2), rmask, eviction_policy='evict_first', other=0.0)""",
1251012510
)
1251112511

1251212512
@config.patch("triton.use_block_ptr", True)
@@ -12538,7 +12538,7 @@ def f(a, b):
1253812538
self.assertExpectedInline(
1253912539
"\n".join(lines),
1254012540
"""\
12541-
tmp0 = tl.reshape(tl.broadcast_to(tl.load(block_ptr0, boundary_check=[2], padding_option='zero', eviction_policy='evict_last')[:, None, :, :], [((511 + XBLOCK) // 512), ((1) * ((1) <= (((511 + XBLOCK) // 512))) + (((511 + XBLOCK) // 512)) * ((((511 + XBLOCK) // 512)) < (1))), ((512) * ((512) <= (XBLOCK)) + (XBLOCK) * ((XBLOCK) < (512))), RBLOCK]), [XBLOCK, RBLOCK])
12541+
tmp0 = tl.reshape(tl.broadcast_to(tl.load(block_ptr0, boundary_check=[2], padding_option='zero', eviction_policy='evict_last')[:, None, :, :], [(511 + XBLOCK) // 512, ((1) * ((1) <= ((511 + XBLOCK) // 512)) + ((511 + XBLOCK) // 512) * (((511 + XBLOCK) // 512) < (1))), ((512) * ((512) <= (XBLOCK)) + (XBLOCK) * ((XBLOCK) < (512))), RBLOCK]), [XBLOCK, RBLOCK])
1254212542
tmp1 = tl.load(block_ptr1, boundary_check=[1], padding_option='zero', eviction_policy='evict_first')""", # noqa: B950 line too long
1254312543
)
1254412544

0 commit comments

Comments
 (0)