Skip to content

Commit ecc5e05

Browse files
jbschlosserpytorchmergebot
authored andcommitted
Refactor NJT min / max seqlen handling for convenience (#138130)
There's an annoying pattern emerging for pulling out the NJT min / max seqlen ints if they exist without computing / caching if they don't. This PR introduces private convenience functions to simplify handling this and avoiding redundant checks. Pull Request resolved: #138130 Approved by: https://github.com/soulitzer
1 parent 66478d0 commit ecc5e05

File tree

3 files changed

+32
-47
lines changed

3 files changed

+32
-47
lines changed

torch/nested/_internal/nested_tensor.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,18 @@ def _max_seqlen(self):
214214
def _min_seqlen(self):
215215
return self._get_min_seqlen()
216216

217+
# Convenience accessors that return a min / max seqlen if one is present and do NOT
218+
# compute / cache them if they're not.
219+
@property
220+
def _maybe_max_seqlen(self) -> Optional[int]:
221+
mt = self._max_seqlen_tensor
222+
return None if mt is None else _load_val_from_tensor(mt)
223+
224+
@property
225+
def _maybe_min_seqlen(self) -> Optional[int]:
226+
mt = self._min_seqlen_tensor
227+
return None if mt is None else _load_val_from_tensor(mt)
228+
217229
def __repr__(self): # type: ignore[override]
218230
# We should implement this in torch/_tensor_str.py instead
219231
grad_fn_str = (

torch/nested/_internal/ops.py

Lines changed: 6 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -293,17 +293,11 @@ def jagged_binary_pointwise(func, *args, **kwargs):
293293
mismatch_error_msg.format(func.__name__, a.shape, b.shape)
294294
)
295295

296-
from .nested_tensor import _load_val_from_tensor, nested_from_padded
296+
from .nested_tensor import nested_from_padded
297297

298298
# handle broadcasting via padded dense -> jagged conversion
299-
min_seqlen = None
300-
if nt._min_seqlen_tensor is not None:
301-
min_seqlen = _load_val_from_tensor(nt._min_seqlen_tensor)
302-
303-
max_seqlen = None
304-
if nt._max_seqlen_tensor is not None:
305-
max_seqlen = _load_val_from_tensor(nt._max_seqlen_tensor)
306-
299+
min_seqlen = nt._maybe_min_seqlen
300+
max_seqlen = nt._maybe_max_seqlen
307301
padded_max_S = max_seqlen
308302
total_L = nt._values.shape[nt._ragged_idx - 1]
309303
if padded_max_S is None:
@@ -993,17 +987,10 @@ def _padded_impl(a, b):
993987
assert a.is_nested and not b.is_nested
994988
nt, t = a, b
995989

996-
from .nested_tensor import _load_val_from_tensor, nested_from_padded
997-
998-
# convert NT -> padded dense
999-
min_seqlen = None
1000-
if nt._min_seqlen_tensor is not None:
1001-
min_seqlen = _load_val_from_tensor(nt._min_seqlen_tensor)
1002-
1003-
max_seqlen = None
1004-
if nt._max_seqlen_tensor is not None:
1005-
max_seqlen = _load_val_from_tensor(nt._max_seqlen_tensor)
990+
from .nested_tensor import nested_from_padded
1006991

992+
min_seqlen = nt._maybe_min_seqlen
993+
max_seqlen = nt._maybe_max_seqlen
1007994
padded_max_S = max_seqlen
1008995
total_L = nt._values.shape[nt._ragged_idx - 1]
1009996
if padded_max_S is None:

torch/nested/_internal/sdpa.py

Lines changed: 14 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -568,8 +568,8 @@ def _sdpa_nested_preprocessing(query, key, value):
568568

569569
output_nt_info = {
570570
"offsets": q_t.offsets(),
571-
"_max_seqlen": q_t._get_max_seqlen(),
572-
"_min_seqlen": q_t._get_min_seqlen(),
571+
"max_seqlen": q_t._get_max_seqlen(),
572+
"min_seqlen": q_t._get_min_seqlen(),
573573
}
574574

575575
return (
@@ -710,7 +710,12 @@ def jagged_scaled_dot_product_attention(
710710
is_causal=is_causal,
711711
scale=scale,
712712
)
713-
return nested_view_from_values_offsets(output, query.offsets())
713+
return nested_view_from_values_offsets(
714+
output,
715+
query.offsets(),
716+
min_seqlen=query._maybe_min_seqlen, # type: ignore[attr-defined]
717+
max_seqlen=query._maybe_max_seqlen, # type: ignore[attr-defined]
718+
)
714719

715720
compute_logsumexp = query.requires_grad or key.requires_grad or value.requires_grad
716721

@@ -766,9 +771,7 @@ def jagged_scaled_dot_product_attention(
766771
# Reshape output to convert nnz to batch_size and seq_len
767772
attention = nested_view_from_values_offsets(
768773
attention, # output from flash_attn is [total_q, num_heads, head_size_og]
769-
output_nt_info["offsets"],
770-
min_seqlen=output_nt_info["_min_seqlen"],
771-
max_seqlen=output_nt_info["_max_seqlen"],
774+
**output_nt_info,
772775
).transpose(1, 2)
773776
return _post_process_flash_output(attention, og_size)
774777
elif backend_choice == SDPBackend.EFFICIENT_ATTENTION:
@@ -807,25 +810,18 @@ def jagged_scaled_dot_product_attention(
807810
# Reshape output to convert nnz to batch_size and seq_len
808811
return nested_view_from_values_offsets(
809812
attention.squeeze(0),
810-
output_nt_info["offsets"],
811-
min_seqlen=output_nt_info["_min_seqlen"],
812-
max_seqlen=output_nt_info["_max_seqlen"],
813+
**output_nt_info,
813814
).transpose(1, 2)
814815
elif backend_choice == SDPBackend.MATH:
815816
# save the offsets and shape of the inputs, so we can reshape the final output
816817
# query @ key = attn: [B, D1, j0, D'] @ [B, D1, D' j1] = [B, D1, j0, j1]
817818
# attn @ value = out: [B, D1, j0, j1] @ [B, D1, j1, D2] = [B, D1, j0, D2]
818819
offsets = query.offsets()
820+
min_seqlen = query._maybe_min_seqlen
821+
max_seqlen = query._maybe_max_seqlen
819822
d1 = query._size[1]
820823
d2 = value._size[-1]
821824

822-
min_seqlen_tensor = query._metadata_cache.get(
823-
"min_seqlen", None
824-
) # type: ignore[attr-defined]
825-
max_seqlen_tensor = query._metadata_cache.get(
826-
"max_seqlen", None
827-
) # type: ignore[attr-defined]
828-
829825
# convert jagged layout Nested Tensor to strided layout Nested Tensor
830826
# which support the math implementation of SDPA
831827
def get_strided_layout_nested_tensor(jagged_layout_nt):
@@ -844,24 +840,14 @@ def get_strided_layout_nested_tensor(jagged_layout_nt):
844840
query, key, value, attn_mask, dropout_p, is_causal, scale=scale
845841
)[0]
846842

847-
from torch.nested._internal.nested_tensor import _load_val_from_tensor
848-
849843
# convert strided layout Nested Tensor back to jagged layout Nested Tensor
850844
attn_out = attn_out.transpose(1, 2).contiguous().values()
851845
attn_out = attn_out.view(-1, d1, d2)
852846
attn_out = nested_view_from_values_offsets(
853847
attn_out,
854848
offsets,
855-
min_seqlen=(
856-
None
857-
if min_seqlen_tensor is None
858-
else _load_val_from_tensor(min_seqlen_tensor)
859-
),
860-
max_seqlen=(
861-
None
862-
if max_seqlen_tensor is None
863-
else _load_val_from_tensor(max_seqlen_tensor)
864-
),
849+
min_seqlen=min_seqlen,
850+
max_seqlen=max_seqlen,
865851
).transpose(1, 2)
866852

867853
return attn_out

0 commit comments

Comments
 (0)