mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-01 04:54:55 +08:00
Expand dynamic dims support for traceable subclasses (#114311)
Continuation of #112185, following the design in this [doc](https://docs.google.com/document/d/1ipSxcTzEMMOAPvxP-YJlD5JBZZmIGgh8Q34ixtOUCRo).
Summary:
* Introduce `SubclassSymbolicPolicy` containing separate dynamic dim / constraint policies for the outer and inner tensors
* Expand the automatic dynamic algorithm to recurse into inner tensors and produce one of these for a subclass instance
* Maintain legacy behavior for subclasses by recursively calling `mark_dynamic()` on inner tensors *of the same dim as outer* when `mark_dynamic(outer, ...)` is called
* Addresses this: 6a86cf00ad/torch/_dynamo/variables/builder.py (L1750)
* Add `outer_size` and `outer_stride` arguments to `__tensor_unflatten__()` so that you can find out what symbols were allocated for the outer size / stride (you are expected to return a tensor that compares equal to the outer symbols)
* Signatures now:
```python
# attrs is a list of inner tensor attributes on x; inner_tensor = getattr(x, attr)
# ctx is anything useful for rebuilding the class we want to guard on
attrs, ctx = x.__tensor_flatten__()
...
# inner_tensors is a dict of {attr -> tensor}
# ctx is taken unmodified from flattening and (eventually) guarded on
# outer_size is the expected size of the output; possibly symbolic
# outer_stride is the expected strides of the output; possibly symbolic
y = MySubclass.__tensor_unflatten__(inner_tensors, ctx, outer_size, outer_stride)
# at the __tensor_unflatten__() call-site in PT2, we assert y.shape == outer_size and y.stride() == outer_stride
# the assert simplifies symbols when there are relationships between outer and inner symbols
```
* Size info needed for `NestedTensor` at least, stride info needed for `DTensor` at least
* Punting on `outer_storage_offset` because storage_offset handling is horribly broken in PT2 right now
* ~~Add new `__tensor_mark_dynamic__()` to allow overriding the behavior of mark_dynamic on a per-subclass basis~~ (booted to future work)
* ~~Add guards for tensor subclasses by calling `__tensor_flatten__()` in the guard to test equality on `ctx`~~
* Now handled in #114469
* Next PR: add TENSOR_MATCH guards on inner tensors
Pull Request resolved: https://github.com/pytorch/pytorch/pull/114311
Approved by: https://github.com/ezyang, https://github.com/drisspg, https://github.com/voznesenskym, https://github.com/bdhirsh
This commit is contained in:
committed by
PyTorch MergeBot
parent
259a99669d
commit
22704426c3
@ -969,6 +969,7 @@ coverage_ignore_functions = [
|
|||||||
"is_concrete_int",
|
"is_concrete_int",
|
||||||
"is_contiguous",
|
"is_contiguous",
|
||||||
"is_non_overlapping_and_dense_indicator",
|
"is_non_overlapping_and_dense_indicator",
|
||||||
|
"is_singleton",
|
||||||
"is_symbol_binding_fx_node",
|
"is_symbol_binding_fx_node",
|
||||||
"is_symbolic",
|
"is_symbolic",
|
||||||
"parallel_and",
|
"parallel_and",
|
||||||
@ -2869,6 +2870,7 @@ coverage_ignore_classes = [
|
|||||||
"SymbolicContext",
|
"SymbolicContext",
|
||||||
"StatelessSymbolicContext",
|
"StatelessSymbolicContext",
|
||||||
"StatefulSymbolicContext",
|
"StatefulSymbolicContext",
|
||||||
|
"SubclassSymbolicContext",
|
||||||
# torch.fx.experimental.unification.match
|
# torch.fx.experimental.unification.match
|
||||||
"Dispatcher",
|
"Dispatcher",
|
||||||
"VarDispatcher",
|
"VarDispatcher",
|
||||||
|
|||||||
@ -696,7 +696,7 @@ class AutogradFunctionTests(torch._dynamo.test_case.TestCase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def __tensor_unflatten__(tensors, metadatas):
|
def __tensor_unflatten__(tensors, metadatas, outer_size, outer_stride):
|
||||||
return FooTensor(tensors["_data"], metadatas[0], metadatas[1])
|
return FooTensor(tensors["_data"], metadatas[0], metadatas[1])
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|||||||
@ -63,6 +63,52 @@ class SigmoidToExpSubclass(torch.Tensor):
|
|||||||
return super().__torch_function__(func, types, args, kwargs)
|
return super().__torch_function__(func, types, args, kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
# Wrapper subclass with two inner tensors: data and scale
|
||||||
|
# data has same shape as outer, and scale has single dim size
|
||||||
|
class ScaledTensor(torch.Tensor):
|
||||||
|
def __new__(
|
||||||
|
cls,
|
||||||
|
data: torch.Tensor,
|
||||||
|
scale: torch.Tensor,
|
||||||
|
):
|
||||||
|
return torch.Tensor._make_wrapper_subclass(
|
||||||
|
cls,
|
||||||
|
data.size(),
|
||||||
|
strides=data.stride(),
|
||||||
|
storage_offset=data.storage_offset(),
|
||||||
|
dtype=data.dtype,
|
||||||
|
layout=data.layout,
|
||||||
|
requires_grad=data.requires_grad,
|
||||||
|
device=data.device,
|
||||||
|
)
|
||||||
|
|
||||||
|
def __init__(self, data: torch.Tensor, scale: torch.Tensor):
|
||||||
|
self._data = data
|
||||||
|
self._scale = scale
|
||||||
|
|
||||||
|
def __tensor_flatten__(self):
|
||||||
|
ctx = {}
|
||||||
|
return ["_data", "_scale"], ctx
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def __tensor_unflatten__(inner_tensors, metadata, outer_size, outer_stride):
|
||||||
|
assert len(inner_tensors) == 2
|
||||||
|
return ScaledTensor(inner_tensors["_data"], inner_tensors["_scale"])
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def __torch_dispatch__(cls, func, types, args, kwargs=None):
|
||||||
|
scaled_tensor = args[0]
|
||||||
|
out = func(scaled_tensor._data, *args[1:], **kwargs)
|
||||||
|
return ScaledTensor(out, scaled_tensor._scale)
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f"{self._data.__repr__()}\n{self._scale.__repr__()}"
|
||||||
|
|
||||||
|
|
||||||
|
def func(a):
|
||||||
|
return a.sin()
|
||||||
|
|
||||||
|
|
||||||
class EagerRecordGraphAndInputs:
|
class EagerRecordGraphAndInputs:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.graphs = []
|
self.graphs = []
|
||||||
@ -77,6 +123,21 @@ class EagerRecordGraphAndInputs:
|
|||||||
GLOBAL_TEST_SUBCLASSES = {MockSubclass, DummyNDim, SigmoidToExpSubclass}
|
GLOBAL_TEST_SUBCLASSES = {MockSubclass, DummyNDim, SigmoidToExpSubclass}
|
||||||
|
|
||||||
|
|
||||||
|
# Returns True if the function recompiles between inputs1 and inputs2 with the
|
||||||
|
# specified dynamic setting.
|
||||||
|
def _recompiles_for_inputs(fn, inputs1, inputs2, dynamic=True):
|
||||||
|
compile_count = [0]
|
||||||
|
|
||||||
|
def counter(gm, example_inputs):
|
||||||
|
compile_count[0] += 1
|
||||||
|
return gm
|
||||||
|
|
||||||
|
compiled_f = torch.compile(fn, fullgraph=True, backend=counter, dynamic=dynamic)
|
||||||
|
compiled_f(*inputs1)
|
||||||
|
compiled_f(*inputs2)
|
||||||
|
return compile_count[0] > 1
|
||||||
|
|
||||||
|
|
||||||
class SubclassTests(torch._dynamo.test_case.TestCase):
|
class SubclassTests(torch._dynamo.test_case.TestCase):
|
||||||
@classmethod
|
@classmethod
|
||||||
def setUpClass(cls):
|
def setUpClass(cls):
|
||||||
@ -608,7 +669,7 @@ class GraphModule(torch.nn.Module):
|
|||||||
return ["inner_elem"], None
|
return ["inner_elem"], None
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def __tensor_unflatten__(inner_tensors, _):
|
def __tensor_unflatten__(inner_tensors, _, outer_size, outer_stride):
|
||||||
return DoubleSizeMaybeAddGeThreeTensor(inner_tensors["inner_elem"])
|
return DoubleSizeMaybeAddGeThreeTensor(inner_tensors["inner_elem"])
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
@ -688,6 +749,42 @@ class GraphModule(torch.nn.Module):
|
|||||||
self.assertEqual(lower_bound_str, expected_lower_bound)
|
self.assertEqual(lower_bound_str, expected_lower_bound)
|
||||||
self.assertEqual(upper_bound_str, expected_upper_bound)
|
self.assertEqual(upper_bound_str, expected_upper_bound)
|
||||||
|
|
||||||
|
def test_wrapper_subclass_with_same_sized_inner_tensor(self):
|
||||||
|
# shouldn't recompile for different sizes when dynamic=True
|
||||||
|
sub1 = ScaledTensor(torch.randn(2, 4), torch.randn(6))
|
||||||
|
sub2 = ScaledTensor(torch.randn(3, 5), torch.randn(7))
|
||||||
|
self.assertFalse(_recompiles_for_inputs(func, (sub1,), (sub2,), dynamic=True))
|
||||||
|
|
||||||
|
# should recompile for different data size when dynamic=False
|
||||||
|
sub1 = ScaledTensor(torch.randn(2, 4), torch.randn(6))
|
||||||
|
sub2 = ScaledTensor(torch.randn(3, 5), torch.randn(6))
|
||||||
|
self.assertTrue(_recompiles_for_inputs(func, (sub1,), (sub2,), dynamic=False))
|
||||||
|
|
||||||
|
# avoid recompile using manual mark_dynamic() for different data size
|
||||||
|
sub1 = ScaledTensor(torch.randn(2, 4), torch.randn(6))
|
||||||
|
# NB: mark_dynamic() on outer tensor should translate to inner tensors of the same size
|
||||||
|
torch._dynamo.mark_dynamic(sub1, 0)
|
||||||
|
torch._dynamo.mark_dynamic(sub1, 1)
|
||||||
|
sub2 = ScaledTensor(torch.randn(3, 5), torch.randn(6))
|
||||||
|
self.assertFalse(_recompiles_for_inputs(func, (sub1,), (sub2,), dynamic=False))
|
||||||
|
|
||||||
|
# Broken because we don't guard properly on inner tensors yet.
|
||||||
|
# TODO: Enable this when we do
|
||||||
|
@unittest.expectedFailure
|
||||||
|
def test_wrapper_subclass_with_differently_sized_inner_tensor(self):
|
||||||
|
# should recompile for different scale size when dynamic=False
|
||||||
|
sub1 = ScaledTensor(torch.randn(2, 4), torch.randn(3))
|
||||||
|
sub2 = ScaledTensor(torch.randn(2, 4), torch.randn(5))
|
||||||
|
self.assertTrue(_recompiles_for_inputs(func, (sub1,), (sub2,), dynamic=False))
|
||||||
|
|
||||||
|
# still recompiles using manual mark_dynamic() on outer for different scale size
|
||||||
|
sub1 = ScaledTensor(torch.randn(2, 4), torch.randn(3))
|
||||||
|
# NB: mark_dynamic() on outer tensor doesn't translate to inner tensors of different size
|
||||||
|
torch._dynamo.mark_dynamic(sub1, 0)
|
||||||
|
torch._dynamo.mark_dynamic(sub1, 1)
|
||||||
|
sub2 = ScaledTensor(torch.randn(2, 4), torch.randn(5))
|
||||||
|
self.assertTrue(_recompiles_for_inputs(func, (sub1,), (sub2,), dynamic=False))
|
||||||
|
|
||||||
def test_recompile_with_symbool_inputs(self):
|
def test_recompile_with_symbool_inputs(self):
|
||||||
def f(pred: bool):
|
def f(pred: bool):
|
||||||
if pred:
|
if pred:
|
||||||
@ -832,18 +929,9 @@ class TestNestedTensor(torch._dynamo.test_case.TestCase):
|
|||||||
)
|
)
|
||||||
return jagged_from_tensor_and_lengths(values_tensor, starts, lengths)
|
return jagged_from_tensor_and_lengths(values_tensor, starts, lengths)
|
||||||
|
|
||||||
def _check_recompiles(self, fn, inputs1, inputs2, recompiles):
|
def _check_recompiles(self, fn, inputs1, inputs2, expected_recompiles):
|
||||||
compile_count = [0]
|
actual_recompiles = _recompiles_for_inputs(fn, inputs1, inputs2)
|
||||||
|
self.assertEqual(actual_recompiles, expected_recompiles)
|
||||||
def counter(gm, example_inputs):
|
|
||||||
compile_count[0] += 1
|
|
||||||
return gm
|
|
||||||
|
|
||||||
compiled_f = torch.compile(fn, fullgraph=True, backend=counter, dynamic=True)
|
|
||||||
out = compiled_f(*inputs1)
|
|
||||||
self.assertEqual(compile_count[0], 1)
|
|
||||||
out = compiled_f(*inputs2)
|
|
||||||
self.assertEqual(compile_count[0], 2 if recompiles else 1)
|
|
||||||
|
|
||||||
def test_unary_does_not_recompile(self):
|
def test_unary_does_not_recompile(self):
|
||||||
nt1, _ = self._get_jagged_tensor(((2, 3, 4), 3), None)
|
nt1, _ = self._get_jagged_tensor(((2, 3, 4), 3), None)
|
||||||
@ -857,9 +945,11 @@ class TestNestedTensor(torch._dynamo.test_case.TestCase):
|
|||||||
else:
|
else:
|
||||||
return nt1.sin()
|
return nt1.sin()
|
||||||
|
|
||||||
# Basic binary
|
# NB: If we have shape e.g. (3, j0, 3), duck sizing will give us (s0, s1, s0).
|
||||||
nt1, offsets = self._get_jagged_tensor(((2, 3, 4), 3), None)
|
# This causes a recompile later on when it realizes the batch and last dim
|
||||||
nt2, _ = self._get_jagged_tensor(((2, 3, 4), 3), offsets)
|
# should not always be equal. To avoid that, we use (3, j0, 5) here.
|
||||||
|
nt1, offsets = self._get_jagged_tensor(((2, 3, 4), 5), None)
|
||||||
|
nt2, _ = self._get_jagged_tensor(((2, 3, 4), 5), offsets)
|
||||||
nt3, offsets = self._get_jagged_tensor(((3, 4, 5), 4), None)
|
nt3, offsets = self._get_jagged_tensor(((3, 4, 5), 4), None)
|
||||||
nt4, _ = self._get_jagged_tensor(((3, 4, 5), 4), offsets)
|
nt4, _ = self._get_jagged_tensor(((3, 4, 5), 4), offsets)
|
||||||
self._check_recompiles(binary, (nt1, nt2), (nt3, nt4), False)
|
self._check_recompiles(binary, (nt1, nt2), (nt3, nt4), False)
|
||||||
@ -872,9 +962,9 @@ class TestNestedTensor(torch._dynamo.test_case.TestCase):
|
|||||||
return nt1.sin()
|
return nt1.sin()
|
||||||
|
|
||||||
# Binary recompiles because singleton ints no longer match
|
# Binary recompiles because singleton ints no longer match
|
||||||
nt1, offsets = self._get_jagged_tensor(((2, 3, 4), 3), None)
|
nt1, offsets = self._get_jagged_tensor(((2, 3, 4), 5), None)
|
||||||
nt2, _ = self._get_jagged_tensor(((2, 3, 4), 3), offsets)
|
nt2, _ = self._get_jagged_tensor(((2, 3, 4), 5), offsets)
|
||||||
nt3, _ = self._get_jagged_tensor(((2, 3, 4), 3), None)
|
nt3, _ = self._get_jagged_tensor(((2, 3, 4), 5), None)
|
||||||
self._check_recompiles(binary, (nt1, nt2), (nt1, nt3), True)
|
self._check_recompiles(binary, (nt1, nt2), (nt1, nt3), True)
|
||||||
|
|
||||||
# TODO: cannot parametrize this test class with device for some reason
|
# TODO: cannot parametrize this test class with device for some reason
|
||||||
@ -909,7 +999,10 @@ class TestNestedTensor(torch._dynamo.test_case.TestCase):
|
|||||||
self._test_autograd("inductor")
|
self._test_autograd("inductor")
|
||||||
|
|
||||||
def test_unbind(self):
|
def test_unbind(self):
|
||||||
nt, _ = self._get_jagged_tensor(((2, 3, 4), 3), None)
|
# NB: If we have shape e.g. (3, j0, 3), duck sizing will give us (s0, s1, s0).
|
||||||
|
# This causes a recompile later on when it realizes the batch and last dim
|
||||||
|
# should not always be equal. To avoid that, we use (3, j0, 5) here.
|
||||||
|
nt, _ = self._get_jagged_tensor(((2, 3, 4), 5), None)
|
||||||
nt2, _ = self._get_jagged_tensor(((2, 3, 5), 2), None)
|
nt2, _ = self._get_jagged_tensor(((2, 3, 5), 2), None)
|
||||||
nt3, _ = self._get_jagged_tensor(((2, 3, 4, 5), 3), None)
|
nt3, _ = self._get_jagged_tensor(((2, 3, 4, 5), 3), None)
|
||||||
|
|
||||||
|
|||||||
@ -32,6 +32,7 @@ from torch.nn.utils.rnn import PackedSequence
|
|||||||
from torch.testing._internal.common_device_type import instantiate_device_type_tests, toleranceOverride, tol
|
from torch.testing._internal.common_device_type import instantiate_device_type_tests, toleranceOverride, tol
|
||||||
from torch.testing._internal.common_methods_invocations import op_db
|
from torch.testing._internal.common_methods_invocations import op_db
|
||||||
from torch.testing._internal.common_modules import module_db, modules
|
from torch.testing._internal.common_modules import module_db, modules
|
||||||
|
from torch.testing._internal.common_utils import parametrize, instantiate_parametrized_tests
|
||||||
from torch.testing._internal.control_flow_opinfo_db import control_flow_opinfo_db
|
from torch.testing._internal.control_flow_opinfo_db import control_flow_opinfo_db
|
||||||
from torch.testing._internal.optests import _test_aot_autograd_forwards_backwards_helper, aot_autograd_check
|
from torch.testing._internal.optests import _test_aot_autograd_forwards_backwards_helper, aot_autograd_check
|
||||||
from functorch import (
|
from functorch import (
|
||||||
@ -1330,9 +1331,6 @@ def forward(self, primals_1):
|
|||||||
return [(x,), (x,)]
|
return [(x,), (x,)]
|
||||||
|
|
||||||
self.verify_aot_autograd(f, partial(inp_callable, req_grad=False), test_mutation=True)
|
self.verify_aot_autograd(f, partial(inp_callable, req_grad=False), test_mutation=True)
|
||||||
self.verify_aot_autograd(f, partial(inp_callable, req_grad=False), test_mutation=True, make_inputs_subclasses=True)
|
|
||||||
with self.assertRaisesRegex(AssertionError, "which is currently unsupported in the subclass use case"):
|
|
||||||
self.verify_aot_autograd(f, partial(inp_callable, req_grad=True), test_mutation=True, make_inputs_subclasses=True)
|
|
||||||
fw_graph = self.verify_aot_autograd(f, partial(inp_callable, req_grad=True), test_mutation=True)
|
fw_graph = self.verify_aot_autograd(f, partial(inp_callable, req_grad=True), test_mutation=True)
|
||||||
# TODO: make this test run with dynamic shapes so it is more meaningful
|
# TODO: make this test run with dynamic shapes so it is more meaningful
|
||||||
# metadata output order: (a_updated_meta, out1_meta, out2_meta, out3_meta)
|
# metadata output order: (a_updated_meta, out1_meta, out2_meta, out3_meta)
|
||||||
@ -1346,6 +1344,21 @@ def forward(self, primals_1):
|
|||||||
unsqueeze = torch.ops.aten.unsqueeze.default(transpose, 0)
|
unsqueeze = torch.ops.aten.unsqueeze.default(transpose, 0)
|
||||||
return [transpose, squeeze, transpose_1, unsqueeze, mul]""")
|
return [transpose, squeeze, transpose_1, unsqueeze, mul]""")
|
||||||
|
|
||||||
|
@parametrize("req_grad", [False, True])
|
||||||
|
def test_subclass_metadata_mutation(self, req_grad):
|
||||||
|
def f(a):
|
||||||
|
a.transpose_(1, 0)
|
||||||
|
tmp = a.mul(2)
|
||||||
|
return tmp.transpose(1, 0)
|
||||||
|
|
||||||
|
def inp_callable(req_grad):
|
||||||
|
x = torch.ones(1, 2, 4, requires_grad=req_grad).clone()
|
||||||
|
return [(x,), (x,)]
|
||||||
|
|
||||||
|
# See https://github.com/pytorch/pytorch/issues/114975
|
||||||
|
with self.assertRaisesRegex(RuntimeError, "Metadata mutations are currently not allowed on tensor subclasses"):
|
||||||
|
self.verify_aot_autograd(f, partial(inp_callable, req_grad=req_grad), test_mutation=True, make_inputs_subclasses=True)
|
||||||
|
|
||||||
def test_input_data_and_metadata_mutation(self):
|
def test_input_data_and_metadata_mutation(self):
|
||||||
def f(a):
|
def f(a):
|
||||||
a.t_()
|
a.t_()
|
||||||
@ -1879,7 +1892,7 @@ def forward(self, primals_1, primals_2):
|
|||||||
|
|
||||||
self.verify_aot_autograd(f, partial(inp_callable, req_grad=False), test_mutation=True)
|
self.verify_aot_autograd(f, partial(inp_callable, req_grad=False), test_mutation=True)
|
||||||
|
|
||||||
with self.assertRaisesRegex(RuntimeError, "is a tensor subclass. This is not supported today"):
|
with self.assertRaisesRegex(RuntimeError, "Metadata mutations are currently not allowed on tensor subclasses"):
|
||||||
self.verify_aot_autograd(f, partial(inp_callable, req_grad=False), test_mutation=True, make_inputs_subclasses=True)
|
self.verify_aot_autograd(f, partial(inp_callable, req_grad=False), test_mutation=True, make_inputs_subclasses=True)
|
||||||
|
|
||||||
fw_graph = self.verify_aot_autograd(f, partial(inp_callable, req_grad=True), test_mutation=True)
|
fw_graph = self.verify_aot_autograd(f, partial(inp_callable, req_grad=True), test_mutation=True)
|
||||||
@ -3606,6 +3619,9 @@ def forward(self, tangents_1, tangents_2):
|
|||||||
self.assertEqual(b_ref_base.grad.a, b_test_base.grad.a)
|
self.assertEqual(b_ref_base.grad.a, b_test_base.grad.a)
|
||||||
self.assertEqual(b_ref_base.grad.b, b_test_base.grad.b)
|
self.assertEqual(b_ref_base.grad.b, b_test_base.grad.b)
|
||||||
|
|
||||||
|
# NB: Metadata mutation for subclasses is currently broken and disabled
|
||||||
|
# See https://github.com/pytorch/pytorch/issues/114975
|
||||||
|
@unittest.expectedFailure
|
||||||
def test_aot_dispatch_input_metadata_mutation(self):
|
def test_aot_dispatch_input_metadata_mutation(self):
|
||||||
def f(a, b):
|
def f(a, b):
|
||||||
a.t_()
|
a.t_()
|
||||||
@ -3651,6 +3667,9 @@ def forward(self, tangents_1, tangents_2):
|
|||||||
self.assertEqual(b_ref_base.grad.a, b_test_base.grad.a)
|
self.assertEqual(b_ref_base.grad.a, b_test_base.grad.a)
|
||||||
self.assertEqual(b_ref_base.grad.b, b_test_base.grad.b)
|
self.assertEqual(b_ref_base.grad.b, b_test_base.grad.b)
|
||||||
|
|
||||||
|
# NB: Metadata mutation for subclasses is currently broken and disabled
|
||||||
|
# See https://github.com/pytorch/pytorch/issues/114975
|
||||||
|
@unittest.expectedFailure
|
||||||
def test_aot_dispatch_input_data_and_metadata_mutation(self):
|
def test_aot_dispatch_input_data_and_metadata_mutation(self):
|
||||||
def f(a, b):
|
def f(a, b):
|
||||||
a.t_()
|
a.t_()
|
||||||
@ -3742,6 +3761,7 @@ def forward(self, tangents_1, tangents_2):
|
|||||||
self.assertEqual(a_ref_base.grad.a, a_test_base.grad.a)
|
self.assertEqual(a_ref_base.grad.a, a_test_base.grad.a)
|
||||||
self.assertEqual(a_ref_base.grad.b, a_test_base.grad.b)
|
self.assertEqual(a_ref_base.grad.b, a_test_base.grad.b)
|
||||||
|
|
||||||
|
|
||||||
class TestAOTModuleSimplified(AOTTestCase):
|
class TestAOTModuleSimplified(AOTTestCase):
|
||||||
def test_aot_module_simplified(self):
|
def test_aot_module_simplified(self):
|
||||||
class MockModule(torch.nn.Module):
|
class MockModule(torch.nn.Module):
|
||||||
@ -4139,6 +4159,7 @@ class TestEagerFusionModuleInfo(AOTTestCase):
|
|||||||
_test_aot_autograd_module_helper(self, device, dtype, training, module_info, dynamic=True)
|
_test_aot_autograd_module_helper(self, device, dtype, training, module_info, dynamic=True)
|
||||||
|
|
||||||
|
|
||||||
|
instantiate_parametrized_tests(TestAOTAutograd)
|
||||||
only_for = ("cpu")
|
only_for = ("cpu")
|
||||||
instantiate_device_type_tests(
|
instantiate_device_type_tests(
|
||||||
TestPythonKey,
|
TestPythonKey,
|
||||||
|
|||||||
@ -42,7 +42,7 @@ def strip_end(s, suffix):
|
|||||||
def show_guards(gm):
|
def show_guards(gm):
|
||||||
names = [strip_end(n, "_1") for n in fx_placeholder_targets(gm)]
|
names = [strip_end(n, "_1") for n in fx_placeholder_targets(gm)]
|
||||||
return "\n".join(
|
return "\n".join(
|
||||||
gm.shape_env.produce_guards(fx_placeholder_vals(gm), names, _simplified=True, constraint_inputs=None)
|
gm.shape_env.produce_guards(fx_placeholder_vals(gm), names, _simplified=True, input_contexts=None)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -1,6 +1,7 @@
|
|||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from torch.utils._python_dispatch import is_traceable_wrapper_subclass
|
||||||
from . import allowed_functions
|
from . import allowed_functions
|
||||||
from .eval_frame import DisableContext, innermost_fn, RunOnlyContext
|
from .eval_frame import DisableContext, innermost_fn, RunOnlyContext
|
||||||
from .exc import IncorrectUsage
|
from .exc import IncorrectUsage
|
||||||
@ -158,6 +159,19 @@ def forbid_in_graph(fn):
|
|||||||
return fn
|
return fn
|
||||||
|
|
||||||
|
|
||||||
|
# Helper function to flatten a tensor subclass and apply a function to
|
||||||
|
# all inner tensors that match the outer dim. Used to reduce duplication
|
||||||
|
# across the various marking APIs.
|
||||||
|
def _apply_func_to_inner_tensors_of_same_dim(func, t, *args):
|
||||||
|
assert is_traceable_wrapper_subclass(t)
|
||||||
|
|
||||||
|
attrs, ctx = t.__tensor_flatten__()
|
||||||
|
for attr in attrs:
|
||||||
|
inner = getattr(t, attr)
|
||||||
|
if inner.dim() == t.dim():
|
||||||
|
func(inner, *args)
|
||||||
|
|
||||||
|
|
||||||
@forbid_in_graph
|
@forbid_in_graph
|
||||||
def mark_dynamic(t, index):
|
def mark_dynamic(t, index):
|
||||||
"""
|
"""
|
||||||
@ -183,6 +197,11 @@ def mark_dynamic(t, index):
|
|||||||
before torch.compile.
|
before torch.compile.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
if is_traceable_wrapper_subclass(t):
|
||||||
|
# default behavior: mirror mark_dynamic() on all inner tensors with same dim as t
|
||||||
|
# TODO: Make this configurable via a supported public API
|
||||||
|
_apply_func_to_inner_tensors_of_same_dim(mark_dynamic, t, index)
|
||||||
|
|
||||||
if isinstance(index, int):
|
if isinstance(index, int):
|
||||||
if not hasattr(t, "_dynamo_dynamic_indices"):
|
if not hasattr(t, "_dynamo_dynamic_indices"):
|
||||||
t._dynamo_dynamic_indices = set()
|
t._dynamo_dynamic_indices = set()
|
||||||
@ -201,6 +220,11 @@ def maybe_mark_dynamic(t, index):
|
|||||||
Mark a tensor as having a dynamic dim, but don't enforce it (i.e., if this
|
Mark a tensor as having a dynamic dim, but don't enforce it (i.e., if this
|
||||||
dimension ends up getting specialized, don't error).
|
dimension ends up getting specialized, don't error).
|
||||||
"""
|
"""
|
||||||
|
if is_traceable_wrapper_subclass(t):
|
||||||
|
# default behavior: mirror maybe_mark_dynamic() on all inner tensors with same dim as t
|
||||||
|
# TODO: Make this configurable via a supported public API
|
||||||
|
_apply_func_to_inner_tensors_of_same_dim(maybe_mark_dynamic, t, index)
|
||||||
|
|
||||||
if isinstance(index, int):
|
if isinstance(index, int):
|
||||||
if not hasattr(t, "_dynamo_weak_dynamic_indices"):
|
if not hasattr(t, "_dynamo_weak_dynamic_indices"):
|
||||||
t._dynamo_weak_dynamic_indices = set()
|
t._dynamo_weak_dynamic_indices = set()
|
||||||
@ -223,6 +247,11 @@ def mark_static(t, index=None):
|
|||||||
|
|
||||||
This has lower precedence than mark_dynamic.
|
This has lower precedence than mark_dynamic.
|
||||||
"""
|
"""
|
||||||
|
if is_traceable_wrapper_subclass(t):
|
||||||
|
# default behavior: mirror mark_static() on all inner tensors with same dim as t
|
||||||
|
# TODO: Make this configurable via a supported public API
|
||||||
|
_apply_func_to_inner_tensors_of_same_dim(mark_static, t, index)
|
||||||
|
|
||||||
if isinstance(index, int):
|
if isinstance(index, int):
|
||||||
if not hasattr(t, "_dynamo_static_indices"):
|
if not hasattr(t, "_dynamo_static_indices"):
|
||||||
t._dynamo_static_indices = set()
|
t._dynamo_static_indices = set()
|
||||||
@ -261,7 +290,7 @@ def _allow_in_graph_einops():
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
# requires einops > 0.6.1, torch >= 2.0
|
# requires einops > 0.6.1, torch >= 2.0
|
||||||
from einops._torch_specific import ( # noqa: F401
|
from einops._torch_specific import ( # type: ignore[attr-defined] # noqa: F401
|
||||||
_ops_were_registered_in_torchdynamo,
|
_ops_were_registered_in_torchdynamo,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -627,7 +627,7 @@ class GuardBuilder(GuardBuilderBase):
|
|||||||
output_graph = self.check_fn_manager.output_graph
|
output_graph = self.check_fn_manager.output_graph
|
||||||
# NB: self.output_graph can be None in the debug_nops tests
|
# NB: self.output_graph can be None in the debug_nops tests
|
||||||
fs = output_graph.tracked_fakes
|
fs = output_graph.tracked_fakes
|
||||||
constraint_inputs = [a.constraint_dims for a in fs]
|
input_contexts = [a.symbolic_context for a in fs]
|
||||||
|
|
||||||
def get_sources(t_id, dim):
|
def get_sources(t_id, dim):
|
||||||
# Looks up base sources mapped to a tensor id and uses them to create
|
# Looks up base sources mapped to a tensor id and uses them to create
|
||||||
@ -670,7 +670,7 @@ class GuardBuilder(GuardBuilderBase):
|
|||||||
guards = output_graph.shape_env.produce_guards(
|
guards = output_graph.shape_env.produce_guards(
|
||||||
[a.fake for a in fs],
|
[a.fake for a in fs],
|
||||||
[a.source for a in fs],
|
[a.source for a in fs],
|
||||||
constraint_inputs=constraint_inputs,
|
input_contexts=input_contexts,
|
||||||
equalities_inputs=equalities_inputs,
|
equalities_inputs=equalities_inputs,
|
||||||
source_ref=self.source_ref,
|
source_ref=self.source_ref,
|
||||||
# Export keeps static.
|
# Export keeps static.
|
||||||
|
|||||||
@ -26,10 +26,10 @@ from torch._streambase import _EventBase, _StreamBase
|
|||||||
from torch._subclasses.fake_tensor import FakeTensor, is_fake, maybe_get_fake_mode
|
from torch._subclasses.fake_tensor import FakeTensor, is_fake, maybe_get_fake_mode
|
||||||
from torch.fx.experimental.symbolic_shapes import (
|
from torch.fx.experimental.symbolic_shapes import (
|
||||||
_constrain_range_for_size,
|
_constrain_range_for_size,
|
||||||
DimConstraint,
|
|
||||||
DimDynamic,
|
DimDynamic,
|
||||||
RelaxedUnspecConstraint,
|
RelaxedUnspecConstraint,
|
||||||
StatefulSymbolicContext,
|
StatefulSymbolicContext,
|
||||||
|
SubclassSymbolicContext,
|
||||||
SymbolicContext,
|
SymbolicContext,
|
||||||
)
|
)
|
||||||
from torch.fx.immutable_collections import immutable_list
|
from torch.fx.immutable_collections import immutable_list
|
||||||
@ -1579,7 +1579,7 @@ class TrackedFake:
|
|||||||
fake: Union[FakeTensor, SymInt]
|
fake: Union[FakeTensor, SymInt]
|
||||||
source: Source
|
source: Source
|
||||||
# Is None when fake is SymInt
|
# Is None when fake is SymInt
|
||||||
constraint_dims: Optional[DimList[DimConstraint]]
|
symbolic_context: Optional[SymbolicContext]
|
||||||
|
|
||||||
def __hash__(self) -> int:
|
def __hash__(self) -> int:
|
||||||
return hash((self.fake, self.source.name()))
|
return hash((self.fake, self.source.name()))
|
||||||
@ -1592,13 +1592,40 @@ class TrackedFake:
|
|||||||
|
|
||||||
# Performs automatic dynamic dim determination.
|
# Performs automatic dynamic dim determination.
|
||||||
# Returns a SymbolicContext
|
# Returns a SymbolicContext
|
||||||
def _automatic_dynamic(e, tx, source, static_shapes) -> SymbolicContext:
|
def _automatic_dynamic(
|
||||||
|
e, tx, source, static_shapes, outer_only=False
|
||||||
|
) -> SymbolicContext:
|
||||||
name = source.name()
|
name = source.name()
|
||||||
prior_policy = tx.output.tracing_context.tensor_to_context.get(e, None)
|
prior_policy = tx.output.tracing_context.tensor_to_context.get(e, None)
|
||||||
source_to_symint_node_cache = (
|
source_to_symint_node_cache = (
|
||||||
prior_policy.source_to_symint_node_cache if prior_policy else None
|
prior_policy.source_to_symint_node_cache if prior_policy else None
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if is_traceable_wrapper_subclass(e) and not outer_only:
|
||||||
|
# Get symbolic context for outer tensor
|
||||||
|
outer_context = _automatic_dynamic(
|
||||||
|
e, tx, source, static_shapes, outer_only=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get symbolic contexts for inner tensors
|
||||||
|
attrs, _ = type(e).__tensor_flatten__(e)
|
||||||
|
inner_contexts = {} # mapping from attr -> symbolic context
|
||||||
|
for attr in attrs:
|
||||||
|
inner_tensor = getattr(e, attr)
|
||||||
|
inner_source = AttrSource(source, attr)
|
||||||
|
inner_context = _automatic_dynamic(
|
||||||
|
inner_tensor, tx, inner_source, static_shapes
|
||||||
|
)
|
||||||
|
inner_contexts[attr] = inner_context
|
||||||
|
|
||||||
|
return SubclassSymbolicContext(
|
||||||
|
dynamic_sizes=outer_context.dynamic_sizes,
|
||||||
|
constraint_sizes=outer_context.constraint_sizes,
|
||||||
|
tensor_source=outer_context.tensor_source,
|
||||||
|
source_to_symint_node_cache=outer_context.source_to_symint_node_cache,
|
||||||
|
inner_contexts=inner_contexts,
|
||||||
|
)
|
||||||
|
|
||||||
if static_shapes:
|
if static_shapes:
|
||||||
return StatefulSymbolicContext(
|
return StatefulSymbolicContext(
|
||||||
dynamic_sizes=[DimDynamic.STATIC] * e.dim(),
|
dynamic_sizes=[DimDynamic.STATIC] * e.dim(),
|
||||||
@ -1609,7 +1636,9 @@ def _automatic_dynamic(e, tx, source, static_shapes) -> SymbolicContext:
|
|||||||
|
|
||||||
# We preserve the dynamism of inputs. For example, when users call
|
# We preserve the dynamism of inputs. For example, when users call
|
||||||
# make_fx(torch.cond, tracing_mode="symbolic")(*args), inputs have SymInt sizes.
|
# make_fx(torch.cond, tracing_mode="symbolic")(*args), inputs have SymInt sizes.
|
||||||
if any(isinstance(s, SymInt) for s in e.size()):
|
from torch.fx.experimental.symbolic_shapes import is_singleton
|
||||||
|
|
||||||
|
if any(isinstance(s, SymInt) and not is_singleton(s) for s in e.size()):
|
||||||
return StatefulSymbolicContext(
|
return StatefulSymbolicContext(
|
||||||
dynamic_sizes=[
|
dynamic_sizes=[
|
||||||
DimDynamic.DYNAMIC if isinstance(s, SymInt) else DimDynamic.STATIC
|
DimDynamic.DYNAMIC if isinstance(s, SymInt) else DimDynamic.STATIC
|
||||||
@ -1734,7 +1763,12 @@ def _automatic_dynamic(e, tx, source, static_shapes) -> SymbolicContext:
|
|||||||
constraint_dims.append(constraint_dim)
|
constraint_dims.append(constraint_dim)
|
||||||
|
|
||||||
# Now, figure out if the dim is dynamic/duck/static
|
# Now, figure out if the dim is dynamic/duck/static
|
||||||
if constraint_dim is not None or marked_dynamic or marked_weak_dynamic:
|
if (
|
||||||
|
constraint_dim is not None
|
||||||
|
or marked_dynamic
|
||||||
|
or marked_weak_dynamic
|
||||||
|
or is_singleton(e.shape[i])
|
||||||
|
):
|
||||||
# NB: We could assert static_shapes is False here, but it
|
# NB: We could assert static_shapes is False here, but it
|
||||||
# seems better to allow the user to override symbolic_context in this
|
# seems better to allow the user to override symbolic_context in this
|
||||||
# case
|
# case
|
||||||
@ -1768,20 +1802,13 @@ def wrap_to_fake_tensor_and_record(e, tx, *, source: Optional[Source], is_tensor
|
|||||||
e, is_tensor, guard_source=source.guard_source()
|
e, is_tensor, guard_source=source.guard_source()
|
||||||
)
|
)
|
||||||
|
|
||||||
symbolic_context = None
|
symbolic_context = _automatic_dynamic(e, tx, source, static_shapes)
|
||||||
if not e.is_nested:
|
|
||||||
# TODO: We should probably support this for nested tensors too
|
|
||||||
symbolic_context = _automatic_dynamic(e, tx, source, static_shapes)
|
|
||||||
|
|
||||||
if symbolic_context:
|
|
||||||
tx.output.tracing_context.tensor_to_context[e] = symbolic_context
|
|
||||||
|
|
||||||
log.debug(
|
log.debug(
|
||||||
"wrap_to_fake %s %s %s %s",
|
"wrap_to_fake %s %s %s",
|
||||||
source.name(),
|
source.name(),
|
||||||
tuple(e.shape),
|
tuple(e.shape),
|
||||||
symbolic_context.dynamic_sizes if symbolic_context is not None else None,
|
symbolic_context,
|
||||||
symbolic_context.constraint_sizes if symbolic_context is not None else None,
|
|
||||||
)
|
)
|
||||||
fake_e = wrap_fake_exception(
|
fake_e = wrap_fake_exception(
|
||||||
lambda: tx.fake_mode.from_tensor(
|
lambda: tx.fake_mode.from_tensor(
|
||||||
@ -1790,22 +1817,36 @@ def wrap_to_fake_tensor_and_record(e, tx, *, source: Optional[Source], is_tensor
|
|||||||
symbolic_context=symbolic_context,
|
symbolic_context=symbolic_context,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
if is_tensor and not (static_shapes and source.is_nn_module()):
|
|
||||||
# TODO: just store the whole symbolic_context here
|
# list of (fake_tensor, real_tensor, source, symbolic_context)
|
||||||
tx.output.tracked_fakes.append(
|
tracking_info = [(fake_e, e, source, symbolic_context)]
|
||||||
TrackedFake(
|
if is_traceable_wrapper_subclass(fake_e):
|
||||||
fake_e,
|
attrs, _ = fake_e.__tensor_flatten__()
|
||||||
source,
|
for attr in attrs:
|
||||||
symbolic_context.constraint_sizes
|
fake_inner = getattr(fake_e, attr)
|
||||||
if symbolic_context is not None
|
inner = getattr(e, attr)
|
||||||
else None,
|
tracking_info.append(
|
||||||
|
(
|
||||||
|
fake_inner,
|
||||||
|
inner,
|
||||||
|
AttrSource(source, attr),
|
||||||
|
symbolic_context.inner_contexts[attr],
|
||||||
|
)
|
||||||
)
|
)
|
||||||
)
|
|
||||||
tx.output.tracked_fakes_id_to_source[id(e)].append(source)
|
for fake, real, source, symbolic_context in tracking_info:
|
||||||
tx.output.tensor_weakref_to_sizes_strides[e] = {
|
tx.output.tracing_context.tensor_to_context[real] = symbolic_context
|
||||||
"size": fake_e.size(),
|
tx.output.tensor_weakref_to_sizes_strides[real] = {
|
||||||
"stride": fake_e.stride(),
|
"size": fake.size(),
|
||||||
}
|
"stride": fake.stride(),
|
||||||
|
}
|
||||||
|
|
||||||
|
if is_tensor and not (static_shapes and source.is_nn_module()):
|
||||||
|
tx.output.tracked_fakes.append(
|
||||||
|
TrackedFake(fake, source, symbolic_context)
|
||||||
|
)
|
||||||
|
tx.output.tracked_fakes_id_to_source[id(real)].append(source)
|
||||||
|
|
||||||
return fake_e
|
return fake_e
|
||||||
else:
|
else:
|
||||||
return e
|
return e
|
||||||
|
|||||||
@ -127,6 +127,10 @@ def run_functionalized_fw_and_collect_metadata(
|
|||||||
mutates_metadata = has_metadata_mutation(
|
mutates_metadata = has_metadata_mutation(
|
||||||
f_arg, arg, check_only_storage_mutation=False
|
f_arg, arg, check_only_storage_mutation=False
|
||||||
)
|
)
|
||||||
|
if mutates_metadata and is_traceable_wrapper_subclass(arg):
|
||||||
|
raise RuntimeError(
|
||||||
|
"Metadata mutations are currently not allowed on tensor subclasses"
|
||||||
|
)
|
||||||
mutates_storage_metadata = has_metadata_mutation(
|
mutates_storage_metadata = has_metadata_mutation(
|
||||||
f_arg, arg, check_only_storage_mutation=True
|
f_arg, arg, check_only_storage_mutation=True
|
||||||
)
|
)
|
||||||
|
|||||||
@ -6,7 +6,7 @@ input/output types, metadata, config, function signatures etc.
|
|||||||
import collections
|
import collections
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, Callable, Dict, List, NewType, Optional, Set, Union
|
from typing import Any, Callable, Dict, List, NewType, Optional, Set, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.utils._pytree as pytree
|
import torch.utils._pytree as pytree
|
||||||
@ -138,9 +138,12 @@ class SubclassCreationMeta:
|
|||||||
# so holding onto this at runtime shouldn't leak memory)
|
# so holding onto this at runtime shouldn't leak memory)
|
||||||
original_subclass: torch.Tensor
|
original_subclass: torch.Tensor
|
||||||
# meta and inner_keys are produced by the subclass's __tensor_flatten__.
|
# meta and inner_keys are produced by the subclass's __tensor_flatten__.
|
||||||
# We need to keep them around to plumb them into __tensor_unflatten__.
|
# We need to keep them around along with outer_size / outer_stride to plumb them
|
||||||
|
# into __tensor_unflatten__.
|
||||||
meta: Any
|
meta: Any
|
||||||
inner_keys: List[Any]
|
inner_keys: List[Any]
|
||||||
|
outer_size: Tuple[int, ...]
|
||||||
|
outer_stride: Tuple[int, ...]
|
||||||
|
|
||||||
def creation_fn(self, all_args, *, is_runtime: bool):
|
def creation_fn(self, all_args, *, is_runtime: bool):
|
||||||
curr_args = all_args[
|
curr_args = all_args[
|
||||||
@ -149,8 +152,13 @@ class SubclassCreationMeta:
|
|||||||
assert len(curr_args) == len(
|
assert len(curr_args) == len(
|
||||||
self.inner_keys
|
self.inner_keys
|
||||||
), f"inner_keys: {str(self.inner_keys)}. len(curr_args): {len(curr_args)}"
|
), f"inner_keys: {str(self.inner_keys)}. len(curr_args): {len(curr_args)}"
|
||||||
|
# NB: Sometimes we have real inner tensors and symbolic metadata.
|
||||||
|
# TODO: Resolve this so we always have matching real / symbolic tensors / metadata.
|
||||||
out = type(self.original_subclass).__tensor_unflatten__( # type: ignore[attr-defined]
|
out = type(self.original_subclass).__tensor_unflatten__( # type: ignore[attr-defined]
|
||||||
dict(zip(self.inner_keys, curr_args)), self.meta
|
dict(zip(self.inner_keys, curr_args)),
|
||||||
|
self.meta,
|
||||||
|
self.outer_size,
|
||||||
|
self.outer_stride,
|
||||||
)
|
)
|
||||||
if not is_runtime:
|
if not is_runtime:
|
||||||
# After wrapping up the inner dense tensors into a subclass, we need to make sure that our new wrapper
|
# After wrapping up the inner dense tensors into a subclass, we need to make sure that our new wrapper
|
||||||
|
|||||||
@ -54,6 +54,8 @@ def create_subclass_meta(
|
|||||||
original_subclass=a,
|
original_subclass=a,
|
||||||
meta=meta,
|
meta=meta,
|
||||||
inner_keys=attrs,
|
inner_keys=attrs,
|
||||||
|
outer_size=a.shape,
|
||||||
|
outer_stride=a.stride(),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
|||||||
@ -1537,27 +1537,13 @@ def make_contiguous_strides_for(
|
|||||||
if not shape:
|
if not shape:
|
||||||
return ()
|
return ()
|
||||||
|
|
||||||
# TODO: Move this somewhere central?
|
from torch.fx.experimental.symbolic_shapes import is_singleton
|
||||||
def _is_singleton(s):
|
|
||||||
# check for SingletonSymNode
|
|
||||||
if not isinstance(s, torch.SymInt):
|
|
||||||
return False
|
|
||||||
if s.node.singleton_int() is not None:
|
|
||||||
return True
|
|
||||||
|
|
||||||
# check for SymInt wrapping a SingletonSymNode (fake-ifying causes this)
|
|
||||||
return (
|
|
||||||
s.node.is_symbolic()
|
|
||||||
and s.node.hint is not None
|
|
||||||
and isinstance(s.node.hint, torch.SymInt)
|
|
||||||
and s.node.hint.node.singleton_int() is not None
|
|
||||||
)
|
|
||||||
|
|
||||||
multiplier = 1
|
multiplier = 1
|
||||||
strides = []
|
strides = []
|
||||||
for l in reversed(shape):
|
for l in reversed(shape):
|
||||||
strides.append(multiplier)
|
strides.append(multiplier)
|
||||||
multiplier *= l if _is_singleton(l) else sym_max(l, 1)
|
multiplier *= l if is_singleton(l) else sym_max(l, 1)
|
||||||
|
|
||||||
result = tuple(reversed(strides))
|
result = tuple(reversed(strides))
|
||||||
|
|
||||||
|
|||||||
@ -186,8 +186,6 @@ class MetaConverter:
|
|||||||
source: Optional[Source] = None,
|
source: Optional[Source] = None,
|
||||||
symbolic_context: Optional["SymbolicContext"] = None,
|
symbolic_context: Optional["SymbolicContext"] = None,
|
||||||
):
|
):
|
||||||
from torch._subclasses.fake_tensor import FakeTensor
|
|
||||||
|
|
||||||
if source is None:
|
if source is None:
|
||||||
from torch._dynamo.source import ConstantSource
|
from torch._dynamo.source import ConstantSource
|
||||||
|
|
||||||
@ -235,10 +233,11 @@ class MetaConverter:
|
|||||||
maybe_suppress = shape_env.suppress_guards
|
maybe_suppress = shape_env.suppress_guards
|
||||||
|
|
||||||
def sym_sizes_strides_storage_offset(
|
def sym_sizes_strides_storage_offset(
|
||||||
t, src
|
t, src, symbolic_context=symbolic_context
|
||||||
) -> Tuple[Tuple[int, ...], Tuple[int, ...], int]:
|
) -> Tuple[Tuple[int, ...], Tuple[int, ...], int]:
|
||||||
if shape_env is not None:
|
if shape_env is not None:
|
||||||
if isinstance(t, FakeTensor) and t.fake_mode.shape_env is shape_env:
|
fake_mode = torch._subclasses.fake_tensor.maybe_get_fake_mode(t)
|
||||||
|
if fake_mode is not None and fake_mode.shape_env is shape_env:
|
||||||
# Don't reallocate the sizes; the shape envs are the same,
|
# Don't reallocate the sizes; the shape envs are the same,
|
||||||
# so reuse the old sizes/strides/etc
|
# so reuse the old sizes/strides/etc
|
||||||
return (t.size(), t.stride(), t.storage_offset())
|
return (t.size(), t.stride(), t.storage_offset())
|
||||||
@ -246,16 +245,25 @@ class MetaConverter:
|
|||||||
return shape_env.create_symbolic_sizes_strides_storage_offset(
|
return shape_env.create_symbolic_sizes_strides_storage_offset(
|
||||||
t,
|
t,
|
||||||
src,
|
src,
|
||||||
# Assume that the set of dims that are dynamic are the same between
|
|
||||||
# the wrapper tensor and any inner tensors.
|
|
||||||
# We can revisit this if this assumption does not hold
|
|
||||||
# for any important subclasses later.
|
|
||||||
symbolic_context=symbolic_context,
|
symbolic_context=symbolic_context,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
assert symbolic_context is None
|
assert symbolic_context is None
|
||||||
return (t.size(), t.stride(), t.storage_offset())
|
return (t.size(), t.stride(), t.storage_offset())
|
||||||
|
|
||||||
|
def empty_create(inner_t, inner_src, symbolic_context=symbolic_context):
|
||||||
|
(
|
||||||
|
inner_sizes,
|
||||||
|
inner_strides,
|
||||||
|
inner_storage_offset,
|
||||||
|
) = sym_sizes_strides_storage_offset(inner_t, inner_src, symbolic_context)
|
||||||
|
return torch.empty_strided(
|
||||||
|
inner_sizes,
|
||||||
|
inner_strides,
|
||||||
|
dtype=inner_t.dtype,
|
||||||
|
device="meta",
|
||||||
|
)
|
||||||
|
|
||||||
# see expired-storages
|
# see expired-storages
|
||||||
self.check_expired_count += 1
|
self.check_expired_count += 1
|
||||||
if self.check_expired_count >= self.check_expired_frequency:
|
if self.check_expired_count >= self.check_expired_frequency:
|
||||||
@ -443,99 +451,45 @@ class MetaConverter:
|
|||||||
|
|
||||||
else:
|
else:
|
||||||
is_leaf = safe_is_leaf(t)
|
is_leaf = safe_is_leaf(t)
|
||||||
if not t.is_nested:
|
|
||||||
# Nested tensor subclasses have special logic for
|
|
||||||
# creating symbolic size/strides/storage_offset
|
|
||||||
(
|
|
||||||
sizes,
|
|
||||||
strides,
|
|
||||||
storage_offset,
|
|
||||||
) = sym_sizes_strides_storage_offset(t, source)
|
|
||||||
|
|
||||||
def empty_create(inner_t, inner_src):
|
from torch.fx.experimental.symbolic_shapes import (
|
||||||
(
|
SubclassSymbolicContext,
|
||||||
inner_sizes,
|
)
|
||||||
inner_strides,
|
|
||||||
inner_storage_offset,
|
(
|
||||||
) = sym_sizes_strides_storage_offset(inner_t, inner_src)
|
sizes,
|
||||||
return torch.empty_strided(
|
strides,
|
||||||
inner_sizes,
|
storage_offset,
|
||||||
inner_strides,
|
) = sym_sizes_strides_storage_offset(t, source, symbolic_context)
|
||||||
dtype=inner_t.dtype,
|
|
||||||
device="meta",
|
|
||||||
)
|
|
||||||
|
|
||||||
# If we have a subclass that desugars into dense tensors,
|
# If we have a subclass that desugars into dense tensors,
|
||||||
# perform our callback on each inner tensor.
|
# perform our callback on each inner tensor.
|
||||||
if is_traceable_wrapper_subclass(t):
|
if is_traceable_wrapper_subclass(t):
|
||||||
# Note: transform_subclass will use __tensor_unflatten__ to generate
|
# Note: transform_subclass will use __tensor_unflatten__ to generate
|
||||||
# a fresh subclass wrapper, which is why sizes/strides are not passed in
|
# a fresh subclass wrapper. We assume that if the inner tensors of
|
||||||
# to the creation function here.
|
# the subclass are given symbolic sizes, their sizes will be used
|
||||||
# We assume that if the inner tensors of the subclass are given symbolic sizes,
|
# to construct the (symbolic) sizes of the wrapper tensor.
|
||||||
# their sizes will be used to construct the (symbolic) sizes of the wrapper tensor.
|
|
||||||
from torch._dynamo.source import AttrSource
|
from torch._dynamo.source import AttrSource
|
||||||
|
|
||||||
if t.is_nested:
|
assert symbolic_context is None or isinstance(
|
||||||
# Avoid circular import
|
symbolic_context, SubclassSymbolicContext
|
||||||
from torch._dynamo.source import (
|
)
|
||||||
TensorProperty,
|
r = transform_subclass(
|
||||||
TensorPropertySource,
|
t,
|
||||||
)
|
lambda attr, inner_t: callback(
|
||||||
|
lambda: empty_create(
|
||||||
# For nested tensors, manually do transform_subclass
|
inner_t,
|
||||||
# so we can insert some special processing on ctx
|
AttrSource(source, attr),
|
||||||
attrs, ctx = t.__tensor_flatten__()
|
symbolic_context=(
|
||||||
transformed_tensors_dict = {}
|
None
|
||||||
orig_shape_env = None
|
if symbolic_context is None
|
||||||
for attr in attrs:
|
else symbolic_context.inner_contexts[attr]
|
||||||
inner_t = getattr(t, attr)
|
|
||||||
if orig_shape_env is None:
|
|
||||||
orig_shape_env = (
|
|
||||||
inner_t.fake_mode.shape_env
|
|
||||||
if isinstance(inner_t, FakeTensor)
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
transformed_tensors_dict[attr] = callback(
|
|
||||||
lambda: empty_create(
|
|
||||||
inner_t, AttrSource(source, attr)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
# We expect JaggedTensor to have a 'ragged_size' in
|
|
||||||
# its context
|
|
||||||
assert isinstance(ctx, dict)
|
|
||||||
assert "ragged_size" in ctx
|
|
||||||
assert isinstance(t._size[1], torch.SymInt)
|
|
||||||
if orig_shape_env is shape_env:
|
|
||||||
# It's already fake and the shape envs line up, reuse the old size
|
|
||||||
# Do not assert singleton_int; it may already
|
|
||||||
# be a variable
|
|
||||||
ctx["ragged_size"] = t._size[1]
|
|
||||||
else:
|
|
||||||
assert t._size[1].node.singleton_int() is not None
|
|
||||||
# Replace the eager ragged size with our freshly
|
|
||||||
# allocated jagged size that has a source
|
|
||||||
ctx["ragged_size"] = shape_env.create_symintnode(
|
|
||||||
shape_env.create_symbol(
|
|
||||||
t._size[1],
|
|
||||||
TensorPropertySource(
|
|
||||||
source, TensorProperty.SIZE, 1
|
|
||||||
),
|
|
||||||
),
|
),
|
||||||
hint=t._size[1],
|
|
||||||
)
|
)
|
||||||
r = type(t).__tensor_unflatten__(
|
),
|
||||||
transformed_tensors_dict, ctx
|
outer_size=sizes,
|
||||||
)
|
outer_stride=strides,
|
||||||
else:
|
)
|
||||||
r = transform_subclass(
|
|
||||||
t,
|
|
||||||
lambda attr, inner_t: callback(
|
|
||||||
lambda: empty_create(
|
|
||||||
inner_t,
|
|
||||||
AttrSource(source, attr),
|
|
||||||
)
|
|
||||||
),
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
r = callback(
|
r = callback(
|
||||||
lambda: torch.empty_strided(
|
lambda: torch.empty_strided(
|
||||||
|
|||||||
@ -389,7 +389,7 @@ class AsyncCollectiveTensor(torch.Tensor):
|
|||||||
return self.elem.tolist()
|
return self.elem.tolist()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def __tensor_unflatten__(inner_tensors, meta):
|
def __tensor_unflatten__(inner_tensors, meta, outer_size, outer_stride):
|
||||||
assert meta is None
|
assert meta is None
|
||||||
elem = inner_tensors["elem"]
|
elem = inner_tensors["elem"]
|
||||||
return AsyncCollectiveTensor(elem)
|
return AsyncCollectiveTensor(elem)
|
||||||
|
|||||||
@ -255,7 +255,7 @@ class DTensor(torch.Tensor): # pyre-ignore[13]: pyre is bad at __new__
|
|||||||
return ["_local_tensor"], (self._spec, self.requires_grad)
|
return ["_local_tensor"], (self._spec, self.requires_grad)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def __tensor_unflatten__(inner_tensors, flatten_spec):
|
def __tensor_unflatten__(inner_tensors, flatten_spec, outer_size, outer_stride):
|
||||||
assert (
|
assert (
|
||||||
flatten_spec is not None
|
flatten_spec is not None
|
||||||
), "Expecting spec to be not None from `__tensor_flatten__` return value!"
|
), "Expecting spec to be not None from `__tensor_flatten__` return value!"
|
||||||
@ -265,10 +265,10 @@ class DTensor(torch.Tensor): # pyre-ignore[13]: pyre is bad at __new__
|
|||||||
local_tensor,
|
local_tensor,
|
||||||
spec.mesh,
|
spec.mesh,
|
||||||
spec.placements,
|
spec.placements,
|
||||||
shape=spec.tensor_meta.shape,
|
shape=outer_size,
|
||||||
dtype=spec.tensor_meta.dtype,
|
dtype=spec.tensor_meta.dtype,
|
||||||
requires_grad=requires_grad,
|
requires_grad=requires_grad,
|
||||||
stride=spec.tensor_meta.stride,
|
stride=outer_stride,
|
||||||
)
|
)
|
||||||
|
|
||||||
__torch_function__ = torch._C._disabled_torch_function_impl
|
__torch_function__ = torch._C._disabled_torch_function_impl
|
||||||
|
|||||||
@ -62,8 +62,9 @@ __all__ = [
|
|||||||
"has_symbolic_sizes_strides", "create_contiguous", "ShapeEnv", "is_concrete_int",
|
"has_symbolic_sizes_strides", "create_contiguous", "ShapeEnv", "is_concrete_int",
|
||||||
"guard_int", "guard_float", "guard_scalar", "canonicalize_bool_expr",
|
"guard_int", "guard_float", "guard_scalar", "canonicalize_bool_expr",
|
||||||
"hint_int", "SYMPY_INTERP", "free_symbols", "is_symbol_binding_fx_node",
|
"hint_int", "SYMPY_INTERP", "free_symbols", "is_symbol_binding_fx_node",
|
||||||
"is_concrete_bool", "SHAPEENV_EVENT_KEY", "CURRENT_NODE_KEY",
|
"is_concrete_bool", "is_singleton", "SHAPEENV_EVENT_KEY", "CURRENT_NODE_KEY",
|
||||||
"has_free_symbols", "sym_eq", "SymbolicContext", "StatelessSymbolicContext", "StatefulSymbolicContext"
|
"has_free_symbols", "sym_eq", "SymbolicContext", "StatelessSymbolicContext",
|
||||||
|
"StatefulSymbolicContext", "SubclassSymbolicContext"
|
||||||
]
|
]
|
||||||
|
|
||||||
# FX node metadata keys for symbolic shape FX graph.
|
# FX node metadata keys for symbolic shape FX graph.
|
||||||
@ -203,6 +204,21 @@ def is_concrete_bool(a: Union[bool, SymBool]):
|
|||||||
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
def is_singleton(s):
|
||||||
|
# check for SingletonSymNode
|
||||||
|
if not isinstance(s, torch.SymInt):
|
||||||
|
return False
|
||||||
|
if s.node.singleton_int() is not None:
|
||||||
|
return True
|
||||||
|
|
||||||
|
# check for symbolic variable wrapping a SingletonSymNode (fake-ifying causes this)
|
||||||
|
return (
|
||||||
|
s.node.is_symbolic()
|
||||||
|
and s.node.hint is not None
|
||||||
|
and isinstance(s.node.hint, torch.SymInt)
|
||||||
|
and s.node.hint.node.singleton_int() is not None
|
||||||
|
)
|
||||||
|
|
||||||
def _iterate_exprs(val: Union[SymInt, torch.Tensor]) -> Iterable[sympy.Basic]:
|
def _iterate_exprs(val: Union[SymInt, torch.Tensor]) -> Iterable[sympy.Basic]:
|
||||||
if isinstance(val, SymTypes):
|
if isinstance(val, SymTypes):
|
||||||
# This allow applies to the jagged layout NestedTensor case as
|
# This allow applies to the jagged layout NestedTensor case as
|
||||||
@ -851,6 +867,20 @@ class StatefulSymbolicContext(StatelessSymbolicContext):
|
|||||||
object.__setattr__(self, 'source_to_symint_node_cache', {})
|
object.__setattr__(self, 'source_to_symint_node_cache', {})
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class SubclassSymbolicContext(StatefulSymbolicContext):
|
||||||
|
"""
|
||||||
|
The correct symbolic context for a given inner tensor of a traceable tensor subclass
|
||||||
|
may differ from that of the outer symbolic context. This structure allows for this
|
||||||
|
flexibility, with inner symbolic contexts mapped via attr -> symbolic context.
|
||||||
|
"""
|
||||||
|
inner_contexts: Dict[str, SymbolicContext] = None
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
if self.inner_contexts is None:
|
||||||
|
self.inner_contexts = {}
|
||||||
|
|
||||||
|
|
||||||
def is_symbolic(val: Union[int, SymInt, float, SymFloat, bool, SymBool]) -> bool:
|
def is_symbolic(val: Union[int, SymInt, float, SymFloat, bool, SymBool]) -> bool:
|
||||||
if isinstance(val, (int, float, bool)):
|
if isinstance(val, (int, float, bool)):
|
||||||
return False
|
return False
|
||||||
@ -1864,7 +1894,7 @@ class ShapeEnv:
|
|||||||
# FakeTensorMeta for two reasons:
|
# FakeTensorMeta for two reasons:
|
||||||
# 1. this is all the information we need when recording ShapeEnvEvents.
|
# 1. this is all the information we need when recording ShapeEnvEvents.
|
||||||
# 2. it works even if each TrackedFake changes its metadata.
|
# 2. it works even if each TrackedFake changes its metadata.
|
||||||
return TrackedFake(inner_fake, fake.source, fake.constraint_dims) # type: ignore[arg-type]
|
return TrackedFake(inner_fake, fake.source, fake.symbolic_context) # type: ignore[arg-type]
|
||||||
|
|
||||||
return [maybe_transform_fake(fake) for fake in self.tracked_fakes]
|
return [maybe_transform_fake(fake) for fake in self.tracked_fakes]
|
||||||
|
|
||||||
@ -2091,8 +2121,6 @@ class ShapeEnv:
|
|||||||
# The order of checking the guards matters. In this specific example:
|
# The order of checking the guards matters. In this specific example:
|
||||||
# If True branch guard check precedes False branch and for True branch, y.size(0) check precedes x == True,
|
# If True branch guard check precedes False branch and for True branch, y.size(0) check precedes x == True,
|
||||||
# we may have an unnessary shape speciliazation for y.
|
# we may have an unnessary shape speciliazation for y.
|
||||||
assert not ex.is_nested
|
|
||||||
|
|
||||||
def maybe_specialize_sym_int_with_hint(maybe_sym) -> int:
|
def maybe_specialize_sym_int_with_hint(maybe_sym) -> int:
|
||||||
assert isinstance(maybe_sym, (int, torch.SymInt))
|
assert isinstance(maybe_sym, (int, torch.SymInt))
|
||||||
if is_symbolic(maybe_sym):
|
if is_symbolic(maybe_sym):
|
||||||
@ -2174,7 +2202,13 @@ class ShapeEnv:
|
|||||||
}
|
}
|
||||||
# iterate over unbound strides in sorted order
|
# iterate over unbound strides in sorted order
|
||||||
val_list = sorted(
|
val_list = sorted(
|
||||||
[(ex_stride[i], i) for i in range(len(stride)) if stride[i] is None]
|
[(ex_stride[i], i) for i in range(len(stride)) if stride[i] is None],
|
||||||
|
key=lambda tup: (
|
||||||
|
# Order singletons by their coefficients.
|
||||||
|
# 1 here to order singletons after non-singletons.
|
||||||
|
(1, tup[0].node.singleton_coeff(), tup[1]) if is_singleton(tup[0])
|
||||||
|
else (0, *tup)
|
||||||
|
)
|
||||||
)
|
)
|
||||||
for _, i in val_list:
|
for _, i in val_list:
|
||||||
if stride[i] is None and ex_stride[i] in candidates:
|
if stride[i] is None and ex_stride[i] in candidates:
|
||||||
@ -2367,9 +2401,6 @@ class ShapeEnv:
|
|||||||
dynamic_dim = DimDynamic.DYNAMIC
|
dynamic_dim = DimDynamic.DYNAMIC
|
||||||
|
|
||||||
if dynamic_dim is DimDynamic.STATIC:
|
if dynamic_dim is DimDynamic.STATIC:
|
||||||
# We don't expect to ever reach here even the user specifies
|
|
||||||
# dynamic=False, because automatic_dynamic skipped for
|
|
||||||
# nested tensors.
|
|
||||||
return sympy.Integer(val)
|
return sympy.Integer(val)
|
||||||
|
|
||||||
elif dynamic_dim is DimDynamic.DUCK:
|
elif dynamic_dim is DimDynamic.DUCK:
|
||||||
@ -2492,11 +2523,7 @@ class ShapeEnv:
|
|||||||
sources,
|
sources,
|
||||||
source_ref=lambda n: n.name(),
|
source_ref=lambda n: n.name(),
|
||||||
*,
|
*,
|
||||||
# An input is either a SymInt (in which case you directly have
|
input_contexts: Optional[DimList[SymbolicContext]] = None,
|
||||||
# DimConstraint) or a Tensor (in which case you have a
|
|
||||||
# DimList[DimConstraint]). Whenever Optional is accepted, that
|
|
||||||
# just means there are no constraints
|
|
||||||
constraint_inputs: Optional[InputList[Union[DimConstraint, Optional[DimList[DimConstraint]]]]] = None,
|
|
||||||
equalities_inputs: Optional[Set[Tuple[Source, Source]]] = None,
|
equalities_inputs: Optional[Set[Tuple[Source, Source]]] = None,
|
||||||
_simplified=False,
|
_simplified=False,
|
||||||
# Indicates if we should produce guards for known static values.
|
# Indicates if we should produce guards for known static values.
|
||||||
@ -2521,22 +2548,28 @@ class ShapeEnv:
|
|||||||
assert len(placeholders) == len(sources)
|
assert len(placeholders) == len(sources)
|
||||||
Tensorlike = (torch.Tensor, FakeTensorMeta)
|
Tensorlike = (torch.Tensor, FakeTensorMeta)
|
||||||
|
|
||||||
|
def _create_no_constraints_context(t):
|
||||||
|
return StatelessSymbolicContext(
|
||||||
|
# Ignored; only the constraints part is relevant below.
|
||||||
|
dynamic_sizes=[DimDynamic.DYNAMIC] * t.dim(),
|
||||||
|
constraint_sizes=[None] * t.dim()
|
||||||
|
)
|
||||||
|
|
||||||
# Expand optional inputs, or verify invariants are upheld
|
# Expand optional inputs, or verify invariants are upheld
|
||||||
if constraint_inputs is None:
|
if input_contexts is None:
|
||||||
constraint_inputs = [
|
input_contexts = [
|
||||||
[None] * t.dim() if isinstance(t, Tensorlike) else None for t in placeholders
|
_create_no_constraints_context(t) if isinstance(t, Tensorlike)
|
||||||
|
else None for t in placeholders
|
||||||
]
|
]
|
||||||
else:
|
else:
|
||||||
assert len(constraint_inputs) == len(placeholders)
|
assert len(input_contexts) == len(placeholders)
|
||||||
for i, (t, constraint) in enumerate(zip(placeholders, constraint_inputs)):
|
for i, (t, context) in enumerate(zip(placeholders, input_contexts)):
|
||||||
if isinstance(t, Tensorlike):
|
if isinstance(t, Tensorlike):
|
||||||
if constraint is None:
|
if context is None:
|
||||||
constraint_inputs[i] = [None] * t.dim()
|
input_contexts[i] = _create_no_constraints_context(t)
|
||||||
else:
|
|
||||||
assert len(constraint) == t.dim()
|
|
||||||
else:
|
else:
|
||||||
assert isinstance(t, (SymInt, int))
|
assert isinstance(t, (SymInt, int))
|
||||||
assert not isinstance(constraint, list)
|
assert not isinstance(context, list)
|
||||||
|
|
||||||
# It took a lot of sweat to figure out the algorithm here. Let's
|
# It took a lot of sweat to figure out the algorithm here. Let's
|
||||||
# explain how it works.
|
# explain how it works.
|
||||||
@ -2682,10 +2715,6 @@ class ShapeEnv:
|
|||||||
# expect to have to compile in this case anyway
|
# expect to have to compile in this case anyway
|
||||||
if i not in (0, 1):
|
if i not in (0, 1):
|
||||||
constraint_violated = True
|
constraint_violated = True
|
||||||
else:
|
|
||||||
# TODO: Maybe non-strict constraint shouldn't error
|
|
||||||
# here? Check what happens in practice
|
|
||||||
constraint_violated = True
|
|
||||||
if constraint_violated:
|
if constraint_violated:
|
||||||
def hint(s):
|
def hint(s):
|
||||||
sexpr = ShapeGuardPrinter(symbol_to_source, source_ref, self.var_to_sources).doprint(s)
|
sexpr = ShapeGuardPrinter(symbol_to_source, source_ref, self.var_to_sources).doprint(s)
|
||||||
@ -2723,7 +2752,7 @@ class ShapeEnv:
|
|||||||
)
|
)
|
||||||
record_constraint_violation(constraint.warn_only, self.debug_name(source), msg)
|
record_constraint_violation(constraint.warn_only, self.debug_name(source), msg)
|
||||||
|
|
||||||
for t, source, constraint in zip(placeholders, sources, constraint_inputs):
|
for t, source, context in zip(placeholders, sources, input_contexts):
|
||||||
if isinstance(source, str):
|
if isinstance(source, str):
|
||||||
from torch._dynamo.source import LocalSource
|
from torch._dynamo.source import LocalSource
|
||||||
source = LocalSource(source)
|
source = LocalSource(source)
|
||||||
@ -2734,32 +2763,35 @@ class ShapeEnv:
|
|||||||
track_symint(source, t)
|
track_symint(source, t)
|
||||||
continue
|
continue
|
||||||
assert isinstance(t, Tensorlike)
|
assert isinstance(t, Tensorlike)
|
||||||
sources_and_tensors = [(source, t)]
|
|
||||||
if is_traceable_wrapper_subclass(t):
|
if is_traceable_wrapper_subclass(t):
|
||||||
# If our placeholder is a tensor subclass, then the "true" symints
|
|
||||||
# come from the subclass's inner tensors.
|
|
||||||
attrs, _ = t.__tensor_flatten__()
|
|
||||||
from torch._dynamo.source import AttrSource
|
from torch._dynamo.source import AttrSource
|
||||||
inner_sources_and_tensors = [(AttrSource(source, attr), getattr(t, attr)) for attr in attrs]
|
|
||||||
if t.is_nested:
|
|
||||||
# For NestedTensors we need to track BOTH symints on the outer
|
|
||||||
# tensor and tensor because we'd like to guard on the ragged
|
|
||||||
# size but the symint representing ragged size is not in terms
|
|
||||||
# of the symints on the inner tensors.
|
|
||||||
sources_and_tensors.extend(inner_sources_and_tensors)
|
|
||||||
else:
|
|
||||||
# For other tensor subclasses, only track the symints from
|
|
||||||
# the inner tensors
|
|
||||||
sources_and_tensors = inner_sources_and_tensors
|
|
||||||
|
|
||||||
for src, curr_t in sources_and_tensors:
|
assert isinstance(context, SubclassSymbolicContext)
|
||||||
|
|
||||||
|
# For subclasses, we need to track symints on BOTH the outer
|
||||||
|
# and inner tensors.
|
||||||
|
sources_tensors_constraints = [
|
||||||
|
(source, t, context.constraint_sizes)
|
||||||
|
]
|
||||||
|
attrs, _ = t.__tensor_flatten__()
|
||||||
|
for attr in attrs:
|
||||||
|
inner_t = getattr(t, attr)
|
||||||
|
inner_context = context.inner_contexts[attr]
|
||||||
|
sources_tensors_constraints.append((
|
||||||
|
AttrSource(source, attr),
|
||||||
|
inner_t,
|
||||||
|
inner_context.constraint_sizes
|
||||||
|
))
|
||||||
|
else:
|
||||||
|
sources_tensors_constraints = [(source, t, context.constraint_sizes)]
|
||||||
|
|
||||||
|
for src, curr_t, constraint in sources_tensors_constraints:
|
||||||
for i, ss in enumerate(curr_t.size()):
|
for i, ss in enumerate(curr_t.size()):
|
||||||
property_source = TensorPropertySource(src, TensorProperty.SIZE, i)
|
property_source = TensorPropertySource(src, TensorProperty.SIZE, i)
|
||||||
track_symint(property_source, ss, constraint[i])
|
track_symint(property_source, ss, constraint[i])
|
||||||
if not t.is_nested:
|
for i, ss in enumerate(curr_t.stride()):
|
||||||
for i, ss in enumerate(curr_t.stride()):
|
track_symint(TensorPropertySource(src, TensorProperty.STRIDE, i), ss)
|
||||||
track_symint(TensorPropertySource(src, TensorProperty.STRIDE, i), ss)
|
track_symint(TensorPropertySource(src, TensorProperty.STORAGE_OFFSET), curr_t.storage_offset())
|
||||||
track_symint(TensorPropertySource(src, TensorProperty.STORAGE_OFFSET), curr_t.storage_offset())
|
|
||||||
|
|
||||||
# 1. Every input must equal the final simplified symbolic expression
|
# 1. Every input must equal the final simplified symbolic expression
|
||||||
# stored on the placeholder. Given a placeholder (s0*2, s1),
|
# stored on the placeholder. Given a placeholder (s0*2, s1),
|
||||||
@ -3246,6 +3278,11 @@ class ShapeEnv:
|
|||||||
"""
|
"""
|
||||||
result_expr = safe_expand(expr).xreplace(self.var_to_val)
|
result_expr = safe_expand(expr).xreplace(self.var_to_val)
|
||||||
if not result_expr.is_number:
|
if not result_expr.is_number:
|
||||||
|
|
||||||
|
from torch.utils._sympy.singleton_int import SingletonInt
|
||||||
|
|
||||||
|
if isinstance(result_expr, SingletonInt):
|
||||||
|
return None
|
||||||
r = self._maybe_evaluate_static(result_expr, compute_hint=True)
|
r = self._maybe_evaluate_static(result_expr, compute_hint=True)
|
||||||
if r is not None:
|
if r is not None:
|
||||||
return r
|
return r
|
||||||
|
|||||||
@ -663,7 +663,7 @@ def bisect(shape_env):
|
|||||||
shape_env.produce_guards(
|
shape_env.produce_guards(
|
||||||
[new_with_shape_env(shape_env, a.fake) for a in tracked_fakes],
|
[new_with_shape_env(shape_env, a.fake) for a in tracked_fakes],
|
||||||
[a.source for a in tracked_fakes],
|
[a.source for a in tracked_fakes],
|
||||||
constraint_inputs=[a.constraint_dims for a in tracked_fakes],
|
input_contexts=[a.symbolic_context for a in tracked_fakes],
|
||||||
)
|
)
|
||||||
return None
|
return None
|
||||||
except ValidationException as e:
|
except ValidationException as e:
|
||||||
|
|||||||
@ -130,6 +130,10 @@ class NestedTensor(torch.Tensor):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# collapsed ragged dim must always be dynamic
|
||||||
|
torch._dynamo.mark_dynamic(self, self._ragged_idx)
|
||||||
|
torch._dynamo.mark_dynamic(self._values, self._ragged_idx - 1)
|
||||||
|
|
||||||
def values(self):
|
def values(self):
|
||||||
return self._values
|
return self._values
|
||||||
|
|
||||||
@ -164,7 +168,6 @@ class NestedTensor(torch.Tensor):
|
|||||||
def __tensor_flatten__(self):
|
def __tensor_flatten__(self):
|
||||||
ctx = {
|
ctx = {
|
||||||
"requires_grad": self.requires_grad,
|
"requires_grad": self.requires_grad,
|
||||||
"ragged_size": self._size[self._ragged_idx],
|
|
||||||
"max_seqlen": self._max_seqlen,
|
"max_seqlen": self._max_seqlen,
|
||||||
"min_seqlen": self._min_seqlen,
|
"min_seqlen": self._min_seqlen,
|
||||||
"ragged_idx": self._ragged_idx,
|
"ragged_idx": self._ragged_idx,
|
||||||
@ -175,37 +178,13 @@ class NestedTensor(torch.Tensor):
|
|||||||
return inner_tensors, ctx
|
return inner_tensors, ctx
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def __tensor_unflatten__(inner_tensors: Dict, meta):
|
def __tensor_unflatten__(inner_tensors: Dict, meta, outer_size, outer_stride):
|
||||||
assert len(inner_tensors) >= 2 and len(inner_tensors) <= 3
|
assert len(inner_tensors) >= 2 and len(inner_tensors) <= 3
|
||||||
values = inner_tensors["_values"]
|
values = inner_tensors["_values"]
|
||||||
offsets = inner_tensors["_offsets"]
|
offsets = inner_tensors["_offsets"]
|
||||||
lengths = inner_tensors.get("_lengths", None)
|
lengths = inner_tensors.get("_lengths", None)
|
||||||
|
ragged_idx = meta["ragged_idx"]
|
||||||
|
|
||||||
# NOTE [ Storing symbolic values as plain attributes on subclasses ]
|
|
||||||
#
|
|
||||||
# When a subclass like NestedTensor stores a "size-like" value (which
|
|
||||||
# can either be Symintified or not) into meta, it's responsible for:
|
|
||||||
#
|
|
||||||
# (1) Propagating that symint during torch dispatch when performing
|
|
||||||
# operations, i.e. torch dispatch plays the role of a meta kernel.
|
|
||||||
#
|
|
||||||
# (2) Facilitating the behavior around symbolic -> non-symbolic
|
|
||||||
# conversions and vice versa, see below.
|
|
||||||
#
|
|
||||||
# [ non-symbolic -> symbolic (fakification in meta_utils) ]
|
|
||||||
#
|
|
||||||
# __tensor_unflatten__ is passed symbolic dense tensors and meta from
|
|
||||||
# non-symbolic subclasses. In this case, the subclass is responsible for
|
|
||||||
# intercepting meta["ragged_size"] for example and replacing it with the
|
|
||||||
# symintified version.
|
|
||||||
#
|
|
||||||
# [ symbolic -> non-symbolic ]
|
|
||||||
#
|
|
||||||
# __tensor_unflatten__ is passed non-symbolic dense tensors and with
|
|
||||||
# meta extracted from fake subclasses. In this case the subclass gets
|
|
||||||
# propagated the meta["ragged_size"] which is still a symint and the
|
|
||||||
# subclass is responsible for making sure that the symint doesn't leak.
|
|
||||||
#
|
|
||||||
# Note that we cannot simply check if is_fake(values) because
|
# Note that we cannot simply check if is_fake(values) because
|
||||||
# during aot autograd, FunctionalTensors are not fake but hold
|
# during aot autograd, FunctionalTensors are not fake but hold
|
||||||
# symbolic sizes.
|
# symbolic sizes.
|
||||||
@ -213,7 +192,8 @@ class NestedTensor(torch.Tensor):
|
|||||||
if has_free_symbols(ragged_source) or has_free_symbols(values):
|
if has_free_symbols(ragged_source) or has_free_symbols(values):
|
||||||
# Associate offsets or lengths (possibly fake, possibly functionalized)
|
# Associate offsets or lengths (possibly fake, possibly functionalized)
|
||||||
# with the ragged_size.
|
# with the ragged_size.
|
||||||
_tensor_symint_registry[ragged_source] = meta["ragged_size"]
|
ragged_size = outer_size[ragged_idx]
|
||||||
|
_tensor_symint_registry[ragged_source] = ragged_size
|
||||||
|
|
||||||
return NestedTensor(
|
return NestedTensor(
|
||||||
values,
|
values,
|
||||||
@ -222,7 +202,7 @@ class NestedTensor(torch.Tensor):
|
|||||||
requires_grad=meta["requires_grad"],
|
requires_grad=meta["requires_grad"],
|
||||||
_max_seqlen=meta["max_seqlen"],
|
_max_seqlen=meta["max_seqlen"],
|
||||||
_min_seqlen=meta["min_seqlen"],
|
_min_seqlen=meta["min_seqlen"],
|
||||||
_ragged_idx=meta["ragged_idx"],
|
_ragged_idx=ragged_idx,
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|||||||
@ -222,7 +222,7 @@ class SparseSemiStructuredTensor(torch.Tensor):
|
|||||||
return ['sparse_tensor_cutlass', 'meta_tensor_cutlass'], (self.original_shape, self.transposed)
|
return ['sparse_tensor_cutlass', 'meta_tensor_cutlass'], (self.original_shape, self.transposed)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def __tensor_unflatten__(inner_tensors, meta):
|
def __tensor_unflatten__(inner_tensors, meta, outer_size, outer_stride):
|
||||||
original_shape, transposed = meta
|
original_shape, transposed = meta
|
||||||
|
|
||||||
if len(inner_tensors) == 2:
|
if len(inner_tensors) == 2:
|
||||||
|
|||||||
@ -42,7 +42,7 @@ class TwoTensor(torch.Tensor):
|
|||||||
return ["a", "b"], None
|
return ["a", "b"], None
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def __tensor_unflatten__(inner_tensors, meta):
|
def __tensor_unflatten__(inner_tensors, meta, outer_size, outer_stride):
|
||||||
assert meta is None
|
assert meta is None
|
||||||
a, b = inner_tensors["a"], inner_tensors["b"]
|
a, b = inner_tensors["a"], inner_tensors["b"]
|
||||||
return TwoTensor(a, b)
|
return TwoTensor(a, b)
|
||||||
|
|||||||
@ -142,13 +142,34 @@ def is_traceable_wrapper_subclass(t):
|
|||||||
is 'traceable' with torch.compile.
|
is 'traceable' with torch.compile.
|
||||||
In order for a tensor subclass to support TorchDispatchMode-style tracing in PT2,
|
In order for a tensor subclass to support TorchDispatchMode-style tracing in PT2,
|
||||||
It must implement two magic methods: __tensor_flatten__ and __tensor_unflatten__.
|
It must implement two magic methods: __tensor_flatten__ and __tensor_unflatten__.
|
||||||
It is also expected to obey some restrictions around traceability and aliasing
|
It is also expected to obey some restrictions around traceability and aliasing:
|
||||||
(TODO: add clear documentation around this.)
|
* The subclass's __torch_dispatch__() implementation should desugar into pytorch
|
||||||
|
dispatcher operations that can be traced into a graph.
|
||||||
|
* The subclass should use return_and_correct_aliasing(). This is needed today to make
|
||||||
|
sure that torch.compile does the right thing in a few cases around input mutation
|
||||||
|
and output aliasing.
|
||||||
|
|
||||||
|
Expected magic method signatures:
|
||||||
|
attrs, ctx = t.__tensor_flatten__()
|
||||||
|
attrs: list of attribute name strings for inner tensors
|
||||||
|
ctx: dict containing any other subclass-specific metadata needed for unflattening
|
||||||
|
|
||||||
|
t = MySubClass.__tensor_unflatten__(inner_tensors, ctx, outer_size, outer_stride)
|
||||||
|
inner_tensors: dict mapping attribute name -> tensor for each inner tensor
|
||||||
|
ctx: dict with subclass metadata in the form that __tensor_flatten__() produces
|
||||||
|
outer_size: expected (possibly symbolic) size that the returned subclass
|
||||||
|
instance should have. Note that this arg is useful for certain subclasses
|
||||||
|
that require the shape info to be constructed. In most cases, this arg can be
|
||||||
|
safely ignored.
|
||||||
|
outer_stride: expected (possibly symbolic) stride that the returned subclass
|
||||||
|
instance should have. Note that this arg is useful for certain subclasses
|
||||||
|
that require the stride info to be constructed. In most cases, this arg can be
|
||||||
|
safely ignored.
|
||||||
"""
|
"""
|
||||||
is_subclass = isinstance(t, torch.Tensor) and type(t) != torch.Tensor
|
is_subclass = isinstance(t, torch.Tensor) and type(t) != torch.Tensor
|
||||||
return is_subclass and hasattr(t, "__tensor_flatten__") and hasattr(t, "__tensor_unflatten__")
|
return is_subclass and hasattr(t, "__tensor_flatten__") and hasattr(t, "__tensor_unflatten__")
|
||||||
|
|
||||||
def transform_subclass(t, callback):
|
def transform_subclass(t, callback, outer_size=None, outer_stride=None):
|
||||||
"""
|
"""
|
||||||
Given a traceable, wrapper tensor subclass ``t`` that implements
|
Given a traceable, wrapper tensor subclass ``t`` that implements
|
||||||
``__torch_dispatch__`` and holds some inner tensors,
|
``__torch_dispatch__`` and holds some inner tensors,
|
||||||
@ -162,11 +183,28 @@ def transform_subclass(t, callback):
|
|||||||
gets the same (autograd, and aliasing) metadata as the original tensor.
|
gets the same (autograd, and aliasing) metadata as the original tensor.
|
||||||
This is generally handled in other subsystems like AOTAutograd.
|
This is generally handled in other subsystems like AOTAutograd.
|
||||||
"""
|
"""
|
||||||
|
outer_size = outer_size if outer_size is not None else t.size()
|
||||||
|
outer_stride = outer_stride if outer_stride is not None else t.stride()
|
||||||
|
|
||||||
attrs, ctx = t.__tensor_flatten__()
|
attrs, ctx = t.__tensor_flatten__()
|
||||||
transformed_tensors_dict = {}
|
transformed_tensors_dict = {}
|
||||||
for attr in attrs:
|
for attr in attrs:
|
||||||
transformed_tensors_dict[attr] = callback(attr, getattr(t, attr))
|
transformed_tensors_dict[attr] = callback(attr, getattr(t, attr))
|
||||||
return type(t).__tensor_unflatten__(transformed_tensors_dict, ctx)
|
sub = type(t).__tensor_unflatten__(
|
||||||
|
transformed_tensors_dict, ctx, outer_size, outer_stride
|
||||||
|
)
|
||||||
|
|
||||||
|
# NB: Purposefully guard here to simplify the inner / outer symbols.
|
||||||
|
# Using sym_eq() for symbolic comparison can result in an expression that's too
|
||||||
|
# difficult to guard on, so we use == here.
|
||||||
|
assert sub.shape == outer_size, \
|
||||||
|
f"Expected return value from {type(t)}__tensor_unflatten__() to have " \
|
||||||
|
f"shape equal to {outer_size}, but got: {sub.shape}"
|
||||||
|
assert sub.stride() == outer_stride, \
|
||||||
|
f"Expected return value from {type(t)}__tensor_unflatten__() to have " \
|
||||||
|
f"stride equal to {outer_stride}, but got: {sub.stride()}"
|
||||||
|
|
||||||
|
return sub
|
||||||
|
|
||||||
def _correct_storage_aliasing(func, schema_info, args, outs):
|
def _correct_storage_aliasing(func, schema_info, args, outs):
|
||||||
"""
|
"""
|
||||||
|
|||||||
Reference in New Issue
Block a user