-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Import MultiheadAttention to PyTorch #18334
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
f3c2eb6 to
c5f7615
Compare
soumith
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looks like a WIP PR. Next time, prefix the title with [WIP] so that reviewers dont end up reviewing prematurely :)
|
Thanks @soumith . Will accommodate your review soon. |
c5f7615 to
4a604c7
Compare
4a604c7 to
35eb09a
Compare
35eb09a to
0eb643f
Compare
|
@zhangguanheng66 - just because it deserves explicit mention: since this is a straight-up port from fairseq, before making any major changes to the code, make sure you wrote a ton of tests so we don't get lost. |
0eb643f to
c3e60e1
Compare
c3e60e1 to
a438789
Compare
|
@cpuhrsch @soumith Thanks for the feedbacks. I updated the code accordingly. Functions that are not tested have been removed. A unit test is created in test.nn. An additional test was conducted (see D14577966 more details). Instead of using fairseq.MultiheadAttention, torch.nn.MultiheadAttention is applied and the corresponding unit test in pytorch_translate works fine. |
cpuhrsch
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm approving this under the assumption that the most recent comments will also be resolved.
Pinging @soumith in case something is still missing.
a438789 to
ea31a70
Compare
|
if you dont mind, before this lands I'd like to page some folks in the NLP community to be assured that they dont need any more features from this. |
soumith
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the documentation for forward is still missing, and forward takes a lot of options. Please flesh out.
|
@soumith Sure. Let me know if any new features are necessary. I will work on the documentation for forward function at the same time. @srush @kyunghyuncho @myleott @glample More unit tests on your end are welcome (I have two unit tests on my side). |
|
@mansimov @jasonleeinf care to take a quick look at this PR? |
myleott
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I didn't review the logic carefully, but let's test this in fairseq. Ideally we'll replace the multihead attention implementation in fairseq with this one.
30dab32 to
d5931d5
Compare
d5931d5 to
f7b9bab
Compare
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@zhangguanheng66 is landing this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
Summary: Import MultiheadAttention into the core pytorch framework. Users now can import MultiheadAttention directly from torch.nn. See "Attention Is All You Need" for more details related to MultiheadAttention function. Pull Request resolved: pytorch#18334 Differential Revision: D14577966 fbshipit-source-id: b18d945ea461c07948d2f33f5b497ca51591d0ce
f7b9bab to
c25c547
Compare
|
The doc string is there, but it seems missing the |
|
@zhangguanheng66 merged this pull request in 4b20fc8. |
|
Missing documentation for |
We had a PR to update the attn_mask. See here (#20071). |
Summary: Import MultiheadAttention into the core pytorch framework. Users now can import MultiheadAttention directly from torch.nn. See "Attention Is All You Need" for more details related to MultiheadAttention function. Pull Request resolved: pytorch#18334 Differential Revision: D14577966 Pulled By: zhangguanheng66 fbshipit-source-id: 756c0deff623f3780651d9f9a70ce84516c806d3
|
What is the intuition behind adding "MultiheadAttention" block under activation.py? @zhangguanheng66 |
Summary:
Import MultiheadAttention into the core pytorch framework.
Users now can import MultiheadAttention directly from torch.nn.
See "Attention Is All You Need" for more details related to MultiheadAttention function.
Differential Revision: D14577966