Skip to content

MultiheadAttention and DDP incompatability #26698

@roderickObrist

Description

@roderickObrist

Hi Guys,

First of all thanks for a great code base. I know I'm not following the bug guidelines but I thought you guys might appreciate a heads up.

nn.MultiheadAttention allows you to provide different keys/value dimensions. When this happens the flag _qkv_same_embed_dim becomes False.

       # MultiheadAttention constructor activation.py:690
        self.in_proj_weight = Parameter(torch.empty(3 * embed_dim, embed_dim))

        if self._qkv_same_embed_dim is False:
            self.q_proj_weight = Parameter(torch.Tensor(embed_dim, embed_dim))
            self.k_proj_weight = Parameter(torch.Tensor(embed_dim, self.kdim))
            self.v_proj_weight = Parameter(torch.Tensor(embed_dim, self.vdim))

As you can see self.in_proj_weight is initialised regardless.

In DDP all parameters expect gradients, and thus are expected to be used when calculating loss. This will throw a RuntimeError: Expected to have finished reduction in the prior iteration before starting a new one.....

If you would like to reproduce this bug you must put a MultiheadAttention with different query/key/value dimensions and wrap it in DDP.

To fix the bug:

       # MultiheadAttention constructor activation.py:690
        if self._qkv_same_embed_dim is False:
            self.q_proj_weight = Parameter(torch.Tensor(embed_dim, embed_dim))
            self.k_proj_weight = Parameter(torch.Tensor(embed_dim, self.kdim))
            self.v_proj_weight = Parameter(torch.Tensor(embed_dim, self.vdim))
            self.register_parameter('in_proj_weight', None)
        else:
            self.in_proj_weight = Parameter(torch.empty(3 * embed_dim, embed_dim))

Metadata

Metadata

Labels

module: nnRelated to torch.nntriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions