-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Update MultiheadAttention documentations #20071
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Update MultiheadAttention documentations #20071
Conversation
|
part of this PR should address your comments. Thanks. @tshrjn |
|
Fixes #20023 |
|
This doesn't actually fully convert the docstrings to the necessary form. Most of the PyTorch docstrings are standardized to display the input sizes in parentheses using a particular naming convention( i.e. N, C etc). Names of sizes that are outside the standardized naming convention are usually discussed in the description of the input. The doc strings also sometimes include the type of the input included i.e. Tensor etc . This docstring also is very light on detail on what each piece means i.e. it should probably tell the user what a Key Query and Value are without the user needing to guess which math symbols correspond to which input and without rereading Attention is All You Need to divine naming choices. |
torch/nn/modules/activation.py
Outdated
| key: [sequence length, batch size, embed dim]. | ||
| value: [sequence length, batch size, embed dim]. | ||
| key_padding_mask: if True, mask padding based on batch size. | ||
| incremental_state: if provided, previous time steps are cashed. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| incremental_state: if provided, previous time steps are cashed. | |
| incremental_state: if provided, previous time steps are cached. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also what does this actually accomplish and why is it important.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think Incremental state is useful at decoding time, i.e. when we are generating output incrementally since we don't want to do recompute most of the operations from the last step.
Static_kv is also somehow related to that.
But @zhangguanheng66 or someone else can answer these questions in more detail.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @tshrjn. incremental_state is just a dictionary used for storing previous states. See facebookresearch/fairseq#166 for more details.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would suggest adding that information about what the incremental state does (and why it is useful) to the documentation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The static_kv is default to false. Unless the users are familiar with this function, they will most likely not touch it. I want to keep it short as possible.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Then add such a warning to the function (or remove the functionality entirely). Short documentation does not equal concise and clear documentation.
torch/nn/modules/activation.py
Outdated
| incremental_state: if provided, previous time steps are cashed. | ||
| need_weights: output attn_output_weights. | ||
| static_kv: key and value are static. | ||
| attn_mask: mask to avoid attn to learn from certain positions. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| attn_mask: mask to avoid attn to learn from certain positions. | |
| attn_mask: mask that prevents attention from attending to certain positions. |
Also the size and type of this object should be added in using the standard parenthesis based format.
torch/nn/modules/activation.py
Outdated
| query: [target length, batch size, embed dim]. | ||
| key: [sequence length, batch size, embed dim]. | ||
| value: [sequence length, batch size, embed dim]. | ||
| key_padding_mask: if True, mask padding based on batch size. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If false, what is the mask padding size based on?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
see the updates in the latest commit.
I made some updates on the docs of the forward functions. Let me know if you have any other questions. Thanks for reviewing. |
torch/nn/modules/activation.py
Outdated
| Args: | ||
| query, key, value: map a query and a set of key-value pairs to an output. | ||
| See "Attention Is All You Need" for more details. | ||
| key_padding_mask: if provided, padding elements can be excluded from by |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Excluded from what?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I.e. put the context of what it is excluded from in the documenation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Padding elements can be excluded from the key by passing a binary ByteTensor
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe a clearer meaning would be to say that the mask causes the attention to ignore padding elements.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
see the updates in the latest commit.
torch/nn/modules/activation.py
Outdated
| incremental_state: if provided, previous time steps are cached. | ||
| need_weights: output attn_output_weights. | ||
| static_kv: key and value are static. | ||
| attn_mask: mask that prevents attention from certain positions. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| attn_mask: mask that prevents attention from certain positions. | |
| attn_mask: mask that prevents attention to certain positions. |
torch/nn/modules/activation.py
Outdated
| Shape: | ||
| - Inputs: | ||
| - query: :math:`(T, B, E)` where T is target length, B is batch size, E is |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| - query: :math:`(T, B, E)` where T is target length, B is batch size, E is | |
| - query: :math:`(L, N, E)` where L is the sequence length of the query, N is the batch size, E is the |
N is the standard notation in PyTorch documentation for the batch size see https://pytorch.org/docs/stable/nn.html#recurrent-layers. Also, I would use L for the sequence length of the query as the word 'target' is somewhat ambiguous in this context since it is not necessarily clear why you are referring to the query as a target.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For self-attn, query, key, and value are same. However, this is not necessarily true for multi-head attention. In a sequence to sequence transform problem, the query and key may have different length. For example, the query fed to the decoder of a transformer may have different length from the query to the encoder.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This change should also be applied to the shapes listed below.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not suggesting that you force them to have the same name and integer value. I just think target and T are ambiguous choices for this name. Instead you can use L for the sequence length of the query and S for the sequence length of the value and key (similar to the choices made in the documentation nn.RNN)
torch/nn/modules/activation.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| meaning ignoring during the attention operation. | |
| which means that it will be ignored during the attention operation. |
torch/nn/modules/activation.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| key_padding_mask: if provided, padding elements can be excluded from the key, | |
| mask_ key_padding (ByteTensor): if provided, specified padding elements in the key will be ignored by the attention. |
and then delete the rest.
torch/nn/modules/activation.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| from previous states. | |
348f3bb to
206aa0a
Compare
torch/nn/modules/activation.py
Outdated
| embed_dim: total dimension of the model | ||
| num_heads: parallel attention layers, or heads | ||
| embed_dim: total dimension of the model. | ||
| num_heads: parallel attention layers, or heads. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| num_heads: parallel attention layers, or heads. | |
| num_heads: Number of parallel attention layers (aka heads). |
torch/nn/modules/activation.py
Outdated
| Shape: | ||
| - Inputs: | ||
| - query: :math:`(L, N, E)` where L is target length, N is batch size, E is |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| - query: :math:`(L, N, E)` where L is target length, N is batch size, E is | |
| - query: :math:`(L, N, E)` where L is the query sequence length, N is the batch size, E is the |
Target is still an ambiguous term.
torch/nn/modules/activation.py
Outdated
| - query: :math:`(L, N, E)` where L is target length, N is batch size, E is | ||
| embedding dimension. | ||
| - key: :math:`(S, N, E)`, ByteTensor, where S is sequence length, N is batch size, E is |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| - key: :math:`(S, N, E)`, ByteTensor, where S is sequence length, N is batch size, E is | |
| - key: :math:`(S, N, E)`, ByteTensor, where S is the sequence length, N is the batch size, E is an |
torch/nn/modules/activation.py
Outdated
| embedding dimension. | ||
| - key: :math:`(S, N, E)`, ByteTensor, where S is sequence length, N is batch size, E is | ||
| embedding dimension. | ||
| - value: :math:`(S, N, E)` where S is sequence length, N is batch size, E is |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| - value: :math:`(S, N, E)` where S is sequence length, N is batch size, E is | |
| - value: :math:`(S, N, E)` where S is the key sequence length, N is the batch size, E is the |
torch/nn/modules/activation.py
Outdated
| embedding dimension. | ||
| - value: :math:`(S, N, E)` where S is sequence length, N is batch size, E is | ||
| embedding dimension. | ||
| - key_padding_mask: :math:`(N, S)` where N is batch size, S is sequence length. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| - key_padding_mask: :math:`(N, S)` where N is batch size, S is sequence length. | |
| - key_padding_mask: :math:`(N, S)` where N is the batch size and S is the key sequence length. |
torch/nn/modules/activation.py
Outdated
| embedding dimension. | ||
| - key_padding_mask: :math:`(N, S)` where N is batch size, S is sequence length. | ||
| - incremental_state: a dictionary used for storing states. | ||
| - attn_mask: :math:`(L, L)` where L is target length. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| - attn_mask: :math:`(L, L)` where L is target length. | |
| - attn_mask: :math:`(L, L)` where L is the query sequence length. |
torch/nn/modules/activation.py
Outdated
| - Outputs: | ||
| - attn_output: :math:`(L, N, E)` where L is target length, N is batch size, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| - attn_output: :math:`(L, N, E)` where L is target length, N is batch size, | |
| - attn_output: :math:`(L, N, E)` where L is the query sequence length and N is the batch size, |
torch/nn/modules/activation.py
Outdated
| - attn_output: :math:`(L, N, E)` where L is target length, N is batch size, | ||
| E is embedding dimension. | ||
| - attn_output_weights: :math:`(N, L, S)` where N is batch size, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| - attn_output_weights: :math:`(N, L, S)` where N is batch size, | |
| - attn_output_weights: :math:`(N, L, S)` where N is the batch size, |
torch/nn/modules/activation.py
Outdated
| - attn_output: :math:`(L, N, E)` where L is target length, N is batch size, | ||
| E is embedding dimension. | ||
| - attn_output_weights: :math:`(N, L, S)` where N is batch size, | ||
| L is target length, S is sequence length. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| L is target length, S is sequence length. | |
| L is the query sequence length, and S is the key sequence length. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would suggest to keep the "target" sequence and "source" sequence since they are not treated as the same in attention layer. In the attention layer, the query, as a target, is mapped with the key and value sequence.
|
@zhangguanheng66 If you actually referred to one as source sequence length and the other as target sequence length then it might make some sense to use the word target. On the other hand, changing the terminology to source and target is not logically consistent as you already have those arrays named key and query. Being consistent with naming schemes is important since if a quantity is named query, then it is much more clear to refer to its sequence length as "query sequence length" instead of the vague and ambiguous "target length" and similarly if there are two sequence lengths you should refer to the key-value sequence length as the "key-value sequence length" and not "sequence length" which could reference either one of the two sequences. From how you have written it, it is not actually unambiguously clear that "target length" even refers to the length of a sequence at all. |
|
Also, could add all the proper articles (like "a" and "the") to their proper place in the docstrings. I didn't spend 5 minutes correcting your grammar just to have you delete the corrections because you really like the word target. |
I will update this in the next commit. |
|
@rmcavoy . Thanks for reviewing the PR. If you have further comments, feel free to create a new PR and put the proposed work there. I think it is more efficient and I'm happy to take a look. |
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@zhangguanheng66 is landing this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
|
@zhangguanheng66 merged this pull request in fc00bfd. |
Add documentations to add_bias_kv, add_zero_attn, and attn_mask.