Skip to content

Conversation

@anijain2305
Copy link
Contributor

@anijain2305 anijain2305 commented Oct 25, 2024

Stack from ghstack (oldest at bottom):

This brings some unsoundness in guards. Earlier we were skipping empty nn module hooks dict guard only on inbuilt nn modules, but as seen in #138386, there could be still be significant guard overhead. With this PR, we reduce the guard eval latency from 420 us to 280 us (1.5x reduction).

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng @chauhang @amjames @rec

@pytorch-bot
Copy link

pytorch-bot bot commented Oct 25, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/138942

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 New Failure, 1 Unrelated Failure

As of commit 5ca3b52 with merge base f9ae3fa (image):

NEW FAILURE - The following job has failed:

BROKEN TRUNK - The following job failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

This brings some unsoundness in guards. Earlier we were skipping empty nn module hooks dict guard only on inbuilt nn modules, but as seen in #138386, there could be still be significant guard overhead. With this PR, we reduce the guard eval latency from 420 us to 280 us (1.5x reduction).


cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames rec

[ghstack-poisoned]
This brings some unsoundness in guards. Earlier we were skipping empty nn module hooks dict guard only on inbuilt nn modules, but as seen in #138386, there could be still be significant guard overhead. With this PR, we reduce the guard eval latency from 420 us to 280 us (1.5x reduction).


cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames rec

[ghstack-poisoned]
This brings some unsoundness in guards. Earlier we were skipping empty nn module hooks dict guard only on inbuilt nn modules, but as seen in #138386, there could be still be significant guard overhead. With this PR, we reduce the guard eval latency from 420 us to 280 us (1.5x reduction).


cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames rec

[ghstack-poisoned]
anijain2305 added a commit that referenced this pull request Oct 26, 2024
ghstack-source-id: 2ff244b
Pull Request resolved: #138942
@ezyang
Copy link
Contributor

ezyang commented Oct 28, 2024

Are you going to allow controlling the unsoundness via config?

@anijain2305
Copy link
Contributor Author

Are you going to allow controlling the unsoundness via config?

Yes, skip_nnmodule_hook_guards controls this.

@ezyang
Copy link
Contributor

ezyang commented Oct 28, 2024

I don't see any reference to it in this PR, is there some nontrivial interaction with earlier PRs in the stack?

@anijain2305
Copy link
Contributor Author

anijain2305 commented Oct 28, 2024

I don't see any reference to it in this PR, is there some nontrivial interaction with earlier PRs in the stack?

Its not obvious from this PR. The change in the PR causes empty nn module hooks for the user defined nn module to insert an EMPTY_NN_MODULE_HOOKS_DICT guard (see snippet), which relies on the skip_nnmodule_hook_guards flag to skip (see permalink).

image

def EMPTY_NN_MODULE_HOOKS_DICT(self, guard):
"""Special guard to skip guards on empty hooks. This is controlled by skip_nnmodule_hook_guards"""
if config.skip_nnmodule_hook_guards:
# This is unsafe if you add/remove a hook on nn module variable
return
self.SEQUENCE_LENGTH(guard)

For due diligence, let me also add a test.

anijain2305

This comment was marked as resolved.

This brings some unsoundness in guards. Earlier we were skipping empty nn module hooks dict guard only on inbuilt nn modules, but as seen in #138386, there could be still be significant guard overhead. With this PR, we reduce the guard eval latency from 420 us to 280 us (1.5x reduction).


cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames rec

[ghstack-poisoned]
This brings some unsoundness in guards. Earlier we were skipping empty nn module hooks dict guard only on inbuilt nn modules, but as seen in #138386, there could be still be significant guard overhead. With this PR, we reduce the guard eval latency from 420 us to 280 us (1.5x reduction).


cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames rec

[ghstack-poisoned]
This brings some unsoundness in guards. Earlier we were skipping empty nn module hooks dict guard only on inbuilt nn modules, but as seen in #138386, there could be still be significant guard overhead. With this PR, we reduce the guard eval latency from 420 us to 280 us (1.5x reduction).


cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames rec

[ghstack-poisoned]
@anijain2305
Copy link
Contributor Author

@pytorchbot merge -f "unrelated CI failure"

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use -f as last resort and instead consider -i/--ignore-current to continue the merge ignoring current failures. This will allow currently pending tests to finish and report signal before the merge.

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

rahulsingh-intel pushed a commit to rahulsingh-intel/pytorch that referenced this pull request Oct 29, 2024
This brings some unsoundness in guards. Earlier we were skipping empty nn module hooks dict guard only on inbuilt nn modules, but as seen in pytorch#138386, there could be still be significant guard overhead. With this PR, we reduce the guard eval latency from 420 us to 280 us (1.5x reduction).

Pull Request resolved: pytorch#138942
Approved by: https://github.com/ezyang, https://github.com/jansel
ghstack dependencies: pytorch#139040, pytorch#138954
@ezyang
Copy link
Contributor

ezyang commented Nov 1, 2024

This seems to result in a modest compile time improvement in HF inference for MobileBertForMaskedLM, probably because processing all the extra guards is also expensive.

image

Samples: https://fburl.com/scuba/torch_open_source_signpost/5snsgy3x

rahulsingh-intel pushed a commit to rahulsingh-intel/pytorch that referenced this pull request Nov 5, 2024
This brings some unsoundness in guards. Earlier we were skipping empty nn module hooks dict guard only on inbuilt nn modules, but as seen in pytorch#138386, there could be still be significant guard overhead. With this PR, we reduce the guard eval latency from 420 us to 280 us (1.5x reduction).

Pull Request resolved: pytorch#138942
Approved by: https://github.com/ezyang, https://github.com/jansel
ghstack dependencies: pytorch#139040, pytorch#138954
@github-actions github-actions bot deleted the gh/anijain2305/561/head branch December 1, 2024 02:20
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants