-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Description
🚀 The feature, motivation and pitch
Currently, providing an attention mask argument to jagged_scaled_dot_product_attention is not supported (see https://github.com/pytorch/pytorch/blob/main/torch/nested/_internal/sdpa.py#L67).
I don't know how technically feasibly this is, as I'm not familiar with the structure of the various SDPA backends, but theoretically I'm hoping we could allow passing a nested tensor attn_mask (strided layout only, not jagged, b/c by definition the nested tensor will have two ragged dimensions). Ideally this would be supported at training time, not just inference (e.g. with autograd support).
Alternatives
For use cases where attention masks are required but inputs are jagged, I believe this only leaves the option of converting everything to a dense padded tensor, which is not very efficient.
Additional context
I'm new to the NJT APIs, so if there is a better way to accomplish this that already exists, please lmk!
cc @cpuhrsch @jbschlosser @bhosmer @drisspg @soulitzer @davidberard98 @YuqingJ @erichan1 @mikaylagawarecki @crcrpar @mcarilli @janeyx99