-
Notifications
You must be signed in to change notification settings - Fork 26.3k
[compile] Switch off inference_mode for fake prop while compiling #149072
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/149072
Note: Links to docs will display an error until the docs builds have been completed. ❌ 4 New Failures, 1 Unrelated FailureAs of commit 78df81b with merge base dcc502f ( NEW FAILURES - The following jobs have failed:
BROKEN TRUNK - The following job failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames [ghstack-poisoned]
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames [ghstack-poisoned]
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames [ghstack-poisoned]
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames [ghstack-poisoned]
torch/_subclasses/meta_utils.py
Outdated
| id=self.get_tensor_id(t), | ||
| storage=storage, | ||
| is_inference=t.is_inference(), | ||
| is_inference=False if DISABLE_INFERENCE_MODE else t.is_inference(), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This line is the main change.
torch/_subclasses/meta_utils.py
Outdated
| assert a == b, f"{a} != {b}" | ||
|
|
||
|
|
||
| DISABLE_INFERENCE_MODE = False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: can we make this a real config somewhere, just in case someone actually needs to flip it if we end up breaking inference code in a subtle way.
I guess the annoying part is that we don't have global config for fake tensor... Maybe in dynamo config?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added a config. And the behavior is controlled in convert_frame.py
| unimplemented_v2( | ||
| gb_type="Encountered torch.is_inference_mode_enabled during tracing", | ||
| context="", | ||
| explanation="torch.is_inference_mode_enabled() is not supported", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
on the explanation - our claim is basically that if you are using compile, we want people to use no_grad instead of inference_mode since it gives the same perf under compile and is more composable. Should we mention that in these explanations?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thats a better message. I can do that.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated the graph break hint messages.
cc H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames [ghstack-poisoned]
bdhirsh
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🤞
…mpiling" cc H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames [ghstack-poisoned]
…mpiling" cc H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames [ghstack-poisoned]
…mpiling" cc H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames [ghstack-poisoned]
|
Seems like with the last change, CI is broken everywhere. This requires more work :( |
…mpiling" cc H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames [ghstack-poisoned]
jansel
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Re-request review once tests are passing
| cnts = torch._dynamo.testing.CompileCounter() | ||
| opt_fn = torch.compile(fn, backend=cnts, fullgraph=True) | ||
| opt_fn = torch.compile(fn, backend=cnts, fullgraph=False) | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I mentioned this offline, but there is some is_inference switching code in flash-attention: https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/layers/rotary.py#L419
Parts of this codebase are fullgraph compileable and parts aren't (I don't remember which) so we should check that we're not regressing anything here
|
Closing in favor of #149321 |
ghstack-source-id: 4d63a7a Pull Request resolved: pytorch/pytorch#149072
ghstack-source-id: 934b23e Pull Request resolved: pytorch/pytorch#149072
Stack from ghstack (oldest at bottom):
cc @H-Huang @awgu @kwen2501 @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @c-p-i-o @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng @chauhang @amjames