mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-27 09:04:53 +08:00
Add an SDPA dispatcher for nested tensors with jagged layouts (#114164)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/114164 Approved by: https://github.com/jbschlosser
This commit is contained in:
committed by
PyTorch MergeBot
parent
43e3242490
commit
aafa8233a4
@ -43,6 +43,9 @@ class NestedTensor(torch.Tensor):
|
||||
_stride: Tuple[int, ...]
|
||||
# Indicates that the nth dimension is ragged
|
||||
_ragged_idx: int
|
||||
# SDPA Metadata
|
||||
_max_seqlen: int
|
||||
_min_seqlen: int
|
||||
|
||||
@staticmethod
|
||||
def __new__(
|
||||
@ -84,12 +87,18 @@ class NestedTensor(torch.Tensor):
|
||||
# (create a new one if needed).
|
||||
ragged_source = offsets if lengths is None else lengths
|
||||
ragged_size = get_tensor_symint(ragged_source, coeff=1)
|
||||
self._ragged_idx = kwargs.get("_ragged_idx", 1)
|
||||
B = offsets.shape[0] - 1
|
||||
Ds = values.shape[1:]
|
||||
self._size = (B, ragged_size, *Ds)
|
||||
Ds = values.shape[: self._ragged_idx - 1] + values.shape[self._ragged_idx :]
|
||||
|
||||
nested_size = [B]
|
||||
nested_size.extend(Ds[: self._ragged_idx - 1])
|
||||
nested_size.append(ragged_size)
|
||||
nested_size.extend(Ds[self._ragged_idx - 1 :])
|
||||
self._size = tuple(nested_size)
|
||||
|
||||
stride = values.stride()
|
||||
self._strides = (ragged_size * stride[0], *stride)
|
||||
self._ragged_idx = 1
|
||||
self._strides = (ragged_size * stride[self._ragged_idx - 1], *stride)
|
||||
|
||||
if values.requires_grad:
|
||||
raise ValueError(
|
||||
@ -100,6 +109,27 @@ class NestedTensor(torch.Tensor):
|
||||
self._offsets = offsets
|
||||
self._lengths = lengths
|
||||
|
||||
# SDPA metadata
|
||||
def get_sdpa_extreme_seqlen(func, tensor):
|
||||
return int(func(tensor).item())
|
||||
|
||||
# Note: Not using kwargs.get to avoid execution of get_sdpa_extreme_seqlen
|
||||
# unless it is really needed
|
||||
self._max_seqlen = (
|
||||
kwargs["_max_seqlen"]
|
||||
if "_max_seqlen" in kwargs
|
||||
else get_sdpa_extreme_seqlen(
|
||||
torch.max, offsets.diff() if lengths is None else lengths
|
||||
)
|
||||
)
|
||||
self._min_seqlen = (
|
||||
kwargs["_min_seqlen"]
|
||||
if "_min_seqlen" in kwargs
|
||||
else get_sdpa_extreme_seqlen(
|
||||
torch.min, offsets.diff() if lengths is None else lengths
|
||||
)
|
||||
)
|
||||
|
||||
def values(self):
|
||||
return self._values
|
||||
|
||||
@ -135,6 +165,9 @@ class NestedTensor(torch.Tensor):
|
||||
ctx = {
|
||||
"requires_grad": self.requires_grad,
|
||||
"ragged_size": self._size[self._ragged_idx],
|
||||
"max_seqlen": self._max_seqlen,
|
||||
"min_seqlen": self._min_seqlen,
|
||||
"ragged_idx": self._ragged_idx,
|
||||
}
|
||||
inner_tensors = ["_values", "_offsets"]
|
||||
if self._lengths is not None:
|
||||
@ -187,6 +220,9 @@ class NestedTensor(torch.Tensor):
|
||||
offsets=offsets,
|
||||
lengths=lengths,
|
||||
requires_grad=meta["requires_grad"],
|
||||
_max_seqlen=meta["max_seqlen"],
|
||||
_min_seqlen=meta["min_seqlen"],
|
||||
_ragged_idx=meta["ragged_idx"],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@ -222,35 +258,55 @@ class ViewBufferFromNested(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x: NestedTensor): # type: ignore[override]
|
||||
ctx.save_for_backward(x.offsets())
|
||||
ctx.max_seqlen = x._max_seqlen
|
||||
ctx.min_seqlen = x._min_seqlen
|
||||
ctx._ragged_idx = x._ragged_idx
|
||||
return x.values()
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, gO: torch.Tensor): # type: ignore[override]
|
||||
(offsets,) = ctx.saved_tensors
|
||||
return NestedTensor(gO, offsets=offsets)
|
||||
return NestedTensor(
|
||||
gO,
|
||||
offsets=offsets,
|
||||
_max_seqlen=ctx.max_seqlen,
|
||||
_min_seqlen=ctx.min_seqlen,
|
||||
_ragged_idx=ctx._ragged_idx,
|
||||
)
|
||||
|
||||
|
||||
# Not actually a view!
|
||||
class ViewNestedFromBuffer(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, values: torch.Tensor, offsets: torch.Tensor): # type: ignore[override]
|
||||
return NestedTensor(values.detach(), offsets=offsets)
|
||||
def forward(ctx, values: torch.Tensor, offsets: torch.Tensor, max_seqlen: int, min_seqlen: int): # type: ignore[override]
|
||||
return NestedTensor(
|
||||
values.detach(),
|
||||
offsets=offsets,
|
||||
_max_seqlen=max_seqlen,
|
||||
_min_seqlen=min_seqlen,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, gO: NestedTensor): # type: ignore[override]
|
||||
return gO.values(), None, None
|
||||
return gO.values(), None, None, None
|
||||
|
||||
|
||||
# Not actually a view!
|
||||
# NOTE: @jbschlosser is working on making it a view
|
||||
class ViewNonContiguousNestedFromBuffer(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, values: torch.Tensor, offsets: torch.Tensor, lengths: torch.Tensor): # type: ignore[override]
|
||||
return NestedTensor(values.detach(), offsets=offsets, lengths=lengths)
|
||||
def forward(ctx, values: torch.Tensor, offsets: torch.Tensor, lengths: torch.Tensor, max_seqlen: int, min_seqlen: int): # type: ignore[override]
|
||||
return NestedTensor(
|
||||
values.detach(),
|
||||
offsets=offsets,
|
||||
lengths=lengths,
|
||||
_max_seqlen=max_seqlen,
|
||||
_min_seqlen=min_seqlen,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, gO: NestedTensor): # type: ignore[override]
|
||||
return gO.values(), None, None
|
||||
return gO.values(), None, None, None, None
|
||||
|
||||
|
||||
# Need to make it obvious that users should be passing in offsets
|
||||
@ -303,7 +359,10 @@ def jagged_from_list(
|
||||
]
|
||||
)
|
||||
|
||||
return ViewNestedFromBuffer.apply(values, offsets), offsets # type: ignore[call-overload]
|
||||
max_seqlen = max([t.shape[0] for t in tensors])
|
||||
min_seqlen = min([t.shape[0] for t in tensors])
|
||||
|
||||
return ViewNestedFromBuffer.apply(values, offsets, max_seqlen, min_seqlen), offsets # type: ignore[call-overload]
|
||||
|
||||
|
||||
def jagged_from_tensor_and_lengths(
|
||||
@ -354,16 +413,28 @@ def jagged_from_tensor_and_lengths(
|
||||
if offsets[0] + length_list[0] != orig_dim:
|
||||
is_contiguous = False
|
||||
|
||||
actual_max_seqlen = int(torch.max(lengths).item())
|
||||
min_seqlen = int(torch.min(lengths).item())
|
||||
|
||||
if is_contiguous:
|
||||
return (
|
||||
ViewNestedFromBuffer.apply(
|
||||
values[offsets[0] : offsets[-1]], offsets - offsets[0]
|
||||
values[offsets[0] : offsets[-1]],
|
||||
offsets - offsets[0],
|
||||
actual_max_seqlen,
|
||||
min_seqlen,
|
||||
),
|
||||
offsets,
|
||||
None,
|
||||
)
|
||||
|
||||
return ViewNonContiguousNestedFromBuffer.apply(values, offsets, length_list), offsets, length_list # type: ignore[call-overload]
|
||||
return (
|
||||
ViewNonContiguousNestedFromBuffer.apply(
|
||||
values, offsets, length_list, actual_max_seqlen, min_seqlen
|
||||
),
|
||||
offsets,
|
||||
length_list,
|
||||
) # type: ignore[call-overload]
|
||||
|
||||
|
||||
def buffer_from_jagged(jagged):
|
||||
|
||||
Reference in New Issue
Block a user