@@ -568,8 +568,8 @@ def _sdpa_nested_preprocessing(query, key, value):
568568
569569 output_nt_info = {
570570 "offsets" : q_t .offsets (),
571- "_max_seqlen " : q_t ._get_max_seqlen (),
572- "_min_seqlen " : q_t ._get_min_seqlen (),
571+ "max_seqlen " : q_t ._get_max_seqlen (),
572+ "min_seqlen " : q_t ._get_min_seqlen (),
573573 }
574574
575575 return (
@@ -710,7 +710,12 @@ def jagged_scaled_dot_product_attention(
710710 is_causal = is_causal ,
711711 scale = scale ,
712712 )
713- return nested_view_from_values_offsets (output , query .offsets ())
713+ return nested_view_from_values_offsets (
714+ output ,
715+ query .offsets (),
716+ min_seqlen = query ._maybe_min_seqlen , # type: ignore[attr-defined]
717+ max_seqlen = query ._maybe_max_seqlen , # type: ignore[attr-defined]
718+ )
714719
715720 compute_logsumexp = query .requires_grad or key .requires_grad or value .requires_grad
716721
@@ -766,9 +771,7 @@ def jagged_scaled_dot_product_attention(
766771 # Reshape output to convert nnz to batch_size and seq_len
767772 attention = nested_view_from_values_offsets (
768773 attention , # output from flash_attn is [total_q, num_heads, head_size_og]
769- output_nt_info ["offsets" ],
770- min_seqlen = output_nt_info ["_min_seqlen" ],
771- max_seqlen = output_nt_info ["_max_seqlen" ],
774+ ** output_nt_info ,
772775 ).transpose (1 , 2 )
773776 return _post_process_flash_output (attention , og_size )
774777 elif backend_choice == SDPBackend .EFFICIENT_ATTENTION :
@@ -807,25 +810,18 @@ def jagged_scaled_dot_product_attention(
807810 # Reshape output to convert nnz to batch_size and seq_len
808811 return nested_view_from_values_offsets (
809812 attention .squeeze (0 ),
810- output_nt_info ["offsets" ],
811- min_seqlen = output_nt_info ["_min_seqlen" ],
812- max_seqlen = output_nt_info ["_max_seqlen" ],
813+ ** output_nt_info ,
813814 ).transpose (1 , 2 )
814815 elif backend_choice == SDPBackend .MATH :
815816 # save the offsets and shape of the inputs, so we can reshape the final output
816817 # query @ key = attn: [B, D1, j0, D'] @ [B, D1, D' j1] = [B, D1, j0, j1]
817818 # attn @ value = out: [B, D1, j0, j1] @ [B, D1, j1, D2] = [B, D1, j0, D2]
818819 offsets = query .offsets ()
820+ min_seqlen = query ._maybe_min_seqlen
821+ max_seqlen = query ._maybe_max_seqlen
819822 d1 = query ._size [1 ]
820823 d2 = value ._size [- 1 ]
821824
822- min_seqlen_tensor = query ._metadata_cache .get (
823- "min_seqlen" , None
824- ) # type: ignore[attr-defined]
825- max_seqlen_tensor = query ._metadata_cache .get (
826- "max_seqlen" , None
827- ) # type: ignore[attr-defined]
828-
829825 # convert jagged layout Nested Tensor to strided layout Nested Tensor
830826 # which support the math implementation of SDPA
831827 def get_strided_layout_nested_tensor (jagged_layout_nt ):
@@ -844,24 +840,14 @@ def get_strided_layout_nested_tensor(jagged_layout_nt):
844840 query , key , value , attn_mask , dropout_p , is_causal , scale = scale
845841 )[0 ]
846842
847- from torch .nested ._internal .nested_tensor import _load_val_from_tensor
848-
849843 # convert strided layout Nested Tensor back to jagged layout Nested Tensor
850844 attn_out = attn_out .transpose (1 , 2 ).contiguous ().values ()
851845 attn_out = attn_out .view (- 1 , d1 , d2 )
852846 attn_out = nested_view_from_values_offsets (
853847 attn_out ,
854848 offsets ,
855- min_seqlen = (
856- None
857- if min_seqlen_tensor is None
858- else _load_val_from_tensor (min_seqlen_tensor )
859- ),
860- max_seqlen = (
861- None
862- if max_seqlen_tensor is None
863- else _load_val_from_tensor (max_seqlen_tensor )
864- ),
849+ min_seqlen = min_seqlen ,
850+ max_seqlen = max_seqlen ,
865851 ).transpose (1 , 2 )
866852
867853 return attn_out
0 commit comments