Files
pytorch/torch/nested/_internal/nested_tensor.py
Joel Schlosser 22704426c3 Expand dynamic dims support for traceable subclasses (#114311)
Continuation of #112185, following the design in this [doc](https://docs.google.com/document/d/1ipSxcTzEMMOAPvxP-YJlD5JBZZmIGgh8Q34ixtOUCRo).

Summary:
* Introduce `SubclassSymbolicPolicy` containing separate dynamic dim / constraint policies for the outer and inner tensors
    * Expand the automatic dynamic algorithm to recurse into inner tensors and produce one of these for a subclass instance
    * Maintain legacy behavior for subclasses by recursively calling `mark_dynamic()` on inner tensors *of the same dim as outer* when `mark_dynamic(outer, ...)` is called
    * Addresses this: 6a86cf00ad/torch/_dynamo/variables/builder.py (L1750)
* Add `outer_size` and `outer_stride` arguments to `__tensor_unflatten__()` so that you can find out what symbols were allocated for the outer size / stride (you are expected to return a tensor that compares equal to the outer symbols)
    * Signatures now:
    ```python
    # attrs is a list of inner tensor attributes on x; inner_tensor = getattr(x, attr)
    # ctx is anything useful for rebuilding the class we want to guard on
    attrs, ctx = x.__tensor_flatten__()
    ...
    # inner_tensors is a dict of {attr -> tensor}
    # ctx is taken unmodified from flattening and (eventually) guarded on
    # outer_size is the expected size of the output; possibly symbolic
    # outer_stride is the expected strides of the output; possibly symbolic
    y = MySubclass.__tensor_unflatten__(inner_tensors, ctx, outer_size, outer_stride)

    # at the __tensor_unflatten__() call-site in PT2, we assert y.shape == outer_size and y.stride() == outer_stride
    # the assert simplifies symbols when there are relationships between outer and inner symbols
    ```
    * Size info needed for `NestedTensor` at least, stride info needed for `DTensor` at least
    * Punting on `outer_storage_offset` because storage_offset handling is horribly broken in PT2 right now
* ~~Add new `__tensor_mark_dynamic__()` to allow overriding the behavior of mark_dynamic on a per-subclass basis~~ (booted to future work)
* ~~Add guards for tensor subclasses by calling `__tensor_flatten__()` in the guard to test equality on `ctx`~~
    * Now handled in #114469
* Next PR: add TENSOR_MATCH guards on inner tensors

Pull Request resolved: https://github.com/pytorch/pytorch/pull/114311
Approved by: https://github.com/ezyang, https://github.com/drisspg, https://github.com/voznesenskym, https://github.com/bdhirsh
2023-12-05 21:09:25 +00:00

422 lines
15 KiB
Python

from typing import Tuple
import torch
from torch._C import DispatchKey, DispatchKeySet
from torch._prims_common import is_expandable_to
from torch.fx.experimental.symbolic_shapes import has_free_symbols
from torch.utils.weak import WeakTensorKeyDictionary
from typing import * # noqa: F403
_tensor_id_counter = 0
_tensor_symint_registry = WeakTensorKeyDictionary()
def get_tensor_symint(tensor, *, coeff=1):
global _tensor_id_counter
if tensor not in _tensor_symint_registry:
_tensor_symint_registry[tensor] = torch._C._get_singleton_int(
_tensor_id_counter, coeff
)
_tensor_id_counter += 1
return _tensor_symint_registry[tensor]
class NestedTensor(torch.Tensor):
_values: torch.Tensor # type: ignore[assignment]
_offsets: torch.Tensor
_lengths: Optional[torch.Tensor]
# NOTE [ Singleton ints for ragged sizes and strides ]
#
# Jagged layout tensors are tensors that represent a n-dim tensor with a
# ragged dimension, but are backed by an (n-1)-dim tensor underneath, e.g.,
# a jagged tensor with outer shape [B, x, D] is represented internally by a
# tensor with shape [sum(x), D] where we introduce what we call a singleton
# (or skolem) denoted as "x" here (but sometimes denoted with "*" to
# represent the ragged dimension, and sum(x) represents the dim of the inner
# tensor or equivalently the sum of all the sizes of the constituent
# tensors' varying lengths.
#
# We also use singleton ints to represent the strides of this tensor.
# For example, a jagged tensor with shape [B, x, D] can be strided in two
# ways: [xD, D, 1] and [x, 1, sum(x)], where xD represents x multiplied by D
_size: Tuple[int, ...]
_stride: Tuple[int, ...]
# Indicates that the nth dimension is ragged
_ragged_idx: int
# SDPA Metadata
_max_seqlen: int
_min_seqlen: int
@staticmethod
def __new__(
cls,
values,
offsets,
*,
lengths=None,
**kwargs,
):
ks = DispatchKeySet(DispatchKey.NestedTensor)
ks = ks.add(DispatchKey.AutogradNestedTensor)
r = torch.Tensor._make_wrapper_subclass( # type: ignore[attr-defined]
cls,
(0,),
(0,),
0,
torch.contiguous_format,
values.dtype,
torch.jagged,
values.device,
False,
kwargs.get("requires_grad", False),
"sizes",
False,
True, # dispatch_layout
ks,
)
return r
def __init__(self, values, offsets, *, lengths=None, **kwargs):
super().__init__()
# Only support jagged for now.
assert offsets is not None
assert offsets.ndim == 1
assert not isinstance(values, NestedTensor)
# Query cache for the symint associated with offsets or lengths
# (create a new one if needed).
ragged_source = offsets if lengths is None else lengths
ragged_size = get_tensor_symint(ragged_source, coeff=1)
self._ragged_idx = kwargs.get("_ragged_idx", 1)
B = offsets.shape[0] - 1
Ds = values.shape[: self._ragged_idx - 1] + values.shape[self._ragged_idx :]
nested_size = [B]
nested_size.extend(Ds[: self._ragged_idx - 1])
nested_size.append(ragged_size)
nested_size.extend(Ds[self._ragged_idx - 1 :])
self._size = tuple(nested_size)
stride = values.stride()
self._strides = (ragged_size * stride[self._ragged_idx - 1], *stride)
if values.requires_grad:
raise ValueError(
"NestedTensor values cannot require grad, please "
"detach before passing to NestedTensor constructor"
)
self._values = values
self._offsets = offsets
self._lengths = lengths
# SDPA metadata
def get_sdpa_extreme_seqlen(func, tensor):
return int(func(tensor).item())
# Note: Not using kwargs.get to avoid execution of get_sdpa_extreme_seqlen
# unless it is really needed
self._max_seqlen = (
kwargs["_max_seqlen"]
if "_max_seqlen" in kwargs
else get_sdpa_extreme_seqlen(
torch.max, offsets.diff() if lengths is None else lengths
)
)
self._min_seqlen = (
kwargs["_min_seqlen"]
if "_min_seqlen" in kwargs
else get_sdpa_extreme_seqlen(
torch.min, offsets.diff() if lengths is None else lengths
)
)
# collapsed ragged dim must always be dynamic
torch._dynamo.mark_dynamic(self, self._ragged_idx)
torch._dynamo.mark_dynamic(self._values, self._ragged_idx - 1)
def values(self):
return self._values
def offsets(self):
return self._offsets
def lengths(self):
return self._lengths
def __repr__(self):
# We should implement this in torch/_tensor_str.py instead
grad_fn_str = (
f", requires_grad={self.requires_grad}" if self.requires_grad else ""
)
if self.grad_fn:
grad_fn_str = f", grad_fn={self.grad_fn}"
return f"NestedTensor(size={self._size}, offsets={self._offsets}{grad_fn_str}, contiguous={self._lengths is None})"
def __reduce_ex__(self, proto):
state = torch._utils._get_obj_state(self)
# SymNodes are not serializable
assert "_size" in state and "_strides" in state
state = dict(state)
del state["_size"]
del state["_strides"]
func = NestedTensor
args = (self._values, self._offsets)
return (torch._tensor._rebuild_from_type_v2, (func, type(self), args, state))
def __tensor_flatten__(self):
ctx = {
"requires_grad": self.requires_grad,
"max_seqlen": self._max_seqlen,
"min_seqlen": self._min_seqlen,
"ragged_idx": self._ragged_idx,
}
inner_tensors = ["_values", "_offsets"]
if self._lengths is not None:
inner_tensors.append("_lengths")
return inner_tensors, ctx
@staticmethod
def __tensor_unflatten__(inner_tensors: Dict, meta, outer_size, outer_stride):
assert len(inner_tensors) >= 2 and len(inner_tensors) <= 3
values = inner_tensors["_values"]
offsets = inner_tensors["_offsets"]
lengths = inner_tensors.get("_lengths", None)
ragged_idx = meta["ragged_idx"]
# Note that we cannot simply check if is_fake(values) because
# during aot autograd, FunctionalTensors are not fake but hold
# symbolic sizes.
ragged_source = offsets if lengths is None else lengths
if has_free_symbols(ragged_source) or has_free_symbols(values):
# Associate offsets or lengths (possibly fake, possibly functionalized)
# with the ragged_size.
ragged_size = outer_size[ragged_idx]
_tensor_symint_registry[ragged_source] = ragged_size
return NestedTensor(
values,
offsets=offsets,
lengths=lengths,
requires_grad=meta["requires_grad"],
_max_seqlen=meta["max_seqlen"],
_min_seqlen=meta["min_seqlen"],
_ragged_idx=ragged_idx,
)
@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
kwargs = {} if kwargs is None else kwargs
# Lazy import to avoid circular dependency
from .ops import lookup_jagged
fn = lookup_jagged(func, *args, **kwargs)
if fn is not None:
return fn(*args, **kwargs)
raise NotImplementedError(func)
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
from .ops import jagged_torch_function
try:
return jagged_torch_function(func, *args, **kwargs)
except NotImplementedError:
pass
with torch._C.DisableTorchFunctionSubclass():
return func(*args, **kwargs)
# Not actually a view!
class ViewBufferFromNested(torch.autograd.Function):
@staticmethod
def forward(ctx, x: NestedTensor): # type: ignore[override]
ctx.save_for_backward(x.offsets())
ctx.max_seqlen = x._max_seqlen
ctx.min_seqlen = x._min_seqlen
ctx._ragged_idx = x._ragged_idx
return x.values()
@staticmethod
def backward(ctx, gO: torch.Tensor): # type: ignore[override]
(offsets,) = ctx.saved_tensors
return NestedTensor(
gO,
offsets=offsets,
_max_seqlen=ctx.max_seqlen,
_min_seqlen=ctx.min_seqlen,
_ragged_idx=ctx._ragged_idx,
)
# Not actually a view!
class ViewNestedFromBuffer(torch.autograd.Function):
@staticmethod
def forward(ctx, values: torch.Tensor, offsets: torch.Tensor, max_seqlen: int, min_seqlen: int): # type: ignore[override]
return NestedTensor(
values.detach(),
offsets=offsets,
_max_seqlen=max_seqlen,
_min_seqlen=min_seqlen,
)
@staticmethod
def backward(ctx, gO: NestedTensor): # type: ignore[override]
return gO.values(), None, None, None
# Not actually a view!
# NOTE: @jbschlosser is working on making it a view
class ViewNonContiguousNestedFromBuffer(torch.autograd.Function):
@staticmethod
def forward(ctx, values: torch.Tensor, offsets: torch.Tensor, lengths: torch.Tensor, max_seqlen: int, min_seqlen: int): # type: ignore[override]
return NestedTensor(
values.detach(),
offsets=offsets,
lengths=lengths,
_max_seqlen=max_seqlen,
_min_seqlen=min_seqlen,
)
@staticmethod
def backward(ctx, gO: NestedTensor): # type: ignore[override]
return gO.values(), None, None, None, None
# Need to make it obvious that users should be passing in offsets
def jagged_from_list(
tensors: List[torch.Tensor],
offsets: Optional[torch.Tensor],
dtype=None,
device=None,
) -> Tuple[NestedTensor, torch.Tensor]:
"""Constructs a NestedTensor backed by jagged layout from a list of tensors"""
if not len(set(t.dtype for t in tensors)) == 1: # noqa: C401
raise RuntimeError(
"When constructing a nested tensor, all tensors in list must have the same dtype"
)
if not len(set(t.device for t in tensors)) == 1: # noqa: C401
raise RuntimeError(
"When constructing a nested tensor, all tensors in list must be on the same device"
)
# Check that the NT is representable by the jagged layout.
# Jagged layout represents (B, *, D_0, D_1, ..., D_N), where the only
# raggedness allowed is for the single dim immediately adjacent to the batch dim.
sizes = [t.shape for t in tensors]
non_first_sizes = [s[1:] for s in sizes]
at_most_first_ragged = all(s == non_first_sizes[0] for s in non_first_sizes)
if not at_most_first_ragged:
raise RuntimeError(
"Cannot represent given tensor list as a nested tensor with the jagged layout. "
"Note that the jagged layout only represents shapes of the form "
"(B, *, D_0, D_1, ..., D_N), with only * allowed to be ragged."
)
# Set properties appropriately.
values = torch.cat(tensors, dim=0)
to_kwargs = {}
if device is not None:
to_kwargs["device"] = device
if dtype is not None:
to_kwargs["dtype"] = dtype
values = values.to(**to_kwargs)
# Calculate jagged offsets if not provided.
if offsets is None:
# Jagged layout specifies that offsets are stored as int64 on the same device as values.
offsets = torch.cat(
[
torch.zeros(1, dtype=torch.int64, device=values.device),
torch.tensor([s[0] for s in sizes], device=values.device).cumsum(dim=0),
]
)
max_seqlen = max([t.shape[0] for t in tensors])
min_seqlen = min([t.shape[0] for t in tensors])
return ViewNestedFromBuffer.apply(values, offsets, max_seqlen, min_seqlen), offsets # type: ignore[call-overload]
def jagged_from_tensor_and_lengths(
tensor: torch.Tensor, starts: torch.Tensor, lengths: torch.Tensor
) -> Tuple[NestedTensor, torch.Tensor, Optional[torch.Tensor]]:
"""Constructs a NestedTensor backed by jagged layout from a tensor, starts of sequences, and sequence lengths"""
batch_size = tensor.shape[0]
if is_expandable_to(starts.shape, (batch_size,)) and is_expandable_to(
lengths.shape, (batch_size,)
):
start_list = starts.expand(batch_size)
length_list = lengths.expand(batch_size)
else:
raise RuntimeError(
"When constructing a jagged nested tensor using narrow(), "
"your start and length must be Tensors that broadcast to input.shape[0]"
)
# Calculate jagged offsets
assert (
len(tensor.shape) >= 2
), "tensor must at least be 2D for the nested narrow op to work"
max_seq_len = tensor.shape[1]
offset_lengths = max_seq_len * torch.arange(
0, batch_size, dtype=torch.int64, device=tensor.device
)
# Jagged layout specifies that offsets are stored as int64 on the same device as values.
offsets = torch.cat(
[
start_list + offset_lengths,
(start_list[-1] + offset_lengths[-1] + length_list[-1]).unsqueeze(0),
]
)
# Reshape buffer to flatten the 1st and 2nd dimension (view used to enforce non-copy)
if len(tensor.shape) > 2:
values = tensor.view(-1, *tensor.shape[2:])
else:
values = tensor.view(-1)
# Check if offsets and lengths make it possibly contiguous and return a regular NT
is_contiguous = True
orig_dim = tensor.shape[1]
if torch.any(length_list[1:-1].ne(orig_dim)):
is_contiguous = False
if torch.any(offsets[1:-2].diff().ne(orig_dim)):
is_contiguous = False
if offsets[0] + length_list[0] != orig_dim:
is_contiguous = False
actual_max_seqlen = int(torch.max(lengths).item())
min_seqlen = int(torch.min(lengths).item())
if is_contiguous:
return (
ViewNestedFromBuffer.apply(
values[offsets[0] : offsets[-1]],
offsets - offsets[0],
actual_max_seqlen,
min_seqlen,
),
offsets,
None,
)
return (
ViewNonContiguousNestedFromBuffer.apply(
values, offsets, length_list, actual_max_seqlen, min_seqlen
),
offsets,
length_list,
) # type: ignore[call-overload]
def buffer_from_jagged(jagged):
return ViewBufferFromNested.apply(jagged)