Refactor NJT min / max seqlen handling for convenience (#138130)

There's an annoying pattern emerging for pulling out the NJT min / max seqlen ints if they exist without computing / caching if they don't. This PR introduces private convenience functions to simplify handling this and avoiding redundant checks.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/138130
Approved by: https://github.com/soulitzer
This commit is contained in:
Joel Schlosser
2024-10-17 10:48:36 -04:00
committed by PyTorch MergeBot
parent 66478d0cf7
commit ecc5e05854
3 changed files with 32 additions and 47 deletions

View File

@ -214,6 +214,18 @@ class NestedTensor(torch.Tensor):
def _min_seqlen(self):
return self._get_min_seqlen()
# Convenience accessors that return a min / max seqlen if one is present and do NOT
# compute / cache them if they're not.
@property
def _maybe_max_seqlen(self) -> Optional[int]:
mt = self._max_seqlen_tensor
return None if mt is None else _load_val_from_tensor(mt)
@property
def _maybe_min_seqlen(self) -> Optional[int]:
mt = self._min_seqlen_tensor
return None if mt is None else _load_val_from_tensor(mt)
def __repr__(self): # type: ignore[override]
# We should implement this in torch/_tensor_str.py instead
grad_fn_str = (

View File

@ -293,17 +293,11 @@ def jagged_binary_pointwise(func, *args, **kwargs):
mismatch_error_msg.format(func.__name__, a.shape, b.shape)
)
from .nested_tensor import _load_val_from_tensor, nested_from_padded
from .nested_tensor import nested_from_padded
# handle broadcasting via padded dense -> jagged conversion
min_seqlen = None
if nt._min_seqlen_tensor is not None:
min_seqlen = _load_val_from_tensor(nt._min_seqlen_tensor)
max_seqlen = None
if nt._max_seqlen_tensor is not None:
max_seqlen = _load_val_from_tensor(nt._max_seqlen_tensor)
min_seqlen = nt._maybe_min_seqlen
max_seqlen = nt._maybe_max_seqlen
padded_max_S = max_seqlen
total_L = nt._values.shape[nt._ragged_idx - 1]
if padded_max_S is None:
@ -993,17 +987,10 @@ def matmul_default(func, *args, **kwargs):
assert a.is_nested and not b.is_nested
nt, t = a, b
from .nested_tensor import _load_val_from_tensor, nested_from_padded
# convert NT -> padded dense
min_seqlen = None
if nt._min_seqlen_tensor is not None:
min_seqlen = _load_val_from_tensor(nt._min_seqlen_tensor)
max_seqlen = None
if nt._max_seqlen_tensor is not None:
max_seqlen = _load_val_from_tensor(nt._max_seqlen_tensor)
from .nested_tensor import nested_from_padded
min_seqlen = nt._maybe_min_seqlen
max_seqlen = nt._maybe_max_seqlen
padded_max_S = max_seqlen
total_L = nt._values.shape[nt._ragged_idx - 1]
if padded_max_S is None:

View File

@ -568,8 +568,8 @@ def _sdpa_nested_preprocessing(query, key, value):
output_nt_info = {
"offsets": q_t.offsets(),
"_max_seqlen": q_t._get_max_seqlen(),
"_min_seqlen": q_t._get_min_seqlen(),
"max_seqlen": q_t._get_max_seqlen(),
"min_seqlen": q_t._get_min_seqlen(),
}
return (
@ -710,7 +710,12 @@ def jagged_scaled_dot_product_attention(
is_causal=is_causal,
scale=scale,
)
return nested_view_from_values_offsets(output, query.offsets())
return nested_view_from_values_offsets(
output,
query.offsets(),
min_seqlen=query._maybe_min_seqlen, # type: ignore[attr-defined]
max_seqlen=query._maybe_max_seqlen, # type: ignore[attr-defined]
)
compute_logsumexp = query.requires_grad or key.requires_grad or value.requires_grad
@ -766,9 +771,7 @@ def jagged_scaled_dot_product_attention(
# Reshape output to convert nnz to batch_size and seq_len
attention = nested_view_from_values_offsets(
attention, # output from flash_attn is [total_q, num_heads, head_size_og]
output_nt_info["offsets"],
min_seqlen=output_nt_info["_min_seqlen"],
max_seqlen=output_nt_info["_max_seqlen"],
**output_nt_info,
).transpose(1, 2)
return _post_process_flash_output(attention, og_size)
elif backend_choice == SDPBackend.EFFICIENT_ATTENTION:
@ -807,25 +810,18 @@ def jagged_scaled_dot_product_attention(
# Reshape output to convert nnz to batch_size and seq_len
return nested_view_from_values_offsets(
attention.squeeze(0),
output_nt_info["offsets"],
min_seqlen=output_nt_info["_min_seqlen"],
max_seqlen=output_nt_info["_max_seqlen"],
**output_nt_info,
).transpose(1, 2)
elif backend_choice == SDPBackend.MATH:
# save the offsets and shape of the inputs, so we can reshape the final output
# query @ key = attn: [B, D1, j0, D'] @ [B, D1, D' j1] = [B, D1, j0, j1]
# attn @ value = out: [B, D1, j0, j1] @ [B, D1, j1, D2] = [B, D1, j0, D2]
offsets = query.offsets()
min_seqlen = query._maybe_min_seqlen
max_seqlen = query._maybe_max_seqlen
d1 = query._size[1]
d2 = value._size[-1]
min_seqlen_tensor = query._metadata_cache.get(
"min_seqlen", None
) # type: ignore[attr-defined]
max_seqlen_tensor = query._metadata_cache.get(
"max_seqlen", None
) # type: ignore[attr-defined]
# convert jagged layout Nested Tensor to strided layout Nested Tensor
# which support the math implementation of SDPA
def get_strided_layout_nested_tensor(jagged_layout_nt):
@ -844,24 +840,14 @@ def jagged_scaled_dot_product_attention(
query, key, value, attn_mask, dropout_p, is_causal, scale=scale
)[0]
from torch.nested._internal.nested_tensor import _load_val_from_tensor
# convert strided layout Nested Tensor back to jagged layout Nested Tensor
attn_out = attn_out.transpose(1, 2).contiguous().values()
attn_out = attn_out.view(-1, d1, d2)
attn_out = nested_view_from_values_offsets(
attn_out,
offsets,
min_seqlen=(
None
if min_seqlen_tensor is None
else _load_val_from_tensor(min_seqlen_tensor)
),
max_seqlen=(
None
if max_seqlen_tensor is None
else _load_val_from_tensor(max_seqlen_tensor)
),
min_seqlen=min_seqlen,
max_seqlen=max_seqlen,
).transpose(1, 2)
return attn_out