mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-26 08:34:52 +08:00
Expand dynamic dims support for traceable subclasses (#114311)
Continuation of #112185, following the design in this [doc](https://docs.google.com/document/d/1ipSxcTzEMMOAPvxP-YJlD5JBZZmIGgh8Q34ixtOUCRo).
Summary:
* Introduce `SubclassSymbolicPolicy` containing separate dynamic dim / constraint policies for the outer and inner tensors
* Expand the automatic dynamic algorithm to recurse into inner tensors and produce one of these for a subclass instance
* Maintain legacy behavior for subclasses by recursively calling `mark_dynamic()` on inner tensors *of the same dim as outer* when `mark_dynamic(outer, ...)` is called
* Addresses this: 6a86cf00ad/torch/_dynamo/variables/builder.py (L1750)
* Add `outer_size` and `outer_stride` arguments to `__tensor_unflatten__()` so that you can find out what symbols were allocated for the outer size / stride (you are expected to return a tensor that compares equal to the outer symbols)
* Signatures now:
```python
# attrs is a list of inner tensor attributes on x; inner_tensor = getattr(x, attr)
# ctx is anything useful for rebuilding the class we want to guard on
attrs, ctx = x.__tensor_flatten__()
...
# inner_tensors is a dict of {attr -> tensor}
# ctx is taken unmodified from flattening and (eventually) guarded on
# outer_size is the expected size of the output; possibly symbolic
# outer_stride is the expected strides of the output; possibly symbolic
y = MySubclass.__tensor_unflatten__(inner_tensors, ctx, outer_size, outer_stride)
# at the __tensor_unflatten__() call-site in PT2, we assert y.shape == outer_size and y.stride() == outer_stride
# the assert simplifies symbols when there are relationships between outer and inner symbols
```
* Size info needed for `NestedTensor` at least, stride info needed for `DTensor` at least
* Punting on `outer_storage_offset` because storage_offset handling is horribly broken in PT2 right now
* ~~Add new `__tensor_mark_dynamic__()` to allow overriding the behavior of mark_dynamic on a per-subclass basis~~ (booted to future work)
* ~~Add guards for tensor subclasses by calling `__tensor_flatten__()` in the guard to test equality on `ctx`~~
* Now handled in #114469
* Next PR: add TENSOR_MATCH guards on inner tensors
Pull Request resolved: https://github.com/pytorch/pytorch/pull/114311
Approved by: https://github.com/ezyang, https://github.com/drisspg, https://github.com/voznesenskym, https://github.com/bdhirsh
This commit is contained in:
committed by
PyTorch MergeBot
parent
259a99669d
commit
22704426c3
@ -130,6 +130,10 @@ class NestedTensor(torch.Tensor):
|
||||
)
|
||||
)
|
||||
|
||||
# collapsed ragged dim must always be dynamic
|
||||
torch._dynamo.mark_dynamic(self, self._ragged_idx)
|
||||
torch._dynamo.mark_dynamic(self._values, self._ragged_idx - 1)
|
||||
|
||||
def values(self):
|
||||
return self._values
|
||||
|
||||
@ -164,7 +168,6 @@ class NestedTensor(torch.Tensor):
|
||||
def __tensor_flatten__(self):
|
||||
ctx = {
|
||||
"requires_grad": self.requires_grad,
|
||||
"ragged_size": self._size[self._ragged_idx],
|
||||
"max_seqlen": self._max_seqlen,
|
||||
"min_seqlen": self._min_seqlen,
|
||||
"ragged_idx": self._ragged_idx,
|
||||
@ -175,37 +178,13 @@ class NestedTensor(torch.Tensor):
|
||||
return inner_tensors, ctx
|
||||
|
||||
@staticmethod
|
||||
def __tensor_unflatten__(inner_tensors: Dict, meta):
|
||||
def __tensor_unflatten__(inner_tensors: Dict, meta, outer_size, outer_stride):
|
||||
assert len(inner_tensors) >= 2 and len(inner_tensors) <= 3
|
||||
values = inner_tensors["_values"]
|
||||
offsets = inner_tensors["_offsets"]
|
||||
lengths = inner_tensors.get("_lengths", None)
|
||||
ragged_idx = meta["ragged_idx"]
|
||||
|
||||
# NOTE [ Storing symbolic values as plain attributes on subclasses ]
|
||||
#
|
||||
# When a subclass like NestedTensor stores a "size-like" value (which
|
||||
# can either be Symintified or not) into meta, it's responsible for:
|
||||
#
|
||||
# (1) Propagating that symint during torch dispatch when performing
|
||||
# operations, i.e. torch dispatch plays the role of a meta kernel.
|
||||
#
|
||||
# (2) Facilitating the behavior around symbolic -> non-symbolic
|
||||
# conversions and vice versa, see below.
|
||||
#
|
||||
# [ non-symbolic -> symbolic (fakification in meta_utils) ]
|
||||
#
|
||||
# __tensor_unflatten__ is passed symbolic dense tensors and meta from
|
||||
# non-symbolic subclasses. In this case, the subclass is responsible for
|
||||
# intercepting meta["ragged_size"] for example and replacing it with the
|
||||
# symintified version.
|
||||
#
|
||||
# [ symbolic -> non-symbolic ]
|
||||
#
|
||||
# __tensor_unflatten__ is passed non-symbolic dense tensors and with
|
||||
# meta extracted from fake subclasses. In this case the subclass gets
|
||||
# propagated the meta["ragged_size"] which is still a symint and the
|
||||
# subclass is responsible for making sure that the symint doesn't leak.
|
||||
#
|
||||
# Note that we cannot simply check if is_fake(values) because
|
||||
# during aot autograd, FunctionalTensors are not fake but hold
|
||||
# symbolic sizes.
|
||||
@ -213,7 +192,8 @@ class NestedTensor(torch.Tensor):
|
||||
if has_free_symbols(ragged_source) or has_free_symbols(values):
|
||||
# Associate offsets or lengths (possibly fake, possibly functionalized)
|
||||
# with the ragged_size.
|
||||
_tensor_symint_registry[ragged_source] = meta["ragged_size"]
|
||||
ragged_size = outer_size[ragged_idx]
|
||||
_tensor_symint_registry[ragged_source] = ragged_size
|
||||
|
||||
return NestedTensor(
|
||||
values,
|
||||
@ -222,7 +202,7 @@ class NestedTensor(torch.Tensor):
|
||||
requires_grad=meta["requires_grad"],
|
||||
_max_seqlen=meta["max_seqlen"],
|
||||
_min_seqlen=meta["min_seqlen"],
|
||||
_ragged_idx=meta["ragged_idx"],
|
||||
_ragged_idx=ragged_idx,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
|
||||
Reference in New Issue
Block a user