mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-11-04 16:04:58 +08:00 
			
		
		
		
	Revert "Add an SDPA dispatcher for nested tensors with jagged layouts (#114164)"
This reverts commit aafa8233a4a1f336014cb122d16941e5b593706c.
Reverted https://github.com/pytorch/pytorch/pull/114164 on behalf of https://github.com/malfet due to Broke ROCM, see aafa8233a4 ([comment](https://github.com/pytorch/pytorch/pull/114164#issuecomment-1839798986))
			
			
This commit is contained in:
		@ -43,9 +43,6 @@ class NestedTensor(torch.Tensor):
 | 
			
		||||
    _stride: Tuple[int, ...]
 | 
			
		||||
    # Indicates that the nth dimension is ragged
 | 
			
		||||
    _ragged_idx: int
 | 
			
		||||
    # SDPA Metadata
 | 
			
		||||
    _max_seqlen: int
 | 
			
		||||
    _min_seqlen: int
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def __new__(
 | 
			
		||||
@ -87,18 +84,12 @@ class NestedTensor(torch.Tensor):
 | 
			
		||||
        # (create a new one if needed).
 | 
			
		||||
        ragged_source = offsets if lengths is None else lengths
 | 
			
		||||
        ragged_size = get_tensor_symint(ragged_source, coeff=1)
 | 
			
		||||
        self._ragged_idx = kwargs.get("_ragged_idx", 1)
 | 
			
		||||
        B = offsets.shape[0] - 1
 | 
			
		||||
        Ds = values.shape[: self._ragged_idx - 1] + values.shape[self._ragged_idx :]
 | 
			
		||||
 | 
			
		||||
        nested_size = [B]
 | 
			
		||||
        nested_size.extend(Ds[: self._ragged_idx - 1])
 | 
			
		||||
        nested_size.append(ragged_size)
 | 
			
		||||
        nested_size.extend(Ds[self._ragged_idx - 1 :])
 | 
			
		||||
        self._size = tuple(nested_size)
 | 
			
		||||
 | 
			
		||||
        Ds = values.shape[1:]
 | 
			
		||||
        self._size = (B, ragged_size, *Ds)
 | 
			
		||||
        stride = values.stride()
 | 
			
		||||
        self._strides = (ragged_size * stride[self._ragged_idx - 1], *stride)
 | 
			
		||||
        self._strides = (ragged_size * stride[0], *stride)
 | 
			
		||||
        self._ragged_idx = 1
 | 
			
		||||
 | 
			
		||||
        if values.requires_grad:
 | 
			
		||||
            raise ValueError(
 | 
			
		||||
@ -109,27 +100,6 @@ class NestedTensor(torch.Tensor):
 | 
			
		||||
        self._offsets = offsets
 | 
			
		||||
        self._lengths = lengths
 | 
			
		||||
 | 
			
		||||
        # SDPA metadata
 | 
			
		||||
        def get_sdpa_extreme_seqlen(func, tensor):
 | 
			
		||||
            return int(func(tensor).item())
 | 
			
		||||
 | 
			
		||||
        # Note: Not using kwargs.get to avoid execution of get_sdpa_extreme_seqlen
 | 
			
		||||
        # unless it is really needed
 | 
			
		||||
        self._max_seqlen = (
 | 
			
		||||
            kwargs["_max_seqlen"]
 | 
			
		||||
            if "_max_seqlen" in kwargs
 | 
			
		||||
            else get_sdpa_extreme_seqlen(
 | 
			
		||||
                torch.max, offsets.diff() if lengths is None else lengths
 | 
			
		||||
            )
 | 
			
		||||
        )
 | 
			
		||||
        self._min_seqlen = (
 | 
			
		||||
            kwargs["_min_seqlen"]
 | 
			
		||||
            if "_min_seqlen" in kwargs
 | 
			
		||||
            else get_sdpa_extreme_seqlen(
 | 
			
		||||
                torch.min, offsets.diff() if lengths is None else lengths
 | 
			
		||||
            )
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def values(self):
 | 
			
		||||
        return self._values
 | 
			
		||||
 | 
			
		||||
@ -165,9 +135,6 @@ class NestedTensor(torch.Tensor):
 | 
			
		||||
        ctx = {
 | 
			
		||||
            "requires_grad": self.requires_grad,
 | 
			
		||||
            "ragged_size": self._size[self._ragged_idx],
 | 
			
		||||
            "max_seqlen": self._max_seqlen,
 | 
			
		||||
            "min_seqlen": self._min_seqlen,
 | 
			
		||||
            "ragged_idx": self._ragged_idx,
 | 
			
		||||
        }
 | 
			
		||||
        inner_tensors = ["_values", "_offsets"]
 | 
			
		||||
        if self._lengths is not None:
 | 
			
		||||
@ -220,9 +187,6 @@ class NestedTensor(torch.Tensor):
 | 
			
		||||
            offsets=offsets,
 | 
			
		||||
            lengths=lengths,
 | 
			
		||||
            requires_grad=meta["requires_grad"],
 | 
			
		||||
            _max_seqlen=meta["max_seqlen"],
 | 
			
		||||
            _min_seqlen=meta["min_seqlen"],
 | 
			
		||||
            _ragged_idx=meta["ragged_idx"],
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
@ -258,55 +222,35 @@ class ViewBufferFromNested(torch.autograd.Function):
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def forward(ctx, x: NestedTensor):  # type: ignore[override]
 | 
			
		||||
        ctx.save_for_backward(x.offsets())
 | 
			
		||||
        ctx.max_seqlen = x._max_seqlen
 | 
			
		||||
        ctx.min_seqlen = x._min_seqlen
 | 
			
		||||
        ctx._ragged_idx = x._ragged_idx
 | 
			
		||||
        return x.values()
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def backward(ctx, gO: torch.Tensor):  # type: ignore[override]
 | 
			
		||||
        (offsets,) = ctx.saved_tensors
 | 
			
		||||
        return NestedTensor(
 | 
			
		||||
            gO,
 | 
			
		||||
            offsets=offsets,
 | 
			
		||||
            _max_seqlen=ctx.max_seqlen,
 | 
			
		||||
            _min_seqlen=ctx.min_seqlen,
 | 
			
		||||
            _ragged_idx=ctx._ragged_idx,
 | 
			
		||||
        )
 | 
			
		||||
        return NestedTensor(gO, offsets=offsets)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Not actually a view!
 | 
			
		||||
class ViewNestedFromBuffer(torch.autograd.Function):
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def forward(ctx, values: torch.Tensor, offsets: torch.Tensor, max_seqlen: int, min_seqlen: int):  # type: ignore[override]
 | 
			
		||||
        return NestedTensor(
 | 
			
		||||
            values.detach(),
 | 
			
		||||
            offsets=offsets,
 | 
			
		||||
            _max_seqlen=max_seqlen,
 | 
			
		||||
            _min_seqlen=min_seqlen,
 | 
			
		||||
        )
 | 
			
		||||
    def forward(ctx, values: torch.Tensor, offsets: torch.Tensor):  # type: ignore[override]
 | 
			
		||||
        return NestedTensor(values.detach(), offsets=offsets)
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def backward(ctx, gO: NestedTensor):  # type: ignore[override]
 | 
			
		||||
        return gO.values(), None, None, None
 | 
			
		||||
        return gO.values(), None, None
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Not actually a view!
 | 
			
		||||
# NOTE: @jbschlosser is working on making it a view
 | 
			
		||||
class ViewNonContiguousNestedFromBuffer(torch.autograd.Function):
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def forward(ctx, values: torch.Tensor, offsets: torch.Tensor, lengths: torch.Tensor, max_seqlen: int, min_seqlen: int):  # type: ignore[override]
 | 
			
		||||
        return NestedTensor(
 | 
			
		||||
            values.detach(),
 | 
			
		||||
            offsets=offsets,
 | 
			
		||||
            lengths=lengths,
 | 
			
		||||
            _max_seqlen=max_seqlen,
 | 
			
		||||
            _min_seqlen=min_seqlen,
 | 
			
		||||
        )
 | 
			
		||||
    def forward(ctx, values: torch.Tensor, offsets: torch.Tensor, lengths: torch.Tensor):  # type: ignore[override]
 | 
			
		||||
        return NestedTensor(values.detach(), offsets=offsets, lengths=lengths)
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def backward(ctx, gO: NestedTensor):  # type: ignore[override]
 | 
			
		||||
        return gO.values(), None, None, None, None
 | 
			
		||||
        return gO.values(), None, None
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Need to make it obvious that users should be passing in offsets
 | 
			
		||||
@ -359,10 +303,7 @@ def jagged_from_list(
 | 
			
		||||
            ]
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    max_seqlen = max([t.shape[0] for t in tensors])
 | 
			
		||||
    min_seqlen = min([t.shape[0] for t in tensors])
 | 
			
		||||
 | 
			
		||||
    return ViewNestedFromBuffer.apply(values, offsets, max_seqlen, min_seqlen), offsets  # type: ignore[call-overload]
 | 
			
		||||
    return ViewNestedFromBuffer.apply(values, offsets), offsets  # type: ignore[call-overload]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def jagged_from_tensor_and_lengths(
 | 
			
		||||
@ -413,28 +354,16 @@ def jagged_from_tensor_and_lengths(
 | 
			
		||||
    if offsets[0] + length_list[0] != orig_dim:
 | 
			
		||||
        is_contiguous = False
 | 
			
		||||
 | 
			
		||||
    actual_max_seqlen = int(torch.max(lengths).item())
 | 
			
		||||
    min_seqlen = int(torch.min(lengths).item())
 | 
			
		||||
 | 
			
		||||
    if is_contiguous:
 | 
			
		||||
        return (
 | 
			
		||||
            ViewNestedFromBuffer.apply(
 | 
			
		||||
                values[offsets[0] : offsets[-1]],
 | 
			
		||||
                offsets - offsets[0],
 | 
			
		||||
                actual_max_seqlen,
 | 
			
		||||
                min_seqlen,
 | 
			
		||||
                values[offsets[0] : offsets[-1]], offsets - offsets[0]
 | 
			
		||||
            ),
 | 
			
		||||
            offsets,
 | 
			
		||||
            None,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    return (
 | 
			
		||||
        ViewNonContiguousNestedFromBuffer.apply(
 | 
			
		||||
            values, offsets, length_list, actual_max_seqlen, min_seqlen
 | 
			
		||||
        ),
 | 
			
		||||
        offsets,
 | 
			
		||||
        length_list,
 | 
			
		||||
    )  # type: ignore[call-overload]
 | 
			
		||||
    return ViewNonContiguousNestedFromBuffer.apply(values, offsets, length_list), offsets, length_list  # type: ignore[call-overload]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def buffer_from_jagged(jagged):
 | 
			
		||||
 | 
			
		||||
		Reference in New Issue
	
	Block a user