-
Notifications
You must be signed in to change notification settings - Fork 26.3k
[jit] lower batchmm to non-diff optimization #19987
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
lower batchmm to non-diff optimization gh-metadata: pytorch pytorch 19987 gh/wanchaol/1/head
lower batchmm to non-diff optimization gh-metadata: pytorch pytorch 19987 gh/wanchaol/1/head
torch/csrc/jit/graph_executor.cpp
Outdated
| for (const auto& pass : getCustomPasses()) { | ||
| pass(graph); | ||
| } | ||
| // decomposition pass, decompose certain ops that will be used in the following |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you move this comment to the diff above?
… optimization" lower batchmm to non-diff optimization gh-metadata: pytorch pytorch 19987 gh/wanchaol/1/head
lower batchmm to non-diff optimization gh-metadata: pytorch pytorch 19987 gh/wanchaol/1/head
|
This is not the meaning of non-differentiable optimization passes! The point was that after the differentiable optimizations the graph can be still run with autograd enabled, not necessarily be symbolically differentiated. Why did we change this? |
@apaszke hmmm ok, but I believe the graph after non-differentiation optimization passes can still be run with autograd enabled, the backward graph will go through The reason that we need this is because, in the stacked diffs above, we want to lower the addmm/linear decomposition in a latter pass after custom fusion and before fusion, rather than |
Stack from ghstack:
Summary:
batchmmis actually a non differentiation optimization pass, it do the graph transformation and replacingmms withprim::BatchMMSide/Reduce, and the registered prim ops will execute themms, this will go through autograd, not autodiff in fact, so we might can post pone the batchmm pass to right before the fusion pass, this will make the separation of differentiable optimization and non-differentiable optimization more clear, and also serve as the first step to make decomposeaddmmafter the custom fusion passTest Plan:
Test that the graph of LSTM graph haven't changed because of this change.
Differential Revision: D15190356