-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Description
🚀 The feature, motivation and pitch
I have a use case where I would like to perform (multi-head) cross-attention between NJT queries/keys using the flex attention API, where the queries have different sequence length structure than the keys/values. Currently, it looks like create_nested_block_mask and flex_attention only look at the query NJT (https://github.com/pytorch/pytorch/blob/main/torch/nn/attention/flex_attention.py#L1195) to determine the sequence length structure for query, key, and value. I would propose we modify/expand the signature of _nested_mod_func_adapter to allow passing a q_nt and kv_nt instead, and modifying the logic here to accommodate queries that have different sequence lengths different from the keys/values.
Alternatives
No response
Additional context
Potentially related issues:
cc @cpuhrsch @jbschlosser @bhosmer @drisspg @soulitzer @davidberard98 @YuqingJ @ezyang @chauhang @penguinwu @zou3519 @bdhirsh @yf225 @Chillee @yanboliang @BoyuanFeng @ydwu4