Expand dynamic dims support for traceable subclasses (#114311)

Continuation of #112185, following the design in this [doc](https://docs.google.com/document/d/1ipSxcTzEMMOAPvxP-YJlD5JBZZmIGgh8Q34ixtOUCRo).

Summary:
* Introduce `SubclassSymbolicPolicy` containing separate dynamic dim / constraint policies for the outer and inner tensors
    * Expand the automatic dynamic algorithm to recurse into inner tensors and produce one of these for a subclass instance
    * Maintain legacy behavior for subclasses by recursively calling `mark_dynamic()` on inner tensors *of the same dim as outer* when `mark_dynamic(outer, ...)` is called
    * Addresses this: 6a86cf00ad/torch/_dynamo/variables/builder.py (L1750)
* Add `outer_size` and `outer_stride` arguments to `__tensor_unflatten__()` so that you can find out what symbols were allocated for the outer size / stride (you are expected to return a tensor that compares equal to the outer symbols)
    * Signatures now:
    ```python
    # attrs is a list of inner tensor attributes on x; inner_tensor = getattr(x, attr)
    # ctx is anything useful for rebuilding the class we want to guard on
    attrs, ctx = x.__tensor_flatten__()
    ...
    # inner_tensors is a dict of {attr -> tensor}
    # ctx is taken unmodified from flattening and (eventually) guarded on
    # outer_size is the expected size of the output; possibly symbolic
    # outer_stride is the expected strides of the output; possibly symbolic
    y = MySubclass.__tensor_unflatten__(inner_tensors, ctx, outer_size, outer_stride)

    # at the __tensor_unflatten__() call-site in PT2, we assert y.shape == outer_size and y.stride() == outer_stride
    # the assert simplifies symbols when there are relationships between outer and inner symbols
    ```
    * Size info needed for `NestedTensor` at least, stride info needed for `DTensor` at least
    * Punting on `outer_storage_offset` because storage_offset handling is horribly broken in PT2 right now
* ~~Add new `__tensor_mark_dynamic__()` to allow overriding the behavior of mark_dynamic on a per-subclass basis~~ (booted to future work)
* ~~Add guards for tensor subclasses by calling `__tensor_flatten__()` in the guard to test equality on `ctx`~~
    * Now handled in #114469
* Next PR: add TENSOR_MATCH guards on inner tensors

Pull Request resolved: https://github.com/pytorch/pytorch/pull/114311
Approved by: https://github.com/ezyang, https://github.com/drisspg, https://github.com/voznesenskym, https://github.com/bdhirsh
This commit is contained in:
Joel Schlosser
2023-12-05 12:50:59 -05:00
committed by PyTorch MergeBot
parent 259a99669d
commit 22704426c3
21 changed files with 453 additions and 258 deletions

View File

@ -130,6 +130,10 @@ class NestedTensor(torch.Tensor):
)
)
# collapsed ragged dim must always be dynamic
torch._dynamo.mark_dynamic(self, self._ragged_idx)
torch._dynamo.mark_dynamic(self._values, self._ragged_idx - 1)
def values(self):
return self._values
@ -164,7 +168,6 @@ class NestedTensor(torch.Tensor):
def __tensor_flatten__(self):
ctx = {
"requires_grad": self.requires_grad,
"ragged_size": self._size[self._ragged_idx],
"max_seqlen": self._max_seqlen,
"min_seqlen": self._min_seqlen,
"ragged_idx": self._ragged_idx,
@ -175,37 +178,13 @@ class NestedTensor(torch.Tensor):
return inner_tensors, ctx
@staticmethod
def __tensor_unflatten__(inner_tensors: Dict, meta):
def __tensor_unflatten__(inner_tensors: Dict, meta, outer_size, outer_stride):
assert len(inner_tensors) >= 2 and len(inner_tensors) <= 3
values = inner_tensors["_values"]
offsets = inner_tensors["_offsets"]
lengths = inner_tensors.get("_lengths", None)
ragged_idx = meta["ragged_idx"]
# NOTE [ Storing symbolic values as plain attributes on subclasses ]
#
# When a subclass like NestedTensor stores a "size-like" value (which
# can either be Symintified or not) into meta, it's responsible for:
#
# (1) Propagating that symint during torch dispatch when performing
# operations, i.e. torch dispatch plays the role of a meta kernel.
#
# (2) Facilitating the behavior around symbolic -> non-symbolic
# conversions and vice versa, see below.
#
# [ non-symbolic -> symbolic (fakification in meta_utils) ]
#
# __tensor_unflatten__ is passed symbolic dense tensors and meta from
# non-symbolic subclasses. In this case, the subclass is responsible for
# intercepting meta["ragged_size"] for example and replacing it with the
# symintified version.
#
# [ symbolic -> non-symbolic ]
#
# __tensor_unflatten__ is passed non-symbolic dense tensors and with
# meta extracted from fake subclasses. In this case the subclass gets
# propagated the meta["ragged_size"] which is still a symint and the
# subclass is responsible for making sure that the symint doesn't leak.
#
# Note that we cannot simply check if is_fake(values) because
# during aot autograd, FunctionalTensors are not fake but hold
# symbolic sizes.
@ -213,7 +192,8 @@ class NestedTensor(torch.Tensor):
if has_free_symbols(ragged_source) or has_free_symbols(values):
# Associate offsets or lengths (possibly fake, possibly functionalized)
# with the ragged_size.
_tensor_symint_registry[ragged_source] = meta["ragged_size"]
ragged_size = outer_size[ragged_idx]
_tensor_symint_registry[ragged_source] = ragged_size
return NestedTensor(
values,
@ -222,7 +202,7 @@ class NestedTensor(torch.Tensor):
requires_grad=meta["requires_grad"],
_max_seqlen=meta["max_seqlen"],
_min_seqlen=meta["min_seqlen"],
_ragged_idx=meta["ragged_idx"],
_ragged_idx=ragged_idx,
)
@classmethod