Skip to content

Conversation

@xmfan
Copy link
Member

@xmfan xmfan commented Jan 4, 2025

This error started popping up in HUD CA benchmarks:

 File "/data/users/xmfan/core/b/pytorch/torch/_dynamo/compiled_autograd.py", line 371, in dce
    self.fx_tracer.graph.eliminate_dead_code(is_impure)
  File "/data/users/xmfan/core/b/pytorch/torch/fx/graph.py", line 1862, in eliminate_dead_code
    self.lint()
  File "/data/users/xmfan/core/b/pytorch/torch/fx/graph.py", line 1753, in lint
    raise RuntimeError(f"Node redefined name {node.name}!")
RuntimeError: Node redefined name aot0_expand!

We added CA initial capture's renaming (#133148) to help debug issues with AOT backward, but it errors out when we have multiple instances of the same AOT backward. This likely only showed up now because of increased hierarchical graph reuse. I fix it by adding a postfix counter to the node name

Stack from ghstack (oldest at bottom):

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @chauhang @aakhundov

[ghstack-poisoned]
@pytorch-bot
Copy link

pytorch-bot bot commented Jan 4, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/144202

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit ad878c3 with merge base 0431d47 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

xmfan added a commit that referenced this pull request Jan 4, 2025
@xmfan xmfan added the topic: not user facing topic category label Jan 5, 2025
[ghstack-poisoned]
xmfan added a commit that referenced this pull request Jan 6, 2025
@xmfan xmfan changed the title fix rename_aot_dispatcher_nodes when AOT bwd graph is reused multiple times [ca] fix rename_aot_dispatcher_nodes when AOT bwd graph is reused multiple times Jan 6, 2025
@xmfan xmfan changed the title [ca] fix rename_aot_dispatcher_nodes when AOT bwd graph is reused multiple times [ca] dedup node names when AOT bwd graph is reused multiple times Jan 6, 2025
@xmfan xmfan marked this pull request as ready for review January 6, 2025 21:47
@xmfan xmfan requested review from bdhirsh and jansel January 6, 2025 21:49
raise StopIteration

ca_node.name = f"aot{aot_id}_{aot_node.name}"
ca_node.name = f"aot{aot_id}{aot_id_postfix}_{aot_node.name}"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tiny nit but it might be clearer to put the "_" in this fstring instead of prefixing it directly in aot_id_postfix above

@xmfan
Copy link
Member Author

xmfan commented Jan 7, 2025

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Jan 7, 2025
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@github-actions github-actions bot deleted the gh/xmfan/151/head branch February 9, 2025 02:10
aostrowski-hbn pushed a commit to HabanaAI/pytorch-fork that referenced this pull request May 21, 2025
Cherry-pick: pytorch#144202
Change-Id: I9f6eb8ff7b1ba601149186939d2666d0a23e1bb0
jedrzejmyrcha pushed a commit to HabanaAI/pytorch-fork that referenced this pull request Jul 29, 2025
Cherry-pick: pytorch#144202
Change-Id: I9f6eb8ff7b1ba601149186939d2666d0a23e1bb0
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants