From ecc5e05854bc7b744bdab901bba960f4a72bb45f Mon Sep 17 00:00:00 2001 From: Joel Schlosser Date: Thu, 17 Oct 2024 10:48:36 -0400 Subject: [PATCH] Refactor NJT min / max seqlen handling for convenience (#138130) There's an annoying pattern emerging for pulling out the NJT min / max seqlen ints if they exist without computing / caching if they don't. This PR introduces private convenience functions to simplify handling this and avoiding redundant checks. Pull Request resolved: https://github.com/pytorch/pytorch/pull/138130 Approved by: https://github.com/soulitzer --- torch/nested/_internal/nested_tensor.py | 12 +++++++ torch/nested/_internal/ops.py | 25 ++++----------- torch/nested/_internal/sdpa.py | 42 +++++++++---------------- 3 files changed, 32 insertions(+), 47 deletions(-) diff --git a/torch/nested/_internal/nested_tensor.py b/torch/nested/_internal/nested_tensor.py index d766cbd7d8bd..42b96f7bcd27 100644 --- a/torch/nested/_internal/nested_tensor.py +++ b/torch/nested/_internal/nested_tensor.py @@ -214,6 +214,18 @@ class NestedTensor(torch.Tensor): def _min_seqlen(self): return self._get_min_seqlen() + # Convenience accessors that return a min / max seqlen if one is present and do NOT + # compute / cache them if they're not. + @property + def _maybe_max_seqlen(self) -> Optional[int]: + mt = self._max_seqlen_tensor + return None if mt is None else _load_val_from_tensor(mt) + + @property + def _maybe_min_seqlen(self) -> Optional[int]: + mt = self._min_seqlen_tensor + return None if mt is None else _load_val_from_tensor(mt) + def __repr__(self): # type: ignore[override] # We should implement this in torch/_tensor_str.py instead grad_fn_str = ( diff --git a/torch/nested/_internal/ops.py b/torch/nested/_internal/ops.py index 1cf3f5853487..b3657be718f6 100644 --- a/torch/nested/_internal/ops.py +++ b/torch/nested/_internal/ops.py @@ -293,17 +293,11 @@ def jagged_binary_pointwise(func, *args, **kwargs): mismatch_error_msg.format(func.__name__, a.shape, b.shape) ) - from .nested_tensor import _load_val_from_tensor, nested_from_padded + from .nested_tensor import nested_from_padded # handle broadcasting via padded dense -> jagged conversion - min_seqlen = None - if nt._min_seqlen_tensor is not None: - min_seqlen = _load_val_from_tensor(nt._min_seqlen_tensor) - - max_seqlen = None - if nt._max_seqlen_tensor is not None: - max_seqlen = _load_val_from_tensor(nt._max_seqlen_tensor) - + min_seqlen = nt._maybe_min_seqlen + max_seqlen = nt._maybe_max_seqlen padded_max_S = max_seqlen total_L = nt._values.shape[nt._ragged_idx - 1] if padded_max_S is None: @@ -993,17 +987,10 @@ def matmul_default(func, *args, **kwargs): assert a.is_nested and not b.is_nested nt, t = a, b - from .nested_tensor import _load_val_from_tensor, nested_from_padded - - # convert NT -> padded dense - min_seqlen = None - if nt._min_seqlen_tensor is not None: - min_seqlen = _load_val_from_tensor(nt._min_seqlen_tensor) - - max_seqlen = None - if nt._max_seqlen_tensor is not None: - max_seqlen = _load_val_from_tensor(nt._max_seqlen_tensor) + from .nested_tensor import nested_from_padded + min_seqlen = nt._maybe_min_seqlen + max_seqlen = nt._maybe_max_seqlen padded_max_S = max_seqlen total_L = nt._values.shape[nt._ragged_idx - 1] if padded_max_S is None: diff --git a/torch/nested/_internal/sdpa.py b/torch/nested/_internal/sdpa.py index 578904af9469..2bf7fcac70f8 100644 --- a/torch/nested/_internal/sdpa.py +++ b/torch/nested/_internal/sdpa.py @@ -568,8 +568,8 @@ def _sdpa_nested_preprocessing(query, key, value): output_nt_info = { "offsets": q_t.offsets(), - "_max_seqlen": q_t._get_max_seqlen(), - "_min_seqlen": q_t._get_min_seqlen(), + "max_seqlen": q_t._get_max_seqlen(), + "min_seqlen": q_t._get_min_seqlen(), } return ( @@ -710,7 +710,12 @@ def jagged_scaled_dot_product_attention( is_causal=is_causal, scale=scale, ) - return nested_view_from_values_offsets(output, query.offsets()) + return nested_view_from_values_offsets( + output, + query.offsets(), + min_seqlen=query._maybe_min_seqlen, # type: ignore[attr-defined] + max_seqlen=query._maybe_max_seqlen, # type: ignore[attr-defined] + ) compute_logsumexp = query.requires_grad or key.requires_grad or value.requires_grad @@ -766,9 +771,7 @@ def jagged_scaled_dot_product_attention( # Reshape output to convert nnz to batch_size and seq_len attention = nested_view_from_values_offsets( attention, # output from flash_attn is [total_q, num_heads, head_size_og] - output_nt_info["offsets"], - min_seqlen=output_nt_info["_min_seqlen"], - max_seqlen=output_nt_info["_max_seqlen"], + **output_nt_info, ).transpose(1, 2) return _post_process_flash_output(attention, og_size) elif backend_choice == SDPBackend.EFFICIENT_ATTENTION: @@ -807,25 +810,18 @@ def jagged_scaled_dot_product_attention( # Reshape output to convert nnz to batch_size and seq_len return nested_view_from_values_offsets( attention.squeeze(0), - output_nt_info["offsets"], - min_seqlen=output_nt_info["_min_seqlen"], - max_seqlen=output_nt_info["_max_seqlen"], + **output_nt_info, ).transpose(1, 2) elif backend_choice == SDPBackend.MATH: # save the offsets and shape of the inputs, so we can reshape the final output # query @ key = attn: [B, D1, j0, D'] @ [B, D1, D' j1] = [B, D1, j0, j1] # attn @ value = out: [B, D1, j0, j1] @ [B, D1, j1, D2] = [B, D1, j0, D2] offsets = query.offsets() + min_seqlen = query._maybe_min_seqlen + max_seqlen = query._maybe_max_seqlen d1 = query._size[1] d2 = value._size[-1] - min_seqlen_tensor = query._metadata_cache.get( - "min_seqlen", None - ) # type: ignore[attr-defined] - max_seqlen_tensor = query._metadata_cache.get( - "max_seqlen", None - ) # type: ignore[attr-defined] - # convert jagged layout Nested Tensor to strided layout Nested Tensor # which support the math implementation of SDPA def get_strided_layout_nested_tensor(jagged_layout_nt): @@ -844,24 +840,14 @@ def jagged_scaled_dot_product_attention( query, key, value, attn_mask, dropout_p, is_causal, scale=scale )[0] - from torch.nested._internal.nested_tensor import _load_val_from_tensor - # convert strided layout Nested Tensor back to jagged layout Nested Tensor attn_out = attn_out.transpose(1, 2).contiguous().values() attn_out = attn_out.view(-1, d1, d2) attn_out = nested_view_from_values_offsets( attn_out, offsets, - min_seqlen=( - None - if min_seqlen_tensor is None - else _load_val_from_tensor(min_seqlen_tensor) - ), - max_seqlen=( - None - if max_seqlen_tensor is None - else _load_val_from_tensor(max_seqlen_tensor) - ), + min_seqlen=min_seqlen, + max_seqlen=max_seqlen, ).transpose(1, 2) return attn_out