-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Closed
Labels
module: nnRelated to torch.nnRelated to torch.nntriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module
Description
Hi Guys,
First of all thanks for a great code base. I know I'm not following the bug guidelines but I thought you guys might appreciate a heads up.
nn.MultiheadAttention allows you to provide different keys/value dimensions. When this happens the flag _qkv_same_embed_dim becomes False.
# MultiheadAttention constructor activation.py:690
self.in_proj_weight = Parameter(torch.empty(3 * embed_dim, embed_dim))
if self._qkv_same_embed_dim is False:
self.q_proj_weight = Parameter(torch.Tensor(embed_dim, embed_dim))
self.k_proj_weight = Parameter(torch.Tensor(embed_dim, self.kdim))
self.v_proj_weight = Parameter(torch.Tensor(embed_dim, self.vdim))
As you can see self.in_proj_weight is initialised regardless.
In DDP all parameters expect gradients, and thus are expected to be used when calculating loss. This will throw a RuntimeError: Expected to have finished reduction in the prior iteration before starting a new one.....
If you would like to reproduce this bug you must put a MultiheadAttention with different query/key/value dimensions and wrap it in DDP.
To fix the bug:
# MultiheadAttention constructor activation.py:690
if self._qkv_same_embed_dim is False:
self.q_proj_weight = Parameter(torch.Tensor(embed_dim, embed_dim))
self.k_proj_weight = Parameter(torch.Tensor(embed_dim, self.kdim))
self.v_proj_weight = Parameter(torch.Tensor(embed_dim, self.vdim))
self.register_parameter('in_proj_weight', None)
else:
self.in_proj_weight = Parameter(torch.empty(3 * embed_dim, embed_dim))
Metadata
Metadata
Assignees
Labels
module: nnRelated to torch.nnRelated to torch.nntriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module