Skip to content

Commit afabfc4

Browse files
committed
Update on "Refactor layout constraint selection logic"
Significant cleanup of the code (it has gotten bad over time). This PR: - does some deduplication - cleans up the "lazy registration path" which seems to never get hit anymore... Test Plan: - tests + CI [ghstack-poisoned]
2 parents b415282 + 4e80e2b commit afabfc4

File tree

7 files changed

+52
-61
lines changed

7 files changed

+52
-61
lines changed

test/inductor/test_triton_kernels.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3451,6 +3451,7 @@ def impl2(x):
34513451

34523452
lib.define(
34533453
"add_op(Tensor x, Tensor y) -> Tensor",
3454+
tags=[torch._C.Tag.needs_exact_strides],
34543455
)
34553456

34563457
def impl(x, y):
@@ -3464,6 +3465,7 @@ def meta(x, y):
34643465

34653466
lib.define(
34663467
"add_out_op(Tensor x, Tensor y, Tensor(a!) out) -> ()",
3468+
tags=[torch._C.Tag.needs_exact_strides],
34673469
)
34683470

34693471
def impl_out(x, y, out):

torch/_inductor/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ def prologue_fusion_enabled() -> bool:
126126
# If the custom op does not have a layout constraint tag already
127127
# then we assume the following applies.
128128
custom_op_default_layout_constraint: Literal[
129-
"needs_fixed_stride_order", "flexible_layout"
129+
"needs_exact_strides", "needs_fixed_stride_order", "flexible_layout"
130130
] = "needs_fixed_stride_order"
131131

132132
# The default layout constraint for user-defined triton kernels.

torch/_inductor/graph.py

Lines changed: 21 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -80,11 +80,13 @@
8080
FALLBACK_ALLOW_LIST,
8181
fallback_handler,
8282
fallback_node_due_to_unsupported_type,
83+
get_layout_constraint_tag,
8384
lowerings,
8485
make_fallback,
8586
maybe_layout_constraints,
8687
needs_realized_inputs,
8788
require_contiguous,
89+
tag_to_layout_constraint,
8890
unsupported_output_tensor,
8991
)
9092
from .runtime import autotune_cache
@@ -1149,34 +1151,26 @@ def call_function(self, target: Callable, args: Any, kwargs: dict[str, Any]) ->
11491151
error.operator_str(target, args, kwargs),
11501152
)
11511153

1152-
# use contiguous unless the (custom) op asks something else
1153-
# explicitly
1154-
if torch._library.utils.needs_exact_strides(target):
1155-
decided_constraint = constrain_to_fake_tensors
1156-
elif torch._C.Tag.needs_fixed_stride_order in target.tags:
1157-
decided_constraint = constrain_to_fx_strides # type: ignore[assignment]
1158-
elif torch._C.Tag.flexible_layout in target.tags:
1159-
decided_constraint = None # type: ignore[assignment]
1160-
else:
1161-
# If there are no tags, we do different things depending on
1162-
# if it's a builtin ATen/prim ops or custom ops.
1163-
# For ATen ops, we require_contiguous to fix https://github.com/pytorch/pytorch/issues/140452
1164-
# For custom ops, we constrain_to_fx_strides to maintain the
1165-
# behavior of PyTorch 2.5: https://github.com/pytorch/pytorch/issues/148356
1154+
tag = get_layout_constraint_tag(target, with_default=False)
1155+
if (
1156+
tag is None
1157+
and torch._library.utils.is_builtin(target)
1158+
and self.is_backward
1159+
):
1160+
# for implicit fallback ATen ops during backward, if there
1161+
# is no layout constraint tag, we conservatively require contiguous
1162+
# input since some eager kernels do not
1163+
# support non-contiguous inputs. Otherwise they may silently cause
1164+
# accuracy problems. Check https://github.com/pytorch/pytorch/issues/140452
1165+
# We only do this For ATen ops and for backward.
11661166
#
1167-
# For ATen ops, only apply the constraint for backward
1168-
# ops since fwd ops should work for any strides.
1169-
if torch._library.utils.is_builtin(target) and self.is_backward:
1170-
decided_constraint = require_contiguous # type: ignore[assignment]
1171-
else:
1172-
# maybe_layout_constraints will decide the layout constraint for the custom op
1173-
# lazily
1174-
decided_constraint = None # type: ignore[assignment]
1175-
1176-
# for implicitly fallback ops, we conservatively requires
1177-
# contiguous input since some eager kernels does not
1178-
# support non-contiguous inputs. They may silently cause
1179-
# accuracy problems. Check https://github.com/pytorch/pytorch/issues/140452
1167+
# TODO: should really switch to "needs_fixed_stride" constraint on these
1168+
# and identify them one by one.
1169+
decided_constraint = require_contiguous # type: ignore[assignment]
1170+
else:
1171+
tag = get_layout_constraint_tag(target, with_default=True)
1172+
decided_constraint = tag_to_layout_constraint(tag)
1173+
11801174
make_fallback(target, layout_constraint=decided_constraint)
11811175

11821176
elif get_decompositions([target]):

torch/_inductor/lowering.py

Lines changed: 24 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -156,37 +156,40 @@ def maybe_layout_constraints(fn: Callable[..., Any]) -> Optional[Callable[..., A
156156
return None
157157
if fn in _maybe_layout_constraints:
158158
return _maybe_layout_constraints[fn]
159-
# OpOverload with custom lowerings override tag-based layout constraints
160-
if fn in lowerings:
161-
_maybe_layout_constraints[fn] = None
162-
return None
163-
# We lazily register tag-based layout constraints.
164-
165-
def handle_layout_constraint_tag(tag):
166-
if tag is torch._C.Tag.needs_fixed_stride_order:
167-
_maybe_layout_constraints[fn] = constrain_to_fx_strides
168-
return _maybe_layout_constraints[fn]
169-
elif tag is torch._C.Tag.flexible_layout:
170-
_maybe_layout_constraints[fn] = None
171-
return None
172-
else:
173-
raise AssertionError(f"Unknown layout constraint tag: {tag}")
159+
return None
160+
174161

175-
tag = get_layout_constraint_tag(fn)
176-
return handle_layout_constraint_tag(tag)
162+
tags_by_priority = [
163+
torch._C.Tag.needs_exact_strides,
164+
torch._C.Tag.needs_fixed_stride_order,
165+
torch._C.Tag.flexible_layout,
166+
]
177167

178168

179-
def get_layout_constraint_tag(fn):
169+
def get_layout_constraint_tag(fn, *, with_default=True):
180170
tags_by_priority = [
171+
torch._C.Tag.needs_exact_strides,
181172
torch._C.Tag.needs_fixed_stride_order,
182173
torch._C.Tag.flexible_layout,
183174
]
184175
for tag in tags_by_priority:
185176
if tag in fn.tags:
186177
return tag
187-
if torch._library.utils.is_builtin(fn):
188-
return torch._C.Tag.flexible_layout
189-
return getattr(torch._C.Tag, config.custom_op_default_layout_constraint)
178+
if with_default:
179+
if torch._library.utils.is_builtin(fn):
180+
return torch._C.Tag.flexible_layout
181+
return getattr(torch._C.Tag, config.custom_op_default_layout_constraint)
182+
return None
183+
184+
185+
def tag_to_layout_constraint(tag):
186+
if tag == torch._C.Tag.needs_exact_strides:
187+
return constrain_to_fake_tensors
188+
if tag == torch._C.Tag.needs_fixed_stride_order:
189+
return constrain_to_fx_strides
190+
if tag == torch._C.Tag.flexible_layout:
191+
return None
192+
raise AssertionError(f"Unknown layout constraint tag: {tag}")
190193

191194

192195
def assert_nyi(cond, msg):

torch/_library/custom_ops.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -615,7 +615,7 @@ def _register_to_dispatcher(self, tags: Sequence[_C.Tag]) -> None:
615615

616616
lib.define(
617617
schema_str,
618-
tags=[_C.Tag.pt2_compliant_tag, *tags],
618+
tags=[_C.Tag.pt2_compliant_tag, _C.Tag.needs_fixed_stride_order, *tags],
619619
)
620620
self._opoverload = utils.lookup_op(self._qualname)
621621

torch/_library/utils.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -215,16 +215,6 @@ def zip_schema(
215215
return
216216

217217

218-
def needs_exact_strides(op: torch._ops.OpOverload):
219-
if torch._C.Tag.needs_exact_strides in op.tags:
220-
return True
221-
if torch._C.Tag.flexible_layout in op.tags:
222-
return False
223-
if torch._C.Tag.needs_fixed_stride_order in op.tags:
224-
return False
225-
return not is_builtin(op)
226-
227-
228218
def hop_schema_from_fx_node(node):
229219
from torchgen.gen_schema_utils import FunctionSchemaGen
230220

torch/fx/experimental/proxy_tensor.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1169,7 +1169,9 @@ def _should_save_eager_input_vals(
11691169
f"propagate the FakeTensor vals. Please file an issue."
11701170
)
11711171
if isinstance(target, torch._ops.OpOverload):
1172-
return torch._library.utils.needs_exact_strides(target)
1172+
from torch._inductor.lowering import get_layout_constraint_tag
1173+
1174+
return get_layout_constraint_tag(target)
11731175
return False
11741176

11751177

0 commit comments

Comments
 (0)