Skip to content

[annotation][export] Add metadata hook for all nodes created in runtime_assert pass#173970

Closed
yushangdi wants to merge 1 commit intomainfrom
sy_export_annotation_2
Closed

[annotation][export] Add metadata hook for all nodes created in runtime_assert pass#173970
yushangdi wants to merge 1 commit intomainfrom
sy_export_annotation_2

Conversation

@yushangdi
Copy link
Contributor

@yushangdi yushangdi commented Jan 30, 2026

Reland of #169497

Fixes the case below, all nodes, including runtime assertion nodes, should have annotations.

import torch

class Foo(torch.nn.Module):
    # check sym ops only get computed once
    @torch._dynamo.disable()
    def forward(self, x, y):
        if (
            x.shape[0] ** 2 - y.shape[0] ** 2 >= 4  # 16
            and x.shape[0] ** 2 - y.shape[0] ** 2 <= 20
            and x.shape[0] ** 2 - y.shape[0] ** 2 != 15
        ):
            return x * 2, y * 2



inputs = (torch.randn(5), torch.randn(3))
shapes = {"x": (torch.export.Dim("dx"),), "y": (torch.export.Dim("dy"),)}
with torch.fx.traceback.preserve_node_meta():
    ep = torch.export.export(
        Foo(),
        inputs,
        dynamic_shapes=shapes,
        prefer_deferred_runtime_asserts_over_guards=True,
    )

ep.module().print_readable()


class GraphModule(torch.nn.Module):
    def forward(self, x, y):
        x: "f32[s77]"; y: "f32[s17]"; 
    
        x, y, = fx_pytree.tree_flatten_spec(([x, y], {}), self._in_spec)
        # No stacktrace found for following nodes
        _guards_fn = self._guards_fn(x, y);  _guards_fn = None
        
        # Annotation: {'_torchdynamo_disable': True, '_torchdynamo_disable_recursive': True, '_torchdynamo_disable_method': 'dispatch_trace'} No stacktrace found for following nodes
        sym_size_int_2: "Sym(s77)" = torch.ops.aten.sym_size.int(x, 0)
        sym_size_int_3: "Sym(s17)" = torch.ops.aten.sym_size.int(y, 0)
        
        # Annotation: {'_torchdynamo_disable': True, '_torchdynamo_disable_recursive': True, '_torchdynamo_disable_method': 'forward'} File: /data/users/shangdiy/pytorch/test.py:12 in forward, code: return x * 2, y * 2
        mul: "f32[s77]" = torch.ops.aten.mul.Tensor(x, 2);  x = None
        pow_7: "Sym(s17**2)" = sym_size_int_3 ** 2;  sym_size_int_3 = None
        add: "Sym(s17**2 + 4)" = 4 + pow_7
        pow_8: "Sym(s77**2)" = sym_size_int_2 ** 2;  sym_size_int_2 = None
        le_1: "Sym(s17**2 + 4 <= s77**2)" = add <= pow_8;  add = None
        _assert_scalar_default = torch.ops.aten._assert_scalar.default(le_1, "Runtime assertion failed for expression s17**2 + 4 <= s77**2 on node 'le_1'");  le_1 = _assert_scalar_default = None
        add_1: "Sym(s17**2 + 20)" = 20 + pow_7
        le_2: "Sym(s77**2 <= s17**2 + 20)" = pow_8 <= add_1;  add_1 = None
        _assert_scalar_default_1 = torch.ops.aten._assert_scalar.default(le_2, "Runtime assertion failed for expression s77**2 <= s17**2 + 20 on node 'le_2'");  le_2 = _assert_scalar_default_1 = None
        add_2: "Sym(s17**2 + 15)" = 15 + pow_7;  pow_7 = None
        eq: "Sym(Eq(s77**2, s17**2 + 15))" = pow_8 == add_2;  pow_8 = add_2 = None
        sym_not: "Sym(Ne(s77**2, s17**2 + 15))" = torch.sym_not(eq);  eq = None
        _assert_scalar_default_2 = torch.ops.aten._assert_scalar.default(sym_not, "Runtime assertion failed for expression Ne(s77**2, s17**2 + 15) on node 'sym_not'");  sym_not = _assert_scalar_default_2 = None
        mul_1: "f32[s17]" = torch.ops.aten.mul.Tensor(y, 2);  y = None
        return pytree.tree_unflatten((mul, mul_1), self._out_spec)
python test/dynamo/test_higher_order_ops.py -k  test_concat_unbacked_shape_tensor  -k test_tensor_and_unbacked_symbol_closure -k test_tensor_with_unbacked_shape_closure

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @kadeng @chauhang @amjames @Lucaskabela @jataylo

@pytorch-bot
Copy link

pytorch-bot bot commented Jan 30, 2026

This PR needs a release notes: label

If your changes are user facing and intended to be a part of release notes, please use a label starting with release notes:.

If not, please add the topic: not user facing label.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "topic: not user facing"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

@pytorch-bot
Copy link

pytorch-bot bot commented Jan 30, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/173970

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

❌ 1 New Failure, 4 Unrelated Failures

As of commit 32bd86a with merge base 58a92ff (image):

NEW FAILURE - The following job has failed:

FLAKY - The following job failed but was likely due to flakiness present on trunk:

BROKEN TRUNK - The following jobs failed but was present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added the release notes: fx release notes category label Jan 30, 2026
@yushangdi yushangdi requested a review from pianpwk January 30, 2026 22:52
@yushangdi yushangdi added the ciflow/trunk Trigger trunk jobs on your pull request label Jan 31, 2026
@yushangdi
Copy link
Contributor Author

openreg test failed python -c 'import torch; torch._C._crash_if_csrc_asan(3)'. seem unrelated to the changes in PR.

@yushangdi
Copy link
Contributor Author

@pytorchbot merge -i

@pytorchmergebot
Copy link
Collaborator

radeksm pushed a commit to radeksm/pytorch that referenced this pull request Feb 20, 2026
…me_assert pass (pytorch#173970)

Reland of pytorch#169497

Fixes the case below, all nodes, including runtime assertion nodes, should have annotations.

```
import torch

class Foo(torch.nn.Module):
    # check sym ops only get computed once
    @torch._dynamo.disable()
    def forward(self, x, y):
        if (
            x.shape[0] ** 2 - y.shape[0] ** 2 >= 4  # 16
            and x.shape[0] ** 2 - y.shape[0] ** 2 <= 20
            and x.shape[0] ** 2 - y.shape[0] ** 2 != 15
        ):
            return x * 2, y * 2

inputs = (torch.randn(5), torch.randn(3))
shapes = {"x": (torch.export.Dim("dx"),), "y": (torch.export.Dim("dy"),)}
with torch.fx.traceback.preserve_node_meta():
    ep = torch.export.export(
        Foo(),
        inputs,
        dynamic_shapes=shapes,
        prefer_deferred_runtime_asserts_over_guards=True,
    )

ep.module().print_readable()

```

```
class GraphModule(torch.nn.Module):
    def forward(self, x, y):
        x: "f32[s77]"; y: "f32[s17]";

        x, y, = fx_pytree.tree_flatten_spec(([x, y], {}), self._in_spec)
        # No stacktrace found for following nodes
        _guards_fn = self._guards_fn(x, y);  _guards_fn = None

        # Annotation: {'_torchdynamo_disable': True, '_torchdynamo_disable_recursive': True, '_torchdynamo_disable_method': 'dispatch_trace'} No stacktrace found for following nodes
        sym_size_int_2: "Sym(s77)" = torch.ops.aten.sym_size.int(x, 0)
        sym_size_int_3: "Sym(s17)" = torch.ops.aten.sym_size.int(y, 0)

        # Annotation: {'_torchdynamo_disable': True, '_torchdynamo_disable_recursive': True, '_torchdynamo_disable_method': 'forward'} File: /data/users/shangdiy/pytorch/test.py:12 in forward, code: return x * 2, y * 2
        mul: "f32[s77]" = torch.ops.aten.mul.Tensor(x, 2);  x = None
        pow_7: "Sym(s17**2)" = sym_size_int_3 ** 2;  sym_size_int_3 = None
        add: "Sym(s17**2 + 4)" = 4 + pow_7
        pow_8: "Sym(s77**2)" = sym_size_int_2 ** 2;  sym_size_int_2 = None
        le_1: "Sym(s17**2 + 4 <= s77**2)" = add <= pow_8;  add = None
        _assert_scalar_default = torch.ops.aten._assert_scalar.default(le_1, "Runtime assertion failed for expression s17**2 + 4 <= s77**2 on node 'le_1'");  le_1 = _assert_scalar_default = None
        add_1: "Sym(s17**2 + 20)" = 20 + pow_7
        le_2: "Sym(s77**2 <= s17**2 + 20)" = pow_8 <= add_1;  add_1 = None
        _assert_scalar_default_1 = torch.ops.aten._assert_scalar.default(le_2, "Runtime assertion failed for expression s77**2 <= s17**2 + 20 on node 'le_2'");  le_2 = _assert_scalar_default_1 = None
        add_2: "Sym(s17**2 + 15)" = 15 + pow_7;  pow_7 = None
        eq: "Sym(Eq(s77**2, s17**2 + 15))" = pow_8 == add_2;  pow_8 = add_2 = None
        sym_not: "Sym(Ne(s77**2, s17**2 + 15))" = torch.sym_not(eq);  eq = None
        _assert_scalar_default_2 = torch.ops.aten._assert_scalar.default(sym_not, "Runtime assertion failed for expression Ne(s77**2, s17**2 + 15) on node 'sym_not'");  sym_not = _assert_scalar_default_2 = None
        mul_1: "f32[s17]" = torch.ops.aten.mul.Tensor(y, 2);  y = None
        return pytree.tree_unflatten((mul, mul_1), self._out_spec)
```

```
python test/dynamo/test_higher_order_ops.py -k  test_concat_unbacked_shape_tensor  -k test_tensor_and_unbacked_symbol_closure -k test_tensor_with_unbacked_shape_closure

```

Pull Request resolved: pytorch#173970
Approved by: https://github.com/pianpwk
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/inductor ciflow/trunk Trigger trunk jobs on your pull request Merged module: dynamo release notes: fx release notes category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants