Skip to content

Conversation

@lucasgadams
Copy link
Contributor

@lucasgadams lucasgadams commented Jul 8, 2019

Motivation:
The forward method of MultiheadAttention has a kwarg a key_padding_mask. This mask is of shape (N,S) where N is batch and S is sequence length. This mask is applied prior to attention softmax where True values in the mask are set to float('-inf'). This allows you to mask position j from attention for all position i in input sequence. It's typically used to mask padded inputs. So for a sample in a batch we will be able to make sure no encoder outputs depend on padding inputs. Currently the Transformer, TransformerEncoder, and TransformerEncoderLayer do not have this kwarg, and only have options for a (S,S), (T,T), and (S,T) masks which are applied equally across the batch for source input, target output, and target-source memory respectively. These masks can't be used for padding and are instead used for things like subsequent masking in language modeling, by masking the attention of position i to position j.

This diff exposes the key_padding_mask to Transformer, TransformerEncoder, and TransformerEncoderLayer forward methods which is ultimately passed to MultiheadAttention forward.

Open question: should we also allow a key_padding_mask for the decoder layer? As padding is usually at the end of each sentence in a batch and sentences are usually decoding from left to right, usually people deal with padding on decoded outputs by just masking those outputs at the loss layer. There might be some scenarios where it's needed though I don't think it would be common. People can also still just subclass and override the layers. We could also pass the input key_padding_mask to the memory <> decoder attention layer. Not sure if that's necessary though because the output of position i from each attention encoder layer won't depend on any masked positions in the input (even if position i is a masked position itself) so there's not really any point in masking position i again.

Summary:
Adds the key_padding_mask kwarg to Transformer, TransformerEncoder, and TransformerEncoderLayer forward methods.
The standard TransformerEncoderLayer uses a MultiheadAttention layer as self_attn. MultiheadAttention forward method has a key_padding_mask kwarg that allows for masking of values such as padding per sequence in a batch, in contrast to the attn_mask kwarg which is usually of shape (S,S) and applied equally across the batch.

MultiheadAttention calls functional.multi_head_attention_forward, which has the same key_padding_mask kwarg of shape (N,S). Masked (True) values are set to float('-inf').

Differential Revision: D16112263

@pytorchbot pytorchbot added the module: nn Related to torch.nn label Jul 8, 2019
@zhangguanheng66
Copy link
Contributor

fix #22374 (comment)

@zhangguanheng66
Copy link
Contributor

@sebamenabar I think this PR is relevant to your feature request. Feel free to add review comments. Thanks.

@zhangguanheng66
Copy link
Contributor

@stephenroller @myleott @ngimel @DNGros @mttk a PR to add key_padding to transformer_encoder. Just wondering if we should do the same thing for transformer_decoder.

@zhangguanheng66
Copy link
Contributor

@lucasgadams please update the text with motivation for the PR.

Copy link
Contributor

@zhangguanheng66 zhangguanheng66 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After some discussions, I feel it makes some sense to have the key_padding in the decoder (at least for the self-attention layer). We prefer to maintain a generic API because we can't predict application on the user side.

@sebamenabar
Copy link

After some discussions, I feel it makes some sense to have the key_padding in the decoder (at least for the self-attention layer). We prefer to maintain a generic API because we can't predict application on the user side.

Agree, I think it also makes sense to add the key_padding_mask on the decoder's multihead attention that looks at the encoder output, since we usually have padding on the source and the target.

@lucasgadams
Copy link
Contributor Author

Ok I have added the key_padding_mask kwarg to the Decoder and DecoderLayer which will pass it in to the target - memory attention. 1 more question is if we want to automatically pass the src key_padding_mask to both the encoder and decoder in Transformer forward method. This means people will be forced to used it on both if they use it on either in Transformer. I think this is probably fine and so that is what I have done in the latest diff. But let me know if you think we don't want to automatically do that in Transformer. I will add some test cases for the decoder piece.

@zhangguanheng66
Copy link
Contributor

Ok I have added the key_padding_mask kwarg to the Decoder and DecoderLayer which will pass it in to the target - memory attention. 1 more question is if we want to automatically pass the src key_padding_mask to both the encoder and decoder in Transformer forward method. This means people will be forced to used it on both if they use it on either in Transformer. I think this is probably fine and so that is what I have done in the latest diff. But let me know if you think we don't want to automatically do that in Transformer. I will add some test cases for the decoder piece.

See my comments. I don't think we want to pass a single key_padding to both encoder and decoder. Thanks.

@lucasgadams
Copy link
Contributor Author

I added the requested changes. Transformer.forward now uses src_key_padding_mask, tgt_key_padding_mask, and memory_key_padding_mask. Decoder uses key_padding_mask and memory_key_padding_mask. Modified the relevant tests and added some new deterministic ones for the decoder masks.

@sebamenabar
Copy link

I added the requested changes. Transformer.forward now uses src_key_padding_mask, tgt_key_padding_mask, and memory_key_padding_mask. Decoder uses key_padding_mask and memory_key_padding_mask. Modified the relevant tests and added some new deterministic ones for the decoder masks.

I think every thing is covered now, arguments look very long, but in my opinion all options are standard use cases of the transformer.

@zhangguanheng66 zhangguanheng66 self-requested a review July 11, 2019 19:14
Copy link
Contributor

@zhangguanheng66 zhangguanheng66 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@lucasgadams has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@zhangguanheng66
Copy link
Contributor

zhangguanheng66 commented Jul 16, 2019

@pytorchbot retest this please

@pytorchbot
Copy link
Collaborator

Sorry, only maintainers are authorized to rebase other people's PRs. Feel free to try again on one of your PRs!

(To learn more about this bot, see Bot commands.)

@lucasgadams
Copy link
Contributor Author

@pytorchbot rebase this please

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@lucasgadams is landing this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

Summary:
Motivation:
The forward method of MultiheadAttention has a kwarg a key_padding_mask. This mask is of shape (N,S) where N is batch and S is sequence length. This mask is applied prior to attention softmax where True values in the mask are set to float('-inf'). This allows you to mask position j from attention for all position i in input sequence. It's typically used to mask padded inputs. So for a sample in a batch we will be able to make sure no encoder outputs depend on padding inputs. Currently the Transformer, TransformerEncoder, and TransformerEncoderLayer do not have this kwarg, and only have options for a (S,S), (T,T), and (S,T) masks which are applied equally across the batch for source input, target output, and target-source memory respectively. These masks can't be used for padding and are instead used for things like subsequent masking in language modeling, by masking the attention of position i to position j.

This diff exposes the key_padding_mask to Transformer, TransformerEncoder, and TransformerEncoderLayer forward methods which is ultimately passed to MultiheadAttention forward.

Open question: should we also allow a key_padding_mask for the decoder layer? As padding is usually at the end of each sentence in a batch and sentences are usually decoding from left to right, usually people deal with padding on decoded outputs by just masking those outputs at the loss layer. There might be some scenarios where it's needed though I don't think it would be common. People can also still just subclass and override the layers. We could also pass the input key_padding_mask to the memory <> decoder attention layer. Not sure if that's necessary though because the output of position i from each attention encoder layer won't depend on any masked positions in the input (even if position i is a masked position itself) so there's not really any point in masking position i again.
Adds the key_padding_mask kwarg to Transformer, TransformerEncoder, and TransformerEncoderLayer forward methods.
The standard TransformerEncoderLayer uses a MultiheadAttention layer as self_attn. MultiheadAttention forward method has a key_padding_mask kwarg that allows for masking of values such as padding per sequence in a batch, in contrast to the attn_mask kwarg which is usually of shape (S,S) and applied equally across the batch.

MultiheadAttention calls functional.multi_head_attention_forward, which has the same key_padding_mask kwarg of shape (N,S). Masked (True) values are set to float('-inf').
Pull Request resolved: pytorch#22588

Test Plan:
buck test mode/dev caffe2/test:nn -- 'test_transformerencoderlayer \(test_nn\.TestNN\)'
buck test mode/dev caffe2/test:nn -- 'test_Transformer_cell \(test_nn\.TestNN\)'
buck test mode/dev caffe2/test:nn -- 'test_transformer_args_check \(test_nn\.TestNN\)'

Differential Revision: D16112263

fbshipit-source-id: c56e9ff409f6666253cfc9b1d23656981e6729d1
@facebook-github-bot
Copy link
Contributor

@lucasgadams merged this pull request in c6fe864.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Merged module: nn Related to torch.nn

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants