Skip to content

NJT + Flex Attention #137711

@drisspg

Description

@drisspg

Issues

I'm working on this support now in #136792. Here's a list of the current issues:

  1. The notebook (internal only) demonstrates the usage of a (1, 1, sum(seqlen), sum(seqlen)) block mask for NJT. Is this inefficient? I'd expect (1, 1, max_seqlen, max_seqlen) as an analogue to what is done for dense, but it's a bit tricky to implement the NJT adapter for this.
    • From offline discussion: if _compile=True is used for create_block_mask(), the full (1, 1, sum(seqlen), sum(seqlen)) mask_tensor isn't materialized; this is good and recommended
    • Still some exploration to be done to see if this is the most efficient way to handle NJTs
  2. [FlexAttention] Adjust BlockMask if reusing the one created at larger seqlen #137255 adds some logic that assumes a constant seqlen. This will have to be hacked around some for NJT.
  3. The notebook example builds a seq_idx to map an index within sum(seqlen) -> the associated beginning offset. It doesn't account for the fact that Q_LEN / KV_LEN are rounded up to the nearest block size multiple, so out-of-bounds access occurs if sum(seqlen) is not a multiple of the block size.
  4. create_block_mask(..., _compile=True) throws torch._dynamo.exc.Unsupported: Unexpected type in sourceless builder torch.Tensor (investigating)
    • Fixed this by changing the NJT wrapper generator to close over seq_idx implicitly instead of explicitly.
  5. The way seq_idx is built assumes offsets[0] == 0. This may not be the case for some non-standard NJT views.

Metadata

Metadata

Assignees

No one assigned

    Labels

    module: flex attentionmodule: higher order operatorstorch.cond and similarmodule: nestedtensorNestedTensor tag see issue #25032module: pt2-dispatcherPT2 dispatcher-related issues (e.g., aotdispatch, functionalization, faketensor, custom-op,oncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions