Skip to content

Commit 3c622fb

Browse files
peterbell10pytorchmergebot
authored andcommitted
[inductor] Fix var_to_range in IndexPropagation (#130984)
The current code assumes that indirect variables will be created by the same `IndexPropagation` instance, however that isn't true in the case of masked sub-blocks where we take in variables from the parent block. This fixes the issue by moving the var range information up to the `LoopBody` object where it can be shared by all sub-blocks. Pull Request resolved: #130984 Approved by: https://github.com/lezcano
1 parent b556d31 commit 3c622fb

File tree

2 files changed

+23
-8
lines changed

2 files changed

+23
-8
lines changed

torch/_inductor/index_propagation.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,12 @@ class IndexPropagation:
189189
190190
"""
191191

192-
def __init__(self, inner: Any, iter_ranges: Dict[sympy.Symbol, sympy.Expr]):
192+
def __init__(
193+
self,
194+
inner: Any,
195+
iter_ranges: Dict[sympy.Symbol, sympy.Expr],
196+
indirect_var_ranges: Dict[sympy.Symbol, sympy.Expr],
197+
):
193198
self._inner = inner
194199
self.shape_env = V.graph.sizevars.shape_env
195200

@@ -199,6 +204,9 @@ def __init__(self, inner: Any, iter_ranges: Dict[sympy.Symbol, sympy.Expr]):
199204
self.var_to_range = tuple(
200205
itertools.chain(self.shape_env.var_to_range.items(), var_to_range.items())
201206
)
207+
# NOTE: this is intentionally kept as a reference so the caller can
208+
# update it in-place
209+
self.indirect_var_ranges = indirect_var_ranges
202210

203211
axioms = []
204212
for x, s in iter_ranges.items():
@@ -306,10 +314,17 @@ def statically_true(self, e):
306314
to perform wrap_expr and in CSEProxy.check_bounds to elide upper / lower bounds also
307315
for indirect_indexing
308316
"""
317+
var_to_range = (
318+
*self.var_to_range,
319+
*(
320+
(k, ValueRanges(0, upper_bound(v) - 1))
321+
for k, v in self.indirect_var_ranges.items()
322+
),
323+
)
309324
evaluated = self.shape_env._maybe_evaluate_static(
310325
e,
311326
axioms=self.axioms,
312-
var_to_range=self.var_to_range,
327+
var_to_range=var_to_range,
313328
)
314329
return bool(evaluated)
315330

@@ -351,9 +366,4 @@ def wrap_expr(expr):
351366
indirect_var = self.fallback(
352367
"indirect_indexing", (index, size, check), {}
353368
).value
354-
assert (
355-
indirect_var not in self.var_to_range
356-
), f"{indirect_var} should've been created in the fallback."
357-
indirect_range = (indirect_var, ValueRanges(0, upper_bound(size) - 1))
358-
self.var_to_range = self.var_to_range + (indirect_range,)
359369
return indirect_var

torch/_inductor/ir.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6479,6 +6479,7 @@ def __init__(self, fn, args, var_ranges):
64796479
self.submodules = {"get_index": self.get_index}
64806480
self.subblocks = {}
64816481
self.indirect_vars = []
6482+
self.indirect_var_ranges: Dict[sympy.Symbol, sympy.Expr] = {}
64826483
self.root_block = LoopBodyBlock(self, fn, args)
64836484
self.indexing = None
64846485

@@ -6531,7 +6532,9 @@ def add_submodule(self, block, prefix):
65316532

65326533
def add_indirect(self, size):
65336534
var = sympy_index_symbol_with_prefix(SymT.INDIRECT, len(self.indirect_vars))
6535+
assert var not in self.indirect_var_ranges
65346536
self.indirect_vars.append(var)
6537+
self.indirect_var_ranges[var] = size
65356538
return var
65366539

65376540
def replace_indirect(self, old, new):
@@ -6712,7 +6715,9 @@ def output(result):
67126715
CaptureIndexing(proxy_ops), self.body.var_ranges
67136716
)
67146717
if config.constant_and_index_propagation:
6715-
handler = IndexPropagation(handler, self.body.var_ranges)
6718+
handler = IndexPropagation(
6719+
handler, self.body.var_ranges, self.body.indirect_var_ranges
6720+
)
67166721

67176722
with V.set_ops_handler(handler):
67186723
# This indirection is just a cute way to get IndexPropagation to

0 commit comments

Comments
 (0)