Skip to content

NJT + Flex *cross* attention #140598

@schmidt-jake

Description

@schmidt-jake

🚀 The feature, motivation and pitch

I have a use case where I would like to perform (multi-head) cross-attention between NJT queries/keys using the flex attention API, where the queries have different sequence length structure than the keys/values. Currently, it looks like create_nested_block_mask and flex_attention only look at the query NJT (https://github.com/pytorch/pytorch/blob/main/torch/nn/attention/flex_attention.py#L1195) to determine the sequence length structure for query, key, and value. I would propose we modify/expand the signature of _nested_mod_func_adapter to allow passing a q_nt and kv_nt instead, and modifying the logic here to accommodate queries that have different sequence lengths different from the keys/values.

Alternatives

No response

Additional context

Potentially related issues:

cc @cpuhrsch @jbschlosser @bhosmer @drisspg @soulitzer @davidberard98 @YuqingJ @ezyang @chauhang @penguinwu @zou3519 @bdhirsh @yf225 @Chillee @yanboliang @BoyuanFeng @ydwu4

Metadata

Metadata

Assignees

Labels

featureA request for a proper, new feature.module: flex attentionmodule: nestedtensorNestedTensor tag see issue #25032module: pt2-dispatcherPT2 dispatcher-related issues (e.g., aotdispatch, functionalization, faketensor, custom-op,oncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions