@@ -689,8 +689,11 @@ class MultiheadAttention(Module):
689689 \text{where} head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)
690690
691691 Args:
692- embed_dim: total dimension of the model
693- num_heads: parallel attention layers, or heads
692+ embed_dim: total dimension of the model.
693+ num_heads: parallel attention heads.
694+ add_bias_kv: add bias to the key and value sequences at dim=0.
695+ add_zero_attn: add a new batch of zeros to the key and
696+ value sequences at dim=1.
694697
695698 Examples::
696699
@@ -741,19 +744,37 @@ def _reset_parameters(self):
741744 @weak_script_method
742745 def forward (self , query , key , value , key_padding_mask = None , incremental_state = None ,
743746 need_weights = True , static_kv = False , attn_mask = None ):
744- """
745- Inputs of forward function
746- query: [target length, batch size, embed dim]
747- key: [sequence length, batch size, embed dim]
748- value: [sequence length, batch size, embed dim]
749- key_padding_mask: if True, mask padding based on batch size
750- incremental_state: if provided, previous time steps are cashed
751- need_weights: output attn_output_weights
752- static_kv: key and value are static
753-
754- Outputs of forward function
755- attn_output: [target length, batch size, embed dim]
756- attn_output_weights: [batch size, target length, sequence length]
747+ r"""
748+ Args:
749+ query, key, value: map a query and a set of key-value pairs to an output.
750+ See "Attention Is All You Need" for more details.
751+ key_padding_mask: if provided, specified padding elements in the key will
752+ be ignored by the attention.
753+ incremental_state: if provided, previous time steps are cached.
754+ need_weights: output attn_output_weights.
755+ static_kv: if true, key and value are static. The key and value in previous
756+ states will be used.
757+ attn_mask: mask that prevents attention to certain positions.
758+
759+ Shape:
760+ - Inputs:
761+
762+ - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is
763+ the embedding dimension.
764+ - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is
765+ the embedding dimension.
766+ - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is
767+ the embedding dimension.
768+ - key_padding_mask: :math:`(N, S)`, ByteTensor, where N is the batch size, S is the source sequence length.
769+ - incremental_state: a dictionary used for storing states.
770+ - attn_mask: :math:`(L, L)` where L is the target sequence length.
771+
772+ - Outputs:
773+
774+ - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
775+ E is the embedding dimension.
776+ - attn_output_weights: :math:`(N, L, S)` where N is the batch size,
777+ L is the target sequence length, S is the source sequence length.
757778 """
758779 qkv_same = query .data_ptr () == key .data_ptr () == value .data_ptr ()
759780 kv_same = key .data_ptr () == value .data_ptr ()
0 commit comments