-
Notifications
You must be signed in to change notification settings - Fork 26.3k
[inductor] Fix bugs in emulate_precision_casts #163520
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/163520
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 49b28d4 with merge base 51152ef ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
eellison
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice!
| if ( | ||
| not isinstance(func, torch._ops.OpOverload) | ||
| or torch.Tag.pointwise not in func.tags | ||
| ): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
👍
| if not output_low_precision: | ||
| for input_node in last_node.all_input_nodes: | ||
| val = input_node.meta.get("val") if hasattr(input_node, "meta") else None | ||
| if isinstance(val, torch.Tensor) and val.dtype in low_pr_fp: | ||
| output_low_precision = True | ||
| break |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thinking out loud:
For something like
x: bfloat16
y = x.to(float32)
This would set low_precision_pointwise_barrier on the x.to(float32) output. I guess this is okay because its actual dtype later in lowering will be float32, so we'll ignore it.
And the decomps themselves should be upcasting intermediaries to be fp32, so those will also get ignored, e.g. gelu here.
|
Starting merge as part of PR stack under #163482 |
This reverts commit a8cd437. See #163481 (comment) This PR might also cause issues with cudagraphs. Pull Request resolved: #163737 Approved by: https://github.com/ezyang ghstack dependencies: #163386, #163398, #163387, #163414, #163415, #163419, #163434, #163393, #163412, #163422, #163481, #163520, #163482
Fixes pytorch#163449 Pull Request resolved: pytorch#163520 Approved by: https://github.com/eellison ghstack dependencies: pytorch#163386, pytorch#163398, pytorch#163387, pytorch#163414, pytorch#163415, pytorch#163419, pytorch#163434, pytorch#163393, pytorch#163412, pytorch#163422, pytorch#163481
Fixes pytorch#163457 Pull Request resolved: pytorch#163482 Approved by: https://github.com/eellison ghstack dependencies: pytorch#163386, pytorch#163398, pytorch#163387, pytorch#163414, pytorch#163415, pytorch#163419, pytorch#163434, pytorch#163393, pytorch#163412, pytorch#163422, pytorch#163481, pytorch#163520
This reverts commit a8cd437. See pytorch#163481 (comment) This PR might also cause issues with cudagraphs. Pull Request resolved: pytorch#163737 Approved by: https://github.com/ezyang ghstack dependencies: pytorch#163386, pytorch#163398, pytorch#163387, pytorch#163414, pytorch#163415, pytorch#163419, pytorch#163434, pytorch#163393, pytorch#163412, pytorch#163422, pytorch#163481, pytorch#163520, pytorch#163482
This reverts commit a8cd437. See #163481 (comment) This PR might also cause issues with cudagraphs. Pull Request resolved: #163737 Approved by: https://github.com/ezyang ghstack dependencies: #163386, #163398, #163387, #163414, #163415, #163419, #163434, #163393, #163412, #163422, #163481, #163520, #163482
Fixes #163449 ghstack-source-id: e173d84 Pull-Request: pytorch/pytorch#163520
Stack from ghstack (oldest at bottom):
Fixes #163449
cc @ezyang @EikanWang @jgong5 @wenzhe-nrv @voznesenskym @penguinwu @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @jiayisunx @ipiszy @chenyang78 @kadeng @muchulee8 @amjames @chauhang @aakhundov @coconutruben