Revert "Add an SDPA dispatcher for nested tensors with jagged layouts (#114164)"

This reverts commit aafa8233a4a1f336014cb122d16941e5b593706c.

Reverted https://github.com/pytorch/pytorch/pull/114164 on behalf of https://github.com/malfet due to Broke ROCM, see aafa8233a4 ([comment](https://github.com/pytorch/pytorch/pull/114164#issuecomment-1839798986))
This commit is contained in:
PyTorch MergeBot
2023-12-05 00:35:20 +00:00
parent aa6920c542
commit 5cfda9b7f8
7 changed files with 123 additions and 1177 deletions

View File

@ -43,9 +43,6 @@ class NestedTensor(torch.Tensor):
_stride: Tuple[int, ...]
# Indicates that the nth dimension is ragged
_ragged_idx: int
# SDPA Metadata
_max_seqlen: int
_min_seqlen: int
@staticmethod
def __new__(
@ -87,18 +84,12 @@ class NestedTensor(torch.Tensor):
# (create a new one if needed).
ragged_source = offsets if lengths is None else lengths
ragged_size = get_tensor_symint(ragged_source, coeff=1)
self._ragged_idx = kwargs.get("_ragged_idx", 1)
B = offsets.shape[0] - 1
Ds = values.shape[: self._ragged_idx - 1] + values.shape[self._ragged_idx :]
nested_size = [B]
nested_size.extend(Ds[: self._ragged_idx - 1])
nested_size.append(ragged_size)
nested_size.extend(Ds[self._ragged_idx - 1 :])
self._size = tuple(nested_size)
Ds = values.shape[1:]
self._size = (B, ragged_size, *Ds)
stride = values.stride()
self._strides = (ragged_size * stride[self._ragged_idx - 1], *stride)
self._strides = (ragged_size * stride[0], *stride)
self._ragged_idx = 1
if values.requires_grad:
raise ValueError(
@ -109,27 +100,6 @@ class NestedTensor(torch.Tensor):
self._offsets = offsets
self._lengths = lengths
# SDPA metadata
def get_sdpa_extreme_seqlen(func, tensor):
return int(func(tensor).item())
# Note: Not using kwargs.get to avoid execution of get_sdpa_extreme_seqlen
# unless it is really needed
self._max_seqlen = (
kwargs["_max_seqlen"]
if "_max_seqlen" in kwargs
else get_sdpa_extreme_seqlen(
torch.max, offsets.diff() if lengths is None else lengths
)
)
self._min_seqlen = (
kwargs["_min_seqlen"]
if "_min_seqlen" in kwargs
else get_sdpa_extreme_seqlen(
torch.min, offsets.diff() if lengths is None else lengths
)
)
def values(self):
return self._values
@ -165,9 +135,6 @@ class NestedTensor(torch.Tensor):
ctx = {
"requires_grad": self.requires_grad,
"ragged_size": self._size[self._ragged_idx],
"max_seqlen": self._max_seqlen,
"min_seqlen": self._min_seqlen,
"ragged_idx": self._ragged_idx,
}
inner_tensors = ["_values", "_offsets"]
if self._lengths is not None:
@ -220,9 +187,6 @@ class NestedTensor(torch.Tensor):
offsets=offsets,
lengths=lengths,
requires_grad=meta["requires_grad"],
_max_seqlen=meta["max_seqlen"],
_min_seqlen=meta["min_seqlen"],
_ragged_idx=meta["ragged_idx"],
)
@classmethod
@ -258,55 +222,35 @@ class ViewBufferFromNested(torch.autograd.Function):
@staticmethod
def forward(ctx, x: NestedTensor): # type: ignore[override]
ctx.save_for_backward(x.offsets())
ctx.max_seqlen = x._max_seqlen
ctx.min_seqlen = x._min_seqlen
ctx._ragged_idx = x._ragged_idx
return x.values()
@staticmethod
def backward(ctx, gO: torch.Tensor): # type: ignore[override]
(offsets,) = ctx.saved_tensors
return NestedTensor(
gO,
offsets=offsets,
_max_seqlen=ctx.max_seqlen,
_min_seqlen=ctx.min_seqlen,
_ragged_idx=ctx._ragged_idx,
)
return NestedTensor(gO, offsets=offsets)
# Not actually a view!
class ViewNestedFromBuffer(torch.autograd.Function):
@staticmethod
def forward(ctx, values: torch.Tensor, offsets: torch.Tensor, max_seqlen: int, min_seqlen: int): # type: ignore[override]
return NestedTensor(
values.detach(),
offsets=offsets,
_max_seqlen=max_seqlen,
_min_seqlen=min_seqlen,
)
def forward(ctx, values: torch.Tensor, offsets: torch.Tensor): # type: ignore[override]
return NestedTensor(values.detach(), offsets=offsets)
@staticmethod
def backward(ctx, gO: NestedTensor): # type: ignore[override]
return gO.values(), None, None, None
return gO.values(), None, None
# Not actually a view!
# NOTE: @jbschlosser is working on making it a view
class ViewNonContiguousNestedFromBuffer(torch.autograd.Function):
@staticmethod
def forward(ctx, values: torch.Tensor, offsets: torch.Tensor, lengths: torch.Tensor, max_seqlen: int, min_seqlen: int): # type: ignore[override]
return NestedTensor(
values.detach(),
offsets=offsets,
lengths=lengths,
_max_seqlen=max_seqlen,
_min_seqlen=min_seqlen,
)
def forward(ctx, values: torch.Tensor, offsets: torch.Tensor, lengths: torch.Tensor): # type: ignore[override]
return NestedTensor(values.detach(), offsets=offsets, lengths=lengths)
@staticmethod
def backward(ctx, gO: NestedTensor): # type: ignore[override]
return gO.values(), None, None, None, None
return gO.values(), None, None
# Need to make it obvious that users should be passing in offsets
@ -359,10 +303,7 @@ def jagged_from_list(
]
)
max_seqlen = max([t.shape[0] for t in tensors])
min_seqlen = min([t.shape[0] for t in tensors])
return ViewNestedFromBuffer.apply(values, offsets, max_seqlen, min_seqlen), offsets # type: ignore[call-overload]
return ViewNestedFromBuffer.apply(values, offsets), offsets # type: ignore[call-overload]
def jagged_from_tensor_and_lengths(
@ -413,28 +354,16 @@ def jagged_from_tensor_and_lengths(
if offsets[0] + length_list[0] != orig_dim:
is_contiguous = False
actual_max_seqlen = int(torch.max(lengths).item())
min_seqlen = int(torch.min(lengths).item())
if is_contiguous:
return (
ViewNestedFromBuffer.apply(
values[offsets[0] : offsets[-1]],
offsets - offsets[0],
actual_max_seqlen,
min_seqlen,
values[offsets[0] : offsets[-1]], offsets - offsets[0]
),
offsets,
None,
)
return (
ViewNonContiguousNestedFromBuffer.apply(
values, offsets, length_list, actual_max_seqlen, min_seqlen
),
offsets,
length_list,
) # type: ignore[call-overload]
return ViewNonContiguousNestedFromBuffer.apply(values, offsets, length_list), offsets, length_list # type: ignore[call-overload]
def buffer_from_jagged(jagged):