mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 21:49:24 +08:00
Refactor NJT min / max seqlen handling for convenience (#138130)
There's an annoying pattern emerging for pulling out the NJT min / max seqlen ints if they exist without computing / caching if they don't. This PR introduces private convenience functions to simplify handling this and avoiding redundant checks. Pull Request resolved: https://github.com/pytorch/pytorch/pull/138130 Approved by: https://github.com/soulitzer
This commit is contained in:
committed by
PyTorch MergeBot
parent
66478d0cf7
commit
ecc5e05854
@ -214,6 +214,18 @@ class NestedTensor(torch.Tensor):
|
||||
def _min_seqlen(self):
|
||||
return self._get_min_seqlen()
|
||||
|
||||
# Convenience accessors that return a min / max seqlen if one is present and do NOT
|
||||
# compute / cache them if they're not.
|
||||
@property
|
||||
def _maybe_max_seqlen(self) -> Optional[int]:
|
||||
mt = self._max_seqlen_tensor
|
||||
return None if mt is None else _load_val_from_tensor(mt)
|
||||
|
||||
@property
|
||||
def _maybe_min_seqlen(self) -> Optional[int]:
|
||||
mt = self._min_seqlen_tensor
|
||||
return None if mt is None else _load_val_from_tensor(mt)
|
||||
|
||||
def __repr__(self): # type: ignore[override]
|
||||
# We should implement this in torch/_tensor_str.py instead
|
||||
grad_fn_str = (
|
||||
|
@ -293,17 +293,11 @@ def jagged_binary_pointwise(func, *args, **kwargs):
|
||||
mismatch_error_msg.format(func.__name__, a.shape, b.shape)
|
||||
)
|
||||
|
||||
from .nested_tensor import _load_val_from_tensor, nested_from_padded
|
||||
from .nested_tensor import nested_from_padded
|
||||
|
||||
# handle broadcasting via padded dense -> jagged conversion
|
||||
min_seqlen = None
|
||||
if nt._min_seqlen_tensor is not None:
|
||||
min_seqlen = _load_val_from_tensor(nt._min_seqlen_tensor)
|
||||
|
||||
max_seqlen = None
|
||||
if nt._max_seqlen_tensor is not None:
|
||||
max_seqlen = _load_val_from_tensor(nt._max_seqlen_tensor)
|
||||
|
||||
min_seqlen = nt._maybe_min_seqlen
|
||||
max_seqlen = nt._maybe_max_seqlen
|
||||
padded_max_S = max_seqlen
|
||||
total_L = nt._values.shape[nt._ragged_idx - 1]
|
||||
if padded_max_S is None:
|
||||
@ -993,17 +987,10 @@ def matmul_default(func, *args, **kwargs):
|
||||
assert a.is_nested and not b.is_nested
|
||||
nt, t = a, b
|
||||
|
||||
from .nested_tensor import _load_val_from_tensor, nested_from_padded
|
||||
|
||||
# convert NT -> padded dense
|
||||
min_seqlen = None
|
||||
if nt._min_seqlen_tensor is not None:
|
||||
min_seqlen = _load_val_from_tensor(nt._min_seqlen_tensor)
|
||||
|
||||
max_seqlen = None
|
||||
if nt._max_seqlen_tensor is not None:
|
||||
max_seqlen = _load_val_from_tensor(nt._max_seqlen_tensor)
|
||||
from .nested_tensor import nested_from_padded
|
||||
|
||||
min_seqlen = nt._maybe_min_seqlen
|
||||
max_seqlen = nt._maybe_max_seqlen
|
||||
padded_max_S = max_seqlen
|
||||
total_L = nt._values.shape[nt._ragged_idx - 1]
|
||||
if padded_max_S is None:
|
||||
|
@ -568,8 +568,8 @@ def _sdpa_nested_preprocessing(query, key, value):
|
||||
|
||||
output_nt_info = {
|
||||
"offsets": q_t.offsets(),
|
||||
"_max_seqlen": q_t._get_max_seqlen(),
|
||||
"_min_seqlen": q_t._get_min_seqlen(),
|
||||
"max_seqlen": q_t._get_max_seqlen(),
|
||||
"min_seqlen": q_t._get_min_seqlen(),
|
||||
}
|
||||
|
||||
return (
|
||||
@ -710,7 +710,12 @@ def jagged_scaled_dot_product_attention(
|
||||
is_causal=is_causal,
|
||||
scale=scale,
|
||||
)
|
||||
return nested_view_from_values_offsets(output, query.offsets())
|
||||
return nested_view_from_values_offsets(
|
||||
output,
|
||||
query.offsets(),
|
||||
min_seqlen=query._maybe_min_seqlen, # type: ignore[attr-defined]
|
||||
max_seqlen=query._maybe_max_seqlen, # type: ignore[attr-defined]
|
||||
)
|
||||
|
||||
compute_logsumexp = query.requires_grad or key.requires_grad or value.requires_grad
|
||||
|
||||
@ -766,9 +771,7 @@ def jagged_scaled_dot_product_attention(
|
||||
# Reshape output to convert nnz to batch_size and seq_len
|
||||
attention = nested_view_from_values_offsets(
|
||||
attention, # output from flash_attn is [total_q, num_heads, head_size_og]
|
||||
output_nt_info["offsets"],
|
||||
min_seqlen=output_nt_info["_min_seqlen"],
|
||||
max_seqlen=output_nt_info["_max_seqlen"],
|
||||
**output_nt_info,
|
||||
).transpose(1, 2)
|
||||
return _post_process_flash_output(attention, og_size)
|
||||
elif backend_choice == SDPBackend.EFFICIENT_ATTENTION:
|
||||
@ -807,25 +810,18 @@ def jagged_scaled_dot_product_attention(
|
||||
# Reshape output to convert nnz to batch_size and seq_len
|
||||
return nested_view_from_values_offsets(
|
||||
attention.squeeze(0),
|
||||
output_nt_info["offsets"],
|
||||
min_seqlen=output_nt_info["_min_seqlen"],
|
||||
max_seqlen=output_nt_info["_max_seqlen"],
|
||||
**output_nt_info,
|
||||
).transpose(1, 2)
|
||||
elif backend_choice == SDPBackend.MATH:
|
||||
# save the offsets and shape of the inputs, so we can reshape the final output
|
||||
# query @ key = attn: [B, D1, j0, D'] @ [B, D1, D' j1] = [B, D1, j0, j1]
|
||||
# attn @ value = out: [B, D1, j0, j1] @ [B, D1, j1, D2] = [B, D1, j0, D2]
|
||||
offsets = query.offsets()
|
||||
min_seqlen = query._maybe_min_seqlen
|
||||
max_seqlen = query._maybe_max_seqlen
|
||||
d1 = query._size[1]
|
||||
d2 = value._size[-1]
|
||||
|
||||
min_seqlen_tensor = query._metadata_cache.get(
|
||||
"min_seqlen", None
|
||||
) # type: ignore[attr-defined]
|
||||
max_seqlen_tensor = query._metadata_cache.get(
|
||||
"max_seqlen", None
|
||||
) # type: ignore[attr-defined]
|
||||
|
||||
# convert jagged layout Nested Tensor to strided layout Nested Tensor
|
||||
# which support the math implementation of SDPA
|
||||
def get_strided_layout_nested_tensor(jagged_layout_nt):
|
||||
@ -844,24 +840,14 @@ def jagged_scaled_dot_product_attention(
|
||||
query, key, value, attn_mask, dropout_p, is_causal, scale=scale
|
||||
)[0]
|
||||
|
||||
from torch.nested._internal.nested_tensor import _load_val_from_tensor
|
||||
|
||||
# convert strided layout Nested Tensor back to jagged layout Nested Tensor
|
||||
attn_out = attn_out.transpose(1, 2).contiguous().values()
|
||||
attn_out = attn_out.view(-1, d1, d2)
|
||||
attn_out = nested_view_from_values_offsets(
|
||||
attn_out,
|
||||
offsets,
|
||||
min_seqlen=(
|
||||
None
|
||||
if min_seqlen_tensor is None
|
||||
else _load_val_from_tensor(min_seqlen_tensor)
|
||||
),
|
||||
max_seqlen=(
|
||||
None
|
||||
if max_seqlen_tensor is None
|
||||
else _load_val_from_tensor(max_seqlen_tensor)
|
||||
),
|
||||
min_seqlen=min_seqlen,
|
||||
max_seqlen=max_seqlen,
|
||||
).transpose(1, 2)
|
||||
|
||||
return attn_out
|
||||
|
Reference in New Issue
Block a user