Skip to content

Conversation

@pytorch-bot
Copy link

pytorch-bot bot commented Mar 12, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/149072

Note: Links to docs will display an error until the docs builds have been completed.

❌ 4 New Failures, 1 Unrelated Failure

As of commit 78df81b with merge base dcc502f (image):

NEW FAILURES - The following jobs have failed:

BROKEN TRUNK - The following job failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames

[ghstack-poisoned]
@anijain2305 anijain2305 added the topic: not user facing topic category label Mar 12, 2025
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames

[ghstack-poisoned]
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames

[ghstack-poisoned]
@anijain2305 anijain2305 added the keep-going Don't stop on first failure, keep running tests until the end label Mar 14, 2025
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames

[ghstack-poisoned]
@pytorch-bot pytorch-bot bot added the oncall: distributed Add this issue/PR to distributed oncall triage queue label Mar 14, 2025
id=self.get_tensor_id(t),
storage=storage,
is_inference=t.is_inference(),
is_inference=False if DISABLE_INFERENCE_MODE else t.is_inference(),
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This line is the main change.

assert a == b, f"{a} != {b}"


DISABLE_INFERENCE_MODE = False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: can we make this a real config somewhere, just in case someone actually needs to flip it if we end up breaking inference code in a subtle way.

I guess the annoying part is that we don't have global config for fake tensor... Maybe in dynamo config?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a config. And the behavior is controlled in convert_frame.py

unimplemented_v2(
gb_type="Encountered torch.is_inference_mode_enabled during tracing",
context="",
explanation="torch.is_inference_mode_enabled() is not supported",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

on the explanation - our claim is basically that if you are using compile, we want people to use no_grad instead of inference_mode since it gives the same perf under compile and is more composable. Should we mention that in these explanations?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thats a better message. I can do that.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated the graph break hint messages.

cc H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames

[ghstack-poisoned]
@anijain2305 anijain2305 requested a review from bdhirsh March 14, 2025 18:10
@anijain2305 anijain2305 changed the title [compile] Switch off inference_mode while compiling [compile] Switch off inference_mode for fake prop while compiling Mar 14, 2025
Copy link
Contributor

@bdhirsh bdhirsh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🤞

…mpiling"

cc H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames

[ghstack-poisoned]
…mpiling"

cc H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames

[ghstack-poisoned]
…mpiling"

cc H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames

[ghstack-poisoned]
@anijain2305 anijain2305 marked this pull request as draft March 15, 2025 00:13
@anijain2305
Copy link
Contributor Author

Seems like with the last change, CI is broken everywhere. This requires more work :(

…mpiling"

cc H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames

[ghstack-poisoned]
Copy link
Contributor

@jansel jansel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Re-request review once tests are passing

cnts = torch._dynamo.testing.CompileCounter()
opt_fn = torch.compile(fn, backend=cnts, fullgraph=True)
opt_fn = torch.compile(fn, backend=cnts, fullgraph=False)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mentioned this offline, but there is some is_inference switching code in flash-attention: https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/layers/rotary.py#L419

Parts of this codebase are fullgraph compileable and parts aren't (I don't remember which) so we should check that we're not regressing anything here

@anijain2305
Copy link
Contributor Author

Closing in favor of #149321

jurgen-paul pushed a commit to jurgen-paul/pytorch.git.file that referenced this pull request Mar 19, 2025
jurgen-paul pushed a commit to jurgen-paul/pytorch.git.file that referenced this pull request Mar 19, 2025
@github-actions github-actions bot deleted the gh/anijain2305/699/head branch April 20, 2025 02:20
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/inductor keep-going Don't stop on first failure, keep running tests until the end module: dynamo oncall: distributed Add this issue/PR to distributed oncall triage queue topic: not user facing topic category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants