Skip to content

Conversation

@zhangguanheng66
Copy link
Contributor

@zhangguanheng66 zhangguanheng66 commented Jun 3, 2019

The changes include:

  1. Allow key/value to have different number of features with query. It supports the case when key and value have different feature dimensions.
  2. Support three separate proj_weight, in addition to a single in_proj_weight. The proj_weight of key and value may have different dimension with that of query so three separate proj_weights are necessary. In case that key and value have same dimension as query, it is preferred to use a single large proj_weight for performance reason. However, it should be noted that using a single large weight or three separate weights is a size-dependent decision.
  3. Give an option to use static k and v in the multihead_attn operator (see saved_k and saved_v). Those static key/value tensors can now be re-used when training the model.
  4. Add more test cases to cover the arguments.

Note: current users should not be affected by the changes.

@zhangguanheng66 zhangguanheng66 requested a review from cpuhrsch June 3, 2019 16:20
@pytorchbot pytorchbot added oncall: jit Add this issue/PR to JIT oncall triage queue module: nn Related to torch.nn labels Jun 3, 2019

self.in_proj_weight = Parameter(torch.empty(3 * embed_dim, embed_dim))
self.q_proj_weight = Parameter(torch.Tensor(embed_dim, embed_dim))
self.k_proj_weight = Parameter(torch.Tensor(embed_dim, self.kdim))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

when you load older models / older models' state-dicts, the model will break / compute wrong result. what are you going to do about that?

you can do what batchnorm did. see https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/batchnorm.py#L16 and look at _version = 2, and https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/batchnorm.py#L90-L103

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for pointing this out.

add _load_from_state_dict() function to detect in_proj_weight. If in_proj_weight exists in state_dict, it will split in_proj_weight into three separate weights. I assume users know how to use state_dict() and load_state_dict() to resolve this type of version conflict.

Test on the word language model

  1. load the model based by the old module, which has in_proj_weight.
    model = torch.load("old.pt")
  2. generate state_dict.
    state_dict = model.state_dict()
  3. create a model based on the new module, which has three separate weights.
    new_model = new_module()
  4. map state on the new model.
    new_model.load_state_dict(state_dict)

@ezyang ezyang added facebook and removed facebook labels Jun 5, 2019
be ignored by the attention.
need_weights: output attn_output_weights.
attn_mask: mask that prevents attention to certain positions.
use_chunk_proj_weight: use in_proj_weight insteady of q_proj_weight, k_proj_weight, v_proj_weight.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: "instead"

Also, this means that passing the flags for q, k, v weights is now the default? That's breaking backwards compatibility. Is that worth it?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It preserves backwards compatiblity (use_chunk_proj_weight=True will use in_proj_weight, like before) but the name is confusing.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have added more docs for use_chunk_proj_weight

the embedding dimension.
- key_padding_mask: :math:`(N, S)`, ByteTensor, where N is the batch size, S is the source sequence length.
- attn_mask: :math:`(L, S)` where L is the target sequence length, S is the source sequence length.
- saved_k: :math:`(N*num_heads, S, E/num_heads)`, where S is the source sequence length,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is it called "saved"?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can use "static_k, static_v".

@zhangguanheng66
Copy link
Contributor Author

fix #21518

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zhangguanheng66 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zhangguanheng66 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

q_proj_weight=None, # type: Optional[Tensor]
k_proj_weight=None, # type: Optional[Tensor]
v_proj_weight=None, # type: Optional[Tensor]
static_k=None, # type: Optional[Tensor]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we had talked about maybe moving this into a separate function by splitting multi_head_attention_forward out into two, one part that does the projections and the other that consumes them. Users that want to do their own projections can then call the latter.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But it may not really help the layout of proj_weights issues. We still have to maintain the two layouts at the same time, right?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you (or user) pass separate q_proj_weight, k_proj_weight, v_proj_weight the layout is pretty flexible, and 2 layouts are not really necessary. But this implies that (worst case) 3 gemms will be called instead of 1, which may or may not hurt performance.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ngimel yeap. @cpuhrsch has this performance concern so we keep the single proj_weight option here. Only use the three gemms option when key/value have different embed dimension with query.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fwiw, we used 3 separate gemms in mlperf submission, because then .contiguous() calls here

pytorch/torch/nn/functional.py

Lines 3276 to 3280 in 6d1f0da

q = q.contiguous().view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1)
if k is not None:
k = k.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1)
if v is not None:
v = v.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1)
become no-ops, so it compensates for gemms potentially being less efficient. Overall which approach is preferable is size-dependent.
Alternatively, pointers and strides can be checked for projection weights and if they point to a contiguous matrix a single gemm can be called.

V_fc = np.concatenate((V_fc, np.repeat(bias_v, V_fc.shape[0], axis=0)), axis=1)
attn_mask = np.concatenate((attn_mask, np.ones([1, 1])), axis=1)
if attn_mask is not None:
attn_mask = np.concatenate((attn_mask, np.ones([1, 1])), axis=1)
Copy link
Collaborator

@ngimel ngimel Jun 17, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What are these np functions everywhere and how will they work with cuda tensors?
Edit: sorry, nevermind, did not notice this was specifically for a test.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

they are used to re-write the multiheadattention function and verify the results of nn.MulitheadAttention module.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@cpuhrsch has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zhangguanheng66 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

k, v = linear(key, _w, _b).chunk(2, dim=-1)

else:
_b = in_proj_bias
Copy link
Contributor

@cpuhrsch cpuhrsch Jul 2, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since this is getting quite complicated it could be good to add some comments.

@zhangguanheng66 zhangguanheng66 changed the title Update MultiheadAttention module to integrate fairseq version with torch.nn version Update MultiheadAttention module support key/value with different number of features and allow static key/value Jul 2, 2019
Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zhangguanheng66 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

@zhangguanheng66 merged this pull request in bb0f299.

xzhu1900 pushed a commit to xzhu1900/pytorch that referenced this pull request Jul 5, 2019
…ber of features and allow static key/value (pytorch#21288)

Summary:
The changes include:

1. Allow key/value to have different number of features with query. It supports the case when key and value have different feature dimensions.
2. Support three separate proj_weight, in addition to a single in_proj_weight. The proj_weight of key and value may have different dimension with that of query so three separate proj_weights are necessary. In case that key and value have same dimension as query, it is preferred to use a single large proj_weight for performance reason. However, it should be noted that using a single large weight or three separate weights is a size-dependent decision.
3. Give an option to use static k and v in the multihead_attn operator (see saved_k and saved_v). Those static key/value tensors can now be re-used when training the model.
4. Add more test cases to cover the arguments.

Note: current users should not be affected by the changes.
Pull Request resolved: pytorch#21288

Differential Revision: D15738808

Pulled By: zhangguanheng66

fbshipit-source-id: 288b995787ad55fba374184b3d15b5c6fe9abb5c
@zhangguanheng66 zhangguanheng66 deleted the k_v_diff_dim branch July 12, 2019 17:06
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Merged module: nn Related to torch.nn oncall: jit Add this issue/PR to JIT oncall triage queue

Projects

None yet

Development

Successfully merging this pull request may close these issues.

8 participants