mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-24 23:54:56 +08:00
Continuation of #112185, following the design in this [doc](https://docs.google.com/document/d/1ipSxcTzEMMOAPvxP-YJlD5JBZZmIGgh8Q34ixtOUCRo).
Summary:
* Introduce `SubclassSymbolicPolicy` containing separate dynamic dim / constraint policies for the outer and inner tensors
* Expand the automatic dynamic algorithm to recurse into inner tensors and produce one of these for a subclass instance
* Maintain legacy behavior for subclasses by recursively calling `mark_dynamic()` on inner tensors *of the same dim as outer* when `mark_dynamic(outer, ...)` is called
* Addresses this: 6a86cf00ad/torch/_dynamo/variables/builder.py (L1750)
* Add `outer_size` and `outer_stride` arguments to `__tensor_unflatten__()` so that you can find out what symbols were allocated for the outer size / stride (you are expected to return a tensor that compares equal to the outer symbols)
* Signatures now:
```python
# attrs is a list of inner tensor attributes on x; inner_tensor = getattr(x, attr)
# ctx is anything useful for rebuilding the class we want to guard on
attrs, ctx = x.__tensor_flatten__()
...
# inner_tensors is a dict of {attr -> tensor}
# ctx is taken unmodified from flattening and (eventually) guarded on
# outer_size is the expected size of the output; possibly symbolic
# outer_stride is the expected strides of the output; possibly symbolic
y = MySubclass.__tensor_unflatten__(inner_tensors, ctx, outer_size, outer_stride)
# at the __tensor_unflatten__() call-site in PT2, we assert y.shape == outer_size and y.stride() == outer_stride
# the assert simplifies symbols when there are relationships between outer and inner symbols
```
* Size info needed for `NestedTensor` at least, stride info needed for `DTensor` at least
* Punting on `outer_storage_offset` because storage_offset handling is horribly broken in PT2 right now
* ~~Add new `__tensor_mark_dynamic__()` to allow overriding the behavior of mark_dynamic on a per-subclass basis~~ (booted to future work)
* ~~Add guards for tensor subclasses by calling `__tensor_flatten__()` in the guard to test equality on `ctx`~~
* Now handled in #114469
* Next PR: add TENSOR_MATCH guards on inner tensors
Pull Request resolved: https://github.com/pytorch/pytorch/pull/114311
Approved by: https://github.com/ezyang, https://github.com/drisspg, https://github.com/voznesenskym, https://github.com/bdhirsh
422 lines
15 KiB
Python
422 lines
15 KiB
Python
from typing import Tuple
|
|
|
|
import torch
|
|
from torch._C import DispatchKey, DispatchKeySet
|
|
from torch._prims_common import is_expandable_to
|
|
from torch.fx.experimental.symbolic_shapes import has_free_symbols
|
|
from torch.utils.weak import WeakTensorKeyDictionary
|
|
from typing import * # noqa: F403
|
|
|
|
_tensor_id_counter = 0
|
|
_tensor_symint_registry = WeakTensorKeyDictionary()
|
|
|
|
|
|
def get_tensor_symint(tensor, *, coeff=1):
|
|
global _tensor_id_counter
|
|
if tensor not in _tensor_symint_registry:
|
|
_tensor_symint_registry[tensor] = torch._C._get_singleton_int(
|
|
_tensor_id_counter, coeff
|
|
)
|
|
_tensor_id_counter += 1
|
|
return _tensor_symint_registry[tensor]
|
|
|
|
|
|
class NestedTensor(torch.Tensor):
|
|
_values: torch.Tensor # type: ignore[assignment]
|
|
_offsets: torch.Tensor
|
|
_lengths: Optional[torch.Tensor]
|
|
# NOTE [ Singleton ints for ragged sizes and strides ]
|
|
#
|
|
# Jagged layout tensors are tensors that represent a n-dim tensor with a
|
|
# ragged dimension, but are backed by an (n-1)-dim tensor underneath, e.g.,
|
|
# a jagged tensor with outer shape [B, x, D] is represented internally by a
|
|
# tensor with shape [sum(x), D] where we introduce what we call a singleton
|
|
# (or skolem) denoted as "x" here (but sometimes denoted with "*" to
|
|
# represent the ragged dimension, and sum(x) represents the dim of the inner
|
|
# tensor or equivalently the sum of all the sizes of the constituent
|
|
# tensors' varying lengths.
|
|
#
|
|
# We also use singleton ints to represent the strides of this tensor.
|
|
# For example, a jagged tensor with shape [B, x, D] can be strided in two
|
|
# ways: [xD, D, 1] and [x, 1, sum(x)], where xD represents x multiplied by D
|
|
_size: Tuple[int, ...]
|
|
_stride: Tuple[int, ...]
|
|
# Indicates that the nth dimension is ragged
|
|
_ragged_idx: int
|
|
# SDPA Metadata
|
|
_max_seqlen: int
|
|
_min_seqlen: int
|
|
|
|
@staticmethod
|
|
def __new__(
|
|
cls,
|
|
values,
|
|
offsets,
|
|
*,
|
|
lengths=None,
|
|
**kwargs,
|
|
):
|
|
ks = DispatchKeySet(DispatchKey.NestedTensor)
|
|
ks = ks.add(DispatchKey.AutogradNestedTensor)
|
|
r = torch.Tensor._make_wrapper_subclass( # type: ignore[attr-defined]
|
|
cls,
|
|
(0,),
|
|
(0,),
|
|
0,
|
|
torch.contiguous_format,
|
|
values.dtype,
|
|
torch.jagged,
|
|
values.device,
|
|
False,
|
|
kwargs.get("requires_grad", False),
|
|
"sizes",
|
|
False,
|
|
True, # dispatch_layout
|
|
ks,
|
|
)
|
|
return r
|
|
|
|
def __init__(self, values, offsets, *, lengths=None, **kwargs):
|
|
super().__init__()
|
|
# Only support jagged for now.
|
|
assert offsets is not None
|
|
assert offsets.ndim == 1
|
|
assert not isinstance(values, NestedTensor)
|
|
|
|
# Query cache for the symint associated with offsets or lengths
|
|
# (create a new one if needed).
|
|
ragged_source = offsets if lengths is None else lengths
|
|
ragged_size = get_tensor_symint(ragged_source, coeff=1)
|
|
self._ragged_idx = kwargs.get("_ragged_idx", 1)
|
|
B = offsets.shape[0] - 1
|
|
Ds = values.shape[: self._ragged_idx - 1] + values.shape[self._ragged_idx :]
|
|
|
|
nested_size = [B]
|
|
nested_size.extend(Ds[: self._ragged_idx - 1])
|
|
nested_size.append(ragged_size)
|
|
nested_size.extend(Ds[self._ragged_idx - 1 :])
|
|
self._size = tuple(nested_size)
|
|
|
|
stride = values.stride()
|
|
self._strides = (ragged_size * stride[self._ragged_idx - 1], *stride)
|
|
|
|
if values.requires_grad:
|
|
raise ValueError(
|
|
"NestedTensor values cannot require grad, please "
|
|
"detach before passing to NestedTensor constructor"
|
|
)
|
|
self._values = values
|
|
self._offsets = offsets
|
|
self._lengths = lengths
|
|
|
|
# SDPA metadata
|
|
def get_sdpa_extreme_seqlen(func, tensor):
|
|
return int(func(tensor).item())
|
|
|
|
# Note: Not using kwargs.get to avoid execution of get_sdpa_extreme_seqlen
|
|
# unless it is really needed
|
|
self._max_seqlen = (
|
|
kwargs["_max_seqlen"]
|
|
if "_max_seqlen" in kwargs
|
|
else get_sdpa_extreme_seqlen(
|
|
torch.max, offsets.diff() if lengths is None else lengths
|
|
)
|
|
)
|
|
self._min_seqlen = (
|
|
kwargs["_min_seqlen"]
|
|
if "_min_seqlen" in kwargs
|
|
else get_sdpa_extreme_seqlen(
|
|
torch.min, offsets.diff() if lengths is None else lengths
|
|
)
|
|
)
|
|
|
|
# collapsed ragged dim must always be dynamic
|
|
torch._dynamo.mark_dynamic(self, self._ragged_idx)
|
|
torch._dynamo.mark_dynamic(self._values, self._ragged_idx - 1)
|
|
|
|
def values(self):
|
|
return self._values
|
|
|
|
def offsets(self):
|
|
return self._offsets
|
|
|
|
def lengths(self):
|
|
return self._lengths
|
|
|
|
def __repr__(self):
|
|
# We should implement this in torch/_tensor_str.py instead
|
|
grad_fn_str = (
|
|
f", requires_grad={self.requires_grad}" if self.requires_grad else ""
|
|
)
|
|
if self.grad_fn:
|
|
grad_fn_str = f", grad_fn={self.grad_fn}"
|
|
return f"NestedTensor(size={self._size}, offsets={self._offsets}{grad_fn_str}, contiguous={self._lengths is None})"
|
|
|
|
def __reduce_ex__(self, proto):
|
|
state = torch._utils._get_obj_state(self)
|
|
|
|
# SymNodes are not serializable
|
|
assert "_size" in state and "_strides" in state
|
|
state = dict(state)
|
|
del state["_size"]
|
|
del state["_strides"]
|
|
|
|
func = NestedTensor
|
|
args = (self._values, self._offsets)
|
|
return (torch._tensor._rebuild_from_type_v2, (func, type(self), args, state))
|
|
|
|
def __tensor_flatten__(self):
|
|
ctx = {
|
|
"requires_grad": self.requires_grad,
|
|
"max_seqlen": self._max_seqlen,
|
|
"min_seqlen": self._min_seqlen,
|
|
"ragged_idx": self._ragged_idx,
|
|
}
|
|
inner_tensors = ["_values", "_offsets"]
|
|
if self._lengths is not None:
|
|
inner_tensors.append("_lengths")
|
|
return inner_tensors, ctx
|
|
|
|
@staticmethod
|
|
def __tensor_unflatten__(inner_tensors: Dict, meta, outer_size, outer_stride):
|
|
assert len(inner_tensors) >= 2 and len(inner_tensors) <= 3
|
|
values = inner_tensors["_values"]
|
|
offsets = inner_tensors["_offsets"]
|
|
lengths = inner_tensors.get("_lengths", None)
|
|
ragged_idx = meta["ragged_idx"]
|
|
|
|
# Note that we cannot simply check if is_fake(values) because
|
|
# during aot autograd, FunctionalTensors are not fake but hold
|
|
# symbolic sizes.
|
|
ragged_source = offsets if lengths is None else lengths
|
|
if has_free_symbols(ragged_source) or has_free_symbols(values):
|
|
# Associate offsets or lengths (possibly fake, possibly functionalized)
|
|
# with the ragged_size.
|
|
ragged_size = outer_size[ragged_idx]
|
|
_tensor_symint_registry[ragged_source] = ragged_size
|
|
|
|
return NestedTensor(
|
|
values,
|
|
offsets=offsets,
|
|
lengths=lengths,
|
|
requires_grad=meta["requires_grad"],
|
|
_max_seqlen=meta["max_seqlen"],
|
|
_min_seqlen=meta["min_seqlen"],
|
|
_ragged_idx=ragged_idx,
|
|
)
|
|
|
|
@classmethod
|
|
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
|
|
kwargs = {} if kwargs is None else kwargs
|
|
|
|
# Lazy import to avoid circular dependency
|
|
from .ops import lookup_jagged
|
|
|
|
fn = lookup_jagged(func, *args, **kwargs)
|
|
if fn is not None:
|
|
return fn(*args, **kwargs)
|
|
|
|
raise NotImplementedError(func)
|
|
|
|
@classmethod
|
|
def __torch_function__(cls, func, types, args=(), kwargs=None):
|
|
if kwargs is None:
|
|
kwargs = {}
|
|
|
|
from .ops import jagged_torch_function
|
|
|
|
try:
|
|
return jagged_torch_function(func, *args, **kwargs)
|
|
except NotImplementedError:
|
|
pass
|
|
with torch._C.DisableTorchFunctionSubclass():
|
|
return func(*args, **kwargs)
|
|
|
|
|
|
# Not actually a view!
|
|
class ViewBufferFromNested(torch.autograd.Function):
|
|
@staticmethod
|
|
def forward(ctx, x: NestedTensor): # type: ignore[override]
|
|
ctx.save_for_backward(x.offsets())
|
|
ctx.max_seqlen = x._max_seqlen
|
|
ctx.min_seqlen = x._min_seqlen
|
|
ctx._ragged_idx = x._ragged_idx
|
|
return x.values()
|
|
|
|
@staticmethod
|
|
def backward(ctx, gO: torch.Tensor): # type: ignore[override]
|
|
(offsets,) = ctx.saved_tensors
|
|
return NestedTensor(
|
|
gO,
|
|
offsets=offsets,
|
|
_max_seqlen=ctx.max_seqlen,
|
|
_min_seqlen=ctx.min_seqlen,
|
|
_ragged_idx=ctx._ragged_idx,
|
|
)
|
|
|
|
|
|
# Not actually a view!
|
|
class ViewNestedFromBuffer(torch.autograd.Function):
|
|
@staticmethod
|
|
def forward(ctx, values: torch.Tensor, offsets: torch.Tensor, max_seqlen: int, min_seqlen: int): # type: ignore[override]
|
|
return NestedTensor(
|
|
values.detach(),
|
|
offsets=offsets,
|
|
_max_seqlen=max_seqlen,
|
|
_min_seqlen=min_seqlen,
|
|
)
|
|
|
|
@staticmethod
|
|
def backward(ctx, gO: NestedTensor): # type: ignore[override]
|
|
return gO.values(), None, None, None
|
|
|
|
|
|
# Not actually a view!
|
|
# NOTE: @jbschlosser is working on making it a view
|
|
class ViewNonContiguousNestedFromBuffer(torch.autograd.Function):
|
|
@staticmethod
|
|
def forward(ctx, values: torch.Tensor, offsets: torch.Tensor, lengths: torch.Tensor, max_seqlen: int, min_seqlen: int): # type: ignore[override]
|
|
return NestedTensor(
|
|
values.detach(),
|
|
offsets=offsets,
|
|
lengths=lengths,
|
|
_max_seqlen=max_seqlen,
|
|
_min_seqlen=min_seqlen,
|
|
)
|
|
|
|
@staticmethod
|
|
def backward(ctx, gO: NestedTensor): # type: ignore[override]
|
|
return gO.values(), None, None, None, None
|
|
|
|
|
|
# Need to make it obvious that users should be passing in offsets
|
|
def jagged_from_list(
|
|
tensors: List[torch.Tensor],
|
|
offsets: Optional[torch.Tensor],
|
|
dtype=None,
|
|
device=None,
|
|
) -> Tuple[NestedTensor, torch.Tensor]:
|
|
"""Constructs a NestedTensor backed by jagged layout from a list of tensors"""
|
|
|
|
if not len(set(t.dtype for t in tensors)) == 1: # noqa: C401
|
|
raise RuntimeError(
|
|
"When constructing a nested tensor, all tensors in list must have the same dtype"
|
|
)
|
|
if not len(set(t.device for t in tensors)) == 1: # noqa: C401
|
|
raise RuntimeError(
|
|
"When constructing a nested tensor, all tensors in list must be on the same device"
|
|
)
|
|
|
|
# Check that the NT is representable by the jagged layout.
|
|
# Jagged layout represents (B, *, D_0, D_1, ..., D_N), where the only
|
|
# raggedness allowed is for the single dim immediately adjacent to the batch dim.
|
|
sizes = [t.shape for t in tensors]
|
|
non_first_sizes = [s[1:] for s in sizes]
|
|
at_most_first_ragged = all(s == non_first_sizes[0] for s in non_first_sizes)
|
|
if not at_most_first_ragged:
|
|
raise RuntimeError(
|
|
"Cannot represent given tensor list as a nested tensor with the jagged layout. "
|
|
"Note that the jagged layout only represents shapes of the form "
|
|
"(B, *, D_0, D_1, ..., D_N), with only * allowed to be ragged."
|
|
)
|
|
|
|
# Set properties appropriately.
|
|
values = torch.cat(tensors, dim=0)
|
|
to_kwargs = {}
|
|
if device is not None:
|
|
to_kwargs["device"] = device
|
|
if dtype is not None:
|
|
to_kwargs["dtype"] = dtype
|
|
values = values.to(**to_kwargs)
|
|
|
|
# Calculate jagged offsets if not provided.
|
|
if offsets is None:
|
|
# Jagged layout specifies that offsets are stored as int64 on the same device as values.
|
|
offsets = torch.cat(
|
|
[
|
|
torch.zeros(1, dtype=torch.int64, device=values.device),
|
|
torch.tensor([s[0] for s in sizes], device=values.device).cumsum(dim=0),
|
|
]
|
|
)
|
|
|
|
max_seqlen = max([t.shape[0] for t in tensors])
|
|
min_seqlen = min([t.shape[0] for t in tensors])
|
|
|
|
return ViewNestedFromBuffer.apply(values, offsets, max_seqlen, min_seqlen), offsets # type: ignore[call-overload]
|
|
|
|
|
|
def jagged_from_tensor_and_lengths(
|
|
tensor: torch.Tensor, starts: torch.Tensor, lengths: torch.Tensor
|
|
) -> Tuple[NestedTensor, torch.Tensor, Optional[torch.Tensor]]:
|
|
"""Constructs a NestedTensor backed by jagged layout from a tensor, starts of sequences, and sequence lengths"""
|
|
batch_size = tensor.shape[0]
|
|
if is_expandable_to(starts.shape, (batch_size,)) and is_expandable_to(
|
|
lengths.shape, (batch_size,)
|
|
):
|
|
start_list = starts.expand(batch_size)
|
|
length_list = lengths.expand(batch_size)
|
|
else:
|
|
raise RuntimeError(
|
|
"When constructing a jagged nested tensor using narrow(), "
|
|
"your start and length must be Tensors that broadcast to input.shape[0]"
|
|
)
|
|
|
|
# Calculate jagged offsets
|
|
assert (
|
|
len(tensor.shape) >= 2
|
|
), "tensor must at least be 2D for the nested narrow op to work"
|
|
max_seq_len = tensor.shape[1]
|
|
offset_lengths = max_seq_len * torch.arange(
|
|
0, batch_size, dtype=torch.int64, device=tensor.device
|
|
)
|
|
# Jagged layout specifies that offsets are stored as int64 on the same device as values.
|
|
offsets = torch.cat(
|
|
[
|
|
start_list + offset_lengths,
|
|
(start_list[-1] + offset_lengths[-1] + length_list[-1]).unsqueeze(0),
|
|
]
|
|
)
|
|
|
|
# Reshape buffer to flatten the 1st and 2nd dimension (view used to enforce non-copy)
|
|
if len(tensor.shape) > 2:
|
|
values = tensor.view(-1, *tensor.shape[2:])
|
|
else:
|
|
values = tensor.view(-1)
|
|
|
|
# Check if offsets and lengths make it possibly contiguous and return a regular NT
|
|
is_contiguous = True
|
|
orig_dim = tensor.shape[1]
|
|
if torch.any(length_list[1:-1].ne(orig_dim)):
|
|
is_contiguous = False
|
|
if torch.any(offsets[1:-2].diff().ne(orig_dim)):
|
|
is_contiguous = False
|
|
if offsets[0] + length_list[0] != orig_dim:
|
|
is_contiguous = False
|
|
|
|
actual_max_seqlen = int(torch.max(lengths).item())
|
|
min_seqlen = int(torch.min(lengths).item())
|
|
|
|
if is_contiguous:
|
|
return (
|
|
ViewNestedFromBuffer.apply(
|
|
values[offsets[0] : offsets[-1]],
|
|
offsets - offsets[0],
|
|
actual_max_seqlen,
|
|
min_seqlen,
|
|
),
|
|
offsets,
|
|
None,
|
|
)
|
|
|
|
return (
|
|
ViewNonContiguousNestedFromBuffer.apply(
|
|
values, offsets, length_list, actual_max_seqlen, min_seqlen
|
|
),
|
|
offsets,
|
|
length_list,
|
|
) # type: ignore[call-overload]
|
|
|
|
|
|
def buffer_from_jagged(jagged):
|
|
return ViewBufferFromNested.apply(jagged)
|