Skip to content

Add key_padding_mask argument to Transformer module #22374

@sebamenabar

Description

@sebamenabar

🚀 Feature

Add key_padding_mask as an argument to the Transformer/TransformerEncoder/TransformerDecoder forward methods.

Motivation

The current implementation of the Transformer only allows the use of the attn_mask parameter of the MultiheadAttention module. I think this can only be applied to a batch as a whole, not per input in the batch. I think it would be useful to allow the use of the key_padding_mask parameter for sequences with padded values.

Pitch

I want the Transformer not to pay attention to padding elements in a sequence.

Alternatives

Modify TransformerEncoderLayer

class TransformerEncoderLayer
    ...
    def forward(self, src, src_attn_mask=None, src_padding_mask=None):
        src2 = self.self_attn(src, src, src, attn_mask=src_attn_mask, key_padding_mask=src_padding_mask)[0]
    ...

Repeat for all relevant modules (Transformer*)

I know this can be achieved using custom Transformer/Encoder/Decoder modules, but my proposal may be a common issue which can be easily included.

Metadata

Metadata

Assignees

No one assigned

    Labels

    featureA request for a proper, new feature.module: nnRelated to torch.nntriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions