Skip to content

Support attn_mask in jagged_scaled_dot_product_attention #138993

@schmidt-jake

Description

@schmidt-jake

🚀 The feature, motivation and pitch

Currently, providing an attention mask argument to jagged_scaled_dot_product_attention is not supported (see https://github.com/pytorch/pytorch/blob/main/torch/nested/_internal/sdpa.py#L67).

I don't know how technically feasibly this is, as I'm not familiar with the structure of the various SDPA backends, but theoretically I'm hoping we could allow passing a nested tensor attn_mask (strided layout only, not jagged, b/c by definition the nested tensor will have two ragged dimensions). Ideally this would be supported at training time, not just inference (e.g. with autograd support).

Alternatives

For use cases where attention masks are required but inputs are jagged, I believe this only leaves the option of converting everything to a dense padded tensor, which is not very efficient.

Additional context

I'm new to the NJT APIs, so if there is a better way to accomplish this that already exists, please lmk!

cc @cpuhrsch @jbschlosser @bhosmer @drisspg @soulitzer @davidberard98 @YuqingJ @erichan1 @mikaylagawarecki @crcrpar @mcarilli @janeyx99

Metadata

Metadata

Assignees

No one assigned

    Labels

    featureA request for a proper, new feature.module: nestedtensorNestedTensor tag see issue #25032module: sdpaAll things related to torch.nn.functional.scaled_dot_product_attentiiontriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions