-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Closed
Labels
module: flex attentionmodule: higher order operatorstorch.cond and similartorch.cond and similarmodule: nestedtensorNestedTensor tag see issue #25032NestedTensor tag see issue #25032module: pt2-dispatcherPT2 dispatcher-related issues (e.g., aotdispatch, functionalization, faketensor, custom-op,PT2 dispatcher-related issues (e.g., aotdispatch, functionalization, faketensor, custom-op,oncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module
Description
Issues
I'm working on this support now in #136792. Here's a list of the current issues:
- The notebook (internal only) demonstrates the usage of a
(1, 1, sum(seqlen), sum(seqlen))block mask for NJT. Is this inefficient? I'd expect(1, 1, max_seqlen, max_seqlen)as an analogue to what is done for dense, but it's a bit tricky to implement the NJT adapter for this.- From offline discussion: if
_compile=Trueis used forcreate_block_mask(), the full(1, 1, sum(seqlen), sum(seqlen))mask_tensor isn't materialized; this is good and recommended - Still some exploration to be done to see if this is the most efficient way to handle NJTs
- From offline discussion: if
- [FlexAttention] Adjust BlockMask if reusing the one created at larger seqlen #137255 adds some logic that assumes a constant seqlen. This will have to be hacked around some for NJT.
- The notebook example builds a
seq_idxto map an index withinsum(seqlen)-> the associated beginning offset. It doesn't account for the fact thatQ_LEN/KV_LENare rounded up to the nearest block size multiple, so out-of-bounds access occurs ifsum(seqlen)is not a multiple of the block size.- Turns out this is a real bug in
create_block_mask(); should do the mod trick to avoid this (FlexAttention: create_block_mask() passes out-of-range indices to mask_mod #137801)
- Turns out this is a real bug in
create_block_mask(..., _compile=True)throwstorch._dynamo.exc.Unsupported: Unexpected type in sourceless builder torch.Tensor(investigating)- Fixed this by changing the NJT wrapper generator to close over
seq_idximplicitly instead of explicitly.
- Fixed this by changing the NJT wrapper generator to close over
- The way
seq_idxis built assumesoffsets[0] == 0. This may not be the case for some non-standard NJT views.
Metadata
Metadata
Assignees
Labels
module: flex attentionmodule: higher order operatorstorch.cond and similartorch.cond and similarmodule: nestedtensorNestedTensor tag see issue #25032NestedTensor tag see issue #25032module: pt2-dispatcherPT2 dispatcher-related issues (e.g., aotdispatch, functionalization, faketensor, custom-op,PT2 dispatcher-related issues (e.g., aotdispatch, functionalization, faketensor, custom-op,oncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module