-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Update MultiheadAttention module support key/value with different number of features and allow static key/value #21288
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
…ule_multi_head_attn_cuda).
torch/nn/modules/activation.py
Outdated
|
|
||
| self.in_proj_weight = Parameter(torch.empty(3 * embed_dim, embed_dim)) | ||
| self.q_proj_weight = Parameter(torch.Tensor(embed_dim, embed_dim)) | ||
| self.k_proj_weight = Parameter(torch.Tensor(embed_dim, self.kdim)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
when you load older models / older models' state-dicts, the model will break / compute wrong result. what are you going to do about that?
you can do what batchnorm did. see https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/batchnorm.py#L16 and look at _version = 2, and https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/batchnorm.py#L90-L103
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for pointing this out.
add _load_from_state_dict() function to detect in_proj_weight. If in_proj_weight exists in state_dict, it will split in_proj_weight into three separate weights. I assume users know how to use state_dict() and load_state_dict() to resolve this type of version conflict.
Test on the word language model
- load the model based by the old module, which has in_proj_weight.
model = torch.load("old.pt") - generate state_dict.
state_dict = model.state_dict() - create a model based on the new module, which has three separate weights.
new_model = new_module() - map state on the new model.
new_model.load_state_dict(state_dict)
torch/nn/functional.py
Outdated
| be ignored by the attention. | ||
| need_weights: output attn_output_weights. | ||
| attn_mask: mask that prevents attention to certain positions. | ||
| use_chunk_proj_weight: use in_proj_weight insteady of q_proj_weight, k_proj_weight, v_proj_weight. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: "instead"
Also, this means that passing the flags for q, k, v weights is now the default? That's breaking backwards compatibility. Is that worth it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It preserves backwards compatiblity (use_chunk_proj_weight=True will use in_proj_weight, like before) but the name is confusing.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have added more docs for use_chunk_proj_weight
torch/nn/functional.py
Outdated
| the embedding dimension. | ||
| - key_padding_mask: :math:`(N, S)`, ByteTensor, where N is the batch size, S is the source sequence length. | ||
| - attn_mask: :math:`(L, S)` where L is the target sequence length, S is the source sequence length. | ||
| - saved_k: :math:`(N*num_heads, S, E/num_heads)`, where S is the source sequence length, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is it called "saved"?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we can use "static_k, static_v".
|
fix #21518 |
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@zhangguanheng66 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@zhangguanheng66 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
| q_proj_weight=None, # type: Optional[Tensor] | ||
| k_proj_weight=None, # type: Optional[Tensor] | ||
| v_proj_weight=None, # type: Optional[Tensor] | ||
| static_k=None, # type: Optional[Tensor] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we had talked about maybe moving this into a separate function by splitting multi_head_attention_forward out into two, one part that does the projections and the other that consumes them. Users that want to do their own projections can then call the latter.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But it may not really help the layout of proj_weights issues. We still have to maintain the two layouts at the same time, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you (or user) pass separate q_proj_weight, k_proj_weight, v_proj_weight the layout is pretty flexible, and 2 layouts are not really necessary. But this implies that (worst case) 3 gemms will be called instead of 1, which may or may not hurt performance.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fwiw, we used 3 separate gemms in mlperf submission, because then .contiguous() calls here
pytorch/torch/nn/functional.py
Lines 3276 to 3280 in 6d1f0da
| q = q.contiguous().view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1) | |
| if k is not None: | |
| k = k.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1) | |
| if v is not None: | |
| v = v.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1) |
Alternatively, pointers and strides can be checked for projection weights and if they point to a contiguous matrix a single gemm can be called.
| V_fc = np.concatenate((V_fc, np.repeat(bias_v, V_fc.shape[0], axis=0)), axis=1) | ||
| attn_mask = np.concatenate((attn_mask, np.ones([1, 1])), axis=1) | ||
| if attn_mask is not None: | ||
| attn_mask = np.concatenate((attn_mask, np.ones([1, 1])), axis=1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What are these np functions everywhere and how will they work with cuda tensors?
Edit: sorry, nevermind, did not notice this was specifically for a test.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
they are used to re-write the multiheadattention function and verify the results of nn.MulitheadAttention module.
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@cpuhrsch has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@zhangguanheng66 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
| k, v = linear(key, _w, _b).chunk(2, dim=-1) | ||
|
|
||
| else: | ||
| _b = in_proj_bias |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since this is getting quite complicated it could be good to add some comments.
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@zhangguanheng66 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
|
@zhangguanheng66 merged this pull request in bb0f299. |
…ber of features and allow static key/value (pytorch#21288) Summary: The changes include: 1. Allow key/value to have different number of features with query. It supports the case when key and value have different feature dimensions. 2. Support three separate proj_weight, in addition to a single in_proj_weight. The proj_weight of key and value may have different dimension with that of query so three separate proj_weights are necessary. In case that key and value have same dimension as query, it is preferred to use a single large proj_weight for performance reason. However, it should be noted that using a single large weight or three separate weights is a size-dependent decision. 3. Give an option to use static k and v in the multihead_attn operator (see saved_k and saved_v). Those static key/value tensors can now be re-used when training the model. 4. Add more test cases to cover the arguments. Note: current users should not be affected by the changes. Pull Request resolved: pytorch#21288 Differential Revision: D15738808 Pulled By: zhangguanheng66 fbshipit-source-id: 288b995787ad55fba374184b3d15b5c6fe9abb5c
The changes include:
Note: current users should not be affected by the changes.