-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Add key_padding_mask kwarg to Transformer #22588
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
fix #22374 (comment) |
|
@sebamenabar I think this PR is relevant to your feature request. Feel free to add review comments. Thanks. |
|
@stephenroller @myleott @ngimel @DNGros @mttk a PR to add key_padding to transformer_encoder. Just wondering if we should do the same thing for transformer_decoder. |
|
@lucasgadams please update the text with motivation for the PR. |
c3c6380 to
74d55ca
Compare
74d55ca to
1f09027
Compare
1f09027 to
5bb9001
Compare
zhangguanheng66
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
After some discussions, I feel it makes some sense to have the key_padding in the decoder (at least for the self-attention layer). We prefer to maintain a generic API because we can't predict application on the user side.
Agree, I think it also makes sense to add the key_padding_mask on the decoder's multihead attention that looks at the encoder output, since we usually have padding on the source and the target. |
|
Ok I have added the key_padding_mask kwarg to the Decoder and DecoderLayer which will pass it in to the target - memory attention. 1 more question is if we want to automatically pass the src key_padding_mask to both the encoder and decoder in Transformer forward method. This means people will be forced to used it on both if they use it on either in Transformer. I think this is probably fine and so that is what I have done in the latest diff. But let me know if you think we don't want to automatically do that in Transformer. I will add some test cases for the decoder piece. |
5bb9001 to
a4b629e
Compare
See my comments. I don't think we want to pass a single key_padding to both encoder and decoder. Thanks. |
|
I added the requested changes. Transformer.forward now uses src_key_padding_mask, tgt_key_padding_mask, and memory_key_padding_mask. Decoder uses key_padding_mask and memory_key_padding_mask. Modified the relevant tests and added some new deterministic ones for the decoder masks. |
a4b629e to
ad0360a
Compare
I think every thing is covered now, arguments look very long, but in my opinion all options are standard use cases of the transformer. |
zhangguanheng66
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
.
ad0360a to
715aea1
Compare
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@lucasgadams has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
715aea1 to
238ee45
Compare
|
@pytorchbot retest this please |
|
Sorry, only maintainers are authorized to rebase other people's PRs. Feel free to try again on one of your PRs! (To learn more about this bot, see Bot commands.) |
|
@pytorchbot rebase this please |
8a39fbe to
958dfa3
Compare
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@lucasgadams is landing this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
Summary:
Motivation:
The forward method of MultiheadAttention has a kwarg a key_padding_mask. This mask is of shape (N,S) where N is batch and S is sequence length. This mask is applied prior to attention softmax where True values in the mask are set to float('-inf'). This allows you to mask position j from attention for all position i in input sequence. It's typically used to mask padded inputs. So for a sample in a batch we will be able to make sure no encoder outputs depend on padding inputs. Currently the Transformer, TransformerEncoder, and TransformerEncoderLayer do not have this kwarg, and only have options for a (S,S), (T,T), and (S,T) masks which are applied equally across the batch for source input, target output, and target-source memory respectively. These masks can't be used for padding and are instead used for things like subsequent masking in language modeling, by masking the attention of position i to position j.
This diff exposes the key_padding_mask to Transformer, TransformerEncoder, and TransformerEncoderLayer forward methods which is ultimately passed to MultiheadAttention forward.
Open question: should we also allow a key_padding_mask for the decoder layer? As padding is usually at the end of each sentence in a batch and sentences are usually decoding from left to right, usually people deal with padding on decoded outputs by just masking those outputs at the loss layer. There might be some scenarios where it's needed though I don't think it would be common. People can also still just subclass and override the layers. We could also pass the input key_padding_mask to the memory <> decoder attention layer. Not sure if that's necessary though because the output of position i from each attention encoder layer won't depend on any masked positions in the input (even if position i is a masked position itself) so there's not really any point in masking position i again.
Adds the key_padding_mask kwarg to Transformer, TransformerEncoder, and TransformerEncoderLayer forward methods.
The standard TransformerEncoderLayer uses a MultiheadAttention layer as self_attn. MultiheadAttention forward method has a key_padding_mask kwarg that allows for masking of values such as padding per sequence in a batch, in contrast to the attn_mask kwarg which is usually of shape (S,S) and applied equally across the batch.
MultiheadAttention calls functional.multi_head_attention_forward, which has the same key_padding_mask kwarg of shape (N,S). Masked (True) values are set to float('-inf').
Pull Request resolved: pytorch#22588
Test Plan:
buck test mode/dev caffe2/test:nn -- 'test_transformerencoderlayer \(test_nn\.TestNN\)'
buck test mode/dev caffe2/test:nn -- 'test_Transformer_cell \(test_nn\.TestNN\)'
buck test mode/dev caffe2/test:nn -- 'test_transformer_args_check \(test_nn\.TestNN\)'
Differential Revision: D16112263
fbshipit-source-id: c56e9ff409f6666253cfc9b1d23656981e6729d1
958dfa3 to
2cd03ec
Compare
|
@lucasgadams merged this pull request in c6fe864. |
Motivation:
The forward method of MultiheadAttention has a kwarg a key_padding_mask. This mask is of shape (N,S) where N is batch and S is sequence length. This mask is applied prior to attention softmax where True values in the mask are set to float('-inf'). This allows you to mask position j from attention for all position i in input sequence. It's typically used to mask padded inputs. So for a sample in a batch we will be able to make sure no encoder outputs depend on padding inputs. Currently the Transformer, TransformerEncoder, and TransformerEncoderLayer do not have this kwarg, and only have options for a (S,S), (T,T), and (S,T) masks which are applied equally across the batch for source input, target output, and target-source memory respectively. These masks can't be used for padding and are instead used for things like subsequent masking in language modeling, by masking the attention of position i to position j.
This diff exposes the key_padding_mask to Transformer, TransformerEncoder, and TransformerEncoderLayer forward methods which is ultimately passed to MultiheadAttention forward.
Open question: should we also allow a key_padding_mask for the decoder layer? As padding is usually at the end of each sentence in a batch and sentences are usually decoding from left to right, usually people deal with padding on decoded outputs by just masking those outputs at the loss layer. There might be some scenarios where it's needed though I don't think it would be common. People can also still just subclass and override the layers. We could also pass the input key_padding_mask to the memory <> decoder attention layer. Not sure if that's necessary though because the output of position i from each attention encoder layer won't depend on any masked positions in the input (even if position i is a masked position itself) so there's not really any point in masking position i again.
Summary:
Adds the key_padding_mask kwarg to Transformer, TransformerEncoder, and TransformerEncoderLayer forward methods.
The standard TransformerEncoderLayer uses a MultiheadAttention layer as self_attn. MultiheadAttention forward method has a key_padding_mask kwarg that allows for masking of values such as padding per sequence in a batch, in contrast to the attn_mask kwarg which is usually of shape (S,S) and applied equally across the batch.
MultiheadAttention calls functional.multi_head_attention_forward, which has the same key_padding_mask kwarg of shape (N,S). Masked (True) values are set to float('-inf').
Differential Revision: D16112263