mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[dynamo] Eagerly install guards (#111415)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/111415 Approved by: https://github.com/voznesenskym ghstack dependencies: #111306
This commit is contained in:
committed by
PyTorch MergeBot
parent
2964682490
commit
9664190952
@ -3292,7 +3292,9 @@ class GraphModule(torch.nn.Module):
|
|||||||
cos = l_x_.cos(); l_x_ = None
|
cos = l_x_.cos(); l_x_ = None
|
||||||
return pytree.tree_unflatten([cos], self._out_spec)
|
return pytree.tree_unflatten([cos], self._out_spec)
|
||||||
"""
|
"""
|
||||||
true_guard_code = ["cast_symbool_to_symint_guardless(L['pred']) == 1"]
|
true_guard_code = [
|
||||||
|
"cast_symbool_to_symint_guardless(L['pred']) == 1",
|
||||||
|
]
|
||||||
false_guard_code = [
|
false_guard_code = [
|
||||||
"Ne(cast_symbool_to_symint_guardless(L['pred']), 1)",
|
"Ne(cast_symbool_to_symint_guardless(L['pred']), 1)",
|
||||||
"-9223372036854775808 <= cast_symbool_to_symint_guardless(L['pred'])",
|
"-9223372036854775808 <= cast_symbool_to_symint_guardless(L['pred'])",
|
||||||
|
|||||||
@ -297,25 +297,25 @@ class HigherOrderOpTests(torch._dynamo.test_case.TestCase):
|
|||||||
actual_graph,
|
actual_graph,
|
||||||
"""\
|
"""\
|
||||||
class GraphModule(torch.nn.Module):
|
class GraphModule(torch.nn.Module):
|
||||||
def forward(self, L_x_ : torch.Tensor, L_y_ : torch.Tensor, L_z_ : torch.Tensor):
|
def forward(self, L_d_x_ : torch.Tensor, L_d_y_0_ : torch.Tensor, L_d_y_1_2_ : torch.Tensor):
|
||||||
l_x_ = L_x_
|
l_d_x_ = L_d_x_
|
||||||
l_y_ = L_y_
|
l_d_y_0_ = L_d_y_0_
|
||||||
l_z_ = L_z_
|
l_d_y_1_2_ = L_d_y_1_2_
|
||||||
|
|
||||||
wrap_body_0 = self.wrap_body_0
|
wrap_body_0 = self.wrap_body_0
|
||||||
wrap = torch._higher_order_ops.wrap.wrap(wrap_body_0, l_x_, l_y_, l_z_); wrap_body_0 = l_x_ = l_y_ = l_z_ = None
|
wrap = torch._higher_order_ops.wrap.wrap(wrap_body_0, l_d_x_, l_d_y_0_, l_d_y_1_2_); wrap_body_0 = l_d_x_ = l_d_y_0_ = l_d_y_1_2_ = None
|
||||||
getitem = wrap[0]; wrap = None
|
getitem = wrap[0]; wrap = None
|
||||||
return (getitem,)
|
return (getitem,)
|
||||||
|
|
||||||
class GraphModule(torch.nn.Module):
|
class GraphModule(torch.nn.Module):
|
||||||
def forward(self, l_x_, l_y_, l_z_):
|
def forward(self, l_d_x_, l_d_y_0_, l_d_y_1_2_):
|
||||||
sin = l_x_.sin(); l_x_ = None
|
sin = l_d_x_.sin(); l_d_x_ = None
|
||||||
cos = l_y_.cos(); l_y_ = None
|
cos = l_d_y_0_.cos(); l_d_y_0_ = None
|
||||||
add = sin + cos; sin = cos = None
|
add = sin + cos; sin = cos = None
|
||||||
sin_1 = l_z_.sin(); l_z_ = None
|
sin_1 = l_d_y_1_2_.sin(); l_d_y_1_2_ = None
|
||||||
sub = add - sin_1; add = sin_1 = None
|
sub = add - sin_1; add = sin_1 = None
|
||||||
return (sub,)
|
return (sub,)
|
||||||
""",
|
""", # NOQA: B950
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_wrap_pytree_args_with_symint_constant(self):
|
def test_wrap_pytree_args_with_symint_constant(self):
|
||||||
@ -3005,9 +3005,9 @@ class GraphModule(torch.nn.Module):
|
|||||||
actual,
|
actual,
|
||||||
"""\
|
"""\
|
||||||
class GraphModule(torch.nn.Module):
|
class GraphModule(torch.nn.Module):
|
||||||
def forward(self, L_x_ : torch.Tensor, L_y_ : torch.Tensor):
|
def forward(self, L_y_ : torch.Tensor, L_x_ : torch.Tensor):
|
||||||
l_x_ = L_x_
|
|
||||||
child = L_y_
|
child = L_y_
|
||||||
|
l_x_ = L_x_
|
||||||
|
|
||||||
_check_randomness_arg = torch._functorch.vmap._check_randomness_arg('error')
|
_check_randomness_arg = torch._functorch.vmap._check_randomness_arg('error')
|
||||||
_check_randomness_arg_1 = torch._functorch.vmap._check_randomness_arg('error')
|
_check_randomness_arg_1 = torch._functorch.vmap._check_randomness_arg('error')
|
||||||
@ -3269,16 +3269,14 @@ class GraphModule(torch.nn.Module):
|
|||||||
return torch.func.vmap(torch.sum, in_dims)(x)
|
return torch.func.vmap(torch.sum, in_dims)(x)
|
||||||
|
|
||||||
x = torch.randn(3, 3, 3, 3)
|
x = torch.randn(3, 3, 3, 3)
|
||||||
opt = torch.compile(wrapper_fn, backend="eager", fullgraph=False, dynamic=True)
|
cnt = CompileCounter()
|
||||||
|
opt = torch.compile(wrapper_fn, backend=cnt, fullgraph=False, dynamic=True)
|
||||||
expected = wrapper_fn(x, 0), wrapper_fn(x, 1), wrapper_fn(x, 2)
|
expected = wrapper_fn(x, 0), wrapper_fn(x, 1), wrapper_fn(x, 2)
|
||||||
# Third invocation of `opt` makes `in_dims` as SymInt.
|
# Third invocation of `opt` makes `in_dims` as SymInt.
|
||||||
actual = opt(x, 0), opt(x, 1), opt(x, 2)
|
actual = opt(x, 0), opt(x, 1), opt(x, 2)
|
||||||
self.assertEqual(expected, actual)
|
self.assertEqual(expected, actual)
|
||||||
self.assertEqual(len(counters["graph_break"]), 1)
|
self.assertEqual(cnt.frame_count, 3)
|
||||||
self.assertEqual(
|
self.assertEqual(cnt.op_count, 9)
|
||||||
dict(counters["graph_break"]),
|
|
||||||
{"torch.func.vmap: in_dims is not an int or tuple variable.": 2},
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_vmap_multiple_invocation_out_dims(self):
|
def test_vmap_multiple_invocation_out_dims(self):
|
||||||
counters.clear()
|
counters.clear()
|
||||||
@ -3287,16 +3285,14 @@ class GraphModule(torch.nn.Module):
|
|||||||
return torch.func.vmap(lambda x: torch.sum(x, 0), out_dims=out_dims)(x)
|
return torch.func.vmap(lambda x: torch.sum(x, 0), out_dims=out_dims)(x)
|
||||||
|
|
||||||
x = torch.randn(3, 3, 3, 3)
|
x = torch.randn(3, 3, 3, 3)
|
||||||
opt = torch.compile(wrapper_fn, backend="eager", fullgraph=False, dynamic=True)
|
cnt = CompileCounter()
|
||||||
|
opt = torch.compile(wrapper_fn, backend=cnt, fullgraph=False, dynamic=True)
|
||||||
expected = wrapper_fn(x, 0), wrapper_fn(x, 1), wrapper_fn(x, 2)
|
expected = wrapper_fn(x, 0), wrapper_fn(x, 1), wrapper_fn(x, 2)
|
||||||
# Third invocation of `opt` makes `in_dims` as SymInt.
|
# Third invocation of `opt` makes `in_dims` as SymInt.
|
||||||
actual = opt(x, 0), opt(x, 1), opt(x, 2)
|
actual = opt(x, 0), opt(x, 1), opt(x, 2)
|
||||||
self.assertEqual(expected, actual)
|
self.assertEqual(expected, actual)
|
||||||
self.assertEqual(len(counters["graph_break"]), 1)
|
self.assertEqual(cnt.frame_count, 3)
|
||||||
self.assertEqual(
|
self.assertEqual(cnt.op_count, 9)
|
||||||
dict(counters["graph_break"]),
|
|
||||||
{"torch.func.vmap: out_dims is not an int or tuple variable.": 2},
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_vmap_new_tensor_in_body(self):
|
def test_vmap_new_tensor_in_body(self):
|
||||||
def fn(x):
|
def fn(x):
|
||||||
|
|||||||
@ -1541,7 +1541,7 @@ utils_device.CURRENT_DEVICE == None""",
|
|||||||
args = [torch.randn(10), 4096, np.int64(8)]
|
args = [torch.randn(10), 4096, np.int64(8)]
|
||||||
correct = fn(*args)
|
correct = fn(*args)
|
||||||
cnts = torch._dynamo.testing.CompileCounter()
|
cnts = torch._dynamo.testing.CompileCounter()
|
||||||
opt_fn = torch._dynamo.optimize(cnts, dynamic=True)(fn)
|
opt_fn = torch._dynamo.optimize(cnts, dynamic=True, nopython=True)(fn)
|
||||||
self.assertTrue(same(opt_fn(*args), correct))
|
self.assertTrue(same(opt_fn(*args), correct))
|
||||||
self.assertTrue(same(opt_fn(*args), correct))
|
self.assertTrue(same(opt_fn(*args), correct))
|
||||||
self.assertEqual(cnts.frame_count, 1)
|
self.assertEqual(cnts.frame_count, 1)
|
||||||
|
|||||||
@ -814,9 +814,8 @@ class MockModule(torch.nn.Module):
|
|||||||
class ReproTests(torch._dynamo.test_case.TestCase):
|
class ReproTests(torch._dynamo.test_case.TestCase):
|
||||||
def test_do_paste_mask(self):
|
def test_do_paste_mask(self):
|
||||||
torch._dynamo.utils.counters.clear()
|
torch._dynamo.utils.counters.clear()
|
||||||
opt__do_paste_mask = torch._dynamo.optimize(
|
cnt = torch._dynamo.testing.CompileCounter()
|
||||||
torch._dynamo.testing.CompileCounter()
|
opt__do_paste_mask = torch.compile(_do_paste_mask, backend=cnt)
|
||||||
)(_do_paste_mask)
|
|
||||||
opt__do_paste_mask(
|
opt__do_paste_mask(
|
||||||
torch.randn(1, 1, 28, 28),
|
torch.randn(1, 1, 28, 28),
|
||||||
torch.tensor([[0.0, 1, 2, 4]]) * 1,
|
torch.tensor([[0.0, 1, 2, 4]]) * 1,
|
||||||
@ -852,12 +851,9 @@ class ReproTests(torch._dynamo.test_case.TestCase):
|
|||||||
640,
|
640,
|
||||||
False,
|
False,
|
||||||
)
|
)
|
||||||
|
# (dynamic shapes, static shapes)
|
||||||
self.assertGreaterEqual(torch._dynamo.utils.counters["frames"]["ok"], 3)
|
self.assertIn(cnt.frame_count, (5, 7))
|
||||||
self.assertEqual(
|
self.assertIn(cnt.op_count, (106, 127))
|
||||||
torch._dynamo.utils.counters["frames"]["total"],
|
|
||||||
torch._dynamo.utils.counters["frames"]["ok"] + 1,
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_convert_boxes_to_pooler_format(self):
|
def test_convert_boxes_to_pooler_format(self):
|
||||||
boxes1 = [
|
boxes1 = [
|
||||||
@ -2451,77 +2447,6 @@ class ReproTests(torch._dynamo.test_case.TestCase):
|
|||||||
self.assertEqual(f(x, x), opt_f(x, x))
|
self.assertEqual(f(x, x), opt_f(x, x))
|
||||||
self.assertEqual(f(x, y), opt_f(x, y))
|
self.assertEqual(f(x, y), opt_f(x, y))
|
||||||
|
|
||||||
def test_reformer_remove_unused_args(self):
|
|
||||||
# This test case is very interesting. First, let's describe
|
|
||||||
# the bug this is testing for. The bug we fixed is twofold:
|
|
||||||
#
|
|
||||||
# - We prune GraphArgs that aren't used in the output graph.
|
|
||||||
# However, sometimes it is possible for those GraphArgs to be
|
|
||||||
# utilized in shape guards (you could imagine this happening if
|
|
||||||
# dynamo poked some shape variables without recording them in the
|
|
||||||
# graph.) If we prune those GraphArgs, we get a
|
|
||||||
# "s1 not in ..." error as we can no longer codegen the
|
|
||||||
# requested guards.
|
|
||||||
#
|
|
||||||
# - But in practice, Dynamo usually traces size accesses into the
|
|
||||||
# graph, preventing the GraphArg from getting pruned. So how
|
|
||||||
# come we were running into this in practice with hf_Reformer?
|
|
||||||
# The answer is checkpointing!
|
|
||||||
#
|
|
||||||
# This brings us to the following test case. Here's what it does:
|
|
||||||
#
|
|
||||||
# 1. It traces some operations, and then checkpoints before inlining
|
|
||||||
# the function call to g
|
|
||||||
#
|
|
||||||
# 2. g traces some more operations (triggering the shape guard
|
|
||||||
# to be created), but then it graph breaks
|
|
||||||
#
|
|
||||||
# 3. Because you can't graph break in an inlining function, we roll
|
|
||||||
# back to the outer checkpoint ("undoing" the operation that
|
|
||||||
# induced the shape guard) and then immediately generate a
|
|
||||||
# subgraph at that point.
|
|
||||||
#
|
|
||||||
# If we failed to checkpoint the ShapeEnv, it can still have guards
|
|
||||||
# from the aborted speculation, which we will then still attempt to
|
|
||||||
# codegen.
|
|
||||||
#
|
|
||||||
# There's an additional nuance: suppose x is used but y is not.
|
|
||||||
# If you create a guard like y == x * 2, you will accidentally avoid
|
|
||||||
# the "s1 not in ..." error, as y will get substituted with x * 2,
|
|
||||||
# but x is still a GraphArg (it's used) and you don't end up with
|
|
||||||
# the error. This is why we must show y + y == x, not vice versa.
|
|
||||||
# Similarly, it is also why we must not do a simple guard like x == y
|
|
||||||
#
|
|
||||||
# Can we actually demonstrate that checkpointing the ShapeEnv is
|
|
||||||
# necessary? It's not so easy to induce this case. Dynamo is very
|
|
||||||
# eager about adding locals to GraphArgs; any local that is in scope,
|
|
||||||
# even if it isn't used, is added to GraphArgs (see also
|
|
||||||
# https://github.com/pytorch/torchdynamo/issues/1925 ). So long
|
|
||||||
# as Dynamo eagerly guards in this way, we have an invariant that
|
|
||||||
# all locals are guaranteed to show up in GraphArgs before the
|
|
||||||
# inlining function call, in which case we will always have enough
|
|
||||||
# information to codegen our guards so long as we don't prune the
|
|
||||||
# unused GraphArgs away (and indeed, the direct fix for this bug
|
|
||||||
# was to make sure we use original GraphArgs). Non locals,
|
|
||||||
# conversely, typically are static, and so won't have guards allocated
|
|
||||||
# for them. That being said, there may still be a way to trigger
|
|
||||||
# this error.
|
|
||||||
|
|
||||||
def g(x, y):
|
|
||||||
r = torch.cat((y, y)) + x
|
|
||||||
print("foo")
|
|
||||||
return r
|
|
||||||
|
|
||||||
def f(x, y):
|
|
||||||
x = x * 3
|
|
||||||
return g(x, y)
|
|
||||||
|
|
||||||
opt_f = torch._dynamo.optimize("aot_eager")(f)
|
|
||||||
|
|
||||||
x = torch.randn(4)
|
|
||||||
y = torch.randn(2)
|
|
||||||
self.assertEqual(f(x, y), opt_f(x, y))
|
|
||||||
|
|
||||||
def test_swin_base_tensor_attr(self):
|
def test_swin_base_tensor_attr(self):
|
||||||
class Foo(torch.nn.Module):
|
class Foo(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
|||||||
@ -271,7 +271,7 @@ class SubclassTests(torch._dynamo.test_case.TestCase):
|
|||||||
kwargs = {}
|
kwargs = {}
|
||||||
return super().__torch_function__(func, types, args, kwargs)
|
return super().__torch_function__(func, types, args, kwargs)
|
||||||
|
|
||||||
@torch.compile(backend="eager", fullgraph=True)
|
@torch.compile(backend="eager")
|
||||||
def fn(x):
|
def fn(x):
|
||||||
return x.sigmoid()
|
return x.sigmoid()
|
||||||
|
|
||||||
@ -819,13 +819,6 @@ class TestNestedTensor(torch._dynamo.test_case.TestCase):
|
|||||||
nt3, _ = self._get_jagged_tensor(((2, 3, 4), 3), None)
|
nt3, _ = self._get_jagged_tensor(((2, 3, 4), 3), None)
|
||||||
self._check_recompiles(binary, (nt1, nt2), (nt1, nt3), True)
|
self._check_recompiles(binary, (nt1, nt2), (nt1, nt3), True)
|
||||||
|
|
||||||
def test_binary_recompiles_due_to_duck_sizing(self):
|
|
||||||
# Even though the input is unused, we still guard due to duck sizing
|
|
||||||
nt1, offsets = self._get_jagged_tensor(((2, 3, 4), 3), None)
|
|
||||||
nt2, _ = self._get_jagged_tensor(((2, 3, 4), 3), offsets)
|
|
||||||
nt3, _ = self._get_jagged_tensor(((2, 3, 4), 3), None)
|
|
||||||
self._check_recompiles(lambda nt1, nt2: nt1.sin(), (nt1, nt2), (nt1, nt3), True)
|
|
||||||
|
|
||||||
# TODO: cannot parametrize this test class with device for some reason
|
# TODO: cannot parametrize this test class with device for some reason
|
||||||
def _test_autograd(self, backend):
|
def _test_autograd(self, backend):
|
||||||
a = torch.randn(2, 3, requires_grad=True, dtype=torch.float64)
|
a = torch.randn(2, 3, requires_grad=True, dtype=torch.float64)
|
||||||
|
|||||||
@ -477,8 +477,8 @@ class SubGraphTests(torch._dynamo.test_case.TestCase):
|
|||||||
opt_fn(v1, a, b, c)
|
opt_fn(v1, a, b, c)
|
||||||
|
|
||||||
# checking here we don't create 2^n graphs
|
# checking here we don't create 2^n graphs
|
||||||
self.assertEqual(cnt.frame_count, 12)
|
self.assertEqual(cnt.frame_count, 7)
|
||||||
self.assertEqual(cnt.op_count, 16)
|
self.assertEqual(cnt.op_count, 10)
|
||||||
|
|
||||||
def test_resume_with_no_grad1(self):
|
def test_resume_with_no_grad1(self):
|
||||||
def fn(a, b):
|
def fn(a, b):
|
||||||
|
|||||||
@ -4,7 +4,7 @@ import hypothesis.strategies as st
|
|||||||
from hypothesis import given
|
from hypothesis import given
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from torch.testing._internal.common_utils import TestCase, run_tests
|
from torch.testing._internal.common_utils import TestCase, run_tests, skipIfTorchDynamo
|
||||||
import torch.testing._internal.hypothesis_utils as hu
|
import torch.testing._internal.hypothesis_utils as hu
|
||||||
hu.assert_deadline_disabled()
|
hu.assert_deadline_disabled()
|
||||||
|
|
||||||
@ -56,6 +56,7 @@ class PruningOpTest(TestCase):
|
|||||||
self.assertEqual(pt_compressed_indices_map.dtype, indices_type)
|
self.assertEqual(pt_compressed_indices_map.dtype, indices_type)
|
||||||
|
|
||||||
|
|
||||||
|
@skipIfTorchDynamo()
|
||||||
@given(
|
@given(
|
||||||
embedding_rows=st.integers(1, 100),
|
embedding_rows=st.integers(1, 100),
|
||||||
embedding_dims=st.integers(1, 100),
|
embedding_dims=st.integers(1, 100),
|
||||||
@ -67,6 +68,7 @@ class PruningOpTest(TestCase):
|
|||||||
self._test_rowwise_prune_op(embedding_rows, embedding_dims, torch.int, weights_dtype)
|
self._test_rowwise_prune_op(embedding_rows, embedding_dims, torch.int, weights_dtype)
|
||||||
|
|
||||||
|
|
||||||
|
@skipIfTorchDynamo()
|
||||||
@given(
|
@given(
|
||||||
embedding_rows=st.integers(1, 100),
|
embedding_rows=st.integers(1, 100),
|
||||||
embedding_dims=st.integers(1, 100),
|
embedding_dims=st.integers(1, 100),
|
||||||
|
|||||||
@ -2530,6 +2530,7 @@ class TestSparseCSR(TestCase):
|
|||||||
run_test(4, 5, 4, 10, False)
|
run_test(4, 5, 4, 10, False)
|
||||||
run_test(4, 4, 4, 16, True)
|
run_test(4, 4, 4, 16, True)
|
||||||
|
|
||||||
|
@skipIfTorchDynamo()
|
||||||
@onlyCPU
|
@onlyCPU
|
||||||
@dtypes(torch.float32, torch.float64, torch.bfloat16)
|
@dtypes(torch.float32, torch.float64, torch.bfloat16)
|
||||||
@precisionOverride({torch.bfloat16: 0.01})
|
@precisionOverride({torch.bfloat16: 0.01})
|
||||||
@ -2894,6 +2895,7 @@ class TestSparseCSR(TestCase):
|
|||||||
run_test(shape, max(shape), index_dtype)
|
run_test(shape, max(shape), index_dtype)
|
||||||
run_test(shape, shape[0] * shape[1], index_dtype)
|
run_test(shape, shape[0] * shape[1], index_dtype)
|
||||||
|
|
||||||
|
@skipIfTorchDynamo()
|
||||||
@skipMeta
|
@skipMeta
|
||||||
@dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16))
|
@dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16))
|
||||||
@all_sparse_compressed_layouts()
|
@all_sparse_compressed_layouts()
|
||||||
|
|||||||
@ -76,8 +76,6 @@ class PyCodegen:
|
|||||||
self.clear_tos()
|
self.clear_tos()
|
||||||
return
|
return
|
||||||
|
|
||||||
self.tx.output.guards.update(value.guards)
|
|
||||||
|
|
||||||
assert isinstance(value, VariableTracker)
|
assert isinstance(value, VariableTracker)
|
||||||
output = self._output
|
output = self._output
|
||||||
graph_outputs = self.graph_outputs
|
graph_outputs = self.graph_outputs
|
||||||
|
|||||||
@ -586,7 +586,7 @@ class GuardBuilder(GuardBuilderBase):
|
|||||||
f"{id(torch._dynamo.eval_frame.guarded_backend_cache.current_backend)}"
|
f"{id(torch._dynamo.eval_frame.guarded_backend_cache.current_backend)}"
|
||||||
)
|
)
|
||||||
code = [
|
code = [
|
||||||
f"___skip_backend_check() or ___current_backend() == ___lookup_backend({backend_id})"
|
f"(___skip_backend_check() or ___current_backend() == ___lookup_backend({backend_id}))"
|
||||||
]
|
]
|
||||||
self._produce_guard_code(guard, code)
|
self._produce_guard_code(guard, code)
|
||||||
|
|
||||||
@ -1366,3 +1366,19 @@ def make_dupe_guard(obj_source, dupe_source):
|
|||||||
# However, this should always be a sound guard to add here.
|
# However, this should always be a sound guard to add here.
|
||||||
return functools.partial(GuardBuilder.DUPLICATE_INPUT, source_b=dupe_source)
|
return functools.partial(GuardBuilder.DUPLICATE_INPUT, source_b=dupe_source)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def install_guard(*guards, skip=0):
|
||||||
|
"""
|
||||||
|
Add dynamo guards to the current tracing context.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
guards: guard(s) to add
|
||||||
|
skip: number of stack frames to ignore for debug stack trace
|
||||||
|
"""
|
||||||
|
from torch._guards import TracingContext
|
||||||
|
|
||||||
|
add = TracingContext.get().guards_context.dynamo_guards.add
|
||||||
|
for guard in guards:
|
||||||
|
assert isinstance(guard, Guard)
|
||||||
|
add(guard, skip=skip + 1)
|
||||||
|
|||||||
@ -62,7 +62,7 @@ from .exc import (
|
|||||||
unimplemented,
|
unimplemented,
|
||||||
unimplemented_with_warning,
|
unimplemented_with_warning,
|
||||||
)
|
)
|
||||||
from .guards import GuardBuilder
|
from .guards import GuardBuilder, install_guard
|
||||||
from .mutation_guard import is_dynamic_nn_module
|
from .mutation_guard import is_dynamic_nn_module
|
||||||
from .side_effects import SideEffects
|
from .side_effects import SideEffects
|
||||||
from .source import (
|
from .source import (
|
||||||
@ -561,7 +561,11 @@ class OutputGraph(Checkpointable[OutputGraphState]):
|
|||||||
# FX deepcopy doesn't work for a partially created graph, so just remove new nodes
|
# FX deepcopy doesn't work for a partially created graph, so just remove new nodes
|
||||||
removed_nodes = 0
|
removed_nodes = 0
|
||||||
for node in reversed(list(self.graph.nodes)):
|
for node in reversed(list(self.graph.nodes)):
|
||||||
if node.meta["creation_timestamp"] > self.timestamp:
|
if (
|
||||||
|
node.meta["creation_timestamp"] > self.timestamp
|
||||||
|
# placeholders here may have been lazily added by existing objects
|
||||||
|
and node.op != "placeholder"
|
||||||
|
):
|
||||||
# Erasing node alone does not remove the meta information
|
# Erasing node alone does not remove the meta information
|
||||||
# So, remove the help tensor explicitly
|
# So, remove the help tensor explicitly
|
||||||
if "example_value" in node.meta:
|
if "example_value" in node.meta:
|
||||||
@ -670,7 +674,6 @@ class OutputGraph(Checkpointable[OutputGraphState]):
|
|||||||
return variables.UnspecializedNNModuleVariable(target, **options)
|
return variables.UnspecializedNNModuleVariable(target, **options)
|
||||||
|
|
||||||
options = dict(options)
|
options = dict(options)
|
||||||
options["guards"] = set(options.get("guards", []))
|
|
||||||
assert "source" in options
|
assert "source" in options
|
||||||
source = options["source"]
|
source = options["source"]
|
||||||
assert not isinstance(source, ParamBufferSource)
|
assert not isinstance(source, ParamBufferSource)
|
||||||
@ -692,10 +695,10 @@ class OutputGraph(Checkpointable[OutputGraphState]):
|
|||||||
tracer = self.root_tracer
|
tracer = self.root_tracer
|
||||||
|
|
||||||
if not is_constant_source(source):
|
if not is_constant_source(source):
|
||||||
options["guards"].add(source.make_guard(GuardBuilder.TENSOR_MATCH))
|
install_guard(source.make_guard(GuardBuilder.TENSOR_MATCH))
|
||||||
|
|
||||||
if get_static_address_type(target) == "guarded":
|
if get_static_address_type(target) == "guarded":
|
||||||
options["guards"].add(source.make_guard(GuardBuilder.DATA_PTR_MATCH))
|
install_guard(source.make_guard(GuardBuilder.DATA_PTR_MATCH))
|
||||||
|
|
||||||
def wrap_name(module_key):
|
def wrap_name(module_key):
|
||||||
assert self.param_name_to_source is not None
|
assert self.param_name_to_source is not None
|
||||||
@ -711,7 +714,7 @@ class OutputGraph(Checkpointable[OutputGraphState]):
|
|||||||
elif isinstance(target, torch.nn.Module):
|
elif isinstance(target, torch.nn.Module):
|
||||||
assert isinstance(target, torch.nn.Module)
|
assert isinstance(target, torch.nn.Module)
|
||||||
|
|
||||||
options["guards"].add(source.make_guard(GuardBuilder.NN_MODULE))
|
install_guard(source.make_guard(GuardBuilder.NN_MODULE))
|
||||||
|
|
||||||
def wrap_name(module_key):
|
def wrap_name(module_key):
|
||||||
return NNModuleVariable(type(target), module_key, **options)
|
return NNModuleVariable(type(target), module_key, **options)
|
||||||
@ -1005,9 +1008,6 @@ class OutputGraph(Checkpointable[OutputGraphState]):
|
|||||||
|
|
||||||
assert isinstance(rv, list)
|
assert isinstance(rv, list)
|
||||||
assert isinstance(root, FakeRootModule)
|
assert isinstance(root, FakeRootModule)
|
||||||
for output in rv:
|
|
||||||
self.guards.update(output.guards)
|
|
||||||
|
|
||||||
self.create_node(
|
self.create_node(
|
||||||
"output",
|
"output",
|
||||||
"output",
|
"output",
|
||||||
|
|||||||
@ -54,7 +54,7 @@ from .codegen import PyCodegen
|
|||||||
from .current_scope_id import current_scope_id
|
from .current_scope_id import current_scope_id
|
||||||
from .exc import ArgsMismatchError, BackendCompilerFailed, unimplemented, Unsupported
|
from .exc import ArgsMismatchError, BackendCompilerFailed, unimplemented, Unsupported
|
||||||
from .funcname_cache import get_funcname
|
from .funcname_cache import get_funcname
|
||||||
from .guards import GuardBuilder
|
from .guards import GuardBuilder, install_guard
|
||||||
from .output_graph import GraphCompileReason, OutputGraph, OutputGraphState
|
from .output_graph import GraphCompileReason, OutputGraph, OutputGraphState
|
||||||
from .replay_record import DummyModule, ExecutionRecorder
|
from .replay_record import DummyModule, ExecutionRecorder
|
||||||
from .resume_execution import ContinueExecutionCache, ReenterWith
|
from .resume_execution import ContinueExecutionCache, ReenterWith
|
||||||
@ -323,13 +323,11 @@ def _detect_and_normalize_assert_statement(
|
|||||||
def generic_jump(truth_fn: typing.Callable[[object], bool], push: bool):
|
def generic_jump(truth_fn: typing.Callable[[object], bool], push: bool):
|
||||||
def inner(self: "InstructionTranslatorBase", inst: Instruction):
|
def inner(self: "InstructionTranslatorBase", inst: Instruction):
|
||||||
value: VariableTracker = self.pop()
|
value: VariableTracker = self.pop()
|
||||||
self.output.guards.update(value.guards)
|
|
||||||
if (
|
if (
|
||||||
config.rewrite_assert_with_torch_assert
|
config.rewrite_assert_with_torch_assert
|
||||||
and _detect_and_normalize_assert_statement(self, truth_fn, push)
|
and _detect_and_normalize_assert_statement(self, truth_fn, push)
|
||||||
):
|
):
|
||||||
error_msg: VariableTracker = self.pop()
|
error_msg: VariableTracker = self.pop()
|
||||||
self.output.guards.update(error_msg.guards)
|
|
||||||
# Skip over things like `assert True`
|
# Skip over things like `assert True`
|
||||||
if value.is_python_constant() and bool(value.as_python_constant()):
|
if value.is_python_constant() and bool(value.as_python_constant()):
|
||||||
self.jump(inst)
|
self.jump(inst)
|
||||||
@ -419,7 +417,6 @@ def generic_jump(truth_fn: typing.Callable[[object], bool], push: bool):
|
|||||||
if isinstance(result, ConstantVariable) and isinstance(
|
if isinstance(result, ConstantVariable) and isinstance(
|
||||||
result.value, (bool, int)
|
result.value, (bool, int)
|
||||||
):
|
):
|
||||||
self.output.guards.update(result.guards)
|
|
||||||
if truth_fn(result.value):
|
if truth_fn(result.value):
|
||||||
push and self.push(value)
|
push and self.push(value)
|
||||||
self.jump(inst)
|
self.jump(inst)
|
||||||
@ -686,9 +683,7 @@ class InstructionTranslatorBase(Checkpointable[InstructionTranslatorGraphState])
|
|||||||
"""
|
"""
|
||||||
A call to some user defined function by inlining it.
|
A call to some user defined function by inlining it.
|
||||||
"""
|
"""
|
||||||
result = InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
|
return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
|
||||||
self.output.guards.update(fn.guards)
|
|
||||||
return result
|
|
||||||
|
|
||||||
def get_line_of_code_header(self, lineno=None):
|
def get_line_of_code_header(self, lineno=None):
|
||||||
if lineno is None:
|
if lineno is None:
|
||||||
@ -1139,7 +1134,6 @@ class InstructionTranslatorBase(Checkpointable[InstructionTranslatorGraphState])
|
|||||||
def FOR_ITER(self, inst):
|
def FOR_ITER(self, inst):
|
||||||
it = self.pop().realize()
|
it = self.pop().realize()
|
||||||
if isinstance(it, (variables.ListIteratorVariable, variables.IteratorVariable)):
|
if isinstance(it, (variables.ListIteratorVariable, variables.IteratorVariable)):
|
||||||
self.output.guards.update(it.guards)
|
|
||||||
try:
|
try:
|
||||||
val, next_iter = it.next_variables(self)
|
val, next_iter = it.next_variables(self)
|
||||||
self.push(next_iter)
|
self.push(next_iter)
|
||||||
@ -1233,8 +1227,6 @@ class InstructionTranslatorBase(Checkpointable[InstructionTranslatorGraphState])
|
|||||||
if sys.version_info >= (3, 11):
|
if sys.version_info >= (3, 11):
|
||||||
null = self.pop()
|
null = self.pop()
|
||||||
assert isinstance(null, NullVariable)
|
assert isinstance(null, NullVariable)
|
||||||
self.output.guards.update(argsvars.guards)
|
|
||||||
self.output.guards.update(kwargsvars.guards)
|
|
||||||
|
|
||||||
if (
|
if (
|
||||||
isinstance(fn, GetAttrVariable)
|
isinstance(fn, GetAttrVariable)
|
||||||
@ -1327,12 +1319,8 @@ class InstructionTranslatorBase(Checkpointable[InstructionTranslatorGraphState])
|
|||||||
), f"Mutating module attribute {inst.argval} during export."
|
), f"Mutating module attribute {inst.argval} during export."
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self.output.guards.update(
|
BuiltinVariable(setattr).call_function(
|
||||||
BuiltinVariable(setattr)
|
self, [obj, ConstantVariable.create(inst.argval), val], {}
|
||||||
.call_function(
|
|
||||||
self, [obj, ConstantVariable.create(inst.argval), val], {}
|
|
||||||
)
|
|
||||||
.guards
|
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
except Unsupported as e:
|
except Unsupported as e:
|
||||||
@ -1355,10 +1343,8 @@ class InstructionTranslatorBase(Checkpointable[InstructionTranslatorGraphState])
|
|||||||
|
|
||||||
def DELETE_ATTR(self, inst):
|
def DELETE_ATTR(self, inst):
|
||||||
obj = self.pop()
|
obj = self.pop()
|
||||||
self.output.guards.update(
|
BuiltinVariable(delattr).call_function(
|
||||||
BuiltinVariable(delattr)
|
self, [obj, ConstantVariable.create(inst.argval)], {}
|
||||||
.call_function(self, [obj, ConstantVariable.create(inst.argval)], {})
|
|
||||||
.guards
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def create_call_resume_at(self, offset):
|
def create_call_resume_at(self, offset):
|
||||||
@ -1375,8 +1361,6 @@ class InstructionTranslatorBase(Checkpointable[InstructionTranslatorGraphState])
|
|||||||
def STORE_SUBSCR(self, inst):
|
def STORE_SUBSCR(self, inst):
|
||||||
val, obj, key = self.popn(3)
|
val, obj, key = self.popn(3)
|
||||||
result = obj.call_method(self, "__setitem__", [key, val], {})
|
result = obj.call_method(self, "__setitem__", [key, val], {})
|
||||||
# no result is pushed, so need to lift the guards to global
|
|
||||||
self.output.guards.update(result.guards)
|
|
||||||
|
|
||||||
def BUILD_TUPLE(self, inst):
|
def BUILD_TUPLE(self, inst):
|
||||||
items = self.popn(inst.argval)
|
items = self.popn(inst.argval)
|
||||||
@ -1511,7 +1495,6 @@ class InstructionTranslatorBase(Checkpointable[InstructionTranslatorGraphState])
|
|||||||
obj,
|
obj,
|
||||||
ListVariable(
|
ListVariable(
|
||||||
obj.items + [v],
|
obj.items + [v],
|
||||||
regen_guards=False,
|
|
||||||
**VariableTracker.propagate([obj, v]),
|
**VariableTracker.propagate([obj, v]),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
@ -1559,7 +1542,6 @@ class InstructionTranslatorBase(Checkpointable[InstructionTranslatorGraphState])
|
|||||||
def UNPACK_SEQUENCE(self, inst):
|
def UNPACK_SEQUENCE(self, inst):
|
||||||
seq = self.pop()
|
seq = self.pop()
|
||||||
if isinstance(seq, (BaseListVariable, SetVariable)):
|
if isinstance(seq, (BaseListVariable, SetVariable)):
|
||||||
self.output.guards.update(seq.guards)
|
|
||||||
val = seq.unpack_var_sequence(self)
|
val = seq.unpack_var_sequence(self)
|
||||||
elif seq.is_python_constant() and isinstance(seq, ConstantVariable):
|
elif seq.is_python_constant() and isinstance(seq, ConstantVariable):
|
||||||
val = seq.unpack_var_sequence(self)
|
val = seq.unpack_var_sequence(self)
|
||||||
@ -1874,8 +1856,6 @@ class InstructionTranslatorBase(Checkpointable[InstructionTranslatorGraphState])
|
|||||||
if isinstance(ctx, GenericContextWrappingVariable):
|
if isinstance(ctx, GenericContextWrappingVariable):
|
||||||
self.generic_context_manager_depth += 1
|
self.generic_context_manager_depth += 1
|
||||||
|
|
||||||
self.output.guards.update(ctx.guards)
|
|
||||||
|
|
||||||
exit = WithExitFunctionVariable(
|
exit = WithExitFunctionVariable(
|
||||||
ctx,
|
ctx,
|
||||||
inst.target,
|
inst.target,
|
||||||
@ -1961,9 +1941,7 @@ class InstructionTranslatorBase(Checkpointable[InstructionTranslatorGraphState])
|
|||||||
)
|
)
|
||||||
|
|
||||||
def store_global_weakref(self, name, value):
|
def store_global_weakref(self, name, value):
|
||||||
self.output.guards.add(
|
install_guard(GlobalWeakRefSource(name).make_guard(GuardBuilder.WEAKREF_ALIVE))
|
||||||
GlobalWeakRefSource(name).make_guard(GuardBuilder.WEAKREF_ALIVE)
|
|
||||||
)
|
|
||||||
if name not in self.output.global_scope:
|
if name not in self.output.global_scope:
|
||||||
self.output.install_global(name, weakref.ref(value))
|
self.output.install_global(name, weakref.ref(value))
|
||||||
|
|
||||||
@ -2148,67 +2126,26 @@ class InstructionTranslator(InstructionTranslatorBase):
|
|||||||
vars.extend(cells_and_freevars)
|
vars.extend(cells_and_freevars)
|
||||||
cells_and_freevars_set = set(cells_and_freevars)
|
cells_and_freevars_set = set(cells_and_freevars)
|
||||||
|
|
||||||
self.symbolic_locals = collections.OrderedDict(
|
self.symbolic_locals = {
|
||||||
(
|
k: variables.LazyVariableTracker.create(
|
||||||
k,
|
f_locals[k],
|
||||||
VariableBuilder(
|
source=LocalSource(k, cell_or_freevar=k in cells_and_freevars_set),
|
||||||
self,
|
|
||||||
LocalSource(k, cell_or_freevar=k in cells_and_freevars_set),
|
|
||||||
)(f_locals[k]),
|
|
||||||
)
|
)
|
||||||
for k in vars
|
for k in vars
|
||||||
if k in f_locals
|
if k in f_locals
|
||||||
)
|
}
|
||||||
if export:
|
if export:
|
||||||
# export gets super confused if we never realize unused inputs
|
# export gets confused if we never realize unused inputs
|
||||||
# in export mode just eagerly realize everything
|
# in export mode just eagerly realize everything
|
||||||
self.symbolic_locals = VariableTracker.apply(
|
self.symbolic_locals = VariableTracker.apply(
|
||||||
lambda x: x.realize(), self.symbolic_locals
|
lambda x: x.realize(), self.symbolic_locals
|
||||||
)
|
)
|
||||||
|
|
||||||
self.init_local_index_guards_hack()
|
|
||||||
|
|
||||||
self._freevars_ids = dict()
|
self._freevars_ids = dict()
|
||||||
for name in self.code_options["co_freevars"]:
|
for name in self.code_options["co_freevars"]:
|
||||||
if name in f_locals:
|
if name in f_locals:
|
||||||
self._freevars_ids[name] = id(f_locals[name])
|
self._freevars_ids[name] = id(f_locals[name])
|
||||||
|
|
||||||
def init_local_index_guards_hack(self):
|
|
||||||
# symbolic_locals contains the mapping from original f_locals to the
|
|
||||||
# Variable objects. During the Variable building phase, each object also
|
|
||||||
# has its associated guards. At the end, we will accumulate these
|
|
||||||
# guards.
|
|
||||||
#
|
|
||||||
# One way of handling these guards is to just accumulate all of them
|
|
||||||
# right now. However, many f_locals might not be used in the frame and
|
|
||||||
# thus can unnecessarily increase guard execution overhead. Therefore,
|
|
||||||
# we selectively update output.guards as we run the Python Bytecode
|
|
||||||
# instruction by instruction.
|
|
||||||
#
|
|
||||||
# An exception here is list/dict variables. Guards related to these
|
|
||||||
# variables have indexed access, like Tensor_match on args[0], and if
|
|
||||||
# args is not used in this frame, we will miss a LIST_LENGTH check like
|
|
||||||
# len(args) == 2. Missing the LIST_LENGTH check causes problem for the
|
|
||||||
# next invocation when args is not a list, and args[0] is a runtime
|
|
||||||
# error. Therefore, we recursively add guards for list/dict variable here.
|
|
||||||
for val in self.symbolic_locals.values():
|
|
||||||
if isinstance(
|
|
||||||
val, (ListIteratorVariable, BaseListVariable, ConstDictVariable)
|
|
||||||
):
|
|
||||||
local_guards = VariableTracker.propagate(val)["guards"]
|
|
||||||
index_guards = [
|
|
||||||
guard
|
|
||||||
for guard in local_guards
|
|
||||||
if guard.create_fn
|
|
||||||
in (
|
|
||||||
GuardBuilder.LIST_LENGTH,
|
|
||||||
GuardBuilder.DICT_KEYS,
|
|
||||||
GuardBuilder.ODICT_KEYS,
|
|
||||||
GuardBuilder.TUPLE_ITERATOR_LEN,
|
|
||||||
)
|
|
||||||
]
|
|
||||||
self.output.guards.update(index_guards)
|
|
||||||
|
|
||||||
def run(self):
|
def run(self):
|
||||||
super().run()
|
super().run()
|
||||||
|
|
||||||
@ -2661,7 +2598,6 @@ class InliningGeneratorInstructionTranslator(InliningInstructionTranslator):
|
|||||||
if isinstance(
|
if isinstance(
|
||||||
tos, (variables.ListIteratorVariable, variables.IteratorVariable)
|
tos, (variables.ListIteratorVariable, variables.IteratorVariable)
|
||||||
):
|
):
|
||||||
self.output.guards.update(tos.guards)
|
|
||||||
try:
|
try:
|
||||||
val, next_iter = tos.next_variables(self)
|
val, next_iter = tos.next_variables(self)
|
||||||
self.push(val)
|
self.push(val)
|
||||||
|
|||||||
@ -1,12 +1,12 @@
|
|||||||
import collections
|
import collections
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, Callable, Dict, List, Optional, Set
|
from typing import Any, Callable, Dict, List
|
||||||
|
|
||||||
from .. import variables
|
from .. import variables
|
||||||
from ..current_scope_id import current_scope_id
|
from ..current_scope_id import current_scope_id
|
||||||
from ..exc import unimplemented
|
from ..exc import unimplemented
|
||||||
from ..source import AttrSource, Source
|
from ..source import AttrSource, Source
|
||||||
from ..utils import dict_values, identity, istype, odict_values
|
from ..utils import identity, istype
|
||||||
|
|
||||||
|
|
||||||
class MutableLocalSource(Enum):
|
class MutableLocalSource(Enum):
|
||||||
@ -154,21 +154,8 @@ class VariableTracker(metaclass=VariableTrackerMeta):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def propagate(*vars: List[List["VariableTracker"]]):
|
def propagate(*vars: List[List["VariableTracker"]]):
|
||||||
"""Combine the guards from many VariableTracker into **kwargs for a new instance"""
|
# TODO(jansel): delete this function
|
||||||
guards = set()
|
return {}
|
||||||
|
|
||||||
def visit(var):
|
|
||||||
if type(var) in (list, tuple, dict_values, odict_values):
|
|
||||||
for i in var:
|
|
||||||
visit(i)
|
|
||||||
else:
|
|
||||||
assert isinstance(var, VariableTracker), typestr(var)
|
|
||||||
guards.update(var.guards)
|
|
||||||
|
|
||||||
visit(vars)
|
|
||||||
return {
|
|
||||||
"guards": guards,
|
|
||||||
}
|
|
||||||
|
|
||||||
def clone(self, **kwargs):
|
def clone(self, **kwargs):
|
||||||
"""Shallow copy with some (optional) changes"""
|
"""Shallow copy with some (optional) changes"""
|
||||||
@ -246,22 +233,8 @@ class VariableTracker(metaclass=VariableTrackerMeta):
|
|||||||
cache[idx] = (result, value)
|
cache[idx] = (result, value)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def add_guard(self, guard):
|
|
||||||
return self.clone(guards=set.union(self.guards, {guard}))
|
|
||||||
|
|
||||||
def add_guards(self, guards):
|
|
||||||
if guards is None:
|
|
||||||
return self
|
|
||||||
assert isinstance(guards, set)
|
|
||||||
return self.clone(guards=set.union(self.guards, guards))
|
|
||||||
|
|
||||||
def add_options(self, options, *more):
|
def add_options(self, options, *more):
|
||||||
if more:
|
return self
|
||||||
return self.add_options(options).add_options(*more)
|
|
||||||
if isinstance(options, VariableTracker):
|
|
||||||
return self.add_guards(options.guards)
|
|
||||||
assert isinstance(options, dict)
|
|
||||||
return self.add_guards(options.get("guards", set()))
|
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
return f"{self.__class__.__name__}()"
|
return f"{self.__class__.__name__}()"
|
||||||
@ -283,13 +256,6 @@ class VariableTracker(metaclass=VariableTrackerMeta):
|
|||||||
except NotImplementedError:
|
except NotImplementedError:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def can_make_guard(self):
|
|
||||||
try:
|
|
||||||
self.make_guard(None)
|
|
||||||
return True
|
|
||||||
except NotImplementedError:
|
|
||||||
return False
|
|
||||||
|
|
||||||
def make_guard(self, fn):
|
def make_guard(self, fn):
|
||||||
if self.source:
|
if self.source:
|
||||||
return self.source.make_guard(fn)
|
return self.source.make_guard(fn)
|
||||||
@ -380,6 +346,10 @@ class VariableTracker(metaclass=VariableTrackerMeta):
|
|||||||
"""Used by LazyVariableTracker to build the real VariableTracker"""
|
"""Used by LazyVariableTracker to build the real VariableTracker"""
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
def recursive_realize(self):
|
||||||
|
"""Realize all objects under this"""
|
||||||
|
return VariableTracker.apply(lambda x: x.realize(), self)
|
||||||
|
|
||||||
def unwrap(self) -> "VariableTracker":
|
def unwrap(self) -> "VariableTracker":
|
||||||
"""Used by LazyVariableTracker to return the real VariableTracker if it already exists"""
|
"""Used by LazyVariableTracker to return the real VariableTracker if it already exists"""
|
||||||
return self
|
return self
|
||||||
@ -391,14 +361,12 @@ class VariableTracker(metaclass=VariableTrackerMeta):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
guards: Optional[Set] = None,
|
|
||||||
source: Source = None,
|
source: Source = None,
|
||||||
mutable_local: MutableLocal = None,
|
mutable_local: MutableLocal = None,
|
||||||
user_code_variable_name: str = None,
|
user_code_variable_name: str = None,
|
||||||
parents_tracker: ParentsTracker = None,
|
parents_tracker: ParentsTracker = None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.guards = guards or set()
|
|
||||||
self.source = source
|
self.source = source
|
||||||
self.mutable_local = mutable_local
|
self.mutable_local = mutable_local
|
||||||
self.user_code_variable_name = user_code_variable_name
|
self.user_code_variable_name = user_code_variable_name
|
||||||
|
|||||||
@ -44,7 +44,7 @@ from ..allowed_functions import (
|
|||||||
|
|
||||||
from ..device_interface import device_interfaces
|
from ..device_interface import device_interfaces
|
||||||
from ..exc import InternalTorchDynamoError, unimplemented
|
from ..exc import InternalTorchDynamoError, unimplemented
|
||||||
from ..guards import GuardBuilder, make_dupe_guard
|
from ..guards import GuardBuilder, install_guard, make_dupe_guard
|
||||||
from ..side_effects import SideEffects
|
from ..side_effects import SideEffects
|
||||||
from ..source import (
|
from ..source import (
|
||||||
AttrSource,
|
AttrSource,
|
||||||
@ -198,6 +198,7 @@ class GraphArg:
|
|||||||
|
|
||||||
def erase(self):
|
def erase(self):
|
||||||
self._example = None
|
self._example = None
|
||||||
|
self.example_strong_ref = None
|
||||||
|
|
||||||
def __eq__(self, other):
|
def __eq__(self, other):
|
||||||
return self.source.name() == other.source.name()
|
return self.source.name() == other.source.name()
|
||||||
@ -231,9 +232,7 @@ class VariableBuilder:
|
|||||||
side_effect_result = self.tx.output.side_effects[value]
|
side_effect_result = self.tx.output.side_effects[value]
|
||||||
dup_guard = make_dupe_guard(self.source, side_effect_result.source)
|
dup_guard = make_dupe_guard(self.source, side_effect_result.source)
|
||||||
if dup_guard:
|
if dup_guard:
|
||||||
side_effect_result = side_effect_result.add_guards(
|
self.install_guards(dup_guard)
|
||||||
self.make_guards(dup_guard)
|
|
||||||
)
|
|
||||||
return side_effect_result
|
return side_effect_result
|
||||||
vt = self._wrap(value).clone(**self.options())
|
vt = self._wrap(value).clone(**self.options())
|
||||||
if self._can_lift_attrs_to_inputs(vt):
|
if self._can_lift_attrs_to_inputs(vt):
|
||||||
@ -272,14 +271,15 @@ class VariableBuilder:
|
|||||||
def options(self):
|
def options(self):
|
||||||
return {"source": self.get_source()}
|
return {"source": self.get_source()}
|
||||||
|
|
||||||
def make_guards(self, *guards):
|
def install_guards(self, *guards):
|
||||||
source = self.get_source()
|
source = self.get_source()
|
||||||
if (
|
if (
|
||||||
isinstance(source, ConstantSource)
|
isinstance(source, ConstantSource)
|
||||||
or source.guard_source() == GuardSource.CONSTANT
|
or source.guard_source() == GuardSource.CONSTANT
|
||||||
):
|
):
|
||||||
return None
|
return None
|
||||||
return {source.make_guard(guard) for guard in guards}
|
install_guard(*[source.make_guard(guard) for guard in guards], skip=1)
|
||||||
|
return {}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@functools.lru_cache(None)
|
@functools.lru_cache(None)
|
||||||
@ -330,7 +330,7 @@ class VariableBuilder:
|
|||||||
lambda self, value: LambdaVariable(
|
lambda self, value: LambdaVariable(
|
||||||
InspectSignatureVariable.create,
|
InspectSignatureVariable.create,
|
||||||
source=self.source,
|
source=self.source,
|
||||||
guards=self.make_guards(GuardBuilder.FUNCTION_MATCH),
|
**self.install_guards(GuardBuilder.FUNCTION_MATCH),
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
(comptime, lambda self, value: ComptimeVariable()),
|
(comptime, lambda self, value: ComptimeVariable()),
|
||||||
@ -339,7 +339,7 @@ class VariableBuilder:
|
|||||||
lambda self, value: LambdaVariable(
|
lambda self, value: LambdaVariable(
|
||||||
_dataclasses_fields_lambda,
|
_dataclasses_fields_lambda,
|
||||||
source=self.source,
|
source=self.source,
|
||||||
guards=self.make_guards(GuardBuilder.FUNCTION_MATCH),
|
**self.install_guards(GuardBuilder.FUNCTION_MATCH),
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
@ -347,7 +347,7 @@ class VariableBuilder:
|
|||||||
lambda self, value: TorchVariable(
|
lambda self, value: TorchVariable(
|
||||||
value,
|
value,
|
||||||
source=self.source,
|
source=self.source,
|
||||||
guards=self.make_guards(GuardBuilder.FUNCTION_MATCH),
|
**self.install_guards(GuardBuilder.FUNCTION_MATCH),
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
@ -375,8 +375,6 @@ class VariableBuilder:
|
|||||||
class Autotuner:
|
class Autotuner:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
make_guards = self.make_guards
|
|
||||||
|
|
||||||
# Handle exact type() match
|
# Handle exact type() match
|
||||||
type_dispatch = self._type_dispatch().get(type(value))
|
type_dispatch = self._type_dispatch().get(type(value))
|
||||||
if type_dispatch is not None:
|
if type_dispatch is not None:
|
||||||
@ -400,13 +398,13 @@ class VariableBuilder:
|
|||||||
return self.wrap_listlike(value)
|
return self.wrap_listlike(value)
|
||||||
|
|
||||||
elif value is torch.utils._pytree.SUPPORTED_NODES:
|
elif value is torch.utils._pytree.SUPPORTED_NODES:
|
||||||
|
# For SUPPORTED_NODES, we guard on the dictionary version (PEP509)
|
||||||
|
# under the assumption that the values themselves don't change.
|
||||||
|
self.install_guards(GuardBuilder.DICT_VERSION)
|
||||||
result = {
|
result = {
|
||||||
k: UserDefinedObjectVariable(
|
k: UserDefinedObjectVariable(
|
||||||
value[k],
|
value[k],
|
||||||
source=GetItemSource(self.get_source(), k),
|
source=GetItemSource(self.get_source(), k),
|
||||||
# For SUPPORTED_NODES, we guard on the dictionary version (PEP509)
|
|
||||||
# under the assumption that the values themselves don't change.
|
|
||||||
guards=self.make_guards(GuardBuilder.DICT_VERSION),
|
|
||||||
)
|
)
|
||||||
for k in value.keys()
|
for k in value.keys()
|
||||||
}
|
}
|
||||||
@ -429,9 +427,9 @@ class VariableBuilder:
|
|||||||
# Why is this OK for (specialized) nnmodules? We set up a setattr hook
|
# Why is this OK for (specialized) nnmodules? We set up a setattr hook
|
||||||
# to check for module property mutations, which does a reasonable,
|
# to check for module property mutations, which does a reasonable,
|
||||||
# but not completely secure job ensuring a property wasn't changed.
|
# but not completely secure job ensuring a property wasn't changed.
|
||||||
guards = self.make_guards(GuardBuilder.BOOL_FALSE)
|
self.install_guards(GuardBuilder.BOOL_FALSE)
|
||||||
else:
|
else:
|
||||||
guards = self.make_guards(GuardBuilder.DICT_KEYS)
|
self.install_guards(GuardBuilder.DICT_KEYS)
|
||||||
|
|
||||||
# store key variables in global location for reconstruction
|
# store key variables in global location for reconstruction
|
||||||
for key in value.keys():
|
for key in value.keys():
|
||||||
@ -448,7 +446,7 @@ class VariableBuilder:
|
|||||||
k: LazyVariableTracker.create(
|
k: LazyVariableTracker.create(
|
||||||
value[k],
|
value[k],
|
||||||
source=GetItemSource(self.get_source(), index_source(k)),
|
source=GetItemSource(self.get_source(), index_source(k)),
|
||||||
).add_guards(guards)
|
)
|
||||||
for k in value.keys()
|
for k in value.keys()
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -457,10 +455,9 @@ class VariableBuilder:
|
|||||||
result,
|
result,
|
||||||
type(value),
|
type(value),
|
||||||
self._wrap(value.default_factory),
|
self._wrap(value.default_factory),
|
||||||
guards=guards,
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
result = ConstDictVariable(result, type(value), guards=guards)
|
result = ConstDictVariable(result, type(value))
|
||||||
|
|
||||||
return self.tx.output.side_effects.track_dict(self.source, value, result)
|
return self.tx.output.side_effects.track_dict(self.source, value, result)
|
||||||
elif isinstance(value, torch.nn.Module):
|
elif isinstance(value, torch.nn.Module):
|
||||||
@ -472,23 +469,14 @@ class VariableBuilder:
|
|||||||
):
|
):
|
||||||
# For frozenset, we can guard by object ID instead of value
|
# For frozenset, we can guard by object ID instead of value
|
||||||
# equality, this allows us to handle non-literal values
|
# equality, this allows us to handle non-literal values
|
||||||
return ConstantVariable.create(
|
self.install_guards(GuardBuilder.ID_MATCH)
|
||||||
value=value,
|
return ConstantVariable.create(value=value, source=self.source)
|
||||||
source=self.source,
|
|
||||||
guards=make_guards(GuardBuilder.ID_MATCH),
|
|
||||||
)
|
|
||||||
elif isinstance(value, enum.Enum):
|
elif isinstance(value, enum.Enum):
|
||||||
return EnumVariable(
|
self.install_guards(GuardBuilder.ID_MATCH)
|
||||||
value=value,
|
return EnumVariable(value=value, source=self.source)
|
||||||
source=self.source,
|
|
||||||
guards=make_guards(GuardBuilder.ID_MATCH),
|
|
||||||
)
|
|
||||||
elif is_builtin_callable(value):
|
elif is_builtin_callable(value):
|
||||||
return BuiltinVariable(
|
self.install_guards(GuardBuilder.BUILTIN_MATCH)
|
||||||
value,
|
return BuiltinVariable(value, source=self.source)
|
||||||
source=self.source,
|
|
||||||
guards=make_guards(GuardBuilder.BUILTIN_MATCH),
|
|
||||||
)
|
|
||||||
elif is_utils_checkpoint(value):
|
elif is_utils_checkpoint(value):
|
||||||
return build_checkpoint_variable(source=self.source)
|
return build_checkpoint_variable(source=self.source)
|
||||||
elif isinstance(value, functools.partial):
|
elif isinstance(value, functools.partial):
|
||||||
@ -509,52 +497,50 @@ class VariableBuilder:
|
|||||||
self.tx, GetItemSource(keywords_source, k)
|
self.tx, GetItemSource(keywords_source, k)
|
||||||
)(v)
|
)(v)
|
||||||
|
|
||||||
guards = {
|
install_guard(
|
||||||
self.get_source().make_guard(GuardBuilder.TYPE_MATCH),
|
self.get_source().make_guard(GuardBuilder.TYPE_MATCH),
|
||||||
keywords_source.make_guard(GuardBuilder.DICT_KEYS),
|
keywords_source.make_guard(GuardBuilder.DICT_KEYS),
|
||||||
args_source.make_guard(GuardBuilder.LIST_LENGTH),
|
args_source.make_guard(GuardBuilder.LIST_LENGTH),
|
||||||
}
|
|
||||||
|
|
||||||
return FunctoolsPartialVariable(
|
|
||||||
func_obj, args, keywords, original=value, guards=guards
|
|
||||||
)
|
)
|
||||||
|
return FunctoolsPartialVariable(func_obj, args, keywords, original=value)
|
||||||
elif is_typing(value):
|
elif is_typing(value):
|
||||||
# typing.List, typing.Mapping, etc.
|
# typing.List, typing.Mapping, etc.
|
||||||
|
self.install_guards(GuardBuilder.ID_MATCH)
|
||||||
return TypingVariable(
|
return TypingVariable(
|
||||||
value,
|
value,
|
||||||
source=self.source,
|
source=self.source,
|
||||||
guards=make_guards(GuardBuilder.ID_MATCH),
|
|
||||||
)
|
)
|
||||||
elif np is not None and isinstance(value, np.generic):
|
elif np is not None and isinstance(value, np.generic):
|
||||||
# numpy array scalars: convert to 0D arrays
|
# numpy array scalars: convert to 0D arrays
|
||||||
return self.wrap_numpy_ndarray(np.asarray(value))
|
return self.wrap_numpy_ndarray(np.asarray(value))
|
||||||
elif is_numpy(value):
|
elif is_numpy(value):
|
||||||
assert np
|
assert np
|
||||||
return NumpyVariable(
|
self.install_guards(
|
||||||
value,
|
GuardBuilder.FUNCTION_MATCH
|
||||||
source=self.source,
|
if callable(value)
|
||||||
guards=make_guards(
|
else GuardBuilder.TYPE_MATCH
|
||||||
GuardBuilder.FUNCTION_MATCH
|
|
||||||
if callable(value)
|
|
||||||
else GuardBuilder.TYPE_MATCH
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
return NumpyVariable(value, source=self.source)
|
||||||
# NB: These can't be put in type_dispatch, they have to run later
|
# NB: These can't be put in type_dispatch, they have to run later
|
||||||
elif CollectiveFunctionRewriteVariable.can_rewrite(value):
|
elif CollectiveFunctionRewriteVariable.can_rewrite(value):
|
||||||
|
self.install_guards(GuardBuilder.FUNCTION_MATCH)
|
||||||
return CollectiveFunctionRewriteVariable.create(
|
return CollectiveFunctionRewriteVariable.create(
|
||||||
self.tx,
|
self.tx,
|
||||||
value,
|
value,
|
||||||
source=self.source,
|
source=self.source,
|
||||||
guards=make_guards(GuardBuilder.FUNCTION_MATCH),
|
|
||||||
)
|
)
|
||||||
elif istype(value, torch.autograd.function.FunctionMeta):
|
elif istype(value, torch.autograd.function.FunctionMeta):
|
||||||
|
self.install_guards(GuardBuilder.FUNCTION_MATCH)
|
||||||
return AutogradFunctionVariable(
|
return AutogradFunctionVariable(
|
||||||
value,
|
value,
|
||||||
source=self.source,
|
source=self.source,
|
||||||
guards=make_guards(GuardBuilder.FUNCTION_MATCH),
|
|
||||||
)
|
)
|
||||||
elif isinstance(value, torch.autograd.function.FunctionCtx):
|
elif isinstance(value, torch.autograd.function.FunctionCtx):
|
||||||
saved_tensors_source = AttrSource(self.source, "saved_tensors")
|
saved_tensors_source = AttrSource(self.source, "saved_tensors")
|
||||||
|
install_guard(
|
||||||
|
self.source.make_guard(GuardBuilder.TYPE_MATCH),
|
||||||
|
saved_tensors_source.make_guard(GuardBuilder.LIST_LENGTH),
|
||||||
|
)
|
||||||
saved_tensors = [
|
saved_tensors = [
|
||||||
VariableBuilder(self.tx, GetItemSource(saved_tensors_source, n))(v)
|
VariableBuilder(self.tx, GetItemSource(saved_tensors_source, n))(v)
|
||||||
for n, v in enumerate(value.saved_tensors)
|
for n, v in enumerate(value.saved_tensors)
|
||||||
@ -565,8 +551,6 @@ class VariableBuilder:
|
|||||||
AutogradFunctionContextVariable(
|
AutogradFunctionContextVariable(
|
||||||
value,
|
value,
|
||||||
source=self.source,
|
source=self.source,
|
||||||
guards=make_guards(GuardBuilder.TYPE_MATCH)
|
|
||||||
| {saved_tensors_source.make_guard(GuardBuilder.LIST_LENGTH)},
|
|
||||||
saved_tensors=SavedTensorBox(saved_tensors),
|
saved_tensors=SavedTensorBox(saved_tensors),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
@ -579,53 +563,43 @@ class VariableBuilder:
|
|||||||
and value == getattr(value.__self__, "apply", None)
|
and value == getattr(value.__self__, "apply", None)
|
||||||
):
|
):
|
||||||
# handle aliased autograd function `apply` calls
|
# handle aliased autograd function `apply` calls
|
||||||
|
self.install_guards(GuardBuilder.FUNCTION_MATCH)
|
||||||
return GetAttrVariable(
|
return GetAttrVariable(
|
||||||
AutogradFunctionVariable(
|
AutogradFunctionVariable(value.__self__, source=self.source),
|
||||||
value.__self__,
|
|
||||||
source=self.source,
|
|
||||||
guards=make_guards(GuardBuilder.FUNCTION_MATCH),
|
|
||||||
),
|
|
||||||
"apply",
|
"apply",
|
||||||
)
|
)
|
||||||
elif np and isinstance(value, np.number):
|
elif np and isinstance(value, np.number):
|
||||||
return self.wrap_unspecialized_primitive(value)
|
return self.wrap_unspecialized_primitive(value)
|
||||||
elif DataClassVariable.is_matching_object(value):
|
elif DataClassVariable.is_matching_object(value):
|
||||||
return DataClassVariable.wrap(self, value).add_guards(
|
self.install_guards(GuardBuilder.TYPE_MATCH)
|
||||||
make_guards(GuardBuilder.TYPE_MATCH)
|
return DataClassVariable.wrap(self, value)
|
||||||
)
|
|
||||||
elif HFPretrainedConfigVariable.is_matching_object(value):
|
elif HFPretrainedConfigVariable.is_matching_object(value):
|
||||||
return HFPretrainedConfigVariable(
|
self.install_guards(GuardBuilder.TYPE_MATCH)
|
||||||
value, guards=make_guards(GuardBuilder.TYPE_MATCH)
|
return HFPretrainedConfigVariable(value)
|
||||||
)
|
|
||||||
elif isinstance(value, HigherOrderOperator):
|
elif isinstance(value, HigherOrderOperator):
|
||||||
return TorchHigherOrderOperatorVariable.make(
|
self.install_guards(GuardBuilder.TYPE_MATCH, GuardBuilder.NAME_MATCH)
|
||||||
value,
|
return TorchHigherOrderOperatorVariable.make(value, source=self.source)
|
||||||
source=self.source,
|
|
||||||
guards=self.make_guards(
|
|
||||||
GuardBuilder.TYPE_MATCH, GuardBuilder.NAME_MATCH
|
|
||||||
),
|
|
||||||
)
|
|
||||||
elif type(value).__name__ == "builtin_function_or_method" and isinstance(
|
elif type(value).__name__ == "builtin_function_or_method" and isinstance(
|
||||||
value.__self__, torch_special_class_types
|
value.__self__, torch_special_class_types
|
||||||
):
|
):
|
||||||
|
self.install_guards(GuardBuilder.FUNCTION_MATCH)
|
||||||
return TorchVariable(
|
return TorchVariable(
|
||||||
value,
|
value,
|
||||||
guards=make_guards(GuardBuilder.FUNCTION_MATCH),
|
|
||||||
)
|
)
|
||||||
elif isinstance(value, _StreamBase):
|
elif isinstance(value, _StreamBase):
|
||||||
|
self.install_guards(GuardBuilder.ID_MATCH)
|
||||||
return StreamVariable(
|
return StreamVariable(
|
||||||
None,
|
None,
|
||||||
value,
|
value,
|
||||||
value.device.type,
|
value.device.type,
|
||||||
source=self.source,
|
source=self.source,
|
||||||
guards=make_guards(GuardBuilder.ID_MATCH),
|
|
||||||
)
|
)
|
||||||
elif isinstance(value, _EventBase):
|
elif isinstance(value, _EventBase):
|
||||||
|
self.install_guards(GuardBuilder.ID_MATCH)
|
||||||
return EventVariable(
|
return EventVariable(
|
||||||
None,
|
None,
|
||||||
value,
|
value,
|
||||||
source=self.source,
|
source=self.source,
|
||||||
guards=make_guards(GuardBuilder.ID_MATCH),
|
|
||||||
)
|
)
|
||||||
elif (
|
elif (
|
||||||
isinstance(value, torch._C._TensorMeta)
|
isinstance(value, torch._C._TensorMeta)
|
||||||
@ -636,55 +610,36 @@ class VariableBuilder:
|
|||||||
istype(value, contextlib.nullcontext)
|
istype(value, contextlib.nullcontext)
|
||||||
and inspect.getattr_static(value, "enter_result", None) is None
|
and inspect.getattr_static(value, "enter_result", None) is None
|
||||||
):
|
):
|
||||||
return NullContextVariable(
|
# TODO(jansel): I think this can be TYPE_MATCH
|
||||||
source=self.source,
|
self.install_guards(GuardBuilder.FUNCTION_MATCH)
|
||||||
guards=make_guards(GuardBuilder.FUNCTION_MATCH),
|
return NullContextVariable(source=self.source)
|
||||||
)
|
|
||||||
elif KeyedJaggedTensorVariable.is_matching_object(value):
|
elif KeyedJaggedTensorVariable.is_matching_object(value):
|
||||||
result = KeyedJaggedTensorVariable(
|
self.install_guards(GuardBuilder.TYPE_MATCH)
|
||||||
value,
|
result = KeyedJaggedTensorVariable(value, source=self.source)
|
||||||
source=self.source,
|
|
||||||
guards=self.make_guards(GuardBuilder.TYPE_MATCH),
|
|
||||||
)
|
|
||||||
# TODO: this doing it manually is bad
|
# TODO: this doing it manually is bad
|
||||||
return self.tx.output.side_effects.track_object_existing(
|
return self.tx.output.side_effects.track_object_existing(
|
||||||
self.source, value, result
|
self.source, value, result
|
||||||
)
|
)
|
||||||
elif isinstance(value, torch.optim.Optimizer):
|
elif isinstance(value, torch.optim.Optimizer):
|
||||||
return OptimizerVariable(
|
self.install_guards(GuardBuilder.TYPE_MATCH)
|
||||||
value,
|
return OptimizerVariable(value, source=self.source)
|
||||||
source=self.source,
|
|
||||||
guards=self.make_guards(GuardBuilder.TYPE_MATCH),
|
|
||||||
)
|
|
||||||
elif ProcessGroupVariable.is_process_group(value):
|
elif ProcessGroupVariable.is_process_group(value):
|
||||||
return ProcessGroupVariable(
|
self.install_guards(GuardBuilder.ID_MATCH)
|
||||||
value,
|
return ProcessGroupVariable(value, source=self.source)
|
||||||
source=self.source,
|
|
||||||
guards=self.make_guards(GuardBuilder.ID_MATCH),
|
|
||||||
)
|
|
||||||
elif DeviceMeshVariable.is_device_mesh(value):
|
elif DeviceMeshVariable.is_device_mesh(value):
|
||||||
# TODO: see if we need to add custom guard instead
|
# TODO: see if we need to add custom guard instead of a simple ID_MATCH
|
||||||
# of a simple ID_MATCH
|
self.install_guards(GuardBuilder.ID_MATCH)
|
||||||
return DeviceMeshVariable(
|
return DeviceMeshVariable(value, source=self.source)
|
||||||
value,
|
|
||||||
source=self.source,
|
|
||||||
guards=self.make_guards(GuardBuilder.ID_MATCH),
|
|
||||||
)
|
|
||||||
elif PlacementClassVariable.is_placement_type(value):
|
elif PlacementClassVariable.is_placement_type(value):
|
||||||
# TODO: see if we need to add custom guard instead
|
# TODO: see if we need to add custom guard instead of a simple ID_MATCH
|
||||||
# of a simple ID_MATCH
|
self.install_guards(GuardBuilder.ID_MATCH)
|
||||||
return PlacementClassVariable(
|
return PlacementClassVariable(value, source=self.source)
|
||||||
value,
|
|
||||||
source=self.source,
|
|
||||||
guards=make_guards(GuardBuilder.ID_MATCH),
|
|
||||||
)
|
|
||||||
elif PlacementVariable.is_placement(value):
|
elif PlacementVariable.is_placement(value):
|
||||||
# TODO: see if we need to add custom guard instead
|
# TODO: see if we need to add custom guard instead of a simple ID_MATCH
|
||||||
# of a simple ID_MATCH
|
self.install_guards(GuardBuilder.ID_MATCH)
|
||||||
return PlacementVariable(
|
return PlacementVariable(
|
||||||
value,
|
value,
|
||||||
source=self.source,
|
source=self.source,
|
||||||
guards=make_guards(GuardBuilder.ID_MATCH),
|
|
||||||
)
|
)
|
||||||
elif isinstance(value, torch.SymBool):
|
elif isinstance(value, torch.SymBool):
|
||||||
# Note: the idea here is to re-use the infra we've built for SymInt by simulating the
|
# Note: the idea here is to re-use the infra we've built for SymInt by simulating the
|
||||||
@ -727,12 +682,12 @@ class VariableBuilder:
|
|||||||
new_symint == 1,
|
new_symint == 1,
|
||||||
)
|
)
|
||||||
elif isinstance(value, (JITFunction, Autotuner)):
|
elif isinstance(value, (JITFunction, Autotuner)):
|
||||||
|
self.install_guards(GuardBuilder.ID_MATCH)
|
||||||
return TritonKernelVariable(
|
return TritonKernelVariable(
|
||||||
value,
|
value,
|
||||||
None, # No kernel idx provided
|
None, # No kernel idx provided
|
||||||
None, # No grid provided
|
None, # No grid provided
|
||||||
source=self.source,
|
source=self.source,
|
||||||
guards=make_guards(GuardBuilder.ID_MATCH),
|
|
||||||
)
|
)
|
||||||
elif trace_rules.lookup(value) is not None:
|
elif trace_rules.lookup(value) is not None:
|
||||||
return trace_rules.lookup(value).create_with_source(
|
return trace_rules.lookup(value).create_with_source(
|
||||||
@ -741,10 +696,10 @@ class VariableBuilder:
|
|||||||
elif is_allowed(value):
|
elif is_allowed(value):
|
||||||
if is_user_defined_allowed(value):
|
if is_user_defined_allowed(value):
|
||||||
self.tx.output.has_user_defined_allowed_in_graph = True
|
self.tx.output.has_user_defined_allowed_in_graph = True
|
||||||
|
self.install_guards(GuardBuilder.FUNCTION_MATCH)
|
||||||
return TorchVariable(
|
return TorchVariable(
|
||||||
value,
|
value,
|
||||||
source=self.source,
|
source=self.source,
|
||||||
guards=make_guards(GuardBuilder.FUNCTION_MATCH),
|
|
||||||
)
|
)
|
||||||
elif (
|
elif (
|
||||||
istype(value, (type, types.FunctionType))
|
istype(value, (type, types.FunctionType))
|
||||||
@ -752,17 +707,17 @@ class VariableBuilder:
|
|||||||
and not inspect.getattr_static(value, "_torchdynamo_inline", False)
|
and not inspect.getattr_static(value, "_torchdynamo_inline", False)
|
||||||
and not inspect.getattr_static(value, "__script_if_tracing_wrapper", False)
|
and not inspect.getattr_static(value, "__script_if_tracing_wrapper", False)
|
||||||
):
|
):
|
||||||
|
self.install_guards(GuardBuilder.FUNCTION_MATCH)
|
||||||
return SkipFilesVariable(
|
return SkipFilesVariable(
|
||||||
value,
|
value,
|
||||||
skipfiles.check_verbose(value, allow_torch=True).reason,
|
skipfiles.check_verbose(value, allow_torch=True).reason,
|
||||||
source=self.source,
|
source=self.source,
|
||||||
guards=make_guards(GuardBuilder.FUNCTION_MATCH),
|
|
||||||
)
|
)
|
||||||
elif istype(value, (types.FunctionType, torch.jit.ScriptFunction)):
|
elif istype(value, (types.FunctionType, torch.jit.ScriptFunction)):
|
||||||
|
self.install_guards(GuardBuilder.FUNCTION_MATCH)
|
||||||
return UserFunctionVariable(
|
return UserFunctionVariable(
|
||||||
value,
|
value,
|
||||||
source=self.source,
|
source=self.source,
|
||||||
guards=make_guards(GuardBuilder.FUNCTION_MATCH),
|
|
||||||
)
|
)
|
||||||
elif isinstance(value, types.MethodType) and isinstance(
|
elif isinstance(value, types.MethodType) and isinstance(
|
||||||
value.__self__, torch.nn.Module
|
value.__self__, torch.nn.Module
|
||||||
@ -784,40 +739,33 @@ class VariableBuilder:
|
|||||||
assert self_obj and isinstance(
|
assert self_obj and isinstance(
|
||||||
self_obj, VariableTracker
|
self_obj, VariableTracker
|
||||||
), "Failed to produce a valid self obj"
|
), "Failed to produce a valid self obj"
|
||||||
|
self.install_guards(GuardBuilder.FUNCTION_MATCH)
|
||||||
return UserMethodVariable(
|
return UserMethodVariable(
|
||||||
value.__func__,
|
value.__func__,
|
||||||
self_obj,
|
self_obj,
|
||||||
source=self.source,
|
source=self.source,
|
||||||
guards=make_guards(GuardBuilder.FUNCTION_MATCH),
|
|
||||||
)
|
)
|
||||||
elif istype(value, (types.ModuleType, replay_record.DummyModule)):
|
elif istype(value, (types.ModuleType, replay_record.DummyModule)):
|
||||||
|
self.install_guards(GuardBuilder.FUNCTION_MATCH)
|
||||||
return PythonModuleVariable(
|
return PythonModuleVariable(
|
||||||
value,
|
value,
|
||||||
source=self.source,
|
source=self.source,
|
||||||
guards=make_guards(GuardBuilder.PYMODULE_MATCH),
|
|
||||||
)
|
)
|
||||||
elif isinstance(value, types.GetSetDescriptorType):
|
elif isinstance(value, types.GetSetDescriptorType):
|
||||||
return GetSetDescriptorVariable(
|
self.install_guards(GuardBuilder.FUNCTION_MATCH)
|
||||||
value, guards=self.make_guards(GuardBuilder.FUNCTION_MATCH)
|
return GetSetDescriptorVariable(value)
|
||||||
)
|
|
||||||
elif isinstance(value, types.MethodWrapperType):
|
elif isinstance(value, types.MethodWrapperType):
|
||||||
return MethodWrapperVariable(
|
self.install_guards(GuardBuilder.FUNCTION_MATCH)
|
||||||
value,
|
return MethodWrapperVariable(value, source=self.source)
|
||||||
source=self.source,
|
|
||||||
guards=self.make_guards(GuardBuilder.FUNCTION_MATCH),
|
|
||||||
)
|
|
||||||
elif issubclass(type(value), type):
|
elif issubclass(type(value), type):
|
||||||
|
self.install_guards(GuardBuilder.FUNCTION_MATCH)
|
||||||
return UserDefinedClassVariable(
|
return UserDefinedClassVariable(
|
||||||
value,
|
value,
|
||||||
source=self.source,
|
source=self.source,
|
||||||
guards=make_guards(GuardBuilder.FUNCTION_MATCH),
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
result = UserDefinedObjectVariable(
|
self.install_guards(GuardBuilder.TYPE_MATCH)
|
||||||
value,
|
result = UserDefinedObjectVariable(value, source=self.source)
|
||||||
source=self.source,
|
|
||||||
guards=self.make_guards(GuardBuilder.TYPE_MATCH),
|
|
||||||
)
|
|
||||||
if not SideEffects.cls_supports_mutation_side_effects(type(value)):
|
if not SideEffects.cls_supports_mutation_side_effects(type(value)):
|
||||||
# don't allow STORE_ATTR mutation with custom __setattr__
|
# don't allow STORE_ATTR mutation with custom __setattr__
|
||||||
return result
|
return result
|
||||||
@ -857,36 +805,32 @@ class VariableBuilder:
|
|||||||
def wrap_listlike(self, value: Union[tuple, list, odict_values, NamedTuple]):
|
def wrap_listlike(self, value: Union[tuple, list, odict_values, NamedTuple]):
|
||||||
# One can index a tensor with a list/tuple. Therefore, we need to
|
# One can index a tensor with a list/tuple. Therefore, we need to
|
||||||
# have a stricter match.
|
# have a stricter match.
|
||||||
guards = self.make_guards(GuardBuilder.LIST_LENGTH)
|
self.install_guards(GuardBuilder.LIST_LENGTH)
|
||||||
|
|
||||||
for item in value:
|
for item in value:
|
||||||
if item is value:
|
if item is value:
|
||||||
unimplemented("list elements are pointing to the list itself")
|
unimplemented("list elements are pointing to the list itself")
|
||||||
|
|
||||||
output = [
|
output = [
|
||||||
VariableBuilder(self.tx, GetItemSource(self.get_source(), i))(
|
VariableBuilder(self.tx, GetItemSource(self.get_source(), i))(item)
|
||||||
item
|
|
||||||
).add_guards(guards)
|
|
||||||
for i, item in enumerate(value)
|
for i, item in enumerate(value)
|
||||||
]
|
]
|
||||||
result = BaseListVariable.cls_for_instance(value)(
|
result = BaseListVariable.cls_for_instance(value)(
|
||||||
output, mutable_local=MutableLocal(), guards=guards
|
output, mutable_local=MutableLocal()
|
||||||
)
|
)
|
||||||
if istype(value, list):
|
if istype(value, list):
|
||||||
return self.tx.output.side_effects.track_list(self.source, value, result)
|
return self.tx.output.side_effects.track_list(self.source, value, result)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def wrap_tuple_iterator(self, value: tuple_iterator):
|
def wrap_tuple_iterator(self, value: tuple_iterator):
|
||||||
guards = self.make_guards(GuardBuilder.TUPLE_ITERATOR_LEN)
|
self.install_guards(GuardBuilder.TUPLE_ITERATOR_LEN)
|
||||||
output = [
|
output = [
|
||||||
VariableBuilder(self.tx, TupleIteratorGetItemSource(self.get_source(), i))(
|
VariableBuilder(self.tx, TupleIteratorGetItemSource(self.get_source(), i))(
|
||||||
tuple_iterator_getitem(value, i)
|
tuple_iterator_getitem(value, i)
|
||||||
).add_guards(guards)
|
)
|
||||||
for i in range(tuple_iterator_len(value))
|
for i in range(tuple_iterator_len(value))
|
||||||
]
|
]
|
||||||
return TupleIteratorVariable(
|
return TupleIteratorVariable(output, mutable_local=MutableLocal())
|
||||||
output, mutable_local=MutableLocal(), guards=guards
|
|
||||||
)
|
|
||||||
|
|
||||||
def wrap_slice_range(self, value: Union[slice, range]):
|
def wrap_slice_range(self, value: Union[slice, range]):
|
||||||
items = [
|
items = [
|
||||||
@ -896,21 +840,20 @@ class VariableBuilder:
|
|||||||
for k in ("start", "stop", "step")
|
for k in ("start", "stop", "step")
|
||||||
]
|
]
|
||||||
if isinstance(value, slice):
|
if isinstance(value, slice):
|
||||||
return SliceVariable(
|
self.install_guards(GuardBuilder.TYPE_MATCH)
|
||||||
items, guards=self.make_guards(GuardBuilder.TYPE_MATCH)
|
return SliceVariable(items)
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
return RangeVariable(
|
# TODO(jansel): I think this can be TYPE_MATCH
|
||||||
items, guards=self.make_guards(GuardBuilder.EQUALS_MATCH)
|
self.install_guards(GuardBuilder.EQUALS_MATCH)
|
||||||
)
|
return RangeVariable(items)
|
||||||
|
|
||||||
def wrap_module(self, value: torch.nn.Module):
|
def wrap_module(self, value: torch.nn.Module):
|
||||||
from ..eval_frame import OptimizedModule
|
from ..eval_frame import OptimizedModule
|
||||||
|
|
||||||
if istype(value, OptimizedModule):
|
if istype(value, OptimizedModule):
|
||||||
guards = self.make_guards(GuardBuilder.TYPE_MATCH)
|
self.install_guards(GuardBuilder.TYPE_MATCH)
|
||||||
self.source = AttrSource(self.source, "_orig_mod")
|
self.source = AttrSource(self.source, "_orig_mod")
|
||||||
return self.wrap_module(value._orig_mod).add_guards(guards)
|
return self.wrap_module(value._orig_mod)
|
||||||
|
|
||||||
if (
|
if (
|
||||||
isinstance(value, (torch.nn.RNN, torch.nn.GRU, torch.nn.LSTM))
|
isinstance(value, (torch.nn.RNN, torch.nn.GRU, torch.nn.LSTM))
|
||||||
@ -919,9 +862,8 @@ class VariableBuilder:
|
|||||||
unimplemented("TorchDynamo purposely graph breaks on RNN, GRU, LSTMs")
|
unimplemented("TorchDynamo purposely graph breaks on RNN, GRU, LSTMs")
|
||||||
if mutation_guard.is_dynamic_nn_module(value):
|
if mutation_guard.is_dynamic_nn_module(value):
|
||||||
# created dynamically, don't specialize on it
|
# created dynamically, don't specialize on it
|
||||||
result = UnspecializedNNModuleVariable(
|
self.install_guards(GuardBuilder.TYPE_MATCH)
|
||||||
value, guards=self.make_guards(GuardBuilder.TYPE_MATCH)
|
result = UnspecializedNNModuleVariable(value)
|
||||||
)
|
|
||||||
if not SideEffects.cls_supports_mutation_side_effects(type(value)):
|
if not SideEffects.cls_supports_mutation_side_effects(type(value)):
|
||||||
# don't allow STORE_ATTR mutation with custom __setattr__
|
# don't allow STORE_ATTR mutation with custom __setattr__
|
||||||
return result
|
return result
|
||||||
@ -931,9 +873,8 @@ class VariableBuilder:
|
|||||||
elif issubclass(
|
elif issubclass(
|
||||||
value.__class__, torch.nn.parallel.distributed.DistributedDataParallel
|
value.__class__, torch.nn.parallel.distributed.DistributedDataParallel
|
||||||
):
|
):
|
||||||
return UnspecializedNNModuleVariable(
|
self.install_guards(GuardBuilder.TYPE_MATCH)
|
||||||
value, guards=self.make_guards(GuardBuilder.TYPE_MATCH)
|
return UnspecializedNNModuleVariable(value)
|
||||||
)
|
|
||||||
elif getattr(value, "_is_fsdp_managed_module", False):
|
elif getattr(value, "_is_fsdp_managed_module", False):
|
||||||
# See note [Dynamo treats FSDP wrapped modules as UnspecializedNNModule]
|
# See note [Dynamo treats FSDP wrapped modules as UnspecializedNNModule]
|
||||||
# in fully_sharded_data_parallel.py for more information
|
# in fully_sharded_data_parallel.py for more information
|
||||||
@ -960,11 +901,8 @@ class VariableBuilder:
|
|||||||
#
|
#
|
||||||
# ID_MATCH is required to disambiguate cases as simple as a unit test that constructs 2 models and wraps
|
# ID_MATCH is required to disambiguate cases as simple as a unit test that constructs 2 models and wraps
|
||||||
# them differently with different FSDP configs. (test_dynamo_distributed.py -k test_fsdp_aot_eager)
|
# them differently with different FSDP configs. (test_dynamo_distributed.py -k test_fsdp_aot_eager)
|
||||||
return FSDPManagedNNModuleVariable(
|
self.install_guards(GuardBuilder.TYPE_MATCH, GuardBuilder.ID_MATCH)
|
||||||
value,
|
return FSDPManagedNNModuleVariable(value, source=self.get_source())
|
||||||
guards=self.make_guards(GuardBuilder.TYPE_MATCH, GuardBuilder.ID_MATCH),
|
|
||||||
source=self.get_source(),
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
return self.tx.output.register_attr_or_module(
|
return self.tx.output.register_attr_or_module(
|
||||||
value,
|
value,
|
||||||
@ -976,12 +914,12 @@ class VariableBuilder:
|
|||||||
def wrap_literal(self, value):
|
def wrap_literal(self, value):
|
||||||
unspec = not config.specialize_int
|
unspec = not config.specialize_int
|
||||||
if unspec and type(value) is torch.Size:
|
if unspec and type(value) is torch.Size:
|
||||||
|
self.install_guards(GuardBuilder.LIST_LENGTH)
|
||||||
return SizeVariable(
|
return SizeVariable(
|
||||||
[
|
[
|
||||||
VariableBuilder(self.tx, GetItemSource(self.get_source(), i))(v)
|
VariableBuilder(self.tx, GetItemSource(self.get_source(), i))(v)
|
||||||
for i, v in enumerate(value)
|
for i, v in enumerate(value)
|
||||||
],
|
]
|
||||||
guards=self.make_guards(GuardBuilder.LIST_LENGTH),
|
|
||||||
)
|
)
|
||||||
elif unspec and type(value) is int:
|
elif unspec and type(value) is int:
|
||||||
# unspecializing int by default, but still
|
# unspecializing int by default, but still
|
||||||
@ -995,17 +933,13 @@ class VariableBuilder:
|
|||||||
# NN modules on the fly)
|
# NN modules on the fly)
|
||||||
or self.source.guard_source().is_nn_module()
|
or self.source.guard_source().is_nn_module()
|
||||||
):
|
):
|
||||||
return ConstantVariable.create(
|
self.install_guards(GuardBuilder.CONSTANT_MATCH)
|
||||||
value=value,
|
return ConstantVariable.create(value=value)
|
||||||
guards=self.make_guards(GuardBuilder.CONSTANT_MATCH),
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
return self.wrap_unspecialized_primitive(value)
|
return self.wrap_unspecialized_primitive(value)
|
||||||
else:
|
else:
|
||||||
return ConstantVariable.create(
|
self.install_guards(GuardBuilder.CONSTANT_MATCH)
|
||||||
value=value,
|
return ConstantVariable.create(value=value)
|
||||||
guards=self.make_guards(GuardBuilder.CONSTANT_MATCH),
|
|
||||||
)
|
|
||||||
|
|
||||||
def assert_not_wrapped_by_this_graph(self, value: torch.Tensor):
|
def assert_not_wrapped_by_this_graph(self, value: torch.Tensor):
|
||||||
if is_fake(value) and maybe_get_fake_mode(value) is self.tx.fake_mode:
|
if is_fake(value) and maybe_get_fake_mode(value) is self.tx.fake_mode:
|
||||||
@ -1027,11 +961,7 @@ class VariableBuilder:
|
|||||||
) and not source.guard_source().is_fsdp_module():
|
) and not source.guard_source().is_fsdp_module():
|
||||||
self.assert_not_wrapped_by_this_graph(value)
|
self.assert_not_wrapped_by_this_graph(value)
|
||||||
return self.tx.output.register_attr_or_module(
|
return self.tx.output.register_attr_or_module(
|
||||||
value,
|
value, self.name, source=source
|
||||||
self.name,
|
|
||||||
source=source,
|
|
||||||
# Guards are done inside register_attr_or_module
|
|
||||||
# guards=self.make_guards(GuardBuilder.TENSOR_MATCH),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if is_constant_source(source):
|
if is_constant_source(source):
|
||||||
@ -1099,20 +1029,7 @@ class VariableBuilder:
|
|||||||
options["torch_function_fn"] = build_torch_function_fn(
|
options["torch_function_fn"] = build_torch_function_fn(
|
||||||
self.tx, value, self.source
|
self.tx, value, self.source
|
||||||
)
|
)
|
||||||
options["guards"] = self.make_guards(GuardBuilder.TYPE_MATCH)
|
self.install_guards(GuardBuilder.TYPE_MATCH)
|
||||||
else:
|
|
||||||
options["guards"] = set()
|
|
||||||
|
|
||||||
options["guards"].update(
|
|
||||||
self.make_guards(
|
|
||||||
functools.partial(
|
|
||||||
GuardBuilder.TENSOR_MATCH,
|
|
||||||
value=value
|
|
||||||
if isinstance(source, NumpyTensorSource)
|
|
||||||
else TensorWeakRef(value),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
if (
|
if (
|
||||||
isinstance(value, torch.Tensor)
|
isinstance(value, torch.Tensor)
|
||||||
@ -1130,6 +1047,16 @@ class VariableBuilder:
|
|||||||
source=source,
|
source=source,
|
||||||
**options,
|
**options,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.install_guards(
|
||||||
|
functools.partial(
|
||||||
|
GuardBuilder.TENSOR_MATCH,
|
||||||
|
value=value
|
||||||
|
if isinstance(source, NumpyTensorSource)
|
||||||
|
else TensorWeakRef(value),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
self.tx.output.input_source_to_var[source] = tensor_variable
|
self.tx.output.input_source_to_var[source] = tensor_variable
|
||||||
assert "tensor_dict" not in tensor_proxy.node.meta
|
assert "tensor_dict" not in tensor_proxy.node.meta
|
||||||
tensor_proxy.node.meta["tensor_dict"] = value.__dict__.copy()
|
tensor_proxy.node.meta["tensor_dict"] = value.__dict__.copy()
|
||||||
@ -1172,11 +1099,11 @@ class VariableBuilder:
|
|||||||
# a tensor. It's a little annoying to make a VT to throw out, but there's so many side effects here
|
# a tensor. It's a little annoying to make a VT to throw out, but there's so many side effects here
|
||||||
# that there's not another great way to do this atm.
|
# that there's not another great way to do this atm.
|
||||||
# This creates the right graphargs, as well as registration for guards in tensor names and shape env.
|
# This creates the right graphargs, as well as registration for guards in tensor names and shape env.
|
||||||
tensor_vt = VariableBuilder(self.tx, source)(tensor_value)
|
VariableBuilder(self.tx, source)(tensor_value).recursive_realize()
|
||||||
proxy = self.tx.output.root_tracer.create_graph_input(
|
proxy = self.tx.output.root_tracer.create_graph_input(
|
||||||
re.sub(r"[^a-zA-Z0-9]+", "_", self.name), type(tensor_value), source=source
|
re.sub(r"[^a-zA-Z0-9]+", "_", self.name), type(tensor_value), source=source
|
||||||
)
|
)
|
||||||
options = {"source": source, "guards": tensor_vt.guards}
|
options = {"source": source}
|
||||||
numpy_ndarray_variable = wrap_fx_proxy_cls(
|
numpy_ndarray_variable = wrap_fx_proxy_cls(
|
||||||
target_cls=NumpyNdarrayVariable,
|
target_cls=NumpyNdarrayVariable,
|
||||||
tx=self.tx,
|
tx=self.tx,
|
||||||
@ -1229,10 +1156,8 @@ class VariableBuilder:
|
|||||||
# If specialize_int is False, also return
|
# If specialize_int is False, also return
|
||||||
# a constant (but this should have been handled
|
# a constant (but this should have been handled
|
||||||
# in the caller, TBH)
|
# in the caller, TBH)
|
||||||
return ConstantVariable.create(
|
self.install_guards(GuardBuilder.CONSTANT_MATCH)
|
||||||
value=value,
|
return ConstantVariable.create(value=value)
|
||||||
guards=self.make_guards(GuardBuilder.CONSTANT_MATCH),
|
|
||||||
)
|
|
||||||
|
|
||||||
name = self.source.name()
|
name = self.source.name()
|
||||||
if name not in self.tx.output.frame_state:
|
if name not in self.tx.output.frame_state:
|
||||||
@ -1264,10 +1189,8 @@ class VariableBuilder:
|
|||||||
else: # assume_static_by_default
|
else: # assume_static_by_default
|
||||||
# TODO: dynamic_dim = DimDynamic.STATIC should work but
|
# TODO: dynamic_dim = DimDynamic.STATIC should work but
|
||||||
# for some reason it doesn't
|
# for some reason it doesn't
|
||||||
return ConstantVariable.create(
|
self.install_guards(GuardBuilder.CONSTANT_MATCH)
|
||||||
value=value,
|
return ConstantVariable.create(value=value)
|
||||||
guards=self.make_guards(GuardBuilder.CONSTANT_MATCH),
|
|
||||||
)
|
|
||||||
|
|
||||||
wrapped_value = shape_env.create_unspecified_symint_and_symbol(
|
wrapped_value = shape_env.create_unspecified_symint_and_symbol(
|
||||||
value,
|
value,
|
||||||
@ -1281,11 +1204,8 @@ class VariableBuilder:
|
|||||||
else:
|
else:
|
||||||
wrapped_value = torch.tensor(value)
|
wrapped_value = torch.tensor(value)
|
||||||
if not isinstance(self.get_source(), RandomValueSource):
|
if not isinstance(self.get_source(), RandomValueSource):
|
||||||
guards = {self.get_source().make_guard(GuardBuilder.TYPE_MATCH)}
|
install_guard(self.get_source().make_guard(GuardBuilder.TYPE_MATCH))
|
||||||
options = {"guards": guards}
|
options = {"source": self.get_source()}
|
||||||
else:
|
|
||||||
options = {}
|
|
||||||
options.update({"source": self.get_source()})
|
|
||||||
if isinstance(wrapped_value, torch.Tensor):
|
if isinstance(wrapped_value, torch.Tensor):
|
||||||
options.update({"raw_value": value})
|
options.update({"raw_value": value})
|
||||||
|
|
||||||
@ -1352,16 +1272,19 @@ def _dataclasses_fields_lambda(obj):
|
|||||||
|
|
||||||
|
|
||||||
def wrap_fx_proxy(tx, proxy, example_value=None, subclass_type=None, **options):
|
def wrap_fx_proxy(tx, proxy, example_value=None, subclass_type=None, **options):
|
||||||
return wrap_fx_proxy_cls(
|
kwargs = {
|
||||||
target_cls=TensorVariable
|
"tx": tx,
|
||||||
if not subclass_type
|
"proxy": proxy,
|
||||||
else TensorWithTFOverrideVariable,
|
"example_value": example_value,
|
||||||
tx=tx,
|
"subclass_type": subclass_type,
|
||||||
proxy=proxy,
|
|
||||||
example_value=example_value,
|
|
||||||
subclass_type=subclass_type,
|
|
||||||
**options,
|
**options,
|
||||||
)
|
}
|
||||||
|
if subclass_type is None:
|
||||||
|
return wrap_fx_proxy_cls(target_cls=TensorVariable, **kwargs)
|
||||||
|
else:
|
||||||
|
result = wrap_fx_proxy_cls(target_cls=TensorWithTFOverrideVariable, **kwargs)
|
||||||
|
result.install_global(tx)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
# Note: Unfortunate split due to some gross classes existing that subclass TensorVariable
|
# Note: Unfortunate split due to some gross classes existing that subclass TensorVariable
|
||||||
|
|||||||
@ -19,7 +19,7 @@ from ..exc import (
|
|||||||
UserError,
|
UserError,
|
||||||
UserErrorType,
|
UserErrorType,
|
||||||
)
|
)
|
||||||
from ..guards import GuardBuilder
|
from ..guards import GuardBuilder, install_guard
|
||||||
from ..replay_record import DummyModule
|
from ..replay_record import DummyModule
|
||||||
from ..source import AttrSource, GetItemSource, is_constant_source, TypeSource
|
from ..source import AttrSource, GetItemSource, is_constant_source, TypeSource
|
||||||
from ..utils import (
|
from ..utils import (
|
||||||
@ -339,7 +339,6 @@ class BuiltinVariable(VariableTracker):
|
|||||||
a,
|
a,
|
||||||
ListVariable(
|
ListVariable(
|
||||||
list(a.items) + list(b.unpack_var_sequence(tx)),
|
list(a.items) + list(b.unpack_var_sequence(tx)),
|
||||||
regen_guards=False,
|
|
||||||
**options,
|
**options,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
@ -826,23 +825,22 @@ class BuiltinVariable(VariableTracker):
|
|||||||
mutable_local=MutableLocal(),
|
mutable_local=MutableLocal(),
|
||||||
)
|
)
|
||||||
elif obj.has_unpack_var_sequence(tx):
|
elif obj.has_unpack_var_sequence(tx):
|
||||||
guards = set()
|
|
||||||
if obj.source and not is_constant_source(obj.source):
|
if obj.source and not is_constant_source(obj.source):
|
||||||
if isinstance(obj, TupleIteratorVariable):
|
if isinstance(obj, TupleIteratorVariable):
|
||||||
guards.add(obj.source.make_guard(GuardBuilder.TUPLE_ITERATOR_LEN))
|
install_guard(
|
||||||
|
obj.source.make_guard(GuardBuilder.TUPLE_ITERATOR_LEN)
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
guards.add(obj.source.make_guard(GuardBuilder.LIST_LENGTH))
|
install_guard(obj.source.make_guard(GuardBuilder.LIST_LENGTH))
|
||||||
if cls is SetVariable:
|
if cls is SetVariable:
|
||||||
return cls(
|
return cls(
|
||||||
list(obj.unpack_var_sequence(tx)),
|
list(obj.unpack_var_sequence(tx)),
|
||||||
mutable_local=MutableLocal(),
|
mutable_local=MutableLocal(),
|
||||||
guards=guards,
|
|
||||||
).add_options(self, obj)
|
).add_options(self, obj)
|
||||||
|
|
||||||
return cls(
|
return cls(
|
||||||
list(obj.unpack_var_sequence(tx)),
|
list(obj.unpack_var_sequence(tx)),
|
||||||
mutable_local=MutableLocal(),
|
mutable_local=MutableLocal(),
|
||||||
guards=guards,
|
|
||||||
).add_options(self, obj)
|
).add_options(self, obj)
|
||||||
|
|
||||||
call_iter = _call_iter_tuple_list
|
call_iter = _call_iter_tuple_list
|
||||||
@ -1060,7 +1058,6 @@ class BuiltinVariable(VariableTracker):
|
|||||||
from .builder import SourcelessBuilder, VariableBuilder
|
from .builder import SourcelessBuilder, VariableBuilder
|
||||||
|
|
||||||
options = VariableTracker.propagate(self, obj, name_var)
|
options = VariableTracker.propagate(self, obj, name_var)
|
||||||
guards = options["guards"]
|
|
||||||
name = name_var.as_python_constant()
|
name = name_var.as_python_constant()
|
||||||
|
|
||||||
if not name_var.is_python_constant():
|
if not name_var.is_python_constant():
|
||||||
@ -1075,10 +1072,9 @@ class BuiltinVariable(VariableTracker):
|
|||||||
|
|
||||||
if default is not None:
|
if default is not None:
|
||||||
hasattr_var = self.call_hasattr(tx, obj, name_var)
|
hasattr_var = self.call_hasattr(tx, obj, name_var)
|
||||||
guards.update(hasattr_var.guards)
|
|
||||||
assert hasattr_var.as_python_constant() in (True, False)
|
assert hasattr_var.as_python_constant() in (True, False)
|
||||||
if not hasattr_var.as_python_constant():
|
if not hasattr_var.as_python_constant():
|
||||||
return default.add_guards(guards)
|
return default
|
||||||
|
|
||||||
if obj.source:
|
if obj.source:
|
||||||
source = AttrSource(obj.source, name)
|
source = AttrSource(obj.source, name)
|
||||||
@ -1152,14 +1148,14 @@ class BuiltinVariable(VariableTracker):
|
|||||||
elif ConstantVariable.is_literal(member):
|
elif ConstantVariable.is_literal(member):
|
||||||
return ConstantVariable.create(member, **options)
|
return ConstantVariable.create(member, **options)
|
||||||
else:
|
else:
|
||||||
return VariableBuilder(tx, source)(member).add_guards(guards)
|
return VariableBuilder(tx, source)(member)
|
||||||
elif isinstance(obj, (PythonModuleVariable, DummyModule)):
|
elif isinstance(obj, (PythonModuleVariable, DummyModule)):
|
||||||
member = obj.value.__dict__[name]
|
member = obj.value.__dict__[name]
|
||||||
|
|
||||||
if config.replay_record_enabled:
|
if config.replay_record_enabled:
|
||||||
tx.exec_recorder.record_module_access(obj.value, name, member)
|
tx.exec_recorder.record_module_access(obj.value, name, member)
|
||||||
|
|
||||||
return VariableBuilder(tx, source)(member).add_guards(guards)
|
return VariableBuilder(tx, source)(member)
|
||||||
elif istype(obj, UserFunctionVariable) and name in ("__name__", "__module__"):
|
elif istype(obj, UserFunctionVariable) and name in ("__name__", "__module__"):
|
||||||
return ConstantVariable.create(
|
return ConstantVariable.create(
|
||||||
getattr(obj.fn, name), **VariableTracker.propagate(obj)
|
getattr(obj.fn, name), **VariableTracker.propagate(obj)
|
||||||
|
|||||||
@ -6,7 +6,7 @@ from torch._dynamo.source import GetItemSource
|
|||||||
|
|
||||||
from .. import variables
|
from .. import variables
|
||||||
from ..exc import unimplemented, UserError, UserErrorType
|
from ..exc import unimplemented, UserError, UserErrorType
|
||||||
from ..guards import GuardBuilder
|
from ..guards import GuardBuilder, install_guard
|
||||||
from ..utils import np
|
from ..utils import np
|
||||||
from .base import typestr, VariableTracker
|
from .base import typestr, VariableTracker
|
||||||
|
|
||||||
@ -41,21 +41,15 @@ class ConstantVariable(VariableTracker):
|
|||||||
items = []
|
items = []
|
||||||
for i, x in enumerate(value):
|
for i, x in enumerate(value):
|
||||||
item_source = GetItemSource(source, i) if source else None
|
item_source = GetItemSource(source, i) if source else None
|
||||||
guards = (
|
if item_source:
|
||||||
{item_source.make_guard(GuardBuilder.CONSTANT_MATCH)}
|
install_guard(item_source.make_guard(GuardBuilder.CONSTANT_MATCH))
|
||||||
if item_source
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
items.append(
|
items.append(
|
||||||
ConstantVariable.create(
|
ConstantVariable.create(
|
||||||
x,
|
x,
|
||||||
source=item_source,
|
source=item_source,
|
||||||
guards=guards,
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
return variables.BaseListVariable.cls_for(type(value))(
|
return variables.BaseListVariable.cls_for(type(value))(items, **kwargs)
|
||||||
items, regen_guards=True, **kwargs
|
|
||||||
)
|
|
||||||
|
|
||||||
return ConstantVariable(value, **kwargs)
|
return ConstantVariable(value, **kwargs)
|
||||||
|
|
||||||
|
|||||||
@ -9,7 +9,7 @@ from .. import variables
|
|||||||
from ..bytecode_transformation import create_call_function, create_instruction
|
from ..bytecode_transformation import create_call_function, create_instruction
|
||||||
from ..device_interface import get_interface_for_device
|
from ..device_interface import get_interface_for_device
|
||||||
from ..exc import unimplemented, Unsupported
|
from ..exc import unimplemented, Unsupported
|
||||||
from ..guards import GuardBuilder
|
from ..guards import GuardBuilder, install_guard
|
||||||
from ..source import AttrSource, GlobalStateSource
|
from ..source import AttrSource, GlobalStateSource
|
||||||
from .base import VariableTracker
|
from .base import VariableTracker
|
||||||
from .functions import (
|
from .functions import (
|
||||||
@ -161,7 +161,7 @@ class GenericContextWrappingVariable(ContextWrappingVariable):
|
|||||||
class GradModeVariable(ContextWrappingVariable):
|
class GradModeVariable(ContextWrappingVariable):
|
||||||
"""represents torch.{no_grad,enable_grad,set_grad_mode}()"""
|
"""represents torch.{no_grad,enable_grad,set_grad_mode}()"""
|
||||||
|
|
||||||
_guards_singleton = {Guard(GlobalStateSource(), GuardBuilder.GRAD_MODE)}
|
_guards_singleton = Guard(GlobalStateSource(), GuardBuilder.GRAD_MODE)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def create(tx, target_value, initialized=True, **kwargs):
|
def create(tx, target_value, initialized=True, **kwargs):
|
||||||
@ -179,8 +179,8 @@ class GradModeVariable(ContextWrappingVariable):
|
|||||||
super().__init__(
|
super().__init__(
|
||||||
target_values=target_values, initial_values=initial_values, **kwargs
|
target_values=target_values, initial_values=initial_values, **kwargs
|
||||||
)
|
)
|
||||||
self.guards = self.guards | self._guards_singleton
|
|
||||||
self.initialized = initialized
|
self.initialized = initialized
|
||||||
|
install_guard(self._guards_singleton)
|
||||||
|
|
||||||
def enter(self, tx):
|
def enter(self, tx):
|
||||||
if not self.initialized:
|
if not self.initialized:
|
||||||
@ -263,7 +263,7 @@ class InferenceModeVariable(ContextWrappingVariable):
|
|||||||
class TorchFunctionDisableVariable(ContextWrappingVariable):
|
class TorchFunctionDisableVariable(ContextWrappingVariable):
|
||||||
"""represents whether torch function overrides are enabled or not"""
|
"""represents whether torch function overrides are enabled or not"""
|
||||||
|
|
||||||
_guards_singleton = {Guard(GlobalStateSource(), GuardBuilder.TORCH_FUNCTION_STATE)}
|
_guards_singleton = Guard(GlobalStateSource(), GuardBuilder.TORCH_FUNCTION_STATE)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def create(tx, **kwargs):
|
def create(tx, **kwargs):
|
||||||
@ -281,7 +281,7 @@ class TorchFunctionDisableVariable(ContextWrappingVariable):
|
|||||||
super().__init__(
|
super().__init__(
|
||||||
target_values=target_values, initial_values=initial_values, **kwargs
|
target_values=target_values, initial_values=initial_values, **kwargs
|
||||||
)
|
)
|
||||||
self.guards = self.guards | self._guards_singleton
|
install_guard(self._guards_singleton)
|
||||||
|
|
||||||
def enter(self, tx):
|
def enter(self, tx):
|
||||||
return variables.ConstantVariable.create(
|
return variables.ConstantVariable.create(
|
||||||
@ -296,9 +296,9 @@ class TorchFunctionDisableVariable(ContextWrappingVariable):
|
|||||||
class DeterministicAlgorithmsVariable(ContextWrappingVariable):
|
class DeterministicAlgorithmsVariable(ContextWrappingVariable):
|
||||||
"""represents torch.{are_deterministic_algorithms_enabled,use_deterministic_algorithms}()"""
|
"""represents torch.{are_deterministic_algorithms_enabled,use_deterministic_algorithms}()"""
|
||||||
|
|
||||||
_guards_singleton = {
|
_guards_singleton = Guard(
|
||||||
Guard(GlobalStateSource(), GuardBuilder.DETERMINISTIC_ALGORITHMS)
|
GlobalStateSource(), GuardBuilder.DETERMINISTIC_ALGORITHMS
|
||||||
}
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def create(tx, target_value, **kwargs):
|
def create(tx, target_value, **kwargs):
|
||||||
@ -315,7 +315,7 @@ class DeterministicAlgorithmsVariable(ContextWrappingVariable):
|
|||||||
super().__init__(
|
super().__init__(
|
||||||
target_values=target_values, initial_values=initial_values, **kwargs
|
target_values=target_values, initial_values=initial_values, **kwargs
|
||||||
)
|
)
|
||||||
self.guards = self.guards | self._guards_singleton
|
install_guard(self._guards_singleton)
|
||||||
|
|
||||||
def enter(self, tx):
|
def enter(self, tx):
|
||||||
return variables.ConstantVariable.create(
|
return variables.ConstantVariable.create(
|
||||||
|
|||||||
@ -13,7 +13,7 @@ from ..bytecode_transformation import create_call_function, create_instruction
|
|||||||
from ..eval_frame import skip_code
|
from ..eval_frame import skip_code
|
||||||
|
|
||||||
from ..exc import unimplemented
|
from ..exc import unimplemented
|
||||||
from ..guards import GuardBuilder, make_dupe_guard
|
from ..guards import GuardBuilder, install_guard, make_dupe_guard
|
||||||
from ..source import AttrSource, GetItemSource, GlobalWeakRefSource
|
from ..source import AttrSource, GetItemSource, GlobalWeakRefSource
|
||||||
from ..utils import global_key_name, istensor, iter_contains
|
from ..utils import global_key_name, istensor, iter_contains
|
||||||
from .base import MutableLocal, VariableTracker
|
from .base import MutableLocal, VariableTracker
|
||||||
@ -24,10 +24,8 @@ from .tensor import TensorVariable
|
|||||||
class ConstDictVariable(VariableTracker):
|
class ConstDictVariable(VariableTracker):
|
||||||
def __init__(self, items, user_cls, **kwargs):
|
def __init__(self, items, user_cls, **kwargs):
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
# All the keys are constants
|
# All the keys are constants
|
||||||
assert not any(isinstance(x, VariableTracker) for x in items)
|
assert not any(isinstance(x, VariableTracker) for x in items)
|
||||||
self.guards.update(VariableTracker.propagate(items.values())["guards"])
|
|
||||||
self.items = items
|
self.items = items
|
||||||
self.user_cls = user_cls
|
self.user_cls = user_cls
|
||||||
|
|
||||||
@ -298,7 +296,6 @@ class SetVariable(VariableTracker):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
items: List[VariableTracker],
|
items: List[VariableTracker],
|
||||||
regen_guards=True,
|
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
@ -309,10 +306,6 @@ class SetVariable(VariableTracker):
|
|||||||
self.items = []
|
self.items = []
|
||||||
self._add(items)
|
self._add(items)
|
||||||
|
|
||||||
# Sometimes, we know that we have passed in the guards from the items in the set
|
|
||||||
if regen_guards:
|
|
||||||
self.guards.update(VariableTracker.propagate(items)["guards"])
|
|
||||||
|
|
||||||
def as_proxy(self):
|
def as_proxy(self):
|
||||||
return [x.as_proxy() for x in self.items]
|
return [x.as_proxy() for x in self.items]
|
||||||
|
|
||||||
@ -378,9 +371,7 @@ class SetVariable(VariableTracker):
|
|||||||
e.vt.source, set_element.vt.source
|
e.vt.source, set_element.vt.source
|
||||||
)
|
)
|
||||||
if alias_guard:
|
if alias_guard:
|
||||||
e.vt = e.vt.add_guards(
|
install_guard(e.vt.source.make_guard(alias_guard))
|
||||||
{e.vt.source.make_guard(alias_guard)}
|
|
||||||
)
|
|
||||||
|
|
||||||
return self.items
|
return self.items
|
||||||
|
|
||||||
@ -401,7 +392,6 @@ class SetVariable(VariableTracker):
|
|||||||
result = SetVariable(
|
result = SetVariable(
|
||||||
self._add(item),
|
self._add(item),
|
||||||
mutable_local=self.mutable_local,
|
mutable_local=self.mutable_local,
|
||||||
regen_guards=False,
|
|
||||||
**options,
|
**options,
|
||||||
)
|
)
|
||||||
tx.replace_all(self, result)
|
tx.replace_all(self, result)
|
||||||
@ -413,7 +403,7 @@ class SetVariable(VariableTracker):
|
|||||||
result = items.pop()
|
result = items.pop()
|
||||||
tx.replace_all(
|
tx.replace_all(
|
||||||
self,
|
self,
|
||||||
SetVariable(items, regen_guards=False, **options),
|
SetVariable(items, **options),
|
||||||
)
|
)
|
||||||
return result
|
return result
|
||||||
elif name == "__len__":
|
elif name == "__len__":
|
||||||
@ -797,46 +787,40 @@ class PythonSysModulesVariable(VariableTracker):
|
|||||||
def _contains_helper(self, tx, key: VariableTracker):
|
def _contains_helper(self, tx, key: VariableTracker):
|
||||||
k = ConstDictVariable.get_key(key)
|
k = ConstDictVariable.get_key(key)
|
||||||
has_key = k in sys.modules
|
has_key = k in sys.modules
|
||||||
guard = self.make_guard(
|
install_guard(
|
||||||
functools.partial(GuardBuilder.DICT_CONTAINS, key=k, invert=not has_key)
|
self.make_guard(
|
||||||
|
functools.partial(GuardBuilder.DICT_CONTAINS, key=k, invert=not has_key)
|
||||||
|
)
|
||||||
)
|
)
|
||||||
guards = {*self.guards, guard}
|
return k, has_key
|
||||||
return k, has_key, guards
|
|
||||||
|
|
||||||
def call_contains(self, tx, key: VariableTracker):
|
def call_contains(self, tx, key: VariableTracker):
|
||||||
k, has_key, guards = self._contains_helper(tx, key)
|
k, has_key = self._contains_helper(tx, key)
|
||||||
return ConstantVariable.create(
|
return ConstantVariable.create(value=has_key)
|
||||||
value=has_key,
|
|
||||||
guards=guards,
|
|
||||||
)
|
|
||||||
|
|
||||||
def call_get(
|
def call_get(
|
||||||
self, tx, key: VariableTracker, default: Optional[VariableTracker] = None
|
self, tx, key: VariableTracker, default: Optional[VariableTracker] = None
|
||||||
):
|
):
|
||||||
from .builder import VariableBuilder
|
from .builder import VariableBuilder
|
||||||
|
|
||||||
k, has_key, guards = self._contains_helper(tx, key)
|
k, has_key = self._contains_helper(tx, key)
|
||||||
|
|
||||||
if has_key:
|
if has_key:
|
||||||
return VariableBuilder(
|
return VariableBuilder(
|
||||||
tx,
|
tx,
|
||||||
GetItemSource(self.source, k),
|
GetItemSource(self.source, k),
|
||||||
)(
|
)(sys.modules[k])
|
||||||
sys.modules[k]
|
|
||||||
).add_guards(guards)
|
|
||||||
|
|
||||||
if default is not None:
|
if default is not None:
|
||||||
return default.add_guards(guards)
|
return default
|
||||||
|
|
||||||
return ConstantVariable.create(value=None, guards=guards)
|
return ConstantVariable.create(value=None)
|
||||||
|
|
||||||
def call_getitem(self, tx, key: VariableTracker):
|
def call_getitem(self, tx, key: VariableTracker):
|
||||||
from .builder import VariableBuilder
|
from .builder import VariableBuilder
|
||||||
|
|
||||||
k, has_key, guards = self._contains_helper(tx, key)
|
k, has_key = self._contains_helper(tx, key)
|
||||||
return VariableBuilder(
|
return VariableBuilder(
|
||||||
tx,
|
tx,
|
||||||
GetItemSource(self.source, k),
|
GetItemSource(self.source, k),
|
||||||
)(
|
)(sys.modules[k])
|
||||||
sys.modules[k]
|
|
||||||
).add_guards(guards)
|
|
||||||
|
|||||||
@ -614,12 +614,6 @@ class FunctoolsPartialVariable(VariableTracker):
|
|||||||
self.keywords = keywords
|
self.keywords = keywords
|
||||||
self.original = original
|
self.original = original
|
||||||
|
|
||||||
self.guards.update(VariableTracker.propagate(func)["guards"])
|
|
||||||
for arg in args:
|
|
||||||
self.guards.update(VariableTracker.propagate(arg)["guards"])
|
|
||||||
for val in keywords.values():
|
|
||||||
self.guards.update(VariableTracker.propagate(val)["guards"])
|
|
||||||
|
|
||||||
def call_function(
|
def call_function(
|
||||||
self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
|
self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
|
||||||
) -> "VariableTracker":
|
) -> "VariableTracker":
|
||||||
|
|||||||
@ -25,7 +25,6 @@ from ..exc import (
|
|||||||
UserError,
|
UserError,
|
||||||
UserErrorType,
|
UserErrorType,
|
||||||
)
|
)
|
||||||
from ..guards import GuardBuilder
|
|
||||||
from ..source import FSDPNNModuleSource, GetItemSource, NNModuleSource
|
from ..source import FSDPNNModuleSource, GetItemSource, NNModuleSource
|
||||||
from ..utils import proxy_args_kwargs
|
from ..utils import proxy_args_kwargs
|
||||||
from .dicts import ConstDictVariable
|
from .dicts import ConstDictVariable
|
||||||
@ -100,9 +99,6 @@ def validate_args_and_maybe_create_graph_inputs(
|
|||||||
assert isinstance(a, VariableTracker)
|
assert isinstance(a, VariableTracker)
|
||||||
|
|
||||||
if isinstance(a, ConstantVariable):
|
if isinstance(a, ConstantVariable):
|
||||||
# Ensures that we recompile when the constant value changes
|
|
||||||
a.add_guard(GuardBuilder.CONSTANT_MATCH)
|
|
||||||
|
|
||||||
if manually_set_subgraph_inputs:
|
if manually_set_subgraph_inputs:
|
||||||
# This arg is not used in the body of the higher order op.
|
# This arg is not used in the body of the higher order op.
|
||||||
# Currently, this new input is added to make the calls
|
# Currently, this new input is added to make the calls
|
||||||
@ -194,6 +190,11 @@ def speculate_subgraph(
|
|||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
f, sub_args, sub_kwargs = VariableTracker.apply(
|
||||||
|
# ensure guards on args get installed in parent subgraph
|
||||||
|
lambda x: x.realize(),
|
||||||
|
(f, sub_args, sub_kwargs),
|
||||||
|
)
|
||||||
with tx.output.subtracer(source_target, tracer) as subtracer:
|
with tx.output.subtracer(source_target, tracer) as subtracer:
|
||||||
args = validate_args_and_maybe_create_graph_inputs(
|
args = validate_args_and_maybe_create_graph_inputs(
|
||||||
sub_args, subtracer, tx, manually_set_subgraph_inputs
|
sub_args, subtracer, tx, manually_set_subgraph_inputs
|
||||||
@ -247,7 +248,6 @@ def speculate_subgraph(
|
|||||||
"HigherOrderOperator body's output must consist of tensors only"
|
"HigherOrderOperator body's output must consist of tensors only"
|
||||||
)
|
)
|
||||||
|
|
||||||
tx.output.guards.update(output.guards)
|
|
||||||
# The output proxies might not belong to this SubgraphTracer
|
# The output proxies might not belong to this SubgraphTracer
|
||||||
# (if they are free variables that were never lifted)
|
# (if they are free variables that were never lifted)
|
||||||
# so lift them here.
|
# so lift them here.
|
||||||
@ -411,7 +411,6 @@ class CondHigherOrderVariable(TorchHigherOrderOperatorVariable):
|
|||||||
f"item but got {str(type(args[0]))} "
|
f"item but got {str(type(args[0]))} "
|
||||||
f"with original python type {str(args[0].python_type())}.",
|
f"with original python type {str(args[0].python_type())}.",
|
||||||
)
|
)
|
||||||
tx.output.guards.update(args[0].guards)
|
|
||||||
|
|
||||||
# operands
|
# operands
|
||||||
if not isinstance(args[3], (ListVariable, TupleVariable)):
|
if not isinstance(args[3], (ListVariable, TupleVariable)):
|
||||||
@ -1116,6 +1115,7 @@ class AutogradFunctionMethodHigherOrderVariable(TorchHigherOrderOperatorVariable
|
|||||||
else:
|
else:
|
||||||
fn = TorchVariable(self.value)
|
fn = TorchVariable(self.value)
|
||||||
checkpoint = tx.copy_graphstate()
|
checkpoint = tx.copy_graphstate()
|
||||||
|
# TODO(jansel): BUG!!! we aren't copying on the line below, so the post-pre check below is pointless
|
||||||
pre_guards = tx.output.guards
|
pre_guards = tx.output.guards
|
||||||
graph_checkpoint = tx.output.graph
|
graph_checkpoint = tx.output.graph
|
||||||
|
|
||||||
|
|||||||
@ -23,7 +23,6 @@ class LazyCache:
|
|||||||
self.vt.parents_tracker.add(parents_tracker)
|
self.vt.parents_tracker.add(parents_tracker)
|
||||||
del self.value
|
del self.value
|
||||||
del self.source
|
del self.source
|
||||||
tx.output.guards.update(self.vt.guards)
|
|
||||||
|
|
||||||
|
|
||||||
class LazyVariableTracker(VariableTracker):
|
class LazyVariableTracker(VariableTracker):
|
||||||
@ -79,8 +78,6 @@ class LazyVariableTracker(VariableTracker):
|
|||||||
return getattr(self.realize(), item)
|
return getattr(self.realize(), item)
|
||||||
|
|
||||||
# most methods are auto-generated below, these are the ones we want to exclude
|
# most methods are auto-generated below, these are the ones we want to exclude
|
||||||
add_guards = VariableTracker.add_guards
|
|
||||||
add_guard = VariableTracker.add_guard
|
|
||||||
add_options = VariableTracker.add_options
|
add_options = VariableTracker.add_options
|
||||||
apply = VariableTracker.apply
|
apply = VariableTracker.apply
|
||||||
copy = VariableTracker.copy
|
copy = VariableTracker.copy
|
||||||
|
|||||||
@ -48,16 +48,11 @@ class BaseListVariable(VariableTracker):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
items: List[VariableTracker],
|
items: List[VariableTracker],
|
||||||
regen_guards=True,
|
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
assert isinstance(items, list)
|
assert isinstance(items, list)
|
||||||
assert all(isinstance(x, VariableTracker) for x in items)
|
assert all(isinstance(x, VariableTracker) for x in items)
|
||||||
# Sometimes, we know that we have passed in the guards from the items in the list
|
|
||||||
if regen_guards:
|
|
||||||
self.guards.update(VariableTracker.propagate(items)["guards"])
|
|
||||||
|
|
||||||
self.items: List[VariableTracker] = items
|
self.items: List[VariableTracker] = items
|
||||||
|
|
||||||
def _as_proxy(self):
|
def _as_proxy(self):
|
||||||
@ -246,7 +241,6 @@ class CommonListMethodsVariable(BaseListVariable):
|
|||||||
self,
|
self,
|
||||||
type(self)(
|
type(self)(
|
||||||
self.items + [arg],
|
self.items + [arg],
|
||||||
regen_guards=False,
|
|
||||||
**options,
|
**options,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
@ -263,7 +257,6 @@ class CommonListMethodsVariable(BaseListVariable):
|
|||||||
self,
|
self,
|
||||||
type(self)(
|
type(self)(
|
||||||
list(self.items) + list(arg.unpack_var_sequence(tx)),
|
list(self.items) + list(arg.unpack_var_sequence(tx)),
|
||||||
regen_guards=False,
|
|
||||||
**options,
|
**options,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
@ -274,7 +267,7 @@ class CommonListMethodsVariable(BaseListVariable):
|
|||||||
items.insert(idx.as_python_constant(), value)
|
items.insert(idx.as_python_constant(), value)
|
||||||
return tx.replace_all(
|
return tx.replace_all(
|
||||||
self,
|
self,
|
||||||
type(self)(items, regen_guards=False, **options),
|
type(self)(items, **options),
|
||||||
)
|
)
|
||||||
elif name == "pop" and self.mutable_local:
|
elif name == "pop" and self.mutable_local:
|
||||||
assert not kwargs
|
assert not kwargs
|
||||||
@ -282,14 +275,14 @@ class CommonListMethodsVariable(BaseListVariable):
|
|||||||
result = items.pop(*[a.as_python_constant() for a in args])
|
result = items.pop(*[a.as_python_constant() for a in args])
|
||||||
tx.replace_all(
|
tx.replace_all(
|
||||||
self,
|
self,
|
||||||
type(self)(items, regen_guards=False, **options),
|
type(self)(items, **options),
|
||||||
)
|
)
|
||||||
return result
|
return result
|
||||||
elif name == "clear" and self.mutable_local:
|
elif name == "clear" and self.mutable_local:
|
||||||
assert not kwargs and not args
|
assert not kwargs and not args
|
||||||
return tx.replace_all(
|
return tx.replace_all(
|
||||||
self,
|
self,
|
||||||
type(self)([], regen_guards=False, **options),
|
type(self)([], **options),
|
||||||
)
|
)
|
||||||
elif (
|
elif (
|
||||||
name == "__setitem__"
|
name == "__setitem__"
|
||||||
@ -304,16 +297,14 @@ class CommonListMethodsVariable(BaseListVariable):
|
|||||||
items[key.as_python_constant()] = list(value.items)
|
items[key.as_python_constant()] = list(value.items)
|
||||||
else:
|
else:
|
||||||
items[key.as_python_constant()] = value
|
items[key.as_python_constant()] = value
|
||||||
result = ListVariable(items, regen_guards=False, **options)
|
result = ListVariable(items, **options)
|
||||||
return tx.replace_all(self, result)
|
return tx.replace_all(self, result)
|
||||||
elif name == "copy":
|
elif name == "copy":
|
||||||
# List copy() doesn't have args and kwargs
|
# List copy() doesn't have args and kwargs
|
||||||
assert not kwargs
|
assert not kwargs
|
||||||
assert not args
|
assert not args
|
||||||
items = list(self.items)
|
items = list(self.items)
|
||||||
return type(self)(
|
return type(self)(items, mutable_local=MutableLocal(), **options)
|
||||||
items, regen_guards=False, mutable_local=MutableLocal(), **options
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
return super().call_method(tx, name, args, kwargs)
|
return super().call_method(tx, name, args, kwargs)
|
||||||
|
|
||||||
@ -351,7 +342,7 @@ class ListVariable(CommonListMethodsVariable):
|
|||||||
items[key.as_python_constant()] = value.unpack_var_sequence(tx)
|
items[key.as_python_constant()] = value.unpack_var_sequence(tx)
|
||||||
else:
|
else:
|
||||||
items[key.as_python_constant()] = value
|
items[key.as_python_constant()] = value
|
||||||
result = ListVariable(items, regen_guards=False, **options)
|
result = ListVariable(items, **options)
|
||||||
return tx.replace_all(self, result)
|
return tx.replace_all(self, result)
|
||||||
else:
|
else:
|
||||||
return super().call_method(tx, name, args, kwargs)
|
return super().call_method(tx, name, args, kwargs)
|
||||||
@ -396,7 +387,7 @@ class DequeVariable(CommonListMethodsVariable):
|
|||||||
)
|
)
|
||||||
items = list(self.items)
|
items = list(self.items)
|
||||||
items[key.as_python_constant()] = value
|
items[key.as_python_constant()] = value
|
||||||
result = DequeVariable(items, regen_guards=False, **options)
|
result = DequeVariable(items, **options)
|
||||||
return tx.replace_all(self, result)
|
return tx.replace_all(self, result)
|
||||||
elif name == "extendleft" and self.mutable_local:
|
elif name == "extendleft" and self.mutable_local:
|
||||||
assert not kwargs
|
assert not kwargs
|
||||||
@ -405,7 +396,6 @@ class DequeVariable(CommonListMethodsVariable):
|
|||||||
self,
|
self,
|
||||||
DequeVariable(
|
DequeVariable(
|
||||||
list(arg.unpack_var_sequence(tx)) + list(self.items),
|
list(arg.unpack_var_sequence(tx)) + list(self.items),
|
||||||
regen_guards=False,
|
|
||||||
**options,
|
**options,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
@ -416,7 +406,7 @@ class DequeVariable(CommonListMethodsVariable):
|
|||||||
result = items.popleft()
|
result = items.popleft()
|
||||||
tx.replace_all(
|
tx.replace_all(
|
||||||
self,
|
self,
|
||||||
DequeVariable(list(items), regen_guards=False, **options),
|
DequeVariable(list(items), **options),
|
||||||
)
|
)
|
||||||
return result
|
return result
|
||||||
elif name == "appendleft" and self.mutable_local:
|
elif name == "appendleft" and self.mutable_local:
|
||||||
@ -425,7 +415,6 @@ class DequeVariable(CommonListMethodsVariable):
|
|||||||
self,
|
self,
|
||||||
DequeVariable(
|
DequeVariable(
|
||||||
[args[0]] + list(self.items),
|
[args[0]] + list(self.items),
|
||||||
regen_guards=False,
|
|
||||||
**options,
|
**options,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|||||||
@ -13,7 +13,7 @@ import torch._numpy as tnp
|
|||||||
from .. import config, polyfill, variables
|
from .. import config, polyfill, variables
|
||||||
from ..bytecode_transformation import create_call_function, create_instruction
|
from ..bytecode_transformation import create_call_function, create_instruction
|
||||||
from ..exc import unimplemented
|
from ..exc import unimplemented
|
||||||
from ..guards import GuardBuilder
|
from ..guards import GuardBuilder, install_guard
|
||||||
from ..source import AttrSource, GetItemSource, ODictGetItemSource, TypeSource
|
from ..source import AttrSource, GetItemSource, ODictGetItemSource, TypeSource
|
||||||
from ..utils import (
|
from ..utils import (
|
||||||
check_constant_args,
|
check_constant_args,
|
||||||
@ -97,9 +97,8 @@ class SuperVariable(VariableTracker):
|
|||||||
return GetAttrVariable(self, name, **options)
|
return GetAttrVariable(self, name, **options)
|
||||||
if source:
|
if source:
|
||||||
options["source"] = source
|
options["source"] = source
|
||||||
return variables.ConstantVariable.create(value, **options).add_guard(
|
install_guard(source.make_guard(GuardBuilder.CONSTANT_MATCH))
|
||||||
source.make_guard(GuardBuilder.CONSTANT_MATCH)
|
return variables.ConstantVariable.create(value, **options)
|
||||||
)
|
|
||||||
return variables.ConstantVariable.create(value, **options)
|
return variables.ConstantVariable.create(value, **options)
|
||||||
|
|
||||||
def call_method(
|
def call_method(
|
||||||
|
|||||||
@ -10,7 +10,7 @@ import torch.nn
|
|||||||
from .. import skipfiles, variables
|
from .. import skipfiles, variables
|
||||||
from ..allowed_functions import is_allowed
|
from ..allowed_functions import is_allowed
|
||||||
from ..exc import unimplemented, UnspecializeRestartAnalysis, Unsupported
|
from ..exc import unimplemented, UnspecializeRestartAnalysis, Unsupported
|
||||||
from ..guards import GuardBuilder
|
from ..guards import GuardBuilder, install_guard
|
||||||
from ..mutation_guard import GenerationTracker
|
from ..mutation_guard import GenerationTracker
|
||||||
from ..source import (
|
from ..source import (
|
||||||
AttrSource,
|
AttrSource,
|
||||||
@ -127,11 +127,12 @@ class NNModuleVariable(VariableTracker):
|
|||||||
options = VariableTracker.propagate(self)
|
options = VariableTracker.propagate(self)
|
||||||
mod = tx.output.get_submodule(self.module_key)
|
mod = tx.output.get_submodule(self.module_key)
|
||||||
result = hasattr(mod, name)
|
result = hasattr(mod, name)
|
||||||
return variables.ConstantVariable.create(result, **options).add_guard(
|
install_guard(
|
||||||
NNModuleSource(AttrSource(self.source, name)).make_guard(
|
NNModuleSource(AttrSource(self.source, name)).make_guard(
|
||||||
GuardBuilder.HASATTR
|
GuardBuilder.HASATTR
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
return variables.ConstantVariable.create(result, **options)
|
||||||
|
|
||||||
def is_training(self, tx):
|
def is_training(self, tx):
|
||||||
mod = tx.output.get_submodule(self.module_key)
|
mod = tx.output.get_submodule(self.module_key)
|
||||||
@ -167,7 +168,6 @@ class NNModuleVariable(VariableTracker):
|
|||||||
from .builder import VariableBuilder
|
from .builder import VariableBuilder
|
||||||
|
|
||||||
options = VariableTracker.propagate(self)
|
options = VariableTracker.propagate(self)
|
||||||
guards = options.get("guards", set())
|
|
||||||
|
|
||||||
if self.source:
|
if self.source:
|
||||||
source = AttrSource(self.source, name)
|
source = AttrSource(self.source, name)
|
||||||
@ -220,13 +220,12 @@ class NNModuleVariable(VariableTracker):
|
|||||||
if istype(subobj, property):
|
if istype(subobj, property):
|
||||||
return variables.UserFunctionVariable(
|
return variables.UserFunctionVariable(
|
||||||
subobj.fget,
|
subobj.fget,
|
||||||
guards=guards,
|
|
||||||
source=source,
|
source=source,
|
||||||
).call_function(tx, [(self)], {})
|
).call_function(tx, [(self)], {})
|
||||||
elif istype(subobj, classmethod):
|
elif istype(subobj, classmethod):
|
||||||
return variables.UserMethodVariable(
|
return variables.UserMethodVariable(
|
||||||
subobj.__func__,
|
subobj.__func__,
|
||||||
variables.UserDefinedObjectVariable(type(base), guards=guards),
|
variables.UserDefinedObjectVariable(type(base)),
|
||||||
**options,
|
**options,
|
||||||
)
|
)
|
||||||
elif istype(subobj, staticmethod):
|
elif istype(subobj, staticmethod):
|
||||||
@ -616,7 +615,7 @@ class NNModuleVariable(VariableTracker):
|
|||||||
):
|
):
|
||||||
# Inline the function
|
# Inline the function
|
||||||
fn = getattr(module, name).__func__
|
fn = getattr(module, name).__func__
|
||||||
fn_source = AttrSource(self.source, "__func__")
|
fn_source = AttrSource(AttrSource(self.source, name), "__func__")
|
||||||
options["source"] = fn_source
|
options["source"] = fn_source
|
||||||
return tx.inline_user_function_return(
|
return tx.inline_user_function_return(
|
||||||
variables.UserFunctionVariable(fn, **options),
|
variables.UserFunctionVariable(fn, **options),
|
||||||
@ -759,7 +758,7 @@ class UnspecializedNNModuleVariable(UserDefinedObjectVariable):
|
|||||||
assert not args or kwargs
|
assert not args or kwargs
|
||||||
if tx.output.side_effects.has_pending_mutation(self):
|
if tx.output.side_effects.has_pending_mutation(self):
|
||||||
unimplemented("Module.parameters() with pending mutation")
|
unimplemented("Module.parameters() with pending mutation")
|
||||||
options["guards"].add(
|
install_guard(
|
||||||
self.source.make_guard(GuardBuilder.NN_MODULE_PARAM_NAMES)
|
self.source.make_guard(GuardBuilder.NN_MODULE_PARAM_NAMES)
|
||||||
)
|
)
|
||||||
items = []
|
items = []
|
||||||
|
|||||||
@ -4,7 +4,7 @@ from typing import Dict, List
|
|||||||
import torch
|
import torch
|
||||||
from ..decorators import mark_static_address
|
from ..decorators import mark_static_address
|
||||||
|
|
||||||
from ..guards import GuardBuilder
|
from ..guards import GuardBuilder, install_guard
|
||||||
from ..source import AttrSource, GetItemSource, GlobalWeakRefSource
|
from ..source import AttrSource, GetItemSource, GlobalWeakRefSource
|
||||||
from ..utils import global_key_name
|
from ..utils import global_key_name
|
||||||
|
|
||||||
@ -126,13 +126,12 @@ class OptimizerVariable(UserDefinedObjectVariable):
|
|||||||
|
|
||||||
# state guards take a long time to generate
|
# state guards take a long time to generate
|
||||||
# so we manually generate them here
|
# so we manually generate them here
|
||||||
guards = set()
|
|
||||||
state_source = AttrSource(self.source, "state")
|
state_source = AttrSource(self.source, "state")
|
||||||
guards.add(state_source.make_guard(GuardBuilder.DICT_KEYS))
|
install_guard(state_source.make_guard(GuardBuilder.DICT_KEYS))
|
||||||
for p, value in self.value.state.items():
|
for p, value in self.value.state.items():
|
||||||
tx.store_global_weakref(global_key_name(p), p)
|
tx.store_global_weakref(global_key_name(p), p)
|
||||||
p_state_source = GetItemSource(state_source, self.tensor_to_source[p])
|
p_state_source = GetItemSource(state_source, self.tensor_to_source[p])
|
||||||
guards.add(p_state_source.make_guard(GuardBuilder.DICT_KEYS))
|
install_guard(p_state_source.make_guard(GuardBuilder.DICT_KEYS))
|
||||||
for k, v in value.items():
|
for k, v in value.items():
|
||||||
if (
|
if (
|
||||||
isinstance(v, torch.Tensor)
|
isinstance(v, torch.Tensor)
|
||||||
@ -141,7 +140,7 @@ class OptimizerVariable(UserDefinedObjectVariable):
|
|||||||
):
|
):
|
||||||
self.tensor_to_source[v] = GetItemSource(p_state_source, k)
|
self.tensor_to_source[v] = GetItemSource(p_state_source, k)
|
||||||
elif v is None or isinstance(v, (bool, int, float, str)):
|
elif v is None or isinstance(v, (bool, int, float, str)):
|
||||||
guards.add(
|
install_guard(
|
||||||
GetItemSource(p_state_source, k).make_guard(
|
GetItemSource(p_state_source, k).make_guard(
|
||||||
GuardBuilder.CONSTANT_MATCH
|
GuardBuilder.CONSTANT_MATCH
|
||||||
)
|
)
|
||||||
@ -149,12 +148,10 @@ class OptimizerVariable(UserDefinedObjectVariable):
|
|||||||
else:
|
else:
|
||||||
raise GuardInstallException()
|
raise GuardInstallException()
|
||||||
|
|
||||||
tx.output.guards.update(guards)
|
# this next line has the side effect of installing guards
|
||||||
|
VariableBuilder(tx, AttrSource(self.source, "param_groups"))(
|
||||||
group_guards = VariableBuilder(tx, AttrSource(self.source, "param_groups"))(
|
|
||||||
self.value.param_groups
|
self.value.param_groups
|
||||||
)
|
).recursive_realize()
|
||||||
tx.output.guards.update(group_guards.guards)
|
|
||||||
|
|
||||||
def wrap_tensor(self, tx, tensor_value):
|
def wrap_tensor(self, tx, tensor_value):
|
||||||
"""Wrap state tensor in a TensorVariable"""
|
"""Wrap state tensor in a TensorVariable"""
|
||||||
|
|||||||
@ -1,4 +1,5 @@
|
|||||||
import functools
|
import functools
|
||||||
|
|
||||||
import inspect
|
import inspect
|
||||||
import operator
|
import operator
|
||||||
import types
|
import types
|
||||||
@ -31,7 +32,7 @@ from .. import config, variables
|
|||||||
from .._trace_wrapped_higher_order_op import trace_wrapped
|
from .._trace_wrapped_higher_order_op import trace_wrapped
|
||||||
|
|
||||||
from ..exc import unimplemented, UserError, UserErrorType
|
from ..exc import unimplemented, UserError, UserErrorType
|
||||||
from ..guards import GuardBuilder
|
from ..guards import GuardBuilder, install_guard
|
||||||
from ..source import AttrSource
|
from ..source import AttrSource
|
||||||
from ..utils import (
|
from ..utils import (
|
||||||
fqn,
|
fqn,
|
||||||
@ -206,12 +207,8 @@ class TensorVariable(VariableTracker):
|
|||||||
from .builder import VariableBuilder
|
from .builder import VariableBuilder
|
||||||
|
|
||||||
attr_source = AttrSource(self.source, name)
|
attr_source = AttrSource(self.source, name)
|
||||||
has_attr_guard = attr_source.make_guard(GuardBuilder.HASATTR)
|
install_guard(attr_source.make_guard(GuardBuilder.HASATTR))
|
||||||
return (
|
return VariableBuilder(tx, attr_source)(real_value).add_options(self)
|
||||||
VariableBuilder(tx, attr_source)(real_value)
|
|
||||||
.add_options(self)
|
|
||||||
.add_guard(has_attr_guard)
|
|
||||||
)
|
|
||||||
|
|
||||||
def var_getattr(self, tx, name):
|
def var_getattr(self, tx, name):
|
||||||
from . import ConstantVariable, TorchVariable
|
from . import ConstantVariable, TorchVariable
|
||||||
@ -254,7 +251,7 @@ class TensorVariable(VariableTracker):
|
|||||||
# In some cases, a <tensor>.<attr> guard can be evaluated first, and break if
|
# In some cases, a <tensor>.<attr> guard can be evaluated first, and break if
|
||||||
# <tensor> is later changed to another type
|
# <tensor> is later changed to another type
|
||||||
if result is not None and self.source is not None:
|
if result is not None and self.source is not None:
|
||||||
result = result.add_guard(self.make_guard(GuardBuilder.TYPE_MATCH))
|
install_guard(self.make_guard(GuardBuilder.TYPE_MATCH))
|
||||||
|
|
||||||
# It's hard to get inplace view (metadata mutation) on graph input work properly across
|
# It's hard to get inplace view (metadata mutation) on graph input work properly across
|
||||||
# dynamo/aot/inductor, just fall back.
|
# dynamo/aot/inductor, just fall back.
|
||||||
@ -607,7 +604,6 @@ class TensorVariable(VariableTracker):
|
|||||||
unimplemented(
|
unimplemented(
|
||||||
"boolean masking setitem backwards requires dynamic shapes"
|
"boolean masking setitem backwards requires dynamic shapes"
|
||||||
)
|
)
|
||||||
tx.output.guards.update(options["guards"])
|
|
||||||
tx.output.create_proxy(
|
tx.output.create_proxy(
|
||||||
"call_function",
|
"call_function",
|
||||||
operator.setitem,
|
operator.setitem,
|
||||||
|
|||||||
@ -8,6 +8,7 @@ import types
|
|||||||
from typing import Dict, List
|
from typing import Dict, List
|
||||||
|
|
||||||
from torch._streambase import _StreamBase
|
from torch._streambase import _StreamBase
|
||||||
|
from ..guards import install_guard
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -159,10 +160,10 @@ class TorchCtxManagerClassVariable(VariableTracker):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def create_with_source(cls, value, source):
|
def create_with_source(cls, value, source):
|
||||||
|
install_guard(source.make_guard(GuardBuilder.FUNCTION_MATCH))
|
||||||
return TorchCtxManagerClassVariable(
|
return TorchCtxManagerClassVariable(
|
||||||
value,
|
value,
|
||||||
source=source,
|
source=source,
|
||||||
guards={source.make_guard(GuardBuilder.FUNCTION_MATCH)},
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def __init__(self, value, **kwargs):
|
def __init__(self, value, **kwargs):
|
||||||
@ -259,6 +260,9 @@ class TorchVariable(VariableTracker):
|
|||||||
except RuntimeError as e:
|
except RuntimeError as e:
|
||||||
assert "No such operator" in str(e), str(e)
|
assert "No such operator" in str(e), str(e)
|
||||||
self_should_be_none = None
|
self_should_be_none = None
|
||||||
|
except AssertionError as e:
|
||||||
|
assert "Unknown attribute" in str(e), str(e)
|
||||||
|
self_should_be_none = None
|
||||||
|
|
||||||
# assert "_ntuple.<locals>.parse" not in str(value)
|
# assert "_ntuple.<locals>.parse" not in str(value)
|
||||||
|
|
||||||
@ -425,18 +429,18 @@ class TorchVariable(VariableTracker):
|
|||||||
return self._call_ntuple(tx, args, kwargs, options)
|
return self._call_ntuple(tx, args, kwargs, options)
|
||||||
elif self.value is torch.is_grad_enabled:
|
elif self.value is torch.is_grad_enabled:
|
||||||
assert not (args or kwargs)
|
assert not (args or kwargs)
|
||||||
return ConstantVariable.create(
|
install_guard(GradModeVariable._guards_singleton)
|
||||||
torch.is_grad_enabled(), **options
|
return ConstantVariable.create(torch.is_grad_enabled(), **options)
|
||||||
).add_guards(GradModeVariable._guards_singleton)
|
|
||||||
elif self.value is torch.use_deterministic_algorithms and len(args) == 1:
|
elif self.value is torch.use_deterministic_algorithms and len(args) == 1:
|
||||||
return DeterministicAlgorithmsVariable.create(
|
return DeterministicAlgorithmsVariable.create(
|
||||||
tx, args[0].as_python_constant(), **options
|
tx, args[0].as_python_constant(), **options
|
||||||
)
|
)
|
||||||
elif self.value is torch.are_deterministic_algorithms_enabled:
|
elif self.value is torch.are_deterministic_algorithms_enabled:
|
||||||
assert not (args or kwargs)
|
assert not (args or kwargs)
|
||||||
|
install_guard(DeterministicAlgorithmsVariable._guards_singleton)
|
||||||
return ConstantVariable.create(
|
return ConstantVariable.create(
|
||||||
torch.are_deterministic_algorithms_enabled(), **options
|
torch.are_deterministic_algorithms_enabled(), **options
|
||||||
).add_guards(DeterministicAlgorithmsVariable._guards_singleton)
|
)
|
||||||
elif self.value is torch.autograd.graph.disable_saved_tensors_hooks:
|
elif self.value is torch.autograd.graph.disable_saved_tensors_hooks:
|
||||||
assert len(args) == 1
|
assert len(args) == 1
|
||||||
return DisabledSavedTensorsHooksVariable.create(
|
return DisabledSavedTensorsHooksVariable.create(
|
||||||
@ -444,9 +448,8 @@ class TorchVariable(VariableTracker):
|
|||||||
)
|
)
|
||||||
elif self.value is torch._C._is_torch_function_enabled:
|
elif self.value is torch._C._is_torch_function_enabled:
|
||||||
assert not (args or kwargs)
|
assert not (args or kwargs)
|
||||||
return ConstantVariable.create(
|
install_guard(TorchFunctionDisableVariable._guards_singleton)
|
||||||
tx.output.torch_function_enabled, **options
|
return ConstantVariable.create(tx.output.torch_function_enabled, **options)
|
||||||
).add_guards(TorchFunctionDisableVariable._guards_singleton)
|
|
||||||
elif self.value in (
|
elif self.value in (
|
||||||
torch.overrides.has_torch_function_variadic,
|
torch.overrides.has_torch_function_variadic,
|
||||||
torch.overrides.has_torch_function_unary,
|
torch.overrides.has_torch_function_unary,
|
||||||
|
|||||||
@ -5,6 +5,7 @@ import torch.utils._pytree as pytree
|
|||||||
|
|
||||||
from torch.overrides import _get_overloaded_args, get_default_nowrap_functions
|
from torch.overrides import _get_overloaded_args, get_default_nowrap_functions
|
||||||
from ..exc import unimplemented
|
from ..exc import unimplemented
|
||||||
|
from ..guards import GuardBuilder, install_guard
|
||||||
from ..source import AttrSource, GlobalSource
|
from ..source import AttrSource, GlobalSource
|
||||||
from ..utils import is_tensor_base_attr_getter
|
from ..utils import is_tensor_base_attr_getter
|
||||||
from .base import VariableTracker
|
from .base import VariableTracker
|
||||||
@ -133,14 +134,15 @@ class TensorWithTFOverrideVariable(TensorVariable):
|
|||||||
kwargs.pop("class_type") is torch.Tensor
|
kwargs.pop("class_type") is torch.Tensor
|
||||||
), "invalid class type in TensorWithTFOverrideVariable.from_tensor_var"
|
), "invalid class type in TensorWithTFOverrideVariable.from_tensor_var"
|
||||||
var = cls(torch_function_fn=torch_function_fn, class_type=class_type, **kwargs)
|
var = cls(torch_function_fn=torch_function_fn, class_type=class_type, **kwargs)
|
||||||
|
var.install_global(tx)
|
||||||
|
return var
|
||||||
|
|
||||||
|
def install_global(self, tx):
|
||||||
# stash the subclass type to rewrap an output tensor if needed
|
# stash the subclass type to rewrap an output tensor if needed
|
||||||
# this is needed because the actual type needs to be available
|
# this is needed because the actual type needs to be available
|
||||||
# each time the compiled artifact is run and outputs a wrapped tensor.
|
# each time the compiled artifact is run and outputs a wrapped tensor.
|
||||||
if var.global_mangled_class_name() not in tx.output.global_scope:
|
if self.global_mangled_class_name() not in tx.output.global_scope:
|
||||||
tx.output.install_global(var.global_mangled_class_name(), class_type)
|
tx.output.install_global(self.global_mangled_class_name(), self.class_type)
|
||||||
|
|
||||||
return var
|
|
||||||
|
|
||||||
def python_type(self):
|
def python_type(self):
|
||||||
return self.class_type
|
return self.class_type
|
||||||
@ -157,7 +159,7 @@ class TensorWithTFOverrideVariable(TensorVariable):
|
|||||||
# [Note: __torch_function__] We currently only support attributes that are defined on
|
# [Note: __torch_function__] We currently only support attributes that are defined on
|
||||||
# base tensors, custom attribute accesses will graph break.
|
# base tensors, custom attribute accesses will graph break.
|
||||||
import torch
|
import torch
|
||||||
from .builder import SourcelessBuilder, VariableBuilder
|
from .builder import SourcelessBuilder
|
||||||
|
|
||||||
if name in banned_attrs or not hasattr(torch.Tensor, name):
|
if name in banned_attrs or not hasattr(torch.Tensor, name):
|
||||||
unimplemented(
|
unimplemented(
|
||||||
@ -172,15 +174,12 @@ class TensorWithTFOverrideVariable(TensorVariable):
|
|||||||
|
|
||||||
if tx.output.torch_function_enabled:
|
if tx.output.torch_function_enabled:
|
||||||
if self.source:
|
if self.source:
|
||||||
get_fn = VariableBuilder(
|
install_guard(
|
||||||
tx,
|
AttrSource(AttrSource(self.source, "__class__"), name).make_guard(
|
||||||
source=AttrSource(
|
GuardBuilder.FUNCTION_MATCH
|
||||||
AttrSource(AttrSource(self.source, "__class__"), name),
|
)
|
||||||
"__get__",
|
)
|
||||||
),
|
get_fn = SourcelessBuilder()(tx, getattr(torch.Tensor, name).__get__)
|
||||||
)(inspect.getattr_static(self.python_type(), name).__get__)
|
|
||||||
else:
|
|
||||||
get_fn = SourcelessBuilder()(tx, getattr(torch.Tensor, name).__get__)
|
|
||||||
|
|
||||||
return self.call_torch_function(
|
return self.call_torch_function(
|
||||||
tx,
|
tx,
|
||||||
|
|||||||
@ -17,7 +17,7 @@ from torch._guards import TracingContext
|
|||||||
from .. import variables
|
from .. import variables
|
||||||
from ..allowed_functions import is_allowed
|
from ..allowed_functions import is_allowed
|
||||||
from ..exc import unimplemented
|
from ..exc import unimplemented
|
||||||
from ..guards import GuardBuilder
|
from ..guards import GuardBuilder, install_guard
|
||||||
from ..source import AttrSource, ODictGetItemSource, RandomValueSource
|
from ..source import AttrSource, ODictGetItemSource, RandomValueSource
|
||||||
from ..utils import (
|
from ..utils import (
|
||||||
all_hook_names,
|
all_hook_names,
|
||||||
@ -266,9 +266,10 @@ class UserDefinedObjectVariable(UserDefinedVariable):
|
|||||||
assert not (args or kwargs)
|
assert not (args or kwargs)
|
||||||
keys = list(self.value.keys())
|
keys = list(self.value.keys())
|
||||||
assert all(map(ConstantVariable.is_literal, keys))
|
assert all(map(ConstantVariable.is_literal, keys))
|
||||||
|
install_guard(self.source.make_guard(GuardBuilder.ODICT_KEYS))
|
||||||
return TupleVariable(
|
return TupleVariable(
|
||||||
[ConstantVariable.create(k, **options) for k in keys], **options
|
[ConstantVariable.create(k, **options) for k in keys], **options
|
||||||
).add_guard(self.source.make_guard(GuardBuilder.ODICT_KEYS))
|
)
|
||||||
|
|
||||||
if (
|
if (
|
||||||
method in (collections.OrderedDict.__contains__, dict.__contains__)
|
method in (collections.OrderedDict.__contains__, dict.__contains__)
|
||||||
@ -278,9 +279,10 @@ class UserDefinedObjectVariable(UserDefinedVariable):
|
|||||||
in (collections.OrderedDict.keys, dict.keys)
|
in (collections.OrderedDict.keys, dict.keys)
|
||||||
):
|
):
|
||||||
assert not kwargs
|
assert not kwargs
|
||||||
|
install_guard(self.source.make_guard(GuardBuilder.ODICT_KEYS))
|
||||||
return ConstantVariable.create(
|
return ConstantVariable.create(
|
||||||
args[0].as_python_constant() in self.value, **options
|
args[0].as_python_constant() in self.value, **options
|
||||||
).add_guard(self.source.make_guard(GuardBuilder.ODICT_KEYS))
|
)
|
||||||
|
|
||||||
if (
|
if (
|
||||||
method is collections.OrderedDict.items
|
method is collections.OrderedDict.items
|
||||||
@ -376,20 +378,15 @@ class UserDefinedObjectVariable(UserDefinedVariable):
|
|||||||
)
|
)
|
||||||
):
|
):
|
||||||
options = VariableTracker.propagate(self, args, kwargs.values())
|
options = VariableTracker.propagate(self, args, kwargs.values())
|
||||||
options.setdefault("guards", set())
|
|
||||||
if self.source:
|
if self.source:
|
||||||
options["guards"].add(
|
install_guard(
|
||||||
AttrSource(self.source, "func").make_guard(GuardBuilder.ID_MATCH)
|
AttrSource(self.source, "func").make_guard(GuardBuilder.ID_MATCH),
|
||||||
)
|
|
||||||
options["guards"].add(
|
|
||||||
AttrSource(self.source, "args").make_guard(
|
AttrSource(self.source, "args").make_guard(
|
||||||
GuardBuilder.CONSTANT_MATCH
|
GuardBuilder.CONSTANT_MATCH
|
||||||
)
|
),
|
||||||
)
|
|
||||||
options["guards"].add(
|
|
||||||
AttrSource(self.source, "keywords").make_guard(
|
AttrSource(self.source, "keywords").make_guard(
|
||||||
GuardBuilder.CONSTANT_MATCH
|
GuardBuilder.CONSTANT_MATCH
|
||||||
)
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
partial_args = [
|
partial_args = [
|
||||||
@ -410,7 +407,7 @@ class UserDefinedObjectVariable(UserDefinedVariable):
|
|||||||
tx, partial_args, partial_kwargs
|
tx, partial_args, partial_kwargs
|
||||||
)
|
)
|
||||||
elif callable(self.value):
|
elif callable(self.value):
|
||||||
self.add_guard(self.source.make_guard(GuardBuilder.FUNCTION_MATCH))
|
install_guard(self.source.make_guard(GuardBuilder.FUNCTION_MATCH))
|
||||||
return self.call_method(tx, "__call__", args, kwargs)
|
return self.call_method(tx, "__call__", args, kwargs)
|
||||||
|
|
||||||
return super().call_function(tx, args, kwargs)
|
return super().call_function(tx, args, kwargs)
|
||||||
@ -578,7 +575,7 @@ class UserDefinedObjectVariable(UserDefinedVariable):
|
|||||||
pass
|
pass
|
||||||
options = VariableTracker.propagate(self)
|
options = VariableTracker.propagate(self)
|
||||||
if self.source:
|
if self.source:
|
||||||
options["guards"].add(
|
install_guard(
|
||||||
AttrSource(self.source, name).make_guard(GuardBuilder.HASATTR)
|
AttrSource(self.source, name).make_guard(GuardBuilder.HASATTR)
|
||||||
)
|
)
|
||||||
if self._check_for_getattribute() or self._check_for_getattr():
|
if self._check_for_getattribute() or self._check_for_getattr():
|
||||||
|
|||||||
@ -241,7 +241,13 @@ class Guard:
|
|||||||
return output
|
return output
|
||||||
|
|
||||||
def create(self, builder: GuardBuilderBase):
|
def create(self, builder: GuardBuilderBase):
|
||||||
return self.create_fn(builder, self)
|
try:
|
||||||
|
return self.create_fn(builder, self)
|
||||||
|
except Exception:
|
||||||
|
log.error("Error while creating guard:\n%s", str(self).rstrip())
|
||||||
|
if self.stack:
|
||||||
|
log.error("Created at:\n%s", "".join(self.stack.format()[-4:]).rstrip())
|
||||||
|
raise
|
||||||
|
|
||||||
def is_nn_module(self):
|
def is_nn_module(self):
|
||||||
return self.source.is_nn_module()
|
return self.source.is_nn_module()
|
||||||
|
|||||||
Reference in New Issue
Block a user