Skip to content

Commit fc00bfd

Browse files
Guanheng Zhangfacebook-github-bot
authored andcommitted
Update MultiheadAttention documentations (#20071)
Summary: Add documentations to add_bias_kv, add_zero_attn, and attn_mask. Pull Request resolved: #20071 Differential Revision: D15213034 Pulled By: zhangguanheng66 fbshipit-source-id: c3db4b9e8527863420ba3ce6abf6098d3b0fb7a7
1 parent ecdeef3 commit fc00bfd

File tree

1 file changed

+36
-15
lines changed

1 file changed

+36
-15
lines changed

torch/nn/modules/activation.py

Lines changed: 36 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -689,8 +689,11 @@ class MultiheadAttention(Module):
689689
\text{where} head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)
690690
691691
Args:
692-
embed_dim: total dimension of the model
693-
num_heads: parallel attention layers, or heads
692+
embed_dim: total dimension of the model.
693+
num_heads: parallel attention heads.
694+
add_bias_kv: add bias to the key and value sequences at dim=0.
695+
add_zero_attn: add a new batch of zeros to the key and
696+
value sequences at dim=1.
694697
695698
Examples::
696699
@@ -741,19 +744,37 @@ def _reset_parameters(self):
741744
@weak_script_method
742745
def forward(self, query, key, value, key_padding_mask=None, incremental_state=None,
743746
need_weights=True, static_kv=False, attn_mask=None):
744-
"""
745-
Inputs of forward function
746-
query: [target length, batch size, embed dim]
747-
key: [sequence length, batch size, embed dim]
748-
value: [sequence length, batch size, embed dim]
749-
key_padding_mask: if True, mask padding based on batch size
750-
incremental_state: if provided, previous time steps are cashed
751-
need_weights: output attn_output_weights
752-
static_kv: key and value are static
753-
754-
Outputs of forward function
755-
attn_output: [target length, batch size, embed dim]
756-
attn_output_weights: [batch size, target length, sequence length]
747+
r"""
748+
Args:
749+
query, key, value: map a query and a set of key-value pairs to an output.
750+
See "Attention Is All You Need" for more details.
751+
key_padding_mask: if provided, specified padding elements in the key will
752+
be ignored by the attention.
753+
incremental_state: if provided, previous time steps are cached.
754+
need_weights: output attn_output_weights.
755+
static_kv: if true, key and value are static. The key and value in previous
756+
states will be used.
757+
attn_mask: mask that prevents attention to certain positions.
758+
759+
Shape:
760+
- Inputs:
761+
762+
- query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is
763+
the embedding dimension.
764+
- key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is
765+
the embedding dimension.
766+
- value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is
767+
the embedding dimension.
768+
- key_padding_mask: :math:`(N, S)`, ByteTensor, where N is the batch size, S is the source sequence length.
769+
- incremental_state: a dictionary used for storing states.
770+
- attn_mask: :math:`(L, L)` where L is the target sequence length.
771+
772+
- Outputs:
773+
774+
- attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
775+
E is the embedding dimension.
776+
- attn_output_weights: :math:`(N, L, S)` where N is the batch size,
777+
L is the target sequence length, S is the source sequence length.
757778
"""
758779
qkv_same = query.data_ptr() == key.data_ptr() == value.data_ptr()
759780
kv_same = key.data_ptr() == value.data_ptr()

0 commit comments

Comments
 (0)