Skip to content

Conversation

@zhangguanheng66
Copy link
Contributor

Fix issue #26698.

With different query/keys/value dimensions, nn.MultiheadAttention has DDP incompatibility issue because in that case in_proj_weight attribute is created but not used. Fix it and add a distributed unit test.

@pytorchbot pytorchbot added oncall: distributed Add this issue/PR to distributed oncall triage queue module: nn Related to torch.nn labels Sep 25, 2019
Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zhangguanheng66 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zhangguanheng66 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zhangguanheng66 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

Copy link
Contributor

@pietern pietern left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Marking this as changes needed, because the test in test_c10d should be removed.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zhangguanheng66 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zhangguanheng66 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

Copy link
Contributor

@cpuhrsch cpuhrsch left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this break backwards compatibility in any sense?

@zhangguanheng66
Copy link
Contributor Author

zhangguanheng66 commented Oct 4, 2019

Can this break backwards compatibility in any sense?

I think it should be fine for BC. In the new version, we only create a in_proj_weight parameter when _qkv_same_embed_dim is true.

In an old trained model, there is always in_proj_weight parameter but, under the new code, in_proj_weight will not be used if self. _qkv_same_embed_dim is False (because None is passed to F.multi_head_attention_forward function).

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zhangguanheng66 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

@zhangguanheng66 merged this pull request in eb93200.

thiagocrepaldi pushed a commit to thiagocrepaldi/pytorch that referenced this pull request Feb 4, 2020
)

Summary:
Fix issue pytorch#26698.

With different query/keys/value dimensions, `nn.MultiheadAttention` has DDP incompatibility issue because in that case `in_proj_weight` attribute is created but not used. Fix it and add a distributed unit test.
Pull Request resolved: pytorch#26826

Differential Revision: D17583807

Pulled By: zhangguanheng66

fbshipit-source-id: c393584c331ed4f57ebaf2d4015ef04589c973f6
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Merged module: nn Related to torch.nn oncall: distributed Add this issue/PR to distributed oncall triage queue

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants