mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
This reverts commit 5f630d28d7ff9fdd8bd6cdbe2438e5c821007845. Reverted https://github.com/pytorch/pytorch/pull/162525 on behalf of https://github.com/anijain2305 due to internal tests fail ([comment](https://github.com/pytorch/pytorch/pull/162525#issuecomment-3310748980))
1367 lines
46 KiB
Python
1367 lines
46 KiB
Python
# Owner(s): ["module: dynamo"]
|
|
import abc
|
|
import functools
|
|
import inspect
|
|
import unittest
|
|
import weakref
|
|
|
|
import torch
|
|
import torch._dynamo
|
|
import torch._dynamo.test_case
|
|
from torch._C._dynamo import guards
|
|
from torch._dynamo.convert_frame import GlobalStateGuard
|
|
from torch._dynamo.eval_frame import _debug_get_cache_entry_list
|
|
from torch.testing._internal.common_utils import set_default_dtype
|
|
|
|
|
|
RootGuardManager = guards.RootGuardManager
|
|
DictGuardManager = guards.DictGuardManager
|
|
GetAttrGuardAccessor = guards.GetAttrGuardAccessor
|
|
GetItemGuardAccessor = guards.GetItemGuardAccessor
|
|
TypeGuardAccessor = guards.TypeGuardAccessor
|
|
OBJECT_ALIASING = guards.OBJECT_ALIASING
|
|
install_object_aliasing_guard = guards.install_object_aliasing_guard
|
|
NO_TENSOR_ALIASING = guards.NO_TENSOR_ALIASING
|
|
install_no_tensor_aliasing_guard = guards.install_no_tensor_aliasing_guard
|
|
|
|
|
|
x = torch.tensor(4)
|
|
weakref_x = weakref.ref(x)
|
|
|
|
default_mgr_enum = torch._dynamo.guards.GuardManagerType.GUARD_MANAGER
|
|
|
|
|
|
class Pair:
|
|
def __init__(self, x, y):
|
|
self.x = x
|
|
self.y = y
|
|
|
|
|
|
global_pair = Pair(torch.randn(4), 1)
|
|
|
|
|
|
def id_type(x):
|
|
return id(type(x))
|
|
|
|
|
|
def equals_match(x, expected):
|
|
return x == expected
|
|
|
|
|
|
def equals_match_verbose_code_parts(expected):
|
|
return [f"x == {expected}"]
|
|
|
|
|
|
def ge_match(x, expected):
|
|
return x >= expected
|
|
|
|
|
|
def ge_match_verbose_code_parts(expected):
|
|
return f"expected >= {expected}"
|
|
|
|
|
|
def less_match(x, expected):
|
|
return x < expected
|
|
|
|
|
|
def less_match_verbose_code_parts(expected):
|
|
return [f"expected < {expected}"]
|
|
|
|
|
|
class GuardManagerTests(torch._dynamo.test_case.TestCase):
|
|
def test_global_state_guard(self):
|
|
root = RootGuardManager()
|
|
guard = guards.GLOBAL_STATE(root, ["global_state_check"])
|
|
self.assertTrue(guard(None))
|
|
with set_default_dtype(torch.double):
|
|
self.assertFalse(guard(None))
|
|
self.assertExpectedInline(
|
|
str(guard.check_verbose(None)),
|
|
"""\
|
|
GuardDebugInfo(
|
|
result=0,
|
|
verbose_code_parts=['GLOBAL_STATE changed: default_dtype '],
|
|
num_guards_executed=0)
|
|
""",
|
|
)
|
|
self.assertTrue(guard(None))
|
|
self.assertTrue(guard.check_verbose(None).result)
|
|
_orig = torch.are_deterministic_algorithms_enabled()
|
|
try:
|
|
torch.use_deterministic_algorithms(not _orig)
|
|
self.assertFalse(guard(None))
|
|
self.assertExpectedInline(
|
|
str(guard.check_verbose(None)),
|
|
"""\
|
|
GuardDebugInfo(
|
|
result=0,
|
|
verbose_code_parts=['GLOBAL_STATE changed: deterministic_algorithms '],
|
|
num_guards_executed=0)
|
|
""",
|
|
)
|
|
finally:
|
|
torch.use_deterministic_algorithms(_orig)
|
|
self.assertTrue(guard(None))
|
|
self.assertTrue(guard.check_verbose(None).result)
|
|
|
|
def test_global_state_reason(self):
|
|
with torch.enable_grad():
|
|
guards = GlobalStateGuard()
|
|
with torch.no_grad():
|
|
self.assertIs(guards.check(), False)
|
|
self.assertEqual(guards.reason(), "grad_mode ")
|
|
|
|
def test_python_lambda_leaf_guard(self):
|
|
root = RootGuardManager()
|
|
const_guard = guards.LAMBDA_GUARD(
|
|
root,
|
|
functools.partial(equals_match, expected=5),
|
|
equals_match_verbose_code_parts(5),
|
|
)
|
|
self.assertTrue(const_guard(5))
|
|
self.assertFalse(const_guard(4))
|
|
self.assertFalse(const_guard("foo"))
|
|
|
|
def test_type_guard(self):
|
|
root = RootGuardManager()
|
|
foo = 4
|
|
guard = guards.TYPE_MATCH(root, id_type(foo), ["type(x) == int"])
|
|
|
|
self.assertTrue(guard(5))
|
|
self.assertTrue(guard(4))
|
|
self.assertFalse(guard("foo"))
|
|
|
|
foo = {"a": 1}
|
|
guard = guards.TYPE_MATCH(root, id_type(foo), ["type(x) == dict"])
|
|
self.assertTrue(guard(foo))
|
|
self.assertTrue(guard({}))
|
|
self.assertFalse(guard(5))
|
|
self.assertFalse(guard("foo"))
|
|
|
|
class Foo:
|
|
def __init__(self, x, y):
|
|
self.x = x
|
|
self.y = y
|
|
|
|
foo = Foo(1, 2)
|
|
|
|
guard = guards.TYPE_MATCH(root, id_type(foo), ["type(x) == Foo"])
|
|
self.assertTrue(guard(foo))
|
|
self.assertFalse(guard({}))
|
|
self.assertFalse(guard(5))
|
|
self.assertFalse(guard("foo"))
|
|
|
|
def test_id_guard(self):
|
|
root = RootGuardManager()
|
|
foo = 4
|
|
guard = guards.ID_MATCH(root, id(foo), ["id(x) == id(foo)"])
|
|
|
|
self.assertTrue(guard(foo))
|
|
self.assertFalse(guard(5))
|
|
self.assertFalse(guard("foo"))
|
|
|
|
foo = {"a": 1}
|
|
guard = guards.ID_MATCH(root, id(foo), ["id(x) == id(foo)"])
|
|
self.assertTrue(guard(foo))
|
|
self.assertFalse(guard({"a": 1}))
|
|
self.assertFalse(guard({}))
|
|
self.assertFalse(guard(5))
|
|
|
|
def test_equals_guard(self):
|
|
root = RootGuardManager()
|
|
foo = 4
|
|
guard = guards.EQUALS_MATCH(root, foo, ["x == 4"])
|
|
|
|
self.assertTrue(guard(4))
|
|
self.assertFalse(guard(5))
|
|
self.assertFalse(guard("foo"))
|
|
|
|
# tuple
|
|
foo = (1, 2, 3)
|
|
guard = guards.EQUALS_MATCH(root, foo, ["x == foo"])
|
|
self.assertTrue(guard(foo))
|
|
self.assertTrue(guard((1, 2, 3)))
|
|
self.assertFalse(guard((1, 2, 3, 4)))
|
|
self.assertFalse(guard({}))
|
|
|
|
# list
|
|
foo = [1, 2, 3]
|
|
guard = guards.EQUALS_MATCH(root, foo, ["x == foo"])
|
|
self.assertTrue(guard(foo))
|
|
self.assertTrue(guard([1, 2, 3]))
|
|
self.assertFalse(guard([1, 2, 3, 4]))
|
|
|
|
# type
|
|
foo = int
|
|
guard = guards.EQUALS_MATCH(root, foo, ["x == foo"])
|
|
self.assertTrue(guard(foo))
|
|
self.assertTrue(guard(int))
|
|
self.assertFalse(guard(float))
|
|
|
|
def test_default_device_guard(self):
|
|
root = RootGuardManager()
|
|
foo = 1
|
|
guard = guards.DEFAULT_DEVICE(root, ["cpu device"])
|
|
self.assertTrue(guard(foo))
|
|
|
|
try:
|
|
torch.set_default_device("cuda")
|
|
self.assertFalse(guard(foo))
|
|
finally:
|
|
torch.set_default_device(None)
|
|
|
|
def test_length_check_guard(self):
|
|
root = RootGuardManager()
|
|
foo = [1, 2, 3]
|
|
guard = guards.LENGTH_CHECK(root, len(foo), ["len(x) == len(foo)"])
|
|
self.assertTrue(guard(foo))
|
|
self.assertFalse(guard([]))
|
|
|
|
def test_no_hasattr_guard(self):
|
|
root = RootGuardManager()
|
|
|
|
class Bar:
|
|
def __init__(self) -> None:
|
|
self.bar = 2
|
|
|
|
bar = Bar()
|
|
|
|
class Foo:
|
|
def __init__(self) -> None:
|
|
self.foo = 2
|
|
|
|
foo = Foo()
|
|
|
|
guard = guards.NO_HASATTR(root, "foo", ["hasattr(x, 'foo') == False"])
|
|
self.assertTrue(guard(bar))
|
|
self.assertFalse(guard(foo))
|
|
|
|
def test_tensor_aliasing_guard(self):
|
|
guard_manager = RootGuardManager()
|
|
|
|
a = torch.randn(3, 4)
|
|
|
|
class Foo:
|
|
def __init__(self, x, y):
|
|
self.x = x
|
|
self.y = y
|
|
|
|
f_locals = Foo(a, a)
|
|
|
|
x_guard_mgr = guard_manager.getattr_manager("x", "", a, default_mgr_enum)
|
|
y_guard_mgr = guard_manager.getattr_manager("y", "", a, default_mgr_enum)
|
|
install_object_aliasing_guard(x_guard_mgr, y_guard_mgr, ["x is y"])
|
|
|
|
# Check structure
|
|
x_guards = x_guard_mgr.get_leaf_guards()
|
|
y_guards = y_guard_mgr.get_leaf_guards()
|
|
self.assertEqual(len(x_guards), 1)
|
|
self.assertEqual(len(y_guards), 1)
|
|
self.assertTrue(isinstance(x_guards[0], OBJECT_ALIASING))
|
|
self.assertTrue(isinstance(y_guards[0], OBJECT_ALIASING))
|
|
# Check that the two guards are the same object
|
|
self.assertTrue(x_guards[0] is y_guards[0])
|
|
|
|
f_locals_unaliased = Foo(torch.randn(3, 4), torch.randn(3, 4))
|
|
self.assertEqual(len(x_guard_mgr.get_leaf_guards()), 1)
|
|
self.assertEqual(len(y_guard_mgr.get_leaf_guards()), 1)
|
|
self.assertTrue(guard_manager.check(f_locals))
|
|
|
|
self.assertFalse(guard_manager.check(f_locals_unaliased))
|
|
|
|
def test_dict_version_guard(self):
|
|
root = RootGuardManager()
|
|
foo = {"a": 1, "b": 2}
|
|
guard = guards.DICT_VERSION(root, foo, ["x.version == foo.version"])
|
|
|
|
self.assertTrue(guard(foo))
|
|
self.assertFalse(guard(dict(foo)))
|
|
foo["a"] = 2
|
|
self.assertFalse(guard(foo))
|
|
self.assertFalse(guard({"a": 1, "b": 2}))
|
|
self.assertFalse(guard({}))
|
|
|
|
def test_dynamic_indices_guard(self):
|
|
root = RootGuardManager()
|
|
guard1 = guards.DYNAMIC_INDICES(root, set(), ["x.size(0) == y.size(0)"])
|
|
guard2 = guards.DYNAMIC_INDICES(root, set({0, 1}), ["x.size(0) == y.size(0)"])
|
|
|
|
x = torch.randn(4)
|
|
self.assertTrue(guard1(x))
|
|
self.assertTrue(guard2(x))
|
|
|
|
x._dynamo_dynamic_indices = set({0})
|
|
self.assertFalse(guard1(x))
|
|
self.assertTrue(guard2(x))
|
|
|
|
x._dynamo_dynamic_indices = set({2})
|
|
self.assertFalse(guard1(x))
|
|
self.assertFalse(guard2(x))
|
|
|
|
def test_tensor_match_guard(self):
|
|
guard_manager = RootGuardManager()
|
|
x = torch.randn(4, 4)
|
|
size = list(x.size())
|
|
stride = list(x.stride())
|
|
guard_manager.add_tensor_match_guard(
|
|
x,
|
|
size,
|
|
stride,
|
|
"x",
|
|
["check_tensor(x)"],
|
|
type(x),
|
|
torch._C._dispatch_keys(x),
|
|
)
|
|
self.assertTrue(guard_manager.check(x))
|
|
self.assertTrue(guard_manager.check_verbose(x).result)
|
|
self.assertTrue(guard_manager.check(torch.randn(4, 4)))
|
|
self.assertTrue(guard_manager.check_verbose(torch.randn(4, 4)).result)
|
|
self.assertFalse(guard_manager.check(x.t_()))
|
|
|
|
x = torch.randn(4, 4)
|
|
x.t_()
|
|
debug_info = guard_manager.check_verbose(x)
|
|
print(debug_info.verbose_code_parts[0])
|
|
self.assertTrue(
|
|
"tensor 'x' stride mismatch" in debug_info.verbose_code_parts[0]
|
|
)
|
|
|
|
def test_no_tensor_aliasing_guard(self):
|
|
guard_manager = RootGuardManager()
|
|
|
|
a = torch.randn(3, 4)
|
|
|
|
class Foo:
|
|
def __init__(self, x, y, z):
|
|
self.x = x
|
|
self.y = y
|
|
self.z = z
|
|
|
|
f_locals = Foo(a, a, a)
|
|
|
|
x_guard_mgr = guard_manager.getattr_manager("x", "", a, default_mgr_enum)
|
|
y_guard_mgr = guard_manager.getattr_manager("y", "", a, default_mgr_enum)
|
|
z_guard_mgr = guard_manager.getattr_manager("z", "", a, default_mgr_enum)
|
|
install_no_tensor_aliasing_guard(
|
|
[x_guard_mgr, y_guard_mgr, z_guard_mgr],
|
|
["x", "y", "z"],
|
|
["no_aliasing(x, y, z)"],
|
|
)
|
|
|
|
# Check structure
|
|
x_guards = x_guard_mgr.get_leaf_guards()
|
|
y_guards = y_guard_mgr.get_leaf_guards()
|
|
z_guards = z_guard_mgr.get_leaf_guards()
|
|
self.assertEqual(len(x_guards), 1)
|
|
self.assertEqual(len(y_guards), 1)
|
|
self.assertEqual(len(z_guards), 1)
|
|
self.assertTrue(isinstance(x_guards[0], NO_TENSOR_ALIASING))
|
|
self.assertTrue(isinstance(y_guards[0], NO_TENSOR_ALIASING))
|
|
self.assertTrue(isinstance(z_guards[0], NO_TENSOR_ALIASING))
|
|
# Check that the two guards are the same object
|
|
self.assertTrue(x_guards[0] is y_guards[0] is z_guards[0])
|
|
self.assertFalse(guard_manager.check(f_locals))
|
|
self.assertFalse(guard_manager.check_verbose(f_locals).result)
|
|
|
|
f_locals_unaliased = Foo(
|
|
torch.randn(3, 4),
|
|
torch.randn(3, 4),
|
|
torch.randn(3, 4),
|
|
)
|
|
self.assertTrue(guard_manager.check(f_locals_unaliased))
|
|
self.assertTrue(guard_manager.check_verbose(f_locals_unaliased).result)
|
|
# Check that hash map is cleared.
|
|
self.assertTrue(guard_manager.check(f_locals_unaliased))
|
|
|
|
f_locals_unaliased = Foo(
|
|
a,
|
|
torch.randn(3, 4),
|
|
a,
|
|
)
|
|
self.assertFalse(guard_manager.check(f_locals_unaliased))
|
|
self.assertFalse(guard_manager.check_verbose(f_locals_unaliased).result)
|
|
|
|
def test_weakref_alive_guard(self):
|
|
root = RootGuardManager()
|
|
x = torch.rand(3, 4)
|
|
weakref_x = weakref.ref(x)
|
|
|
|
guard = guards.NOT_NONE(root, ["weakref_x is not None"])
|
|
self.assertTrue(guard(weakref_x()))
|
|
del x
|
|
self.assertFalse(guard(weakref_x()))
|
|
|
|
@unittest.skipIf(not torch.cuda.is_available(), "requires cuda")
|
|
def test_call_function_no_args_guard(self):
|
|
root = RootGuardManager()
|
|
x = torch.cuda.current_device()
|
|
guard = guards.EQUALS_MATCH(root, x, [0])
|
|
self.assertTrue(guard(0))
|
|
self.assertFalse(guard(1))
|
|
self.assertFalse(guard(2))
|
|
|
|
def test_guard_manager_leaf_guard(self):
|
|
guard_manager = RootGuardManager()
|
|
guard_manager.add_type_match_guard(id_type(5), ["type(x) == int"])
|
|
guard_manager.add_lambda_guard(
|
|
functools.partial(ge_match, expected=5),
|
|
ge_match_verbose_code_parts(expected=5),
|
|
)
|
|
guard_manager.add_lambda_guard(
|
|
functools.partial(less_match, expected=10),
|
|
less_match_verbose_code_parts(expected=10),
|
|
)
|
|
self.assertEqual(len(guard_manager.get_leaf_guards()), 3)
|
|
self.assertEqual(len(guard_manager.get_accessors()), 0)
|
|
self.assertTrue(guard_manager.check(6))
|
|
self.assertFalse(guard_manager.check(4))
|
|
self.assertFalse(guard_manager.check("foo"))
|
|
|
|
def test_attr_guard_manager(self):
|
|
class Foo:
|
|
def __init__(self, x, y):
|
|
self.x = x
|
|
self.y = y
|
|
|
|
foo = Foo(1, 2)
|
|
guard_manager = RootGuardManager()
|
|
guard_manager.add_type_match_guard(id_type(foo), ["type(x) == Foo"])
|
|
guard_manager.getattr_manager("x", "x", 1, default_mgr_enum).add_lambda_guard(
|
|
functools.partial(equals_match, expected=foo.x),
|
|
equals_match_verbose_code_parts(foo.x),
|
|
)
|
|
guard_manager.getattr_manager("y", "y", 2, default_mgr_enum).add_lambda_guard(
|
|
functools.partial(equals_match, expected=foo.y),
|
|
equals_match_verbose_code_parts(foo.y),
|
|
)
|
|
self.assertEqual(len(guard_manager.get_leaf_guards()), 1)
|
|
# 2 child managers, one for x and one for y
|
|
self.assertEqual(len(guard_manager.get_accessors()), 2)
|
|
self.assertTrue(
|
|
isinstance(guard_manager.get_accessors()[0], GetAttrGuardAccessor)
|
|
)
|
|
self.assertTrue(
|
|
isinstance(guard_manager.get_accessors()[1], GetAttrGuardAccessor)
|
|
)
|
|
# Check leaf guards on child managers
|
|
self.assertEqual(
|
|
len(
|
|
guard_manager.getattr_manager(
|
|
attr="x",
|
|
source="x",
|
|
example_value=None,
|
|
guard_manager_enum=default_mgr_enum,
|
|
).get_leaf_guards()
|
|
),
|
|
1,
|
|
)
|
|
self.assertEqual(
|
|
len(
|
|
guard_manager.getattr_manager(
|
|
"y", "y", None, default_mgr_enum
|
|
).get_leaf_guards()
|
|
),
|
|
1,
|
|
)
|
|
|
|
self.assertTrue(guard_manager.check(foo))
|
|
self.assertFalse(guard_manager.check(Foo(3, 4)))
|
|
self.assertFalse(guard_manager.check("foo"))
|
|
|
|
def test_item_guard_manager(self):
|
|
foo = [1, 2]
|
|
guard_manager = RootGuardManager()
|
|
guard_manager.add_type_match_guard(id_type(foo), ["type(x) == Foo"])
|
|
guard_manager.getitem_manager(0, "", 1, default_mgr_enum).add_lambda_guard(
|
|
functools.partial(equals_match, expected=foo[0]),
|
|
equals_match_verbose_code_parts(foo[0]),
|
|
)
|
|
guard_manager.getitem_manager(1, "", 2, default_mgr_enum).add_lambda_guard(
|
|
functools.partial(equals_match, expected=foo[1]),
|
|
equals_match_verbose_code_parts(foo[1]),
|
|
)
|
|
self.assertEqual(len(guard_manager.get_leaf_guards()), 1)
|
|
# 2 child managers, one for x and one for y
|
|
self.assertEqual(len(guard_manager.get_accessors()), 2)
|
|
self.assertTrue(
|
|
isinstance(guard_manager.get_accessors()[0], GetItemGuardAccessor)
|
|
)
|
|
self.assertTrue(
|
|
isinstance(guard_manager.get_accessors()[1], GetItemGuardAccessor)
|
|
)
|
|
# Check leaf guards on child managers
|
|
self.assertEqual(
|
|
len(
|
|
guard_manager.getitem_manager(
|
|
0, "", None, default_mgr_enum
|
|
).get_leaf_guards()
|
|
),
|
|
1,
|
|
)
|
|
self.assertEqual(
|
|
len(
|
|
guard_manager.getitem_manager(
|
|
1, "", None, default_mgr_enum
|
|
).get_leaf_guards()
|
|
),
|
|
1,
|
|
)
|
|
|
|
self.assertTrue(guard_manager.check(foo))
|
|
self.assertFalse(guard_manager.check([3, 4]))
|
|
self.assertFalse(guard_manager.check("foo"))
|
|
|
|
def test_framelocals_accessor(self):
|
|
foo = {
|
|
"a": 1,
|
|
"b": 2,
|
|
}
|
|
|
|
guards_manager = RootGuardManager()
|
|
guards_manager.add_type_match_guard(id_type(foo), ["type(x) == Foo"])
|
|
guards_manager.framelocals_manager(
|
|
("a", 0), "", 1, default_mgr_enum
|
|
).add_equals_match_guard(1, ["a == 1"])
|
|
guards_manager.framelocals_manager(
|
|
("b", 1), "", 2, default_mgr_enum
|
|
).add_equals_match_guard(2, ["b == 2"])
|
|
|
|
self.assertTrue(guards_manager.check(foo))
|
|
self.assertFalse(guards_manager.check({"a": 1, "b": 3}))
|
|
|
|
def test_framelocals_guard_e2e(self):
|
|
def fn(x, y, z):
|
|
return x + y + z[0]
|
|
|
|
opt_fn = torch.compile(fn, backend="eager")
|
|
|
|
ref = opt_fn(torch.ones(3), 2, {0: 1, 2: 3})
|
|
with torch._dynamo.set_stance("fail_on_recompile"):
|
|
res = opt_fn(torch.ones(3), 2, {0: 1, 2: 3})
|
|
self.assertEqual(ref, res)
|
|
|
|
c1 = _debug_get_cache_entry_list(fn.__code__)
|
|
self.assertEqual(len(c1), 1)
|
|
guard_str = str(c1[0].guard_manager)
|
|
self.assertIn(
|
|
"source=L['x'], accessed_by=FrameLocalsGuardAccessor(key='x', framelocals_idx=0)",
|
|
guard_str,
|
|
)
|
|
self.assertIn(
|
|
"source=L['y'], accessed_by=FrameLocalsGuardAccessor(key='y', framelocals_idx=1)",
|
|
guard_str,
|
|
)
|
|
self.assertIn(
|
|
"source=L['z'], accessed_by=FrameLocalsGuardAccessor(key='z', framelocals_idx=2)",
|
|
guard_str,
|
|
)
|
|
|
|
def test_dict_getitem_accessor(self):
|
|
foo = {
|
|
"a": 1,
|
|
"b": 2,
|
|
}
|
|
|
|
guards_manager = RootGuardManager()
|
|
guards_manager.add_type_match_guard(id_type(foo), ["type(x) == Foo"])
|
|
guards_manager.dict_getitem_manager(
|
|
"a", "", 1, default_mgr_enum
|
|
).add_equals_match_guard(1, ["a == 1"])
|
|
guards_manager.dict_getitem_manager(
|
|
"b", "", 2, default_mgr_enum
|
|
).add_equals_match_guard(2, ["b == 2"])
|
|
|
|
self.assertTrue(guards_manager.check(foo))
|
|
self.assertFalse(guards_manager.check({"a": 1, "b": 3}))
|
|
|
|
def test_globals(self):
|
|
global global_pair, Pair
|
|
guard_manager = RootGuardManager()
|
|
gpair_mgr = guard_manager.globals_dict_manager(
|
|
globals(), "", None, default_mgr_enum
|
|
).getitem_manager("global_pair", "", global_pair, default_mgr_enum)
|
|
|
|
gpair_mgr.add_lambda_guard(
|
|
lambda x: isinstance(x, Pair)
|
|
and isinstance(x.x, torch.Tensor)
|
|
and isinstance(x.y, int),
|
|
"global guard fail",
|
|
)
|
|
|
|
self.assertTrue(guard_manager.check(global_pair))
|
|
global_pair.y = "foo"
|
|
self.assertFalse(guard_manager.check(global_pair))
|
|
|
|
def test_type_manager(self):
|
|
guard_manager = RootGuardManager()
|
|
|
|
class A:
|
|
a = 4
|
|
|
|
class B(A):
|
|
def mul(self, x):
|
|
super().mul(x)
|
|
|
|
foo = B()
|
|
f_locals = {"foo": foo}
|
|
|
|
# len(type(foo).__mro__) == 2
|
|
foo_mgr = guard_manager.getitem_manager("foo", "", foo, default_mgr_enum)
|
|
type_manager = foo_mgr.type_manager("", type(foo), default_mgr_enum)
|
|
self.assertTrue(isinstance(foo_mgr.get_accessors()[0], TypeGuardAccessor))
|
|
mro_manager = type_manager.getattr_manager(
|
|
"__mro__", "", type(foo).__mro__, default_mgr_enum
|
|
)
|
|
self.assertTrue(
|
|
isinstance(type_manager.get_accessors()[0], GetAttrGuardAccessor)
|
|
)
|
|
mro_manager.add_length_check_guard(
|
|
3,
|
|
"Expected len(type(foo).__mro__) == 3",
|
|
)
|
|
|
|
# type(foo).__mro__[0].a = 4
|
|
item_manager = mro_manager.getitem_manager(
|
|
1, "", type(foo).__mro__[1], default_mgr_enum
|
|
)
|
|
self.assertTrue(
|
|
isinstance(mro_manager.get_accessors()[0], GetItemGuardAccessor)
|
|
)
|
|
attr_manager = item_manager.getattr_manager(
|
|
"a", "", type(foo).__mro__[0].a, default_mgr_enum
|
|
)
|
|
self.assertTrue(
|
|
isinstance(item_manager.get_accessors()[0], GetAttrGuardAccessor)
|
|
)
|
|
attr_manager.add_lambda_guard(
|
|
lambda x: x == 4,
|
|
"Expected value 4",
|
|
)
|
|
|
|
self.assertTrue(guard_manager.check(f_locals))
|
|
|
|
def test_tuple_iterator_getitem(self):
|
|
a = (1, 2, 3, 4, 5, 6)
|
|
foo = iter(a)
|
|
next(foo) # foo points at index=1
|
|
|
|
guard_manager = RootGuardManager()
|
|
# Check a[3] which is tuple_iterator_getitem(foo, 2)
|
|
guard_manager.add_tuple_iterator_length_guard(
|
|
5, id_type(iter(())), ["len == 5"]
|
|
)
|
|
guard_manager.tuple_iterator_getitem_manager(
|
|
2, "", foo, default_mgr_enum
|
|
).add_equals_match_guard(a[3], ["x==4"])
|
|
|
|
# Check that type match works
|
|
self.assertFalse(guard_manager.check(False))
|
|
|
|
self.assertTrue(guard_manager.check(foo))
|
|
|
|
# Check that index error fails gracefully
|
|
b = (1, 2)
|
|
b_foo = iter(b)
|
|
self.assertFalse(guard_manager.check(b_foo))
|
|
|
|
def test_global_weakref(self):
|
|
guard_manager = RootGuardManager()
|
|
globals_manager = guard_manager.globals_dict_manager(
|
|
globals(), "", None, default_mgr_enum
|
|
)
|
|
weakref_manager = globals_manager.global_weakref_manager(
|
|
"weakref_x", "", None, default_mgr_enum
|
|
)
|
|
|
|
weakref_manager.add_lambda_guard(
|
|
lambda x: isinstance(x, torch.Tensor),
|
|
"global weakref fail",
|
|
)
|
|
|
|
self.assertTrue(guard_manager.check(None))
|
|
global x
|
|
del x
|
|
self.assertFalse(guard_manager.check(None))
|
|
|
|
def test_lambda_manager(self):
|
|
a = (1, 1, 3, 4, 5, 6)
|
|
|
|
guard_manager = RootGuardManager()
|
|
|
|
# Check that we can use the same accessor
|
|
foo_mgr = guard_manager.lambda_manager(
|
|
lambda x: x[2], "", None, default_mgr_enum
|
|
)
|
|
foo_mgr.add_lambda_guard(
|
|
lambda x: x == 3,
|
|
"Expected value 3",
|
|
)
|
|
self.assertTrue(guard_manager.check(a))
|
|
|
|
# test that exception works
|
|
guard_manager = RootGuardManager()
|
|
|
|
def fn(x):
|
|
raise AssertionError("Test")
|
|
return x
|
|
|
|
foo_mgr = guard_manager.lambda_manager(fn, "", None, default_mgr_enum)
|
|
|
|
self.assertFalse(guard_manager.check(None))
|
|
debug_info = guard_manager.check_verbose(None)
|
|
self.assertFalse(debug_info.result)
|
|
self.assertTrue("Test" in debug_info.verbose_code_parts[0])
|
|
|
|
def test_dict_contains_guard(self):
|
|
root = RootGuardManager()
|
|
foo = {"a": 1, "b": 2}
|
|
guard = guards.DICT_CONTAINS(root, True, "a", ["has a"])
|
|
|
|
self.assertTrue(guard(foo))
|
|
self.assertTrue(guard({"a": 1, "b": 2}))
|
|
self.assertFalse(guard({"b": 2, "c": 3}))
|
|
self.assertFalse(guard({}))
|
|
|
|
guard = guards.DICT_CONTAINS(root, False, "c", ["not has c"])
|
|
self.assertTrue(guard(foo))
|
|
self.assertTrue(guard({"a": 1, "b": 2}))
|
|
self.assertFalse(guard({"b": 2, "c": 3}))
|
|
self.assertTrue(guard({}))
|
|
|
|
def test_dict_guard_manager(self):
|
|
root = RootGuardManager()
|
|
|
|
def nothing():
|
|
pass
|
|
|
|
f_locals = {
|
|
"d": {"a": 1, nothing: {"z": 3}, 100: torch.randn(4)},
|
|
}
|
|
|
|
# its a getitem_manager just for f_locals. But the child guard manager
|
|
# should be a DictGuardManager.
|
|
dict_mgr = root.getitem_manager(
|
|
"d",
|
|
"",
|
|
f_locals["d"],
|
|
torch._dynamo.guards.GuardManagerType.DICT_GUARD_MANAGER,
|
|
)
|
|
self.assertTrue(isinstance(dict_mgr, DictGuardManager))
|
|
|
|
self.assertTrue(root.check(f_locals))
|
|
|
|
# Check that no one can add a leaf guard
|
|
with self.assertRaises(RuntimeError):
|
|
dict_mgr.add_id_match_guard(id_type(f_locals), "id match")
|
|
|
|
# Check that no one can add an arbitrary accessor
|
|
with self.assertRaises(RuntimeError):
|
|
dict_mgr.getitem_manager("a", "", f_locals["d"]["a"])
|
|
|
|
# Check that it fails with different length dict
|
|
f_locals_prime = {
|
|
"d": {"a": 1, "b": 2},
|
|
}
|
|
self.assertFalse(root.check(f_locals_prime))
|
|
|
|
# Add key-value manager ("a" : 1)
|
|
self.assertTrue(root.check(f_locals))
|
|
dict_mgr.get_key_manager(0, "", "a", default_mgr_enum).add_equals_match_guard(
|
|
"a",
|
|
["dict.keys()[0] == a"],
|
|
)
|
|
self.assertTrue(root.check(f_locals))
|
|
dict_mgr.get_value_manager(0, "", 1, default_mgr_enum).add_equals_match_guard(
|
|
1, ["d[0] == 1"]
|
|
)
|
|
self.assertTrue(root.check(f_locals))
|
|
|
|
# Add key-value manager (nothing : {"z" : 3})
|
|
self.assertTrue(root.check(f_locals))
|
|
dict_mgr.get_key_manager(1, "", nothing, default_mgr_enum).add_lambda_guard(
|
|
lambda x: x is nothing, ["x is nothing"]
|
|
)
|
|
self.assertTrue(root.check(f_locals))
|
|
value_mgr = dict_mgr.get_value_manager(
|
|
1,
|
|
"",
|
|
f_locals["d"][nothing],
|
|
torch._dynamo.guards.GuardManagerType.DICT_GUARD_MANAGER,
|
|
)
|
|
self.assertTrue(isinstance(value_mgr, DictGuardManager))
|
|
self.assertTrue(root.check(f_locals))
|
|
|
|
# Check structure
|
|
# Check that we are only guarding on two keys. This is common in
|
|
# LazyVariableTracker.
|
|
self.assertEqual(len(dict_mgr.get_key_value_managers()), 2)
|
|
|
|
f_locals["d"]["a"] = 2
|
|
self.assertFalse(root.check(f_locals))
|
|
self.assertFalse(root.check_verbose(f_locals).result)
|
|
|
|
f_locals["d"]["a"] = 1
|
|
self.assertTrue(root.check(f_locals))
|
|
|
|
f_locals["d"].pop(100)
|
|
# fails because of len check
|
|
self.assertFalse(root.check(f_locals))
|
|
|
|
def test_clone(self):
|
|
try:
|
|
from .utils import install_guard_manager_testing_hook
|
|
except ImportError:
|
|
from utils import install_guard_manager_testing_hook
|
|
|
|
def hook(guard_wrapper, f_locals, builder):
|
|
root = guard_wrapper.root
|
|
|
|
# Check full cloning works as expected
|
|
cloned_root = root.clone_manager(lambda x: True)
|
|
self.assertTrue(cloned_root.check(f_locals))
|
|
f_locals["foo"] = [3, 4]
|
|
self.assertFalse(cloned_root.check(f_locals))
|
|
f_locals["foo"] = [2, 3]
|
|
|
|
# Skip guarding on foo
|
|
cloned_root = root.clone_manager(lambda x: "foo" not in x.get_source())
|
|
f_locals["foo"] = [3, 4]
|
|
# Original root should fail, but new root should pass because of
|
|
# absence of guards on foo.
|
|
self.assertFalse(root.check(f_locals))
|
|
self.assertTrue(cloned_root.check(f_locals))
|
|
|
|
class Bar:
|
|
x = 4
|
|
y = torch.randn(4)
|
|
|
|
foo = [2, 3]
|
|
bar = Bar()
|
|
|
|
def fn(x, foo, bar):
|
|
return x + foo[0] + bar.x * bar.y
|
|
|
|
x = torch.randn(4)
|
|
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
|
|
with install_guard_manager_testing_hook(hook):
|
|
opt_fn(x, foo, bar)
|
|
|
|
def test_diff_guard_manager(self):
|
|
try:
|
|
from .utils import install_guard_manager_testing_hook
|
|
except ImportError:
|
|
from utils import install_guard_manager_testing_hook
|
|
counter = 0
|
|
|
|
def hook(guard_wrapper, f_locals, builder):
|
|
nonlocal counter
|
|
root = guard_wrapper.root
|
|
diff_guard_root = guard_wrapper.diff_guard_root
|
|
|
|
# Check full cloning works as expected
|
|
self.assertTrue(root.check(f_locals))
|
|
self.assertTrue(diff_guard_root.check(f_locals))
|
|
|
|
# Check that tensor guards run well
|
|
old_tensor = f_locals["bar"].y
|
|
f_locals["bar"].y = torch.randn(5)
|
|
self.assertFalse(root.check(f_locals))
|
|
self.assertFalse(diff_guard_root.check(f_locals))
|
|
f_locals["bar"].y = old_tensor
|
|
|
|
# Original root should fail on foo changes, but diff_guard_root
|
|
# should pass because it does not have foo guards on counter = 0. On
|
|
# counter = 1, it should pass because we have caused a recompile
|
|
# because of foo, causing it to recompile on foo.
|
|
f_locals["foo"] = [3, 3]
|
|
self.assertFalse(root.check(f_locals))
|
|
if counter == 0:
|
|
self.assertTrue(diff_guard_root.check(f_locals))
|
|
else:
|
|
self.assertFalse(diff_guard_root.check(f_locals))
|
|
counter += 1
|
|
|
|
class Bar:
|
|
def __init__(self):
|
|
self.x = 4
|
|
self.y = torch.randn(4)
|
|
|
|
bar = Bar()
|
|
|
|
def fn(x, foo, bar):
|
|
return x + foo[0] + bar.x * bar.y
|
|
|
|
x = torch.randn(4)
|
|
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
|
|
with install_guard_manager_testing_hook(hook):
|
|
foo = (12.0, 13)
|
|
opt_fn(x, foo, bar)
|
|
|
|
foo = (10.0, 11)
|
|
opt_fn(x, foo, bar)
|
|
|
|
|
|
class TypePropagationTests(torch._dynamo.test_case.TestCase):
|
|
@torch._dynamo.config.patch(skip_tensor_guards_with_matching_dict_tags=True)
|
|
def test_basic_types(self):
|
|
class Foo:
|
|
def __init__(self):
|
|
self.x = {"a": 2}
|
|
self.y = torch.randn(4)
|
|
self.z = {}
|
|
|
|
foo = Foo()
|
|
|
|
mod = torch.nn.Linear(4, 4)
|
|
|
|
def fn(x):
|
|
return x + foo.x["a"] + foo.y + mod(x)
|
|
|
|
try:
|
|
from .utils import install_guard_manager_testing_hook
|
|
except ImportError:
|
|
from utils import install_guard_manager_testing_hook
|
|
|
|
def hook(guard_wrapper, f_locals, builder):
|
|
from torch._dynamo.source import AttrSource, DictGetItemSource, LocalSource
|
|
|
|
foo_source = LocalSource("foo")
|
|
foo_x_source = AttrSource(foo_source, "x")
|
|
|
|
self.assertTrue(builder.get(foo_source.name()) is foo)
|
|
self.assertTrue(builder.get(foo_x_source.name()) is foo.x)
|
|
|
|
# Check types of foo.x
|
|
foo_x_mgr = builder.get_guard_manager_from_source(foo_x_source)
|
|
self.assertTrue(issubclass(foo_x_mgr.get_type_of_guarded_value(), dict))
|
|
|
|
# Check types of foo.x["a"]
|
|
foo_x_a_source = DictGetItemSource(foo_x_source, "a")
|
|
foo_x_a_mgr = builder.get_guard_manager_from_source(foo_x_a_source)
|
|
self.assertTrue(foo_x_a_mgr.is_guarded_value_immutable())
|
|
|
|
# Check types of foo.y
|
|
foo_y_source = AttrSource(foo_source, "y")
|
|
foo_y_mgr = builder.get_guard_manager_from_source(foo_y_source)
|
|
self.assertTrue(foo_y_mgr.is_guarded_value_immutable())
|
|
|
|
# Check types of foo.z
|
|
foo_z_source = AttrSource(foo_source, "z")
|
|
foo_z_mgr = builder.get_guard_manager_from_source(foo_z_source)
|
|
self.assertTrue(issubclass(foo_z_mgr.get_type_of_guarded_value(), dict))
|
|
|
|
# Check types of mod
|
|
mod_source = LocalSource("mod")
|
|
mod_mgr = builder.get_guard_manager_from_source(mod_source)
|
|
self.assertTrue(
|
|
issubclass(mod_mgr.get_type_of_guarded_value(), torch.nn.Module)
|
|
)
|
|
|
|
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
|
|
with install_guard_manager_testing_hook(hook):
|
|
opt_fn(torch.randn(4, 4))
|
|
|
|
|
|
class RecursiveDictTagTests(torch._dynamo.test_case.TestCase):
|
|
def setUp(self):
|
|
self._prev = torch._dynamo.config.use_recursive_dict_tags_for_guards
|
|
torch._dynamo.config.use_recursive_dict_tags_for_guards = True
|
|
|
|
def tearDown(self):
|
|
torch._dynamo.config.use_recursive_dict_tags_for_guards = self._prev
|
|
|
|
|
|
class TagSafetyChecks(RecursiveDictTagTests):
|
|
def setUp(self):
|
|
self._prev = torch._dynamo.config.use_recursive_dict_tags_for_guards
|
|
torch._dynamo.config.use_recursive_dict_tags_for_guards = True
|
|
|
|
def tearDown(self):
|
|
torch._dynamo.config.use_recursive_dict_tags_for_guards = self._prev
|
|
|
|
def test_immutable_tag_safe(self):
|
|
class Bar:
|
|
pass
|
|
|
|
class Foo:
|
|
def __init__(self):
|
|
self.a = Bar()
|
|
self.b = torch.randn(4)
|
|
self.c = 3
|
|
self.d = (3, 4)
|
|
self.e = (3, Bar())
|
|
|
|
foo = Foo()
|
|
|
|
def fn(x):
|
|
if foo.a:
|
|
x = torch.sin(x)
|
|
x = x * foo.b + foo.c + foo.d[0] + foo.d[1] + foo.e[0]
|
|
if foo.e[1]:
|
|
x = torch.sin(x)
|
|
return x
|
|
|
|
try:
|
|
from .utils import install_guard_manager_testing_hook
|
|
except ImportError:
|
|
from utils import install_guard_manager_testing_hook
|
|
|
|
def hook(guard_wrapper, f_locals, builder):
|
|
from torch._dynamo.source import AttrSource, LocalSource
|
|
|
|
foo_source = LocalSource("foo")
|
|
foo_mgr = builder.get_guard_manager_from_source(foo_source)
|
|
for accessor in foo_mgr.get_accessors():
|
|
if isinstance(accessor, GetAttrGuardAccessor):
|
|
self.assertTrue(
|
|
accessor.get_attr_name() in ("a", "b", "c", "d", "e")
|
|
)
|
|
|
|
# Check types of foo.a
|
|
foo_a_source = AttrSource(foo_source, "a")
|
|
foo_a_mgr = builder.get_guard_manager_from_source(foo_a_source)
|
|
self.assertFalse(foo_a_mgr.is_tag_safe())
|
|
self.assertFalse(foo_a_mgr.is_tag_safe_root())
|
|
|
|
# Check types of foo.b
|
|
foo_b_source = AttrSource(foo_source, "b")
|
|
foo_b_mgr = builder.get_guard_manager_from_source(foo_b_source)
|
|
if torch._dynamo.config.skip_tensor_guards_with_matching_dict_tags:
|
|
self.assertTrue(foo_b_mgr.is_tag_safe())
|
|
else:
|
|
self.assertFalse(foo_b_mgr.is_tag_safe())
|
|
|
|
self.assertFalse(foo_b_mgr.is_tag_safe_root())
|
|
|
|
# Check types of foo.c
|
|
foo_c_source = AttrSource(foo_source, "c")
|
|
foo_c_mgr = builder.get_guard_manager_from_source(foo_c_source)
|
|
self.assertTrue(foo_c_mgr.is_tag_safe())
|
|
self.assertFalse(foo_c_mgr.is_tag_safe_root())
|
|
|
|
# Check types of foo.d
|
|
foo_d_source = AttrSource(foo_source, "d")
|
|
foo_d_mgr = builder.get_guard_manager_from_source(foo_d_source)
|
|
self.assertTrue(foo_d_mgr.is_tag_safe())
|
|
self.assertFalse(foo_d_mgr.is_tag_safe_root())
|
|
|
|
# Check types of foo.e
|
|
foo_e_source = AttrSource(foo_source, "e")
|
|
foo_e_mgr = builder.get_guard_manager_from_source(foo_e_source)
|
|
self.assertFalse(foo_e_mgr.is_tag_safe())
|
|
self.assertFalse(foo_e_mgr.is_tag_safe_root())
|
|
|
|
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
|
|
with install_guard_manager_testing_hook(hook):
|
|
opt_fn(torch.randn(4, 4))
|
|
|
|
def test_dict_tag_safe(self):
|
|
class Foo:
|
|
def __init__(self):
|
|
self.a = 4
|
|
|
|
foo = Foo()
|
|
terminal_dict = {
|
|
"a": 1,
|
|
}
|
|
|
|
tag_safe_dict = {
|
|
"const": 1,
|
|
"tup": (2, 3),
|
|
"nested_dict": terminal_dict,
|
|
}
|
|
|
|
tag_unsafe_dict = {
|
|
"const": 1,
|
|
"foo": foo,
|
|
}
|
|
|
|
outer_dict = {
|
|
"safe": tag_safe_dict,
|
|
"unsafe": tag_unsafe_dict,
|
|
"terminal_dict": {"a": 1},
|
|
}
|
|
|
|
def fn(x):
|
|
x = x + outer_dict["safe"]["const"]
|
|
|
|
x = x + outer_dict["safe"]["tup"][0]
|
|
x = x + outer_dict["safe"]["tup"][1]
|
|
|
|
x = x + outer_dict["safe"]["nested_dict"]["a"]
|
|
|
|
x = x + outer_dict["unsafe"]["const"]
|
|
|
|
x = x + outer_dict["unsafe"]["foo"].a
|
|
|
|
if outer_dict["terminal_dict"]:
|
|
x = torch.sin(x)
|
|
return x
|
|
|
|
try:
|
|
from .utils import install_guard_manager_testing_hook
|
|
except ImportError:
|
|
from utils import install_guard_manager_testing_hook
|
|
|
|
def hook(guard_wrapper, f_locals, builder):
|
|
from torch._dynamo.source import DictGetItemSource, LocalSource
|
|
|
|
outer_source = LocalSource("outer_dict")
|
|
|
|
# Check tagness of outer dict
|
|
outer_mgr = builder.get_guard_manager_from_source(outer_source)
|
|
self.assertFalse(outer_mgr.is_tag_safe())
|
|
self.assertFalse(outer_mgr.is_tag_safe_root())
|
|
|
|
# Check tagness of outer["safe"]
|
|
outer_safe_source = DictGetItemSource(outer_source, "safe")
|
|
outer_safe_mgr = builder.get_guard_manager_from_source(outer_safe_source)
|
|
self.assertTrue(outer_safe_mgr.is_tag_safe())
|
|
self.assertFalse(outer_safe_mgr.is_tag_safe_root())
|
|
|
|
# Check tagness of outer["unsafe"]
|
|
outer_unsafe_source = DictGetItemSource(outer_source, "unsafe")
|
|
outer_unsafe_mgr = builder.get_guard_manager_from_source(
|
|
outer_unsafe_source
|
|
)
|
|
self.assertFalse(outer_unsafe_mgr.is_tag_safe())
|
|
self.assertFalse(outer_unsafe_mgr.is_tag_safe_root())
|
|
|
|
# Check tagness of outer["terminal_dict"]
|
|
outer_terminal_source = DictGetItemSource(outer_source, "terminal_dict")
|
|
outer_terminal_mgr = builder.get_guard_manager_from_source(
|
|
outer_terminal_source
|
|
)
|
|
self.assertTrue(outer_terminal_mgr.is_tag_safe())
|
|
self.assertFalse(outer_terminal_mgr.is_tag_safe_root())
|
|
|
|
# Check tagness of outer["safe"]["nested_dict"]
|
|
outer_safe_nested_source = DictGetItemSource(
|
|
outer_safe_source, "nested_dict"
|
|
)
|
|
outer_safe_nested_mgr = builder.get_guard_manager_from_source(
|
|
outer_safe_nested_source
|
|
)
|
|
self.assertTrue(outer_safe_nested_mgr.is_tag_safe())
|
|
# This should not be marked as a root
|
|
self.assertFalse(outer_safe_nested_mgr.is_tag_safe_root())
|
|
|
|
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
|
|
with install_guard_manager_testing_hook(hook):
|
|
opt_fn(torch.randn(4, 4))
|
|
|
|
def test_nn_module_tag_safe(self):
|
|
class Foo(torch.nn.Module):
|
|
c = 2
|
|
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.a = 4
|
|
|
|
def check(self, x):
|
|
return True
|
|
|
|
def forward(self, x):
|
|
inspect.signature(self.check).parameters.items()
|
|
return x + self.a + self.c
|
|
|
|
foo = Foo()
|
|
|
|
class Env(metaclass=abc.ABCMeta): # noqa: B024
|
|
pass
|
|
|
|
class Baz(torch.nn.Module, Env):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.foo = foo
|
|
|
|
def forward(self, x):
|
|
if "Foo" in str(type(self).__mro__):
|
|
x = torch.sin(x)
|
|
return self.foo(x)
|
|
|
|
baz = Baz()
|
|
|
|
def fn(x):
|
|
x = x + baz(x)
|
|
return x
|
|
|
|
try:
|
|
from .utils import install_guard_manager_testing_hook
|
|
except ImportError:
|
|
from utils import install_guard_manager_testing_hook
|
|
|
|
def hook(guard_wrapper, f_locals, builder):
|
|
from torch._dynamo.source import LocalSource
|
|
|
|
baz_source = LocalSource("baz")
|
|
|
|
# Check tagness of baz
|
|
baz_mgr = builder.get_guard_manager_from_source(baz_source)
|
|
self.assertTrue(baz_mgr.is_tag_safe())
|
|
self.assertTrue(baz_mgr.is_tag_safe_root())
|
|
|
|
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
|
|
with install_guard_manager_testing_hook(hook):
|
|
opt_fn(torch.randn(4, 4))
|
|
|
|
def test_nn_module_tag_overridden_getattr_safe(self):
|
|
class Baz(torch.nn.Module, metaclass=abc.ABCMeta):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.norm = 2
|
|
|
|
def __getattr__(self, key):
|
|
if key == "a":
|
|
return 5
|
|
return super().__getattr__(key)
|
|
|
|
def forward(self, x):
|
|
return x + self.a + self.norm
|
|
|
|
baz = Baz()
|
|
|
|
def fn(x):
|
|
x = x + baz(x)
|
|
return x
|
|
|
|
try:
|
|
from .utils import install_guard_manager_testing_hook
|
|
except ImportError:
|
|
from utils import install_guard_manager_testing_hook
|
|
|
|
def hook(guard_wrapper, f_locals, builder):
|
|
from torch._dynamo.source import LocalSource
|
|
|
|
baz_source = LocalSource("baz")
|
|
|
|
# Check tagness of baz
|
|
baz_mgr = builder.get_guard_manager_from_source(baz_source)
|
|
self.assertTrue(baz_mgr.is_tag_safe())
|
|
self.assertTrue(baz_mgr.is_tag_safe_root())
|
|
|
|
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
|
|
with install_guard_manager_testing_hook(hook):
|
|
opt_fn(torch.randn(4, 4))
|
|
|
|
|
|
class RecursiveDictGuardTests(RecursiveDictTagTests):
|
|
def test_disabling(self):
|
|
class Mod(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.a = 4
|
|
|
|
def forward(self, x):
|
|
return x + self.a
|
|
|
|
mod = Mod()
|
|
mod_to_fail = Mod()
|
|
|
|
def fn(x):
|
|
return mod(x)
|
|
|
|
x = torch.randn(4, 4)
|
|
|
|
try:
|
|
from .utils import install_guard_manager_testing_hook
|
|
except ImportError:
|
|
from utils import install_guard_manager_testing_hook
|
|
|
|
def basic_hook_test(guard_wrapper, f_locals, builder):
|
|
from torch._dynamo.source import LocalSource
|
|
|
|
mod_source = LocalSource("mod")
|
|
|
|
# Check tagness of mod
|
|
mod_mgr = builder.get_guard_manager_from_source(mod_source)
|
|
self.assertTrue(mod_mgr.is_tag_safe())
|
|
self.assertTrue(mod_mgr.is_tag_safe_root())
|
|
self.assertFalse(mod_mgr.is_recursive_dict_tag_matching_disabled())
|
|
|
|
for _ in range(10):
|
|
self.assertTrue(guard_wrapper.check({"mod": mod, "x": x}))
|
|
self.assertFalse(mod_mgr.is_recursive_dict_tag_matching_disabled())
|
|
|
|
# Let the guard pass but dict matching fail, this should add new cached entry
|
|
self.assertTrue(guard_wrapper.check({"mod": mod_to_fail, "x": x}))
|
|
self.assertFalse(mod_mgr.is_recursive_dict_tag_matching_disabled())
|
|
|
|
# Let the guard fail, this should disable dict tag optimization as well
|
|
mod_to_fail.a = 5
|
|
self.assertFalse(guard_wrapper.check({"mod": mod_to_fail, "x": x}))
|
|
self.assertTrue(mod_mgr.is_recursive_dict_tag_matching_disabled())
|
|
|
|
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
|
|
with install_guard_manager_testing_hook(basic_hook_test):
|
|
opt_fn(x)
|
|
|
|
# Test that dict tag matching failure leads to disable of dict tag optimization
|
|
torch.compiler.reset()
|
|
mod = Mod()
|
|
mod_to_fail = Mod()
|
|
|
|
def disable_on_dict_tag_match_failure(guard_wrapper, f_locals, builder):
|
|
from torch._dynamo.source import LocalSource
|
|
|
|
mod_source = LocalSource("mod")
|
|
|
|
# Check tagness of mod
|
|
mod_mgr = builder.get_guard_manager_from_source(mod_source)
|
|
self.assertTrue(mod_mgr.is_tag_safe())
|
|
self.assertTrue(mod_mgr.is_tag_safe_root())
|
|
self.assertFalse(mod_mgr.is_recursive_dict_tag_matching_disabled())
|
|
|
|
for _ in range(10):
|
|
self.assertTrue(guard_wrapper.check({"mod": mod, "x": x}))
|
|
self.assertFalse(mod_mgr.is_recursive_dict_tag_matching_disabled())
|
|
|
|
# Change the mod attr to cause dict tag matching to fail, this still
|
|
# get the guard pass. This should disable the dict tag optimization.
|
|
mod.a = 5
|
|
mod.a = 4
|
|
self.assertTrue(guard_wrapper.check({"mod": mod, "x": x}))
|
|
self.assertTrue(mod_mgr.is_recursive_dict_tag_matching_disabled())
|
|
|
|
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
|
|
with install_guard_manager_testing_hook(disable_on_dict_tag_match_failure):
|
|
opt_fn(x)
|
|
|
|
# Test that max size limit breach disables the dict tag optimization
|
|
torch.compiler.reset()
|
|
mod = Mod()
|
|
mod_to_fail = Mod()
|
|
|
|
def max_size_test(guard_wrapper, f_locals, builder):
|
|
from torch._dynamo.source import LocalSource
|
|
|
|
mod_source = LocalSource("mod")
|
|
|
|
# Check tagness of mod
|
|
mod_mgr = builder.get_guard_manager_from_source(mod_source)
|
|
self.assertTrue(mod_mgr.is_tag_safe())
|
|
self.assertTrue(mod_mgr.is_tag_safe_root())
|
|
self.assertFalse(mod_mgr.is_recursive_dict_tag_matching_disabled())
|
|
|
|
for _ in range(10):
|
|
self.assertTrue(guard_wrapper.check({"mod": mod, "x": x}))
|
|
self.assertFalse(mod_mgr.is_recursive_dict_tag_matching_disabled())
|
|
|
|
# Let the guard pass but dict matching fail, since cache size is set
|
|
# to 1, this would cause dict tag optimization to be disabled.
|
|
self.assertTrue(guard_wrapper.check({"mod": mod_to_fail, "x": x}))
|
|
self.assertTrue(mod_mgr.is_recursive_dict_tag_matching_disabled())
|
|
|
|
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
|
|
with torch._dynamo.config.patch(
|
|
max_saved_pointers_for_recursive_dict_tags_check=1
|
|
):
|
|
with install_guard_manager_testing_hook(max_size_test):
|
|
opt_fn(x)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
from torch._dynamo.test_case import run_tests
|
|
|
|
run_tests()
|