Expand dynamic dims support for traceable subclasses (#114311)

Continuation of #112185, following the design in this [doc](https://docs.google.com/document/d/1ipSxcTzEMMOAPvxP-YJlD5JBZZmIGgh8Q34ixtOUCRo).

Summary:
* Introduce `SubclassSymbolicPolicy` containing separate dynamic dim / constraint policies for the outer and inner tensors
    * Expand the automatic dynamic algorithm to recurse into inner tensors and produce one of these for a subclass instance
    * Maintain legacy behavior for subclasses by recursively calling `mark_dynamic()` on inner tensors *of the same dim as outer* when `mark_dynamic(outer, ...)` is called
    * Addresses this: 6a86cf00ad/torch/_dynamo/variables/builder.py (L1750)
* Add `outer_size` and `outer_stride` arguments to `__tensor_unflatten__()` so that you can find out what symbols were allocated for the outer size / stride (you are expected to return a tensor that compares equal to the outer symbols)
    * Signatures now:
    ```python
    # attrs is a list of inner tensor attributes on x; inner_tensor = getattr(x, attr)
    # ctx is anything useful for rebuilding the class we want to guard on
    attrs, ctx = x.__tensor_flatten__()
    ...
    # inner_tensors is a dict of {attr -> tensor}
    # ctx is taken unmodified from flattening and (eventually) guarded on
    # outer_size is the expected size of the output; possibly symbolic
    # outer_stride is the expected strides of the output; possibly symbolic
    y = MySubclass.__tensor_unflatten__(inner_tensors, ctx, outer_size, outer_stride)

    # at the __tensor_unflatten__() call-site in PT2, we assert y.shape == outer_size and y.stride() == outer_stride
    # the assert simplifies symbols when there are relationships between outer and inner symbols
    ```
    * Size info needed for `NestedTensor` at least, stride info needed for `DTensor` at least
    * Punting on `outer_storage_offset` because storage_offset handling is horribly broken in PT2 right now
* ~~Add new `__tensor_mark_dynamic__()` to allow overriding the behavior of mark_dynamic on a per-subclass basis~~ (booted to future work)
* ~~Add guards for tensor subclasses by calling `__tensor_flatten__()` in the guard to test equality on `ctx`~~
    * Now handled in #114469
* Next PR: add TENSOR_MATCH guards on inner tensors

Pull Request resolved: https://github.com/pytorch/pytorch/pull/114311
Approved by: https://github.com/ezyang, https://github.com/drisspg, https://github.com/voznesenskym, https://github.com/bdhirsh
This commit is contained in:
Joel Schlosser
2023-12-05 12:50:59 -05:00
committed by PyTorch MergeBot
parent 259a99669d
commit 22704426c3
21 changed files with 453 additions and 258 deletions

View File

@ -969,6 +969,7 @@ coverage_ignore_functions = [
"is_concrete_int", "is_concrete_int",
"is_contiguous", "is_contiguous",
"is_non_overlapping_and_dense_indicator", "is_non_overlapping_and_dense_indicator",
"is_singleton",
"is_symbol_binding_fx_node", "is_symbol_binding_fx_node",
"is_symbolic", "is_symbolic",
"parallel_and", "parallel_and",
@ -2869,6 +2870,7 @@ coverage_ignore_classes = [
"SymbolicContext", "SymbolicContext",
"StatelessSymbolicContext", "StatelessSymbolicContext",
"StatefulSymbolicContext", "StatefulSymbolicContext",
"SubclassSymbolicContext",
# torch.fx.experimental.unification.match # torch.fx.experimental.unification.match
"Dispatcher", "Dispatcher",
"VarDispatcher", "VarDispatcher",

View File

@ -696,7 +696,7 @@ class AutogradFunctionTests(torch._dynamo.test_case.TestCase):
) )
@staticmethod @staticmethod
def __tensor_unflatten__(tensors, metadatas): def __tensor_unflatten__(tensors, metadatas, outer_size, outer_stride):
return FooTensor(tensors["_data"], metadatas[0], metadatas[1]) return FooTensor(tensors["_data"], metadatas[0], metadatas[1])
@classmethod @classmethod

View File

@ -63,6 +63,52 @@ class SigmoidToExpSubclass(torch.Tensor):
return super().__torch_function__(func, types, args, kwargs) return super().__torch_function__(func, types, args, kwargs)
# Wrapper subclass with two inner tensors: data and scale
# data has same shape as outer, and scale has single dim size
class ScaledTensor(torch.Tensor):
def __new__(
cls,
data: torch.Tensor,
scale: torch.Tensor,
):
return torch.Tensor._make_wrapper_subclass(
cls,
data.size(),
strides=data.stride(),
storage_offset=data.storage_offset(),
dtype=data.dtype,
layout=data.layout,
requires_grad=data.requires_grad,
device=data.device,
)
def __init__(self, data: torch.Tensor, scale: torch.Tensor):
self._data = data
self._scale = scale
def __tensor_flatten__(self):
ctx = {}
return ["_data", "_scale"], ctx
@staticmethod
def __tensor_unflatten__(inner_tensors, metadata, outer_size, outer_stride):
assert len(inner_tensors) == 2
return ScaledTensor(inner_tensors["_data"], inner_tensors["_scale"])
@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs=None):
scaled_tensor = args[0]
out = func(scaled_tensor._data, *args[1:], **kwargs)
return ScaledTensor(out, scaled_tensor._scale)
def __repr__(self):
return f"{self._data.__repr__()}\n{self._scale.__repr__()}"
def func(a):
return a.sin()
class EagerRecordGraphAndInputs: class EagerRecordGraphAndInputs:
def __init__(self): def __init__(self):
self.graphs = [] self.graphs = []
@ -77,6 +123,21 @@ class EagerRecordGraphAndInputs:
GLOBAL_TEST_SUBCLASSES = {MockSubclass, DummyNDim, SigmoidToExpSubclass} GLOBAL_TEST_SUBCLASSES = {MockSubclass, DummyNDim, SigmoidToExpSubclass}
# Returns True if the function recompiles between inputs1 and inputs2 with the
# specified dynamic setting.
def _recompiles_for_inputs(fn, inputs1, inputs2, dynamic=True):
compile_count = [0]
def counter(gm, example_inputs):
compile_count[0] += 1
return gm
compiled_f = torch.compile(fn, fullgraph=True, backend=counter, dynamic=dynamic)
compiled_f(*inputs1)
compiled_f(*inputs2)
return compile_count[0] > 1
class SubclassTests(torch._dynamo.test_case.TestCase): class SubclassTests(torch._dynamo.test_case.TestCase):
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
@ -608,7 +669,7 @@ class GraphModule(torch.nn.Module):
return ["inner_elem"], None return ["inner_elem"], None
@staticmethod @staticmethod
def __tensor_unflatten__(inner_tensors, _): def __tensor_unflatten__(inner_tensors, _, outer_size, outer_stride):
return DoubleSizeMaybeAddGeThreeTensor(inner_tensors["inner_elem"]) return DoubleSizeMaybeAddGeThreeTensor(inner_tensors["inner_elem"])
def __repr__(self): def __repr__(self):
@ -688,6 +749,42 @@ class GraphModule(torch.nn.Module):
self.assertEqual(lower_bound_str, expected_lower_bound) self.assertEqual(lower_bound_str, expected_lower_bound)
self.assertEqual(upper_bound_str, expected_upper_bound) self.assertEqual(upper_bound_str, expected_upper_bound)
def test_wrapper_subclass_with_same_sized_inner_tensor(self):
# shouldn't recompile for different sizes when dynamic=True
sub1 = ScaledTensor(torch.randn(2, 4), torch.randn(6))
sub2 = ScaledTensor(torch.randn(3, 5), torch.randn(7))
self.assertFalse(_recompiles_for_inputs(func, (sub1,), (sub2,), dynamic=True))
# should recompile for different data size when dynamic=False
sub1 = ScaledTensor(torch.randn(2, 4), torch.randn(6))
sub2 = ScaledTensor(torch.randn(3, 5), torch.randn(6))
self.assertTrue(_recompiles_for_inputs(func, (sub1,), (sub2,), dynamic=False))
# avoid recompile using manual mark_dynamic() for different data size
sub1 = ScaledTensor(torch.randn(2, 4), torch.randn(6))
# NB: mark_dynamic() on outer tensor should translate to inner tensors of the same size
torch._dynamo.mark_dynamic(sub1, 0)
torch._dynamo.mark_dynamic(sub1, 1)
sub2 = ScaledTensor(torch.randn(3, 5), torch.randn(6))
self.assertFalse(_recompiles_for_inputs(func, (sub1,), (sub2,), dynamic=False))
# Broken because we don't guard properly on inner tensors yet.
# TODO: Enable this when we do
@unittest.expectedFailure
def test_wrapper_subclass_with_differently_sized_inner_tensor(self):
# should recompile for different scale size when dynamic=False
sub1 = ScaledTensor(torch.randn(2, 4), torch.randn(3))
sub2 = ScaledTensor(torch.randn(2, 4), torch.randn(5))
self.assertTrue(_recompiles_for_inputs(func, (sub1,), (sub2,), dynamic=False))
# still recompiles using manual mark_dynamic() on outer for different scale size
sub1 = ScaledTensor(torch.randn(2, 4), torch.randn(3))
# NB: mark_dynamic() on outer tensor doesn't translate to inner tensors of different size
torch._dynamo.mark_dynamic(sub1, 0)
torch._dynamo.mark_dynamic(sub1, 1)
sub2 = ScaledTensor(torch.randn(2, 4), torch.randn(5))
self.assertTrue(_recompiles_for_inputs(func, (sub1,), (sub2,), dynamic=False))
def test_recompile_with_symbool_inputs(self): def test_recompile_with_symbool_inputs(self):
def f(pred: bool): def f(pred: bool):
if pred: if pred:
@ -832,18 +929,9 @@ class TestNestedTensor(torch._dynamo.test_case.TestCase):
) )
return jagged_from_tensor_and_lengths(values_tensor, starts, lengths) return jagged_from_tensor_and_lengths(values_tensor, starts, lengths)
def _check_recompiles(self, fn, inputs1, inputs2, recompiles): def _check_recompiles(self, fn, inputs1, inputs2, expected_recompiles):
compile_count = [0] actual_recompiles = _recompiles_for_inputs(fn, inputs1, inputs2)
self.assertEqual(actual_recompiles, expected_recompiles)
def counter(gm, example_inputs):
compile_count[0] += 1
return gm
compiled_f = torch.compile(fn, fullgraph=True, backend=counter, dynamic=True)
out = compiled_f(*inputs1)
self.assertEqual(compile_count[0], 1)
out = compiled_f(*inputs2)
self.assertEqual(compile_count[0], 2 if recompiles else 1)
def test_unary_does_not_recompile(self): def test_unary_does_not_recompile(self):
nt1, _ = self._get_jagged_tensor(((2, 3, 4), 3), None) nt1, _ = self._get_jagged_tensor(((2, 3, 4), 3), None)
@ -857,9 +945,11 @@ class TestNestedTensor(torch._dynamo.test_case.TestCase):
else: else:
return nt1.sin() return nt1.sin()
# Basic binary # NB: If we have shape e.g. (3, j0, 3), duck sizing will give us (s0, s1, s0).
nt1, offsets = self._get_jagged_tensor(((2, 3, 4), 3), None) # This causes a recompile later on when it realizes the batch and last dim
nt2, _ = self._get_jagged_tensor(((2, 3, 4), 3), offsets) # should not always be equal. To avoid that, we use (3, j0, 5) here.
nt1, offsets = self._get_jagged_tensor(((2, 3, 4), 5), None)
nt2, _ = self._get_jagged_tensor(((2, 3, 4), 5), offsets)
nt3, offsets = self._get_jagged_tensor(((3, 4, 5), 4), None) nt3, offsets = self._get_jagged_tensor(((3, 4, 5), 4), None)
nt4, _ = self._get_jagged_tensor(((3, 4, 5), 4), offsets) nt4, _ = self._get_jagged_tensor(((3, 4, 5), 4), offsets)
self._check_recompiles(binary, (nt1, nt2), (nt3, nt4), False) self._check_recompiles(binary, (nt1, nt2), (nt3, nt4), False)
@ -872,9 +962,9 @@ class TestNestedTensor(torch._dynamo.test_case.TestCase):
return nt1.sin() return nt1.sin()
# Binary recompiles because singleton ints no longer match # Binary recompiles because singleton ints no longer match
nt1, offsets = self._get_jagged_tensor(((2, 3, 4), 3), None) nt1, offsets = self._get_jagged_tensor(((2, 3, 4), 5), None)
nt2, _ = self._get_jagged_tensor(((2, 3, 4), 3), offsets) nt2, _ = self._get_jagged_tensor(((2, 3, 4), 5), offsets)
nt3, _ = self._get_jagged_tensor(((2, 3, 4), 3), None) nt3, _ = self._get_jagged_tensor(((2, 3, 4), 5), None)
self._check_recompiles(binary, (nt1, nt2), (nt1, nt3), True) self._check_recompiles(binary, (nt1, nt2), (nt1, nt3), True)
# TODO: cannot parametrize this test class with device for some reason # TODO: cannot parametrize this test class with device for some reason
@ -909,7 +999,10 @@ class TestNestedTensor(torch._dynamo.test_case.TestCase):
self._test_autograd("inductor") self._test_autograd("inductor")
def test_unbind(self): def test_unbind(self):
nt, _ = self._get_jagged_tensor(((2, 3, 4), 3), None) # NB: If we have shape e.g. (3, j0, 3), duck sizing will give us (s0, s1, s0).
# This causes a recompile later on when it realizes the batch and last dim
# should not always be equal. To avoid that, we use (3, j0, 5) here.
nt, _ = self._get_jagged_tensor(((2, 3, 4), 5), None)
nt2, _ = self._get_jagged_tensor(((2, 3, 5), 2), None) nt2, _ = self._get_jagged_tensor(((2, 3, 5), 2), None)
nt3, _ = self._get_jagged_tensor(((2, 3, 4, 5), 3), None) nt3, _ = self._get_jagged_tensor(((2, 3, 4, 5), 3), None)

View File

@ -32,6 +32,7 @@ from torch.nn.utils.rnn import PackedSequence
from torch.testing._internal.common_device_type import instantiate_device_type_tests, toleranceOverride, tol from torch.testing._internal.common_device_type import instantiate_device_type_tests, toleranceOverride, tol
from torch.testing._internal.common_methods_invocations import op_db from torch.testing._internal.common_methods_invocations import op_db
from torch.testing._internal.common_modules import module_db, modules from torch.testing._internal.common_modules import module_db, modules
from torch.testing._internal.common_utils import parametrize, instantiate_parametrized_tests
from torch.testing._internal.control_flow_opinfo_db import control_flow_opinfo_db from torch.testing._internal.control_flow_opinfo_db import control_flow_opinfo_db
from torch.testing._internal.optests import _test_aot_autograd_forwards_backwards_helper, aot_autograd_check from torch.testing._internal.optests import _test_aot_autograd_forwards_backwards_helper, aot_autograd_check
from functorch import ( from functorch import (
@ -1330,9 +1331,6 @@ def forward(self, primals_1):
return [(x,), (x,)] return [(x,), (x,)]
self.verify_aot_autograd(f, partial(inp_callable, req_grad=False), test_mutation=True) self.verify_aot_autograd(f, partial(inp_callable, req_grad=False), test_mutation=True)
self.verify_aot_autograd(f, partial(inp_callable, req_grad=False), test_mutation=True, make_inputs_subclasses=True)
with self.assertRaisesRegex(AssertionError, "which is currently unsupported in the subclass use case"):
self.verify_aot_autograd(f, partial(inp_callable, req_grad=True), test_mutation=True, make_inputs_subclasses=True)
fw_graph = self.verify_aot_autograd(f, partial(inp_callable, req_grad=True), test_mutation=True) fw_graph = self.verify_aot_autograd(f, partial(inp_callable, req_grad=True), test_mutation=True)
# TODO: make this test run with dynamic shapes so it is more meaningful # TODO: make this test run with dynamic shapes so it is more meaningful
# metadata output order: (a_updated_meta, out1_meta, out2_meta, out3_meta) # metadata output order: (a_updated_meta, out1_meta, out2_meta, out3_meta)
@ -1346,6 +1344,21 @@ def forward(self, primals_1):
unsqueeze = torch.ops.aten.unsqueeze.default(transpose, 0) unsqueeze = torch.ops.aten.unsqueeze.default(transpose, 0)
return [transpose, squeeze, transpose_1, unsqueeze, mul]""") return [transpose, squeeze, transpose_1, unsqueeze, mul]""")
@parametrize("req_grad", [False, True])
def test_subclass_metadata_mutation(self, req_grad):
def f(a):
a.transpose_(1, 0)
tmp = a.mul(2)
return tmp.transpose(1, 0)
def inp_callable(req_grad):
x = torch.ones(1, 2, 4, requires_grad=req_grad).clone()
return [(x,), (x,)]
# See https://github.com/pytorch/pytorch/issues/114975
with self.assertRaisesRegex(RuntimeError, "Metadata mutations are currently not allowed on tensor subclasses"):
self.verify_aot_autograd(f, partial(inp_callable, req_grad=req_grad), test_mutation=True, make_inputs_subclasses=True)
def test_input_data_and_metadata_mutation(self): def test_input_data_and_metadata_mutation(self):
def f(a): def f(a):
a.t_() a.t_()
@ -1879,7 +1892,7 @@ def forward(self, primals_1, primals_2):
self.verify_aot_autograd(f, partial(inp_callable, req_grad=False), test_mutation=True) self.verify_aot_autograd(f, partial(inp_callable, req_grad=False), test_mutation=True)
with self.assertRaisesRegex(RuntimeError, "is a tensor subclass. This is not supported today"): with self.assertRaisesRegex(RuntimeError, "Metadata mutations are currently not allowed on tensor subclasses"):
self.verify_aot_autograd(f, partial(inp_callable, req_grad=False), test_mutation=True, make_inputs_subclasses=True) self.verify_aot_autograd(f, partial(inp_callable, req_grad=False), test_mutation=True, make_inputs_subclasses=True)
fw_graph = self.verify_aot_autograd(f, partial(inp_callable, req_grad=True), test_mutation=True) fw_graph = self.verify_aot_autograd(f, partial(inp_callable, req_grad=True), test_mutation=True)
@ -3606,6 +3619,9 @@ def forward(self, tangents_1, tangents_2):
self.assertEqual(b_ref_base.grad.a, b_test_base.grad.a) self.assertEqual(b_ref_base.grad.a, b_test_base.grad.a)
self.assertEqual(b_ref_base.grad.b, b_test_base.grad.b) self.assertEqual(b_ref_base.grad.b, b_test_base.grad.b)
# NB: Metadata mutation for subclasses is currently broken and disabled
# See https://github.com/pytorch/pytorch/issues/114975
@unittest.expectedFailure
def test_aot_dispatch_input_metadata_mutation(self): def test_aot_dispatch_input_metadata_mutation(self):
def f(a, b): def f(a, b):
a.t_() a.t_()
@ -3651,6 +3667,9 @@ def forward(self, tangents_1, tangents_2):
self.assertEqual(b_ref_base.grad.a, b_test_base.grad.a) self.assertEqual(b_ref_base.grad.a, b_test_base.grad.a)
self.assertEqual(b_ref_base.grad.b, b_test_base.grad.b) self.assertEqual(b_ref_base.grad.b, b_test_base.grad.b)
# NB: Metadata mutation for subclasses is currently broken and disabled
# See https://github.com/pytorch/pytorch/issues/114975
@unittest.expectedFailure
def test_aot_dispatch_input_data_and_metadata_mutation(self): def test_aot_dispatch_input_data_and_metadata_mutation(self):
def f(a, b): def f(a, b):
a.t_() a.t_()
@ -3742,6 +3761,7 @@ def forward(self, tangents_1, tangents_2):
self.assertEqual(a_ref_base.grad.a, a_test_base.grad.a) self.assertEqual(a_ref_base.grad.a, a_test_base.grad.a)
self.assertEqual(a_ref_base.grad.b, a_test_base.grad.b) self.assertEqual(a_ref_base.grad.b, a_test_base.grad.b)
class TestAOTModuleSimplified(AOTTestCase): class TestAOTModuleSimplified(AOTTestCase):
def test_aot_module_simplified(self): def test_aot_module_simplified(self):
class MockModule(torch.nn.Module): class MockModule(torch.nn.Module):
@ -4139,6 +4159,7 @@ class TestEagerFusionModuleInfo(AOTTestCase):
_test_aot_autograd_module_helper(self, device, dtype, training, module_info, dynamic=True) _test_aot_autograd_module_helper(self, device, dtype, training, module_info, dynamic=True)
instantiate_parametrized_tests(TestAOTAutograd)
only_for = ("cpu") only_for = ("cpu")
instantiate_device_type_tests( instantiate_device_type_tests(
TestPythonKey, TestPythonKey,

View File

@ -42,7 +42,7 @@ def strip_end(s, suffix):
def show_guards(gm): def show_guards(gm):
names = [strip_end(n, "_1") for n in fx_placeholder_targets(gm)] names = [strip_end(n, "_1") for n in fx_placeholder_targets(gm)]
return "\n".join( return "\n".join(
gm.shape_env.produce_guards(fx_placeholder_vals(gm), names, _simplified=True, constraint_inputs=None) gm.shape_env.produce_guards(fx_placeholder_vals(gm), names, _simplified=True, input_contexts=None)
) )

View File

@ -1,6 +1,7 @@
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
import torch import torch
from torch.utils._python_dispatch import is_traceable_wrapper_subclass
from . import allowed_functions from . import allowed_functions
from .eval_frame import DisableContext, innermost_fn, RunOnlyContext from .eval_frame import DisableContext, innermost_fn, RunOnlyContext
from .exc import IncorrectUsage from .exc import IncorrectUsage
@ -158,6 +159,19 @@ def forbid_in_graph(fn):
return fn return fn
# Helper function to flatten a tensor subclass and apply a function to
# all inner tensors that match the outer dim. Used to reduce duplication
# across the various marking APIs.
def _apply_func_to_inner_tensors_of_same_dim(func, t, *args):
assert is_traceable_wrapper_subclass(t)
attrs, ctx = t.__tensor_flatten__()
for attr in attrs:
inner = getattr(t, attr)
if inner.dim() == t.dim():
func(inner, *args)
@forbid_in_graph @forbid_in_graph
def mark_dynamic(t, index): def mark_dynamic(t, index):
""" """
@ -183,6 +197,11 @@ def mark_dynamic(t, index):
before torch.compile. before torch.compile.
""" """
if is_traceable_wrapper_subclass(t):
# default behavior: mirror mark_dynamic() on all inner tensors with same dim as t
# TODO: Make this configurable via a supported public API
_apply_func_to_inner_tensors_of_same_dim(mark_dynamic, t, index)
if isinstance(index, int): if isinstance(index, int):
if not hasattr(t, "_dynamo_dynamic_indices"): if not hasattr(t, "_dynamo_dynamic_indices"):
t._dynamo_dynamic_indices = set() t._dynamo_dynamic_indices = set()
@ -201,6 +220,11 @@ def maybe_mark_dynamic(t, index):
Mark a tensor as having a dynamic dim, but don't enforce it (i.e., if this Mark a tensor as having a dynamic dim, but don't enforce it (i.e., if this
dimension ends up getting specialized, don't error). dimension ends up getting specialized, don't error).
""" """
if is_traceable_wrapper_subclass(t):
# default behavior: mirror maybe_mark_dynamic() on all inner tensors with same dim as t
# TODO: Make this configurable via a supported public API
_apply_func_to_inner_tensors_of_same_dim(maybe_mark_dynamic, t, index)
if isinstance(index, int): if isinstance(index, int):
if not hasattr(t, "_dynamo_weak_dynamic_indices"): if not hasattr(t, "_dynamo_weak_dynamic_indices"):
t._dynamo_weak_dynamic_indices = set() t._dynamo_weak_dynamic_indices = set()
@ -223,6 +247,11 @@ def mark_static(t, index=None):
This has lower precedence than mark_dynamic. This has lower precedence than mark_dynamic.
""" """
if is_traceable_wrapper_subclass(t):
# default behavior: mirror mark_static() on all inner tensors with same dim as t
# TODO: Make this configurable via a supported public API
_apply_func_to_inner_tensors_of_same_dim(mark_static, t, index)
if isinstance(index, int): if isinstance(index, int):
if not hasattr(t, "_dynamo_static_indices"): if not hasattr(t, "_dynamo_static_indices"):
t._dynamo_static_indices = set() t._dynamo_static_indices = set()
@ -261,7 +290,7 @@ def _allow_in_graph_einops():
try: try:
# requires einops > 0.6.1, torch >= 2.0 # requires einops > 0.6.1, torch >= 2.0
from einops._torch_specific import ( # noqa: F401 from einops._torch_specific import ( # type: ignore[attr-defined] # noqa: F401
_ops_were_registered_in_torchdynamo, _ops_were_registered_in_torchdynamo,
) )

View File

@ -627,7 +627,7 @@ class GuardBuilder(GuardBuilderBase):
output_graph = self.check_fn_manager.output_graph output_graph = self.check_fn_manager.output_graph
# NB: self.output_graph can be None in the debug_nops tests # NB: self.output_graph can be None in the debug_nops tests
fs = output_graph.tracked_fakes fs = output_graph.tracked_fakes
constraint_inputs = [a.constraint_dims for a in fs] input_contexts = [a.symbolic_context for a in fs]
def get_sources(t_id, dim): def get_sources(t_id, dim):
# Looks up base sources mapped to a tensor id and uses them to create # Looks up base sources mapped to a tensor id and uses them to create
@ -670,7 +670,7 @@ class GuardBuilder(GuardBuilderBase):
guards = output_graph.shape_env.produce_guards( guards = output_graph.shape_env.produce_guards(
[a.fake for a in fs], [a.fake for a in fs],
[a.source for a in fs], [a.source for a in fs],
constraint_inputs=constraint_inputs, input_contexts=input_contexts,
equalities_inputs=equalities_inputs, equalities_inputs=equalities_inputs,
source_ref=self.source_ref, source_ref=self.source_ref,
# Export keeps static. # Export keeps static.

View File

@ -26,10 +26,10 @@ from torch._streambase import _EventBase, _StreamBase
from torch._subclasses.fake_tensor import FakeTensor, is_fake, maybe_get_fake_mode from torch._subclasses.fake_tensor import FakeTensor, is_fake, maybe_get_fake_mode
from torch.fx.experimental.symbolic_shapes import ( from torch.fx.experimental.symbolic_shapes import (
_constrain_range_for_size, _constrain_range_for_size,
DimConstraint,
DimDynamic, DimDynamic,
RelaxedUnspecConstraint, RelaxedUnspecConstraint,
StatefulSymbolicContext, StatefulSymbolicContext,
SubclassSymbolicContext,
SymbolicContext, SymbolicContext,
) )
from torch.fx.immutable_collections import immutable_list from torch.fx.immutable_collections import immutable_list
@ -1579,7 +1579,7 @@ class TrackedFake:
fake: Union[FakeTensor, SymInt] fake: Union[FakeTensor, SymInt]
source: Source source: Source
# Is None when fake is SymInt # Is None when fake is SymInt
constraint_dims: Optional[DimList[DimConstraint]] symbolic_context: Optional[SymbolicContext]
def __hash__(self) -> int: def __hash__(self) -> int:
return hash((self.fake, self.source.name())) return hash((self.fake, self.source.name()))
@ -1592,13 +1592,40 @@ class TrackedFake:
# Performs automatic dynamic dim determination. # Performs automatic dynamic dim determination.
# Returns a SymbolicContext # Returns a SymbolicContext
def _automatic_dynamic(e, tx, source, static_shapes) -> SymbolicContext: def _automatic_dynamic(
e, tx, source, static_shapes, outer_only=False
) -> SymbolicContext:
name = source.name() name = source.name()
prior_policy = tx.output.tracing_context.tensor_to_context.get(e, None) prior_policy = tx.output.tracing_context.tensor_to_context.get(e, None)
source_to_symint_node_cache = ( source_to_symint_node_cache = (
prior_policy.source_to_symint_node_cache if prior_policy else None prior_policy.source_to_symint_node_cache if prior_policy else None
) )
if is_traceable_wrapper_subclass(e) and not outer_only:
# Get symbolic context for outer tensor
outer_context = _automatic_dynamic(
e, tx, source, static_shapes, outer_only=True
)
# Get symbolic contexts for inner tensors
attrs, _ = type(e).__tensor_flatten__(e)
inner_contexts = {} # mapping from attr -> symbolic context
for attr in attrs:
inner_tensor = getattr(e, attr)
inner_source = AttrSource(source, attr)
inner_context = _automatic_dynamic(
inner_tensor, tx, inner_source, static_shapes
)
inner_contexts[attr] = inner_context
return SubclassSymbolicContext(
dynamic_sizes=outer_context.dynamic_sizes,
constraint_sizes=outer_context.constraint_sizes,
tensor_source=outer_context.tensor_source,
source_to_symint_node_cache=outer_context.source_to_symint_node_cache,
inner_contexts=inner_contexts,
)
if static_shapes: if static_shapes:
return StatefulSymbolicContext( return StatefulSymbolicContext(
dynamic_sizes=[DimDynamic.STATIC] * e.dim(), dynamic_sizes=[DimDynamic.STATIC] * e.dim(),
@ -1609,7 +1636,9 @@ def _automatic_dynamic(e, tx, source, static_shapes) -> SymbolicContext:
# We preserve the dynamism of inputs. For example, when users call # We preserve the dynamism of inputs. For example, when users call
# make_fx(torch.cond, tracing_mode="symbolic")(*args), inputs have SymInt sizes. # make_fx(torch.cond, tracing_mode="symbolic")(*args), inputs have SymInt sizes.
if any(isinstance(s, SymInt) for s in e.size()): from torch.fx.experimental.symbolic_shapes import is_singleton
if any(isinstance(s, SymInt) and not is_singleton(s) for s in e.size()):
return StatefulSymbolicContext( return StatefulSymbolicContext(
dynamic_sizes=[ dynamic_sizes=[
DimDynamic.DYNAMIC if isinstance(s, SymInt) else DimDynamic.STATIC DimDynamic.DYNAMIC if isinstance(s, SymInt) else DimDynamic.STATIC
@ -1734,7 +1763,12 @@ def _automatic_dynamic(e, tx, source, static_shapes) -> SymbolicContext:
constraint_dims.append(constraint_dim) constraint_dims.append(constraint_dim)
# Now, figure out if the dim is dynamic/duck/static # Now, figure out if the dim is dynamic/duck/static
if constraint_dim is not None or marked_dynamic or marked_weak_dynamic: if (
constraint_dim is not None
or marked_dynamic
or marked_weak_dynamic
or is_singleton(e.shape[i])
):
# NB: We could assert static_shapes is False here, but it # NB: We could assert static_shapes is False here, but it
# seems better to allow the user to override symbolic_context in this # seems better to allow the user to override symbolic_context in this
# case # case
@ -1768,20 +1802,13 @@ def wrap_to_fake_tensor_and_record(e, tx, *, source: Optional[Source], is_tensor
e, is_tensor, guard_source=source.guard_source() e, is_tensor, guard_source=source.guard_source()
) )
symbolic_context = None symbolic_context = _automatic_dynamic(e, tx, source, static_shapes)
if not e.is_nested:
# TODO: We should probably support this for nested tensors too
symbolic_context = _automatic_dynamic(e, tx, source, static_shapes)
if symbolic_context:
tx.output.tracing_context.tensor_to_context[e] = symbolic_context
log.debug( log.debug(
"wrap_to_fake %s %s %s %s", "wrap_to_fake %s %s %s",
source.name(), source.name(),
tuple(e.shape), tuple(e.shape),
symbolic_context.dynamic_sizes if symbolic_context is not None else None, symbolic_context,
symbolic_context.constraint_sizes if symbolic_context is not None else None,
) )
fake_e = wrap_fake_exception( fake_e = wrap_fake_exception(
lambda: tx.fake_mode.from_tensor( lambda: tx.fake_mode.from_tensor(
@ -1790,22 +1817,36 @@ def wrap_to_fake_tensor_and_record(e, tx, *, source: Optional[Source], is_tensor
symbolic_context=symbolic_context, symbolic_context=symbolic_context,
) )
) )
if is_tensor and not (static_shapes and source.is_nn_module()):
# TODO: just store the whole symbolic_context here # list of (fake_tensor, real_tensor, source, symbolic_context)
tx.output.tracked_fakes.append( tracking_info = [(fake_e, e, source, symbolic_context)]
TrackedFake( if is_traceable_wrapper_subclass(fake_e):
fake_e, attrs, _ = fake_e.__tensor_flatten__()
source, for attr in attrs:
symbolic_context.constraint_sizes fake_inner = getattr(fake_e, attr)
if symbolic_context is not None inner = getattr(e, attr)
else None, tracking_info.append(
(
fake_inner,
inner,
AttrSource(source, attr),
symbolic_context.inner_contexts[attr],
)
) )
)
tx.output.tracked_fakes_id_to_source[id(e)].append(source) for fake, real, source, symbolic_context in tracking_info:
tx.output.tensor_weakref_to_sizes_strides[e] = { tx.output.tracing_context.tensor_to_context[real] = symbolic_context
"size": fake_e.size(), tx.output.tensor_weakref_to_sizes_strides[real] = {
"stride": fake_e.stride(), "size": fake.size(),
} "stride": fake.stride(),
}
if is_tensor and not (static_shapes and source.is_nn_module()):
tx.output.tracked_fakes.append(
TrackedFake(fake, source, symbolic_context)
)
tx.output.tracked_fakes_id_to_source[id(real)].append(source)
return fake_e return fake_e
else: else:
return e return e

View File

@ -127,6 +127,10 @@ def run_functionalized_fw_and_collect_metadata(
mutates_metadata = has_metadata_mutation( mutates_metadata = has_metadata_mutation(
f_arg, arg, check_only_storage_mutation=False f_arg, arg, check_only_storage_mutation=False
) )
if mutates_metadata and is_traceable_wrapper_subclass(arg):
raise RuntimeError(
"Metadata mutations are currently not allowed on tensor subclasses"
)
mutates_storage_metadata = has_metadata_mutation( mutates_storage_metadata = has_metadata_mutation(
f_arg, arg, check_only_storage_mutation=True f_arg, arg, check_only_storage_mutation=True
) )

View File

@ -6,7 +6,7 @@ input/output types, metadata, config, function signatures etc.
import collections import collections
from dataclasses import dataclass from dataclasses import dataclass
from enum import Enum from enum import Enum
from typing import Any, Callable, Dict, List, NewType, Optional, Set, Union from typing import Any, Callable, Dict, List, NewType, Optional, Set, Tuple, Union
import torch import torch
import torch.utils._pytree as pytree import torch.utils._pytree as pytree
@ -138,9 +138,12 @@ class SubclassCreationMeta:
# so holding onto this at runtime shouldn't leak memory) # so holding onto this at runtime shouldn't leak memory)
original_subclass: torch.Tensor original_subclass: torch.Tensor
# meta and inner_keys are produced by the subclass's __tensor_flatten__. # meta and inner_keys are produced by the subclass's __tensor_flatten__.
# We need to keep them around to plumb them into __tensor_unflatten__. # We need to keep them around along with outer_size / outer_stride to plumb them
# into __tensor_unflatten__.
meta: Any meta: Any
inner_keys: List[Any] inner_keys: List[Any]
outer_size: Tuple[int, ...]
outer_stride: Tuple[int, ...]
def creation_fn(self, all_args, *, is_runtime: bool): def creation_fn(self, all_args, *, is_runtime: bool):
curr_args = all_args[ curr_args = all_args[
@ -149,8 +152,13 @@ class SubclassCreationMeta:
assert len(curr_args) == len( assert len(curr_args) == len(
self.inner_keys self.inner_keys
), f"inner_keys: {str(self.inner_keys)}. len(curr_args): {len(curr_args)}" ), f"inner_keys: {str(self.inner_keys)}. len(curr_args): {len(curr_args)}"
# NB: Sometimes we have real inner tensors and symbolic metadata.
# TODO: Resolve this so we always have matching real / symbolic tensors / metadata.
out = type(self.original_subclass).__tensor_unflatten__( # type: ignore[attr-defined] out = type(self.original_subclass).__tensor_unflatten__( # type: ignore[attr-defined]
dict(zip(self.inner_keys, curr_args)), self.meta dict(zip(self.inner_keys, curr_args)),
self.meta,
self.outer_size,
self.outer_stride,
) )
if not is_runtime: if not is_runtime:
# After wrapping up the inner dense tensors into a subclass, we need to make sure that our new wrapper # After wrapping up the inner dense tensors into a subclass, we need to make sure that our new wrapper

View File

@ -54,6 +54,8 @@ def create_subclass_meta(
original_subclass=a, original_subclass=a,
meta=meta, meta=meta,
inner_keys=attrs, inner_keys=attrs,
outer_size=a.shape,
outer_stride=a.stride(),
) )
) )
else: else:

View File

@ -1537,27 +1537,13 @@ def make_contiguous_strides_for(
if not shape: if not shape:
return () return ()
# TODO: Move this somewhere central? from torch.fx.experimental.symbolic_shapes import is_singleton
def _is_singleton(s):
# check for SingletonSymNode
if not isinstance(s, torch.SymInt):
return False
if s.node.singleton_int() is not None:
return True
# check for SymInt wrapping a SingletonSymNode (fake-ifying causes this)
return (
s.node.is_symbolic()
and s.node.hint is not None
and isinstance(s.node.hint, torch.SymInt)
and s.node.hint.node.singleton_int() is not None
)
multiplier = 1 multiplier = 1
strides = [] strides = []
for l in reversed(shape): for l in reversed(shape):
strides.append(multiplier) strides.append(multiplier)
multiplier *= l if _is_singleton(l) else sym_max(l, 1) multiplier *= l if is_singleton(l) else sym_max(l, 1)
result = tuple(reversed(strides)) result = tuple(reversed(strides))

View File

@ -186,8 +186,6 @@ class MetaConverter:
source: Optional[Source] = None, source: Optional[Source] = None,
symbolic_context: Optional["SymbolicContext"] = None, symbolic_context: Optional["SymbolicContext"] = None,
): ):
from torch._subclasses.fake_tensor import FakeTensor
if source is None: if source is None:
from torch._dynamo.source import ConstantSource from torch._dynamo.source import ConstantSource
@ -235,10 +233,11 @@ class MetaConverter:
maybe_suppress = shape_env.suppress_guards maybe_suppress = shape_env.suppress_guards
def sym_sizes_strides_storage_offset( def sym_sizes_strides_storage_offset(
t, src t, src, symbolic_context=symbolic_context
) -> Tuple[Tuple[int, ...], Tuple[int, ...], int]: ) -> Tuple[Tuple[int, ...], Tuple[int, ...], int]:
if shape_env is not None: if shape_env is not None:
if isinstance(t, FakeTensor) and t.fake_mode.shape_env is shape_env: fake_mode = torch._subclasses.fake_tensor.maybe_get_fake_mode(t)
if fake_mode is not None and fake_mode.shape_env is shape_env:
# Don't reallocate the sizes; the shape envs are the same, # Don't reallocate the sizes; the shape envs are the same,
# so reuse the old sizes/strides/etc # so reuse the old sizes/strides/etc
return (t.size(), t.stride(), t.storage_offset()) return (t.size(), t.stride(), t.storage_offset())
@ -246,16 +245,25 @@ class MetaConverter:
return shape_env.create_symbolic_sizes_strides_storage_offset( return shape_env.create_symbolic_sizes_strides_storage_offset(
t, t,
src, src,
# Assume that the set of dims that are dynamic are the same between
# the wrapper tensor and any inner tensors.
# We can revisit this if this assumption does not hold
# for any important subclasses later.
symbolic_context=symbolic_context, symbolic_context=symbolic_context,
) )
else: else:
assert symbolic_context is None assert symbolic_context is None
return (t.size(), t.stride(), t.storage_offset()) return (t.size(), t.stride(), t.storage_offset())
def empty_create(inner_t, inner_src, symbolic_context=symbolic_context):
(
inner_sizes,
inner_strides,
inner_storage_offset,
) = sym_sizes_strides_storage_offset(inner_t, inner_src, symbolic_context)
return torch.empty_strided(
inner_sizes,
inner_strides,
dtype=inner_t.dtype,
device="meta",
)
# see expired-storages # see expired-storages
self.check_expired_count += 1 self.check_expired_count += 1
if self.check_expired_count >= self.check_expired_frequency: if self.check_expired_count >= self.check_expired_frequency:
@ -443,99 +451,45 @@ class MetaConverter:
else: else:
is_leaf = safe_is_leaf(t) is_leaf = safe_is_leaf(t)
if not t.is_nested:
# Nested tensor subclasses have special logic for
# creating symbolic size/strides/storage_offset
(
sizes,
strides,
storage_offset,
) = sym_sizes_strides_storage_offset(t, source)
def empty_create(inner_t, inner_src): from torch.fx.experimental.symbolic_shapes import (
( SubclassSymbolicContext,
inner_sizes, )
inner_strides,
inner_storage_offset, (
) = sym_sizes_strides_storage_offset(inner_t, inner_src) sizes,
return torch.empty_strided( strides,
inner_sizes, storage_offset,
inner_strides, ) = sym_sizes_strides_storage_offset(t, source, symbolic_context)
dtype=inner_t.dtype,
device="meta",
)
# If we have a subclass that desugars into dense tensors, # If we have a subclass that desugars into dense tensors,
# perform our callback on each inner tensor. # perform our callback on each inner tensor.
if is_traceable_wrapper_subclass(t): if is_traceable_wrapper_subclass(t):
# Note: transform_subclass will use __tensor_unflatten__ to generate # Note: transform_subclass will use __tensor_unflatten__ to generate
# a fresh subclass wrapper, which is why sizes/strides are not passed in # a fresh subclass wrapper. We assume that if the inner tensors of
# to the creation function here. # the subclass are given symbolic sizes, their sizes will be used
# We assume that if the inner tensors of the subclass are given symbolic sizes, # to construct the (symbolic) sizes of the wrapper tensor.
# their sizes will be used to construct the (symbolic) sizes of the wrapper tensor.
from torch._dynamo.source import AttrSource from torch._dynamo.source import AttrSource
if t.is_nested: assert symbolic_context is None or isinstance(
# Avoid circular import symbolic_context, SubclassSymbolicContext
from torch._dynamo.source import ( )
TensorProperty, r = transform_subclass(
TensorPropertySource, t,
) lambda attr, inner_t: callback(
lambda: empty_create(
# For nested tensors, manually do transform_subclass inner_t,
# so we can insert some special processing on ctx AttrSource(source, attr),
attrs, ctx = t.__tensor_flatten__() symbolic_context=(
transformed_tensors_dict = {} None
orig_shape_env = None if symbolic_context is None
for attr in attrs: else symbolic_context.inner_contexts[attr]
inner_t = getattr(t, attr)
if orig_shape_env is None:
orig_shape_env = (
inner_t.fake_mode.shape_env
if isinstance(inner_t, FakeTensor)
else None
)
transformed_tensors_dict[attr] = callback(
lambda: empty_create(
inner_t, AttrSource(source, attr)
)
)
# We expect JaggedTensor to have a 'ragged_size' in
# its context
assert isinstance(ctx, dict)
assert "ragged_size" in ctx
assert isinstance(t._size[1], torch.SymInt)
if orig_shape_env is shape_env:
# It's already fake and the shape envs line up, reuse the old size
# Do not assert singleton_int; it may already
# be a variable
ctx["ragged_size"] = t._size[1]
else:
assert t._size[1].node.singleton_int() is not None
# Replace the eager ragged size with our freshly
# allocated jagged size that has a source
ctx["ragged_size"] = shape_env.create_symintnode(
shape_env.create_symbol(
t._size[1],
TensorPropertySource(
source, TensorProperty.SIZE, 1
),
), ),
hint=t._size[1],
) )
r = type(t).__tensor_unflatten__( ),
transformed_tensors_dict, ctx outer_size=sizes,
) outer_stride=strides,
else: )
r = transform_subclass(
t,
lambda attr, inner_t: callback(
lambda: empty_create(
inner_t,
AttrSource(source, attr),
)
),
)
else: else:
r = callback( r = callback(
lambda: torch.empty_strided( lambda: torch.empty_strided(

View File

@ -389,7 +389,7 @@ class AsyncCollectiveTensor(torch.Tensor):
return self.elem.tolist() return self.elem.tolist()
@staticmethod @staticmethod
def __tensor_unflatten__(inner_tensors, meta): def __tensor_unflatten__(inner_tensors, meta, outer_size, outer_stride):
assert meta is None assert meta is None
elem = inner_tensors["elem"] elem = inner_tensors["elem"]
return AsyncCollectiveTensor(elem) return AsyncCollectiveTensor(elem)

View File

@ -255,7 +255,7 @@ class DTensor(torch.Tensor): # pyre-ignore[13]: pyre is bad at __new__
return ["_local_tensor"], (self._spec, self.requires_grad) return ["_local_tensor"], (self._spec, self.requires_grad)
@staticmethod @staticmethod
def __tensor_unflatten__(inner_tensors, flatten_spec): def __tensor_unflatten__(inner_tensors, flatten_spec, outer_size, outer_stride):
assert ( assert (
flatten_spec is not None flatten_spec is not None
), "Expecting spec to be not None from `__tensor_flatten__` return value!" ), "Expecting spec to be not None from `__tensor_flatten__` return value!"
@ -265,10 +265,10 @@ class DTensor(torch.Tensor): # pyre-ignore[13]: pyre is bad at __new__
local_tensor, local_tensor,
spec.mesh, spec.mesh,
spec.placements, spec.placements,
shape=spec.tensor_meta.shape, shape=outer_size,
dtype=spec.tensor_meta.dtype, dtype=spec.tensor_meta.dtype,
requires_grad=requires_grad, requires_grad=requires_grad,
stride=spec.tensor_meta.stride, stride=outer_stride,
) )
__torch_function__ = torch._C._disabled_torch_function_impl __torch_function__ = torch._C._disabled_torch_function_impl

View File

@ -62,8 +62,9 @@ __all__ = [
"has_symbolic_sizes_strides", "create_contiguous", "ShapeEnv", "is_concrete_int", "has_symbolic_sizes_strides", "create_contiguous", "ShapeEnv", "is_concrete_int",
"guard_int", "guard_float", "guard_scalar", "canonicalize_bool_expr", "guard_int", "guard_float", "guard_scalar", "canonicalize_bool_expr",
"hint_int", "SYMPY_INTERP", "free_symbols", "is_symbol_binding_fx_node", "hint_int", "SYMPY_INTERP", "free_symbols", "is_symbol_binding_fx_node",
"is_concrete_bool", "SHAPEENV_EVENT_KEY", "CURRENT_NODE_KEY", "is_concrete_bool", "is_singleton", "SHAPEENV_EVENT_KEY", "CURRENT_NODE_KEY",
"has_free_symbols", "sym_eq", "SymbolicContext", "StatelessSymbolicContext", "StatefulSymbolicContext" "has_free_symbols", "sym_eq", "SymbolicContext", "StatelessSymbolicContext",
"StatefulSymbolicContext", "SubclassSymbolicContext"
] ]
# FX node metadata keys for symbolic shape FX graph. # FX node metadata keys for symbolic shape FX graph.
@ -203,6 +204,21 @@ def is_concrete_bool(a: Union[bool, SymBool]):
return False return False
def is_singleton(s):
# check for SingletonSymNode
if not isinstance(s, torch.SymInt):
return False
if s.node.singleton_int() is not None:
return True
# check for symbolic variable wrapping a SingletonSymNode (fake-ifying causes this)
return (
s.node.is_symbolic()
and s.node.hint is not None
and isinstance(s.node.hint, torch.SymInt)
and s.node.hint.node.singleton_int() is not None
)
def _iterate_exprs(val: Union[SymInt, torch.Tensor]) -> Iterable[sympy.Basic]: def _iterate_exprs(val: Union[SymInt, torch.Tensor]) -> Iterable[sympy.Basic]:
if isinstance(val, SymTypes): if isinstance(val, SymTypes):
# This allow applies to the jagged layout NestedTensor case as # This allow applies to the jagged layout NestedTensor case as
@ -851,6 +867,20 @@ class StatefulSymbolicContext(StatelessSymbolicContext):
object.__setattr__(self, 'source_to_symint_node_cache', {}) object.__setattr__(self, 'source_to_symint_node_cache', {})
@dataclass(frozen=True)
class SubclassSymbolicContext(StatefulSymbolicContext):
"""
The correct symbolic context for a given inner tensor of a traceable tensor subclass
may differ from that of the outer symbolic context. This structure allows for this
flexibility, with inner symbolic contexts mapped via attr -> symbolic context.
"""
inner_contexts: Dict[str, SymbolicContext] = None
def __post_init__(self):
if self.inner_contexts is None:
self.inner_contexts = {}
def is_symbolic(val: Union[int, SymInt, float, SymFloat, bool, SymBool]) -> bool: def is_symbolic(val: Union[int, SymInt, float, SymFloat, bool, SymBool]) -> bool:
if isinstance(val, (int, float, bool)): if isinstance(val, (int, float, bool)):
return False return False
@ -1864,7 +1894,7 @@ class ShapeEnv:
# FakeTensorMeta for two reasons: # FakeTensorMeta for two reasons:
# 1. this is all the information we need when recording ShapeEnvEvents. # 1. this is all the information we need when recording ShapeEnvEvents.
# 2. it works even if each TrackedFake changes its metadata. # 2. it works even if each TrackedFake changes its metadata.
return TrackedFake(inner_fake, fake.source, fake.constraint_dims) # type: ignore[arg-type] return TrackedFake(inner_fake, fake.source, fake.symbolic_context) # type: ignore[arg-type]
return [maybe_transform_fake(fake) for fake in self.tracked_fakes] return [maybe_transform_fake(fake) for fake in self.tracked_fakes]
@ -2091,8 +2121,6 @@ class ShapeEnv:
# The order of checking the guards matters. In this specific example: # The order of checking the guards matters. In this specific example:
# If True branch guard check precedes False branch and for True branch, y.size(0) check precedes x == True, # If True branch guard check precedes False branch and for True branch, y.size(0) check precedes x == True,
# we may have an unnessary shape speciliazation for y. # we may have an unnessary shape speciliazation for y.
assert not ex.is_nested
def maybe_specialize_sym_int_with_hint(maybe_sym) -> int: def maybe_specialize_sym_int_with_hint(maybe_sym) -> int:
assert isinstance(maybe_sym, (int, torch.SymInt)) assert isinstance(maybe_sym, (int, torch.SymInt))
if is_symbolic(maybe_sym): if is_symbolic(maybe_sym):
@ -2174,7 +2202,13 @@ class ShapeEnv:
} }
# iterate over unbound strides in sorted order # iterate over unbound strides in sorted order
val_list = sorted( val_list = sorted(
[(ex_stride[i], i) for i in range(len(stride)) if stride[i] is None] [(ex_stride[i], i) for i in range(len(stride)) if stride[i] is None],
key=lambda tup: (
# Order singletons by their coefficients.
# 1 here to order singletons after non-singletons.
(1, tup[0].node.singleton_coeff(), tup[1]) if is_singleton(tup[0])
else (0, *tup)
)
) )
for _, i in val_list: for _, i in val_list:
if stride[i] is None and ex_stride[i] in candidates: if stride[i] is None and ex_stride[i] in candidates:
@ -2367,9 +2401,6 @@ class ShapeEnv:
dynamic_dim = DimDynamic.DYNAMIC dynamic_dim = DimDynamic.DYNAMIC
if dynamic_dim is DimDynamic.STATIC: if dynamic_dim is DimDynamic.STATIC:
# We don't expect to ever reach here even the user specifies
# dynamic=False, because automatic_dynamic skipped for
# nested tensors.
return sympy.Integer(val) return sympy.Integer(val)
elif dynamic_dim is DimDynamic.DUCK: elif dynamic_dim is DimDynamic.DUCK:
@ -2492,11 +2523,7 @@ class ShapeEnv:
sources, sources,
source_ref=lambda n: n.name(), source_ref=lambda n: n.name(),
*, *,
# An input is either a SymInt (in which case you directly have input_contexts: Optional[DimList[SymbolicContext]] = None,
# DimConstraint) or a Tensor (in which case you have a
# DimList[DimConstraint]). Whenever Optional is accepted, that
# just means there are no constraints
constraint_inputs: Optional[InputList[Union[DimConstraint, Optional[DimList[DimConstraint]]]]] = None,
equalities_inputs: Optional[Set[Tuple[Source, Source]]] = None, equalities_inputs: Optional[Set[Tuple[Source, Source]]] = None,
_simplified=False, _simplified=False,
# Indicates if we should produce guards for known static values. # Indicates if we should produce guards for known static values.
@ -2521,22 +2548,28 @@ class ShapeEnv:
assert len(placeholders) == len(sources) assert len(placeholders) == len(sources)
Tensorlike = (torch.Tensor, FakeTensorMeta) Tensorlike = (torch.Tensor, FakeTensorMeta)
def _create_no_constraints_context(t):
return StatelessSymbolicContext(
# Ignored; only the constraints part is relevant below.
dynamic_sizes=[DimDynamic.DYNAMIC] * t.dim(),
constraint_sizes=[None] * t.dim()
)
# Expand optional inputs, or verify invariants are upheld # Expand optional inputs, or verify invariants are upheld
if constraint_inputs is None: if input_contexts is None:
constraint_inputs = [ input_contexts = [
[None] * t.dim() if isinstance(t, Tensorlike) else None for t in placeholders _create_no_constraints_context(t) if isinstance(t, Tensorlike)
else None for t in placeholders
] ]
else: else:
assert len(constraint_inputs) == len(placeholders) assert len(input_contexts) == len(placeholders)
for i, (t, constraint) in enumerate(zip(placeholders, constraint_inputs)): for i, (t, context) in enumerate(zip(placeholders, input_contexts)):
if isinstance(t, Tensorlike): if isinstance(t, Tensorlike):
if constraint is None: if context is None:
constraint_inputs[i] = [None] * t.dim() input_contexts[i] = _create_no_constraints_context(t)
else:
assert len(constraint) == t.dim()
else: else:
assert isinstance(t, (SymInt, int)) assert isinstance(t, (SymInt, int))
assert not isinstance(constraint, list) assert not isinstance(context, list)
# It took a lot of sweat to figure out the algorithm here. Let's # It took a lot of sweat to figure out the algorithm here. Let's
# explain how it works. # explain how it works.
@ -2682,10 +2715,6 @@ class ShapeEnv:
# expect to have to compile in this case anyway # expect to have to compile in this case anyway
if i not in (0, 1): if i not in (0, 1):
constraint_violated = True constraint_violated = True
else:
# TODO: Maybe non-strict constraint shouldn't error
# here? Check what happens in practice
constraint_violated = True
if constraint_violated: if constraint_violated:
def hint(s): def hint(s):
sexpr = ShapeGuardPrinter(symbol_to_source, source_ref, self.var_to_sources).doprint(s) sexpr = ShapeGuardPrinter(symbol_to_source, source_ref, self.var_to_sources).doprint(s)
@ -2723,7 +2752,7 @@ class ShapeEnv:
) )
record_constraint_violation(constraint.warn_only, self.debug_name(source), msg) record_constraint_violation(constraint.warn_only, self.debug_name(source), msg)
for t, source, constraint in zip(placeholders, sources, constraint_inputs): for t, source, context in zip(placeholders, sources, input_contexts):
if isinstance(source, str): if isinstance(source, str):
from torch._dynamo.source import LocalSource from torch._dynamo.source import LocalSource
source = LocalSource(source) source = LocalSource(source)
@ -2734,32 +2763,35 @@ class ShapeEnv:
track_symint(source, t) track_symint(source, t)
continue continue
assert isinstance(t, Tensorlike) assert isinstance(t, Tensorlike)
sources_and_tensors = [(source, t)]
if is_traceable_wrapper_subclass(t): if is_traceable_wrapper_subclass(t):
# If our placeholder is a tensor subclass, then the "true" symints
# come from the subclass's inner tensors.
attrs, _ = t.__tensor_flatten__()
from torch._dynamo.source import AttrSource from torch._dynamo.source import AttrSource
inner_sources_and_tensors = [(AttrSource(source, attr), getattr(t, attr)) for attr in attrs]
if t.is_nested:
# For NestedTensors we need to track BOTH symints on the outer
# tensor and tensor because we'd like to guard on the ragged
# size but the symint representing ragged size is not in terms
# of the symints on the inner tensors.
sources_and_tensors.extend(inner_sources_and_tensors)
else:
# For other tensor subclasses, only track the symints from
# the inner tensors
sources_and_tensors = inner_sources_and_tensors
for src, curr_t in sources_and_tensors: assert isinstance(context, SubclassSymbolicContext)
# For subclasses, we need to track symints on BOTH the outer
# and inner tensors.
sources_tensors_constraints = [
(source, t, context.constraint_sizes)
]
attrs, _ = t.__tensor_flatten__()
for attr in attrs:
inner_t = getattr(t, attr)
inner_context = context.inner_contexts[attr]
sources_tensors_constraints.append((
AttrSource(source, attr),
inner_t,
inner_context.constraint_sizes
))
else:
sources_tensors_constraints = [(source, t, context.constraint_sizes)]
for src, curr_t, constraint in sources_tensors_constraints:
for i, ss in enumerate(curr_t.size()): for i, ss in enumerate(curr_t.size()):
property_source = TensorPropertySource(src, TensorProperty.SIZE, i) property_source = TensorPropertySource(src, TensorProperty.SIZE, i)
track_symint(property_source, ss, constraint[i]) track_symint(property_source, ss, constraint[i])
if not t.is_nested: for i, ss in enumerate(curr_t.stride()):
for i, ss in enumerate(curr_t.stride()): track_symint(TensorPropertySource(src, TensorProperty.STRIDE, i), ss)
track_symint(TensorPropertySource(src, TensorProperty.STRIDE, i), ss) track_symint(TensorPropertySource(src, TensorProperty.STORAGE_OFFSET), curr_t.storage_offset())
track_symint(TensorPropertySource(src, TensorProperty.STORAGE_OFFSET), curr_t.storage_offset())
# 1. Every input must equal the final simplified symbolic expression # 1. Every input must equal the final simplified symbolic expression
# stored on the placeholder. Given a placeholder (s0*2, s1), # stored on the placeholder. Given a placeholder (s0*2, s1),
@ -3246,6 +3278,11 @@ class ShapeEnv:
""" """
result_expr = safe_expand(expr).xreplace(self.var_to_val) result_expr = safe_expand(expr).xreplace(self.var_to_val)
if not result_expr.is_number: if not result_expr.is_number:
from torch.utils._sympy.singleton_int import SingletonInt
if isinstance(result_expr, SingletonInt):
return None
r = self._maybe_evaluate_static(result_expr, compute_hint=True) r = self._maybe_evaluate_static(result_expr, compute_hint=True)
if r is not None: if r is not None:
return r return r

View File

@ -663,7 +663,7 @@ def bisect(shape_env):
shape_env.produce_guards( shape_env.produce_guards(
[new_with_shape_env(shape_env, a.fake) for a in tracked_fakes], [new_with_shape_env(shape_env, a.fake) for a in tracked_fakes],
[a.source for a in tracked_fakes], [a.source for a in tracked_fakes],
constraint_inputs=[a.constraint_dims for a in tracked_fakes], input_contexts=[a.symbolic_context for a in tracked_fakes],
) )
return None return None
except ValidationException as e: except ValidationException as e:

View File

@ -130,6 +130,10 @@ class NestedTensor(torch.Tensor):
) )
) )
# collapsed ragged dim must always be dynamic
torch._dynamo.mark_dynamic(self, self._ragged_idx)
torch._dynamo.mark_dynamic(self._values, self._ragged_idx - 1)
def values(self): def values(self):
return self._values return self._values
@ -164,7 +168,6 @@ class NestedTensor(torch.Tensor):
def __tensor_flatten__(self): def __tensor_flatten__(self):
ctx = { ctx = {
"requires_grad": self.requires_grad, "requires_grad": self.requires_grad,
"ragged_size": self._size[self._ragged_idx],
"max_seqlen": self._max_seqlen, "max_seqlen": self._max_seqlen,
"min_seqlen": self._min_seqlen, "min_seqlen": self._min_seqlen,
"ragged_idx": self._ragged_idx, "ragged_idx": self._ragged_idx,
@ -175,37 +178,13 @@ class NestedTensor(torch.Tensor):
return inner_tensors, ctx return inner_tensors, ctx
@staticmethod @staticmethod
def __tensor_unflatten__(inner_tensors: Dict, meta): def __tensor_unflatten__(inner_tensors: Dict, meta, outer_size, outer_stride):
assert len(inner_tensors) >= 2 and len(inner_tensors) <= 3 assert len(inner_tensors) >= 2 and len(inner_tensors) <= 3
values = inner_tensors["_values"] values = inner_tensors["_values"]
offsets = inner_tensors["_offsets"] offsets = inner_tensors["_offsets"]
lengths = inner_tensors.get("_lengths", None) lengths = inner_tensors.get("_lengths", None)
ragged_idx = meta["ragged_idx"]
# NOTE [ Storing symbolic values as plain attributes on subclasses ]
#
# When a subclass like NestedTensor stores a "size-like" value (which
# can either be Symintified or not) into meta, it's responsible for:
#
# (1) Propagating that symint during torch dispatch when performing
# operations, i.e. torch dispatch plays the role of a meta kernel.
#
# (2) Facilitating the behavior around symbolic -> non-symbolic
# conversions and vice versa, see below.
#
# [ non-symbolic -> symbolic (fakification in meta_utils) ]
#
# __tensor_unflatten__ is passed symbolic dense tensors and meta from
# non-symbolic subclasses. In this case, the subclass is responsible for
# intercepting meta["ragged_size"] for example and replacing it with the
# symintified version.
#
# [ symbolic -> non-symbolic ]
#
# __tensor_unflatten__ is passed non-symbolic dense tensors and with
# meta extracted from fake subclasses. In this case the subclass gets
# propagated the meta["ragged_size"] which is still a symint and the
# subclass is responsible for making sure that the symint doesn't leak.
#
# Note that we cannot simply check if is_fake(values) because # Note that we cannot simply check if is_fake(values) because
# during aot autograd, FunctionalTensors are not fake but hold # during aot autograd, FunctionalTensors are not fake but hold
# symbolic sizes. # symbolic sizes.
@ -213,7 +192,8 @@ class NestedTensor(torch.Tensor):
if has_free_symbols(ragged_source) or has_free_symbols(values): if has_free_symbols(ragged_source) or has_free_symbols(values):
# Associate offsets or lengths (possibly fake, possibly functionalized) # Associate offsets or lengths (possibly fake, possibly functionalized)
# with the ragged_size. # with the ragged_size.
_tensor_symint_registry[ragged_source] = meta["ragged_size"] ragged_size = outer_size[ragged_idx]
_tensor_symint_registry[ragged_source] = ragged_size
return NestedTensor( return NestedTensor(
values, values,
@ -222,7 +202,7 @@ class NestedTensor(torch.Tensor):
requires_grad=meta["requires_grad"], requires_grad=meta["requires_grad"],
_max_seqlen=meta["max_seqlen"], _max_seqlen=meta["max_seqlen"],
_min_seqlen=meta["min_seqlen"], _min_seqlen=meta["min_seqlen"],
_ragged_idx=meta["ragged_idx"], _ragged_idx=ragged_idx,
) )
@classmethod @classmethod

View File

@ -222,7 +222,7 @@ class SparseSemiStructuredTensor(torch.Tensor):
return ['sparse_tensor_cutlass', 'meta_tensor_cutlass'], (self.original_shape, self.transposed) return ['sparse_tensor_cutlass', 'meta_tensor_cutlass'], (self.original_shape, self.transposed)
@staticmethod @staticmethod
def __tensor_unflatten__(inner_tensors, meta): def __tensor_unflatten__(inner_tensors, meta, outer_size, outer_stride):
original_shape, transposed = meta original_shape, transposed = meta
if len(inner_tensors) == 2: if len(inner_tensors) == 2:

View File

@ -42,7 +42,7 @@ class TwoTensor(torch.Tensor):
return ["a", "b"], None return ["a", "b"], None
@staticmethod @staticmethod
def __tensor_unflatten__(inner_tensors, meta): def __tensor_unflatten__(inner_tensors, meta, outer_size, outer_stride):
assert meta is None assert meta is None
a, b = inner_tensors["a"], inner_tensors["b"] a, b = inner_tensors["a"], inner_tensors["b"]
return TwoTensor(a, b) return TwoTensor(a, b)

View File

@ -142,13 +142,34 @@ def is_traceable_wrapper_subclass(t):
is 'traceable' with torch.compile. is 'traceable' with torch.compile.
In order for a tensor subclass to support TorchDispatchMode-style tracing in PT2, In order for a tensor subclass to support TorchDispatchMode-style tracing in PT2,
It must implement two magic methods: __tensor_flatten__ and __tensor_unflatten__. It must implement two magic methods: __tensor_flatten__ and __tensor_unflatten__.
It is also expected to obey some restrictions around traceability and aliasing It is also expected to obey some restrictions around traceability and aliasing:
(TODO: add clear documentation around this.) * The subclass's __torch_dispatch__() implementation should desugar into pytorch
dispatcher operations that can be traced into a graph.
* The subclass should use return_and_correct_aliasing(). This is needed today to make
sure that torch.compile does the right thing in a few cases around input mutation
and output aliasing.
Expected magic method signatures:
attrs, ctx = t.__tensor_flatten__()
attrs: list of attribute name strings for inner tensors
ctx: dict containing any other subclass-specific metadata needed for unflattening
t = MySubClass.__tensor_unflatten__(inner_tensors, ctx, outer_size, outer_stride)
inner_tensors: dict mapping attribute name -> tensor for each inner tensor
ctx: dict with subclass metadata in the form that __tensor_flatten__() produces
outer_size: expected (possibly symbolic) size that the returned subclass
instance should have. Note that this arg is useful for certain subclasses
that require the shape info to be constructed. In most cases, this arg can be
safely ignored.
outer_stride: expected (possibly symbolic) stride that the returned subclass
instance should have. Note that this arg is useful for certain subclasses
that require the stride info to be constructed. In most cases, this arg can be
safely ignored.
""" """
is_subclass = isinstance(t, torch.Tensor) and type(t) != torch.Tensor is_subclass = isinstance(t, torch.Tensor) and type(t) != torch.Tensor
return is_subclass and hasattr(t, "__tensor_flatten__") and hasattr(t, "__tensor_unflatten__") return is_subclass and hasattr(t, "__tensor_flatten__") and hasattr(t, "__tensor_unflatten__")
def transform_subclass(t, callback): def transform_subclass(t, callback, outer_size=None, outer_stride=None):
""" """
Given a traceable, wrapper tensor subclass ``t`` that implements Given a traceable, wrapper tensor subclass ``t`` that implements
``__torch_dispatch__`` and holds some inner tensors, ``__torch_dispatch__`` and holds some inner tensors,
@ -162,11 +183,28 @@ def transform_subclass(t, callback):
gets the same (autograd, and aliasing) metadata as the original tensor. gets the same (autograd, and aliasing) metadata as the original tensor.
This is generally handled in other subsystems like AOTAutograd. This is generally handled in other subsystems like AOTAutograd.
""" """
outer_size = outer_size if outer_size is not None else t.size()
outer_stride = outer_stride if outer_stride is not None else t.stride()
attrs, ctx = t.__tensor_flatten__() attrs, ctx = t.__tensor_flatten__()
transformed_tensors_dict = {} transformed_tensors_dict = {}
for attr in attrs: for attr in attrs:
transformed_tensors_dict[attr] = callback(attr, getattr(t, attr)) transformed_tensors_dict[attr] = callback(attr, getattr(t, attr))
return type(t).__tensor_unflatten__(transformed_tensors_dict, ctx) sub = type(t).__tensor_unflatten__(
transformed_tensors_dict, ctx, outer_size, outer_stride
)
# NB: Purposefully guard here to simplify the inner / outer symbols.
# Using sym_eq() for symbolic comparison can result in an expression that's too
# difficult to guard on, so we use == here.
assert sub.shape == outer_size, \
f"Expected return value from {type(t)}__tensor_unflatten__() to have " \
f"shape equal to {outer_size}, but got: {sub.shape}"
assert sub.stride() == outer_stride, \
f"Expected return value from {type(t)}__tensor_unflatten__() to have " \
f"stride equal to {outer_stride}, but got: {sub.stride()}"
return sub
def _correct_storage_aliasing(func, schema_info, args, outs): def _correct_storage_aliasing(func, schema_info, args, outs):
""" """