Skip to content

Conversation

@zhangguanheng66
Copy link
Contributor

Add documentations to add_bias_kv, add_zero_attn, and attn_mask.

@pytorchbot pytorchbot added the module: nn Related to torch.nn label May 2, 2019
@zhangguanheng66 zhangguanheng66 requested a review from soumith May 2, 2019 17:26
@zhangguanheng66
Copy link
Contributor Author

part of this PR should address your comments. Thanks. @tshrjn

@zhangguanheng66
Copy link
Contributor Author

Fixes #20023

@rmcavoy
Copy link

rmcavoy commented May 3, 2019

This doesn't actually fully convert the docstrings to the necessary form. Most of the PyTorch docstrings are standardized to display the input sizes in parentheses using a particular naming convention( i.e. N, C etc). Names of sizes that are outside the standardized naming convention are usually discussed in the description of the input. The doc strings also sometimes include the type of the input included i.e. Tensor etc . This docstring also is very light on detail on what each piece means i.e. it should probably tell the user what a Key Query and Value are without the user needing to guess which math symbols correspond to which input and without rereading Attention is All You Need to divine naming choices.

key: [sequence length, batch size, embed dim].
value: [sequence length, batch size, embed dim].
key_padding_mask: if True, mask padding based on batch size.
incremental_state: if provided, previous time steps are cashed.
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
incremental_state: if provided, previous time steps are cashed.
incremental_state: if provided, previous time steps are cached.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also what does this actually accomplish and why is it important.

Copy link

@tshrjn tshrjn May 3, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think Incremental state is useful at decoding time, i.e. when we are generating output incrementally since we don't want to do recompute most of the operations from the last step.
Static_kv is also somehow related to that.

But @zhangguanheng66 or someone else can answer these questions in more detail.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @tshrjn. incremental_state is just a dictionary used for storing previous states. See facebookresearch/fairseq#166 for more details.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would suggest adding that information about what the incremental state does (and why it is useful) to the documentation.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The static_kv is default to false. Unless the users are familiar with this function, they will most likely not touch it. I want to keep it short as possible.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Then add such a warning to the function (or remove the functionality entirely). Short documentation does not equal concise and clear documentation.

incremental_state: if provided, previous time steps are cashed.
need_weights: output attn_output_weights.
static_kv: key and value are static.
attn_mask: mask to avoid attn to learn from certain positions.
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
attn_mask: mask to avoid attn to learn from certain positions.
attn_mask: mask that prevents attention from attending to certain positions.

Also the size and type of this object should be added in using the standard parenthesis based format.

query: [target length, batch size, embed dim].
key: [sequence length, batch size, embed dim].
value: [sequence length, batch size, embed dim].
key_padding_mask: if True, mask padding based on batch size.
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If false, what is the mask padding size based on?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

see the updates in the latest commit.

@zhangguanheng66
Copy link
Contributor Author

zhangguanheng66 commented May 3, 2019

This doesn't actually fully convert the docstrings to the necessary form. Most of the PyTorch docstrings are standardized to display the input sizes in parentheses using a particular naming convention( i.e. N, C etc). Names of sizes that are outside the standardized naming convention are usually discussed in the description of the input. The doc strings also sometimes include the type of the input included i.e. Tensor etc . This docstring also is very light on detail on what each piece means i.e. it should probably tell the user what a Key Query and Value are without the user needing to guess which math symbols correspond to which input and without rereading Attention is All You Need to divine naming choices.

I made some updates on the docs of the forward functions. Let me know if you have any other questions. Thanks for reviewing.

Args:
query, key, value: map a query and a set of key-value pairs to an output.
See "Attention Is All You Need" for more details.
key_padding_mask: if provided, padding elements can be excluded from by
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Excluded from what?

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I.e. put the context of what it is excluded from in the documenation

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Padding elements can be excluded from the key by passing a binary ByteTensor

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe a clearer meaning would be to say that the mask causes the attention to ignore padding elements.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

see the updates in the latest commit.

incremental_state: if provided, previous time steps are cached.
need_weights: output attn_output_weights.
static_kv: key and value are static.
attn_mask: mask that prevents attention from certain positions.
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
attn_mask: mask that prevents attention from certain positions.
attn_mask: mask that prevents attention to certain positions.

Shape:
- Inputs:
- query: :math:`(T, B, E)` where T is target length, B is batch size, E is
Copy link

@rmcavoy rmcavoy May 3, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
- query: :math:`(T, B, E)` where T is target length, B is batch size, E is
- query: :math:`(L, N, E)` where L is the sequence length of the query, N is the batch size, E is the

N is the standard notation in PyTorch documentation for the batch size see https://pytorch.org/docs/stable/nn.html#recurrent-layers. Also, I would use L for the sequence length of the query as the word 'target' is somewhat ambiguous in this context since it is not necessarily clear why you are referring to the query as a target.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For self-attn, query, key, and value are same. However, this is not necessarily true for multi-head attention. In a sequence to sequence transform problem, the query and key may have different length. For example, the query fed to the decoder of a transformer may have different length from the query to the encoder.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change should also be applied to the shapes listed below.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not suggesting that you force them to have the same name and integer value. I just think target and T are ambiguous choices for this name. Instead you can use L for the sequence length of the query and S for the sequence length of the value and key (similar to the choices made in the documentation nn.RNN)

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
meaning ignoring during the attention operation.
which means that it will be ignored during the attention operation.

Copy link

@rmcavoy rmcavoy May 3, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
key_padding_mask: if provided, padding elements can be excluded from the key,
mask_ key_padding (ByteTensor): if provided, specified padding elements in the key will be ignored by the attention.

and then delete the rest.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
from previous states.

@zhangguanheng66 zhangguanheng66 force-pushed the modify_multihead_attn branch from 348f3bb to 206aa0a Compare May 3, 2019 15:51
embed_dim: total dimension of the model
num_heads: parallel attention layers, or heads
embed_dim: total dimension of the model.
num_heads: parallel attention layers, or heads.
Copy link

@rmcavoy rmcavoy May 3, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
num_heads: parallel attention layers, or heads.
num_heads: Number of parallel attention layers (aka heads).

Shape:
- Inputs:
- query: :math:`(L, N, E)` where L is target length, N is batch size, E is
Copy link

@rmcavoy rmcavoy May 3, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
- query: :math:`(L, N, E)` where L is target length, N is batch size, E is
- query: :math:`(L, N, E)` where L is the query sequence length, N is the batch size, E is the

Target is still an ambiguous term.

- query: :math:`(L, N, E)` where L is target length, N is batch size, E is
embedding dimension.
- key: :math:`(S, N, E)`, ByteTensor, where S is sequence length, N is batch size, E is
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
- key: :math:`(S, N, E)`, ByteTensor, where S is sequence length, N is batch size, E is
- key: :math:`(S, N, E)`, ByteTensor, where S is the sequence length, N is the batch size, E is an

embedding dimension.
- key: :math:`(S, N, E)`, ByteTensor, where S is sequence length, N is batch size, E is
embedding dimension.
- value: :math:`(S, N, E)` where S is sequence length, N is batch size, E is
Copy link

@rmcavoy rmcavoy May 3, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
- value: :math:`(S, N, E)` where S is sequence length, N is batch size, E is
- value: :math:`(S, N, E)` where S is the key sequence length, N is the batch size, E is the

embedding dimension.
- value: :math:`(S, N, E)` where S is sequence length, N is batch size, E is
embedding dimension.
- key_padding_mask: :math:`(N, S)` where N is batch size, S is sequence length.
Copy link

@rmcavoy rmcavoy May 3, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
- key_padding_mask: :math:`(N, S)` where N is batch size, S is sequence length.
- key_padding_mask: :math:`(N, S)` where N is the batch size and S is the key sequence length.

embedding dimension.
- key_padding_mask: :math:`(N, S)` where N is batch size, S is sequence length.
- incremental_state: a dictionary used for storing states.
- attn_mask: :math:`(L, L)` where L is target length.
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
- attn_mask: :math:`(L, L)` where L is target length.
- attn_mask: :math:`(L, L)` where L is the query sequence length.

- Outputs:
- attn_output: :math:`(L, N, E)` where L is target length, N is batch size,
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
- attn_output: :math:`(L, N, E)` where L is target length, N is batch size,
- attn_output: :math:`(L, N, E)` where L is the query sequence length and N is the batch size,

- attn_output: :math:`(L, N, E)` where L is target length, N is batch size,
E is embedding dimension.
- attn_output_weights: :math:`(N, L, S)` where N is batch size,
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
- attn_output_weights: :math:`(N, L, S)` where N is batch size,
- attn_output_weights: :math:`(N, L, S)` where N is the batch size,

- attn_output: :math:`(L, N, E)` where L is target length, N is batch size,
E is embedding dimension.
- attn_output_weights: :math:`(N, L, S)` where N is batch size,
L is target length, S is sequence length.
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
L is target length, S is sequence length.
L is the query sequence length, and S is the key sequence length.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would suggest to keep the "target" sequence and "source" sequence since they are not treated as the same in attention layer. In the attention layer, the query, as a target, is mapped with the key and value sequence.

@rmcavoy
Copy link

rmcavoy commented May 3, 2019

@zhangguanheng66 If you actually referred to one as source sequence length and the other as target sequence length then it might make some sense to use the word target. On the other hand, changing the terminology to source and target is not logically consistent as you already have those arrays named key and query. Being consistent with naming schemes is important since if a quantity is named query, then it is much more clear to refer to its sequence length as "query sequence length" instead of the vague and ambiguous "target length" and similarly if there are two sequence lengths you should refer to the key-value sequence length as the "key-value sequence length" and not "sequence length" which could reference either one of the two sequences. From how you have written it, it is not actually unambiguously clear that "target length" even refers to the length of a sequence at all.

@rmcavoy
Copy link

rmcavoy commented May 3, 2019

Also, could add all the proper articles (like "a" and "the") to their proper place in the docstrings. I didn't spend 5 minutes correcting your grammar just to have you delete the corrections because you really like the word target.

@zhangguanheng66
Copy link
Contributor Author

Also, could add all the proper articles (like "a" and "the") to their proper place in the docstrings. I didn't spend 5 minutes correcting your grammar just to have you delete the corrections because you really like the word target.

I will update this in the next commit.

@zhangguanheng66
Copy link
Contributor Author

@rmcavoy . Thanks for reviewing the PR. If you have further comments, feel free to create a new PR and put the proposed work there. I think it is more efficient and I'm happy to take a look.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zhangguanheng66 is landing this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

@zhangguanheng66 merged this pull request in fc00bfd.

@zhangguanheng66 zhangguanheng66 deleted the modify_multihead_attn branch May 6, 2019 15:00
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Merged module: nn Related to torch.nn

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants