Add an SDPA dispatcher for nested tensors with jagged layouts (#114164)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/114164
Approved by: https://github.com/jbschlosser
This commit is contained in:
Antoni Viros
2023-12-04 19:10:17 +00:00
committed by PyTorch MergeBot
parent 43e3242490
commit aafa8233a4
7 changed files with 1177 additions and 123 deletions

View File

@ -43,6 +43,9 @@ class NestedTensor(torch.Tensor):
_stride: Tuple[int, ...]
# Indicates that the nth dimension is ragged
_ragged_idx: int
# SDPA Metadata
_max_seqlen: int
_min_seqlen: int
@staticmethod
def __new__(
@ -84,12 +87,18 @@ class NestedTensor(torch.Tensor):
# (create a new one if needed).
ragged_source = offsets if lengths is None else lengths
ragged_size = get_tensor_symint(ragged_source, coeff=1)
self._ragged_idx = kwargs.get("_ragged_idx", 1)
B = offsets.shape[0] - 1
Ds = values.shape[1:]
self._size = (B, ragged_size, *Ds)
Ds = values.shape[: self._ragged_idx - 1] + values.shape[self._ragged_idx :]
nested_size = [B]
nested_size.extend(Ds[: self._ragged_idx - 1])
nested_size.append(ragged_size)
nested_size.extend(Ds[self._ragged_idx - 1 :])
self._size = tuple(nested_size)
stride = values.stride()
self._strides = (ragged_size * stride[0], *stride)
self._ragged_idx = 1
self._strides = (ragged_size * stride[self._ragged_idx - 1], *stride)
if values.requires_grad:
raise ValueError(
@ -100,6 +109,27 @@ class NestedTensor(torch.Tensor):
self._offsets = offsets
self._lengths = lengths
# SDPA metadata
def get_sdpa_extreme_seqlen(func, tensor):
return int(func(tensor).item())
# Note: Not using kwargs.get to avoid execution of get_sdpa_extreme_seqlen
# unless it is really needed
self._max_seqlen = (
kwargs["_max_seqlen"]
if "_max_seqlen" in kwargs
else get_sdpa_extreme_seqlen(
torch.max, offsets.diff() if lengths is None else lengths
)
)
self._min_seqlen = (
kwargs["_min_seqlen"]
if "_min_seqlen" in kwargs
else get_sdpa_extreme_seqlen(
torch.min, offsets.diff() if lengths is None else lengths
)
)
def values(self):
return self._values
@ -135,6 +165,9 @@ class NestedTensor(torch.Tensor):
ctx = {
"requires_grad": self.requires_grad,
"ragged_size": self._size[self._ragged_idx],
"max_seqlen": self._max_seqlen,
"min_seqlen": self._min_seqlen,
"ragged_idx": self._ragged_idx,
}
inner_tensors = ["_values", "_offsets"]
if self._lengths is not None:
@ -187,6 +220,9 @@ class NestedTensor(torch.Tensor):
offsets=offsets,
lengths=lengths,
requires_grad=meta["requires_grad"],
_max_seqlen=meta["max_seqlen"],
_min_seqlen=meta["min_seqlen"],
_ragged_idx=meta["ragged_idx"],
)
@classmethod
@ -222,35 +258,55 @@ class ViewBufferFromNested(torch.autograd.Function):
@staticmethod
def forward(ctx, x: NestedTensor): # type: ignore[override]
ctx.save_for_backward(x.offsets())
ctx.max_seqlen = x._max_seqlen
ctx.min_seqlen = x._min_seqlen
ctx._ragged_idx = x._ragged_idx
return x.values()
@staticmethod
def backward(ctx, gO: torch.Tensor): # type: ignore[override]
(offsets,) = ctx.saved_tensors
return NestedTensor(gO, offsets=offsets)
return NestedTensor(
gO,
offsets=offsets,
_max_seqlen=ctx.max_seqlen,
_min_seqlen=ctx.min_seqlen,
_ragged_idx=ctx._ragged_idx,
)
# Not actually a view!
class ViewNestedFromBuffer(torch.autograd.Function):
@staticmethod
def forward(ctx, values: torch.Tensor, offsets: torch.Tensor): # type: ignore[override]
return NestedTensor(values.detach(), offsets=offsets)
def forward(ctx, values: torch.Tensor, offsets: torch.Tensor, max_seqlen: int, min_seqlen: int): # type: ignore[override]
return NestedTensor(
values.detach(),
offsets=offsets,
_max_seqlen=max_seqlen,
_min_seqlen=min_seqlen,
)
@staticmethod
def backward(ctx, gO: NestedTensor): # type: ignore[override]
return gO.values(), None, None
return gO.values(), None, None, None
# Not actually a view!
# NOTE: @jbschlosser is working on making it a view
class ViewNonContiguousNestedFromBuffer(torch.autograd.Function):
@staticmethod
def forward(ctx, values: torch.Tensor, offsets: torch.Tensor, lengths: torch.Tensor): # type: ignore[override]
return NestedTensor(values.detach(), offsets=offsets, lengths=lengths)
def forward(ctx, values: torch.Tensor, offsets: torch.Tensor, lengths: torch.Tensor, max_seqlen: int, min_seqlen: int): # type: ignore[override]
return NestedTensor(
values.detach(),
offsets=offsets,
lengths=lengths,
_max_seqlen=max_seqlen,
_min_seqlen=min_seqlen,
)
@staticmethod
def backward(ctx, gO: NestedTensor): # type: ignore[override]
return gO.values(), None, None
return gO.values(), None, None, None, None
# Need to make it obvious that users should be passing in offsets
@ -303,7 +359,10 @@ def jagged_from_list(
]
)
return ViewNestedFromBuffer.apply(values, offsets), offsets # type: ignore[call-overload]
max_seqlen = max([t.shape[0] for t in tensors])
min_seqlen = min([t.shape[0] for t in tensors])
return ViewNestedFromBuffer.apply(values, offsets, max_seqlen, min_seqlen), offsets # type: ignore[call-overload]
def jagged_from_tensor_and_lengths(
@ -354,16 +413,28 @@ def jagged_from_tensor_and_lengths(
if offsets[0] + length_list[0] != orig_dim:
is_contiguous = False
actual_max_seqlen = int(torch.max(lengths).item())
min_seqlen = int(torch.min(lengths).item())
if is_contiguous:
return (
ViewNestedFromBuffer.apply(
values[offsets[0] : offsets[-1]], offsets - offsets[0]
values[offsets[0] : offsets[-1]],
offsets - offsets[0],
actual_max_seqlen,
min_seqlen,
),
offsets,
None,
)
return ViewNonContiguousNestedFromBuffer.apply(values, offsets, length_list), offsets, length_list # type: ignore[call-overload]
return (
ViewNonContiguousNestedFromBuffer.apply(
values, offsets, length_list, actual_max_seqlen, min_seqlen
),
offsets,
length_list,
) # type: ignore[call-overload]
def buffer_from_jagged(jagged):