-
Notifications
You must be signed in to change notification settings - Fork 26.3k
[JIT]optimize matmul memory usage for certain cases #23433
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ailzhang has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
torch/csrc/jit/symbolic_script.cpp
Outdated
| grad_other = AD_matmul_special_fold(grad_output, self, self_size) | ||
| grad_other = grad_other.squeeze(-1) | ||
| elif dim1 >= 3 and dim2 == 2: | ||
| grad_other = AD_matmul_special_fold(grad_output, self, self_size) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why can't we put those two cases into AD_matmul_size? That way the grad_self part might benefit from the optimization too. Also, can you please describe why would this decrease memory usage?
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ailzhang has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
wanchaol
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good, can you add AD checks to common_method_invocations to check if the nodes in the if-else branches is there?
|
FYI https://github.com/pytorch/pytorch/blob/master/test/common_methods_invocations.py#L477-L478 already covers tests for these if-else branches. We should be good as long as that passes. |
What I mean is that you can add node names to the |
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ailzhang has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
Fixes #21406