[annotation][export] Add metadata hook for all nodes created in runtime_assert pass#173970
[annotation][export] Add metadata hook for all nodes created in runtime_assert pass#173970
Conversation
This PR needs a
|
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/173970
Note: Links to docs will display an error until the docs builds have been completed. ❗ 1 Active SEVsThere are 1 currently active SEVs. If your PR is affected, please view them below: ❌ 1 New Failure, 4 Unrelated FailuresAs of commit 32bd86a with merge base 58a92ff ( NEW FAILURE - The following job has failed:
FLAKY - The following job failed but was likely due to flakiness present on trunk:
BROKEN TRUNK - The following jobs failed but was present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
openreg test failed |
|
@pytorchbot merge -i |
Merge startedYour change will be merged while ignoring the following 5 checks: pull / linux-docs / build-docs-python-false, pull / linux-jammy-py3.10-clang18-asan / test (openreg, 1, 1, linux.4xlarge), inductor / unit-test / inductor-halide-test / test (inductor-halide, 1, 1, linux.12xlarge), trunk / linux-jammy-cuda13.0-py3.10-gcc11 / test (default, 3, 5, linux.g6.4xlarge.experimental.nvidia.gpu), trunk / linux-jammy-cuda12.8-py3.10-gcc11 / test (default, 3, 5, linux.g6.4xlarge.experimental.nvidia.gpu) Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
…me_assert pass (pytorch#173970) Reland of pytorch#169497 Fixes the case below, all nodes, including runtime assertion nodes, should have annotations. ``` import torch class Foo(torch.nn.Module): # check sym ops only get computed once @torch._dynamo.disable() def forward(self, x, y): if ( x.shape[0] ** 2 - y.shape[0] ** 2 >= 4 # 16 and x.shape[0] ** 2 - y.shape[0] ** 2 <= 20 and x.shape[0] ** 2 - y.shape[0] ** 2 != 15 ): return x * 2, y * 2 inputs = (torch.randn(5), torch.randn(3)) shapes = {"x": (torch.export.Dim("dx"),), "y": (torch.export.Dim("dy"),)} with torch.fx.traceback.preserve_node_meta(): ep = torch.export.export( Foo(), inputs, dynamic_shapes=shapes, prefer_deferred_runtime_asserts_over_guards=True, ) ep.module().print_readable() ``` ``` class GraphModule(torch.nn.Module): def forward(self, x, y): x: "f32[s77]"; y: "f32[s17]"; x, y, = fx_pytree.tree_flatten_spec(([x, y], {}), self._in_spec) # No stacktrace found for following nodes _guards_fn = self._guards_fn(x, y); _guards_fn = None # Annotation: {'_torchdynamo_disable': True, '_torchdynamo_disable_recursive': True, '_torchdynamo_disable_method': 'dispatch_trace'} No stacktrace found for following nodes sym_size_int_2: "Sym(s77)" = torch.ops.aten.sym_size.int(x, 0) sym_size_int_3: "Sym(s17)" = torch.ops.aten.sym_size.int(y, 0) # Annotation: {'_torchdynamo_disable': True, '_torchdynamo_disable_recursive': True, '_torchdynamo_disable_method': 'forward'} File: /data/users/shangdiy/pytorch/test.py:12 in forward, code: return x * 2, y * 2 mul: "f32[s77]" = torch.ops.aten.mul.Tensor(x, 2); x = None pow_7: "Sym(s17**2)" = sym_size_int_3 ** 2; sym_size_int_3 = None add: "Sym(s17**2 + 4)" = 4 + pow_7 pow_8: "Sym(s77**2)" = sym_size_int_2 ** 2; sym_size_int_2 = None le_1: "Sym(s17**2 + 4 <= s77**2)" = add <= pow_8; add = None _assert_scalar_default = torch.ops.aten._assert_scalar.default(le_1, "Runtime assertion failed for expression s17**2 + 4 <= s77**2 on node 'le_1'"); le_1 = _assert_scalar_default = None add_1: "Sym(s17**2 + 20)" = 20 + pow_7 le_2: "Sym(s77**2 <= s17**2 + 20)" = pow_8 <= add_1; add_1 = None _assert_scalar_default_1 = torch.ops.aten._assert_scalar.default(le_2, "Runtime assertion failed for expression s77**2 <= s17**2 + 20 on node 'le_2'"); le_2 = _assert_scalar_default_1 = None add_2: "Sym(s17**2 + 15)" = 15 + pow_7; pow_7 = None eq: "Sym(Eq(s77**2, s17**2 + 15))" = pow_8 == add_2; pow_8 = add_2 = None sym_not: "Sym(Ne(s77**2, s17**2 + 15))" = torch.sym_not(eq); eq = None _assert_scalar_default_2 = torch.ops.aten._assert_scalar.default(sym_not, "Runtime assertion failed for expression Ne(s77**2, s17**2 + 15) on node 'sym_not'"); sym_not = _assert_scalar_default_2 = None mul_1: "f32[s17]" = torch.ops.aten.mul.Tensor(y, 2); y = None return pytree.tree_unflatten((mul, mul_1), self._out_spec) ``` ``` python test/dynamo/test_higher_order_ops.py -k test_concat_unbacked_shape_tensor -k test_tensor_and_unbacked_symbol_closure -k test_tensor_with_unbacked_shape_closure ``` Pull Request resolved: pytorch#173970 Approved by: https://github.com/pianpwk
Reland of #169497
Fixes the case below, all nodes, including runtime assertion nodes, should have annotations.
cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @kadeng @chauhang @amjames @Lucaskabela @jataylo