Skip to content

Conversation

@angelayi
Copy link
Contributor

Summary:
We previously incorrectly handled the following graph, specifically for the node w.3 in block0:

 graph(%x.1 : Float(3, strides=[1], requires_grad=0, device=cpu),
       %y.1 : int):
   %2 : __torch__.___torch_mangle_1.M = prim::CreateObject()
   %3 : int = prim::Constant[value=20](), scope: M:: # /data/users/angelayi/pytorch/test/export/test_converter.py:747:34
   %4 : int = prim::Constant[value=10](), scope: M:: # /data/users/angelayi/pytorch/test/export/test_converter.py:746:34
   %5 : int = prim::Constant[value=1](), scope: M::
   %w.1 : int = prim::GetAttr[name="w"](%2), scope: M::
   %7 : int = aten::mul(%w.1, %4), scope: M:: # /data/users/angelayi/pytorch/test/export/test_converter.py:746:25
    = prim::SetAttr[name="w"](%2, %7), scope: M::
   %h.1 : int = prim::GetAttr[name="h"](%2), scope: M::
   %9 : int = aten::mul(%h.1, %3), scope: M:: # /data/users/angelayi/pytorch/test/export/test_converter.py:747:25
    = prim::SetAttr[name="h"](%2, %9), scope: M::
   %10 : bool = aten::gt(%y.1, %4), scope: M:: # /data/users/angelayi/pytorch/test/export/test_converter.py:749:19
   %res.37 : Tensor = prim::If(%10), scope: M:: # /data/users/angelayi/pytorch/test/export/test_converter.py:749:16
     block0():
       %w.3 : int = prim::GetAttr[name="w"](%2), scope: M::
       %res.1 : Tensor = aten::add(%x.1, %w.3, %5), scope: M:: # <string>:5:9
       -> (%res.1)
     block1():
       %h.3 : int = prim::GetAttr[name="h"](%2), scope: M::
       %res.3 : Tensor = aten::add(%x.1, %h.3, %5), scope: M:: # <string>:5:9
       -> (%res.3)
   %16 : bool = aten::lt(%y.1, %4), scope: M:: # /data/users/angelayi/pytorch/test/export/test_converter.py:754:19
   %res : Tensor = prim::If(%16), scope: M:: # /data/users/angelayi/pytorch/test/export/test_converter.py:754:16
     block0():
       %w : int = prim::GetAttr[name="w"](%2), scope: M::
       %res.15 : Tensor = aten::add(%res.37, %w, %5), scope: M:: # <string>:5:9
       -> (%res.15)
     block1():
       %h : int = prim::GetAttr[name="h"](%2), scope: M::
       %res.21 : Tensor = aten::add(%res.37, %h, %5), scope: M:: # <string>:5:9
       -> (%res.21)
   return (%res)

Test Plan: CI

Differential Revision: D63399064

@pytorch-bot
Copy link

pytorch-bot bot commented Sep 25, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/136648

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 New Failure, 1 Unrelated Failure

As of commit 96c2827 with merge base 99eb47f (image):

NEW FAILURE - The following job has failed:

FLAKY - The following job failed but was likely due to flakiness present on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D63399064


# Mapping from torchscript node output name to attribute fully qualified name
self.name_to_attribute_fqn: Dict[str, str] = {}
self.name_to_attribute_fqn: Dict[str, str] = name_to_attribute_fqn
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am starting to loose track of the meaning of

name_to_param, name_to_buffer, name_to_node, name_to_constant, name_to_attribute_fqn.

Do you mind adding some comments on so that we don't confuse our future selves?

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Sep 26, 2024
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D63399064

angelayi added a commit to angelayi/pytorch that referenced this pull request Sep 26, 2024
Summary:
Pull Request resolved: pytorch#136648

We previously incorrectly handled the following graph, specifically for the node `w.3` in `block0`:
```
 graph(%x.1 : Float(3, strides=[1], requires_grad=0, device=cpu),
       %y.1 : int):
   %2 : __torch__.___torch_mangle_1.M = prim::CreateObject()
   %3 : int = prim::Constant[value=20](), scope: M:: # /data/users/angelayi/pytorch/test/export/test_converter.py:747:34
   %4 : int = prim::Constant[value=10](), scope: M:: # /data/users/angelayi/pytorch/test/export/test_converter.py:746:34
   %5 : int = prim::Constant[value=1](), scope: M::
   %w.1 : int = prim::GetAttr[name="w"](%2), scope: M::
   %7 : int = aten::mul(%w.1, %4), scope: M:: # /data/users/angelayi/pytorch/test/export/test_converter.py:746:25
    = prim::SetAttr[name="w"](%2, %7), scope: M::
   %h.1 : int = prim::GetAttr[name="h"](%2), scope: M::
   %9 : int = aten::mul(%h.1, %3), scope: M:: # /data/users/angelayi/pytorch/test/export/test_converter.py:747:25
    = prim::SetAttr[name="h"](%2, %9), scope: M::
   %10 : bool = aten::gt(%y.1, %4), scope: M:: # /data/users/angelayi/pytorch/test/export/test_converter.py:749:19
   %res.37 : Tensor = prim::If(%10), scope: M:: # /data/users/angelayi/pytorch/test/export/test_converter.py:749:16
     block0():
       %w.3 : int = prim::GetAttr[name="w"](%2), scope: M::
       %res.1 : Tensor = aten::add(%x.1, %w.3, %5), scope: M:: # <string>:5:9
       -> (%res.1)
     block1():
       %h.3 : int = prim::GetAttr[name="h"](%2), scope: M::
       %res.3 : Tensor = aten::add(%x.1, %h.3, %5), scope: M:: # <string>:5:9
       -> (%res.3)
   %16 : bool = aten::lt(%y.1, %4), scope: M:: # /data/users/angelayi/pytorch/test/export/test_converter.py:754:19
   %res : Tensor = prim::If(%16), scope: M:: # /data/users/angelayi/pytorch/test/export/test_converter.py:754:16
     block0():
       %w : int = prim::GetAttr[name="w"](%2), scope: M::
       %res.15 : Tensor = aten::add(%res.37, %w, %5), scope: M:: # <string>:5:9
       -> (%res.15)
     block1():
       %h : int = prim::GetAttr[name="h"](%2), scope: M::
       %res.21 : Tensor = aten::add(%res.37, %h, %5), scope: M:: # <string>:5:9
       -> (%res.21)
   return (%res)
```

Test Plan: CI

Reviewed By: SherlockNoMad

Differential Revision: D63399064
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D63399064

Summary:
Pull Request resolved: pytorch#136648

We previously incorrectly handled the following graph, specifically for the node `w.3` in `block0`:
```
 graph(%x.1 : Float(3, strides=[1], requires_grad=0, device=cpu),
       %y.1 : int):
   %2 : __torch__.___torch_mangle_1.M = prim::CreateObject()
   %3 : int = prim::Constant[value=20](), scope: M:: # /data/users/angelayi/pytorch/test/export/test_converter.py:747:34
   %4 : int = prim::Constant[value=10](), scope: M:: # /data/users/angelayi/pytorch/test/export/test_converter.py:746:34
   %5 : int = prim::Constant[value=1](), scope: M::
   %w.1 : int = prim::GetAttr[name="w"](%2), scope: M::
   %7 : int = aten::mul(%w.1, %4), scope: M:: # /data/users/angelayi/pytorch/test/export/test_converter.py:746:25
    = prim::SetAttr[name="w"](%2, %7), scope: M::
   %h.1 : int = prim::GetAttr[name="h"](%2), scope: M::
   %9 : int = aten::mul(%h.1, %3), scope: M:: # /data/users/angelayi/pytorch/test/export/test_converter.py:747:25
    = prim::SetAttr[name="h"](%2, %9), scope: M::
   %10 : bool = aten::gt(%y.1, %4), scope: M:: # /data/users/angelayi/pytorch/test/export/test_converter.py:749:19
   %res.37 : Tensor = prim::If(%10), scope: M:: # /data/users/angelayi/pytorch/test/export/test_converter.py:749:16
     block0():
       %w.3 : int = prim::GetAttr[name="w"](%2), scope: M::
       %res.1 : Tensor = aten::add(%x.1, %w.3, %5), scope: M:: # <string>:5:9
       -> (%res.1)
     block1():
       %h.3 : int = prim::GetAttr[name="h"](%2), scope: M::
       %res.3 : Tensor = aten::add(%x.1, %h.3, %5), scope: M:: # <string>:5:9
       -> (%res.3)
   %16 : bool = aten::lt(%y.1, %4), scope: M:: # /data/users/angelayi/pytorch/test/export/test_converter.py:754:19
   %res : Tensor = prim::If(%16), scope: M:: # /data/users/angelayi/pytorch/test/export/test_converter.py:754:16
     block0():
       %w : int = prim::GetAttr[name="w"](%2), scope: M::
       %res.15 : Tensor = aten::add(%res.37, %w, %5), scope: M:: # <string>:5:9
       -> (%res.15)
     block1():
       %h : int = prim::GetAttr[name="h"](%2), scope: M::
       %res.21 : Tensor = aten::add(%res.37, %h, %5), scope: M:: # <string>:5:9
       -> (%res.21)
   return (%res)
```

Test Plan: CI

Reviewed By: SherlockNoMad

Differential Revision: D63399064
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D63399064

@angelayi
Copy link
Contributor Author

angelayi commented Oct 2, 2024

@pytorchbot merge -f "errors seem unrelated?"

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use -f as last resort and instead consider -i/--ignore-current to continue the merge ignoring current failures. This will allow currently pending tests to finish and report signal before the merge.

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

AnantGulati pushed a commit to AnantGulati/pytorch that referenced this pull request Oct 2, 2024
Summary:
We previously incorrectly handled the following graph, specifically for the node `w.3` in `block0`:
```
 graph(%x.1 : Float(3, strides=[1], requires_grad=0, device=cpu),
       %y.1 : int):
   %2 : __torch__.___torch_mangle_1.M = prim::CreateObject()
   %3 : int = prim::Constant[value=20](), scope: M:: # /data/users/angelayi/pytorch/test/export/test_converter.py:747:34
   %4 : int = prim::Constant[value=10](), scope: M:: # /data/users/angelayi/pytorch/test/export/test_converter.py:746:34
   %5 : int = prim::Constant[value=1](), scope: M::
   %w.1 : int = prim::GetAttr[name="w"](%2), scope: M::
   %7 : int = aten::mul(%w.1, %4), scope: M:: # /data/users/angelayi/pytorch/test/export/test_converter.py:746:25
    = prim::SetAttr[name="w"](%2, %7), scope: M::
   %h.1 : int = prim::GetAttr[name="h"](%2), scope: M::
   %9 : int = aten::mul(%h.1, %3), scope: M:: # /data/users/angelayi/pytorch/test/export/test_converter.py:747:25
    = prim::SetAttr[name="h"](%2, %9), scope: M::
   %10 : bool = aten::gt(%y.1, %4), scope: M:: # /data/users/angelayi/pytorch/test/export/test_converter.py:749:19
   %res.37 : Tensor = prim::If(%10), scope: M:: # /data/users/angelayi/pytorch/test/export/test_converter.py:749:16
     block0():
       %w.3 : int = prim::GetAttr[name="w"](%2), scope: M::
       %res.1 : Tensor = aten::add(%x.1, %w.3, %5), scope: M:: # <string>:5:9
       -> (%res.1)
     block1():
       %h.3 : int = prim::GetAttr[name="h"](%2), scope: M::
       %res.3 : Tensor = aten::add(%x.1, %h.3, %5), scope: M:: # <string>:5:9
       -> (%res.3)
   %16 : bool = aten::lt(%y.1, %4), scope: M:: # /data/users/angelayi/pytorch/test/export/test_converter.py:754:19
   %res : Tensor = prim::If(%16), scope: M:: # /data/users/angelayi/pytorch/test/export/test_converter.py:754:16
     block0():
       %w : int = prim::GetAttr[name="w"](%2), scope: M::
       %res.15 : Tensor = aten::add(%res.37, %w, %5), scope: M:: # <string>:5:9
       -> (%res.15)
     block1():
       %h : int = prim::GetAttr[name="h"](%2), scope: M::
       %res.21 : Tensor = aten::add(%res.37, %h, %5), scope: M:: # <string>:5:9
       -> (%res.21)
   return (%res)
```

Test Plan: CI

Differential Revision: D63399064

Pull Request resolved: pytorch#136648
Approved by: https://github.com/SherlockNoMad
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants