[dynamo] Eagerly install guards (#111415)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/111415
Approved by: https://github.com/voznesenskym
ghstack dependencies: #111306
This commit is contained in:
Jason Ansel
2023-11-07 08:12:57 -08:00
committed by PyTorch MergeBot
parent 2964682490
commit 9664190952
30 changed files with 333 additions and 622 deletions

View File

@ -3292,7 +3292,9 @@ class GraphModule(torch.nn.Module):
cos = l_x_.cos(); l_x_ = None cos = l_x_.cos(); l_x_ = None
return pytree.tree_unflatten([cos], self._out_spec) return pytree.tree_unflatten([cos], self._out_spec)
""" """
true_guard_code = ["cast_symbool_to_symint_guardless(L['pred']) == 1"] true_guard_code = [
"cast_symbool_to_symint_guardless(L['pred']) == 1",
]
false_guard_code = [ false_guard_code = [
"Ne(cast_symbool_to_symint_guardless(L['pred']), 1)", "Ne(cast_symbool_to_symint_guardless(L['pred']), 1)",
"-9223372036854775808 <= cast_symbool_to_symint_guardless(L['pred'])", "-9223372036854775808 <= cast_symbool_to_symint_guardless(L['pred'])",

View File

@ -297,25 +297,25 @@ class HigherOrderOpTests(torch._dynamo.test_case.TestCase):
actual_graph, actual_graph,
"""\ """\
class GraphModule(torch.nn.Module): class GraphModule(torch.nn.Module):
def forward(self, L_x_ : torch.Tensor, L_y_ : torch.Tensor, L_z_ : torch.Tensor): def forward(self, L_d_x_ : torch.Tensor, L_d_y_0_ : torch.Tensor, L_d_y_1_2_ : torch.Tensor):
l_x_ = L_x_ l_d_x_ = L_d_x_
l_y_ = L_y_ l_d_y_0_ = L_d_y_0_
l_z_ = L_z_ l_d_y_1_2_ = L_d_y_1_2_
wrap_body_0 = self.wrap_body_0 wrap_body_0 = self.wrap_body_0
wrap = torch._higher_order_ops.wrap.wrap(wrap_body_0, l_x_, l_y_, l_z_); wrap_body_0 = l_x_ = l_y_ = l_z_ = None wrap = torch._higher_order_ops.wrap.wrap(wrap_body_0, l_d_x_, l_d_y_0_, l_d_y_1_2_); wrap_body_0 = l_d_x_ = l_d_y_0_ = l_d_y_1_2_ = None
getitem = wrap[0]; wrap = None getitem = wrap[0]; wrap = None
return (getitem,) return (getitem,)
class GraphModule(torch.nn.Module): class GraphModule(torch.nn.Module):
def forward(self, l_x_, l_y_, l_z_): def forward(self, l_d_x_, l_d_y_0_, l_d_y_1_2_):
sin = l_x_.sin(); l_x_ = None sin = l_d_x_.sin(); l_d_x_ = None
cos = l_y_.cos(); l_y_ = None cos = l_d_y_0_.cos(); l_d_y_0_ = None
add = sin + cos; sin = cos = None add = sin + cos; sin = cos = None
sin_1 = l_z_.sin(); l_z_ = None sin_1 = l_d_y_1_2_.sin(); l_d_y_1_2_ = None
sub = add - sin_1; add = sin_1 = None sub = add - sin_1; add = sin_1 = None
return (sub,) return (sub,)
""", """, # NOQA: B950
) )
def test_wrap_pytree_args_with_symint_constant(self): def test_wrap_pytree_args_with_symint_constant(self):
@ -3005,9 +3005,9 @@ class GraphModule(torch.nn.Module):
actual, actual,
"""\ """\
class GraphModule(torch.nn.Module): class GraphModule(torch.nn.Module):
def forward(self, L_x_ : torch.Tensor, L_y_ : torch.Tensor): def forward(self, L_y_ : torch.Tensor, L_x_ : torch.Tensor):
l_x_ = L_x_
child = L_y_ child = L_y_
l_x_ = L_x_
_check_randomness_arg = torch._functorch.vmap._check_randomness_arg('error') _check_randomness_arg = torch._functorch.vmap._check_randomness_arg('error')
_check_randomness_arg_1 = torch._functorch.vmap._check_randomness_arg('error') _check_randomness_arg_1 = torch._functorch.vmap._check_randomness_arg('error')
@ -3269,16 +3269,14 @@ class GraphModule(torch.nn.Module):
return torch.func.vmap(torch.sum, in_dims)(x) return torch.func.vmap(torch.sum, in_dims)(x)
x = torch.randn(3, 3, 3, 3) x = torch.randn(3, 3, 3, 3)
opt = torch.compile(wrapper_fn, backend="eager", fullgraph=False, dynamic=True) cnt = CompileCounter()
opt = torch.compile(wrapper_fn, backend=cnt, fullgraph=False, dynamic=True)
expected = wrapper_fn(x, 0), wrapper_fn(x, 1), wrapper_fn(x, 2) expected = wrapper_fn(x, 0), wrapper_fn(x, 1), wrapper_fn(x, 2)
# Third invocation of `opt` makes `in_dims` as SymInt. # Third invocation of `opt` makes `in_dims` as SymInt.
actual = opt(x, 0), opt(x, 1), opt(x, 2) actual = opt(x, 0), opt(x, 1), opt(x, 2)
self.assertEqual(expected, actual) self.assertEqual(expected, actual)
self.assertEqual(len(counters["graph_break"]), 1) self.assertEqual(cnt.frame_count, 3)
self.assertEqual( self.assertEqual(cnt.op_count, 9)
dict(counters["graph_break"]),
{"torch.func.vmap: in_dims is not an int or tuple variable.": 2},
)
def test_vmap_multiple_invocation_out_dims(self): def test_vmap_multiple_invocation_out_dims(self):
counters.clear() counters.clear()
@ -3287,16 +3285,14 @@ class GraphModule(torch.nn.Module):
return torch.func.vmap(lambda x: torch.sum(x, 0), out_dims=out_dims)(x) return torch.func.vmap(lambda x: torch.sum(x, 0), out_dims=out_dims)(x)
x = torch.randn(3, 3, 3, 3) x = torch.randn(3, 3, 3, 3)
opt = torch.compile(wrapper_fn, backend="eager", fullgraph=False, dynamic=True) cnt = CompileCounter()
opt = torch.compile(wrapper_fn, backend=cnt, fullgraph=False, dynamic=True)
expected = wrapper_fn(x, 0), wrapper_fn(x, 1), wrapper_fn(x, 2) expected = wrapper_fn(x, 0), wrapper_fn(x, 1), wrapper_fn(x, 2)
# Third invocation of `opt` makes `in_dims` as SymInt. # Third invocation of `opt` makes `in_dims` as SymInt.
actual = opt(x, 0), opt(x, 1), opt(x, 2) actual = opt(x, 0), opt(x, 1), opt(x, 2)
self.assertEqual(expected, actual) self.assertEqual(expected, actual)
self.assertEqual(len(counters["graph_break"]), 1) self.assertEqual(cnt.frame_count, 3)
self.assertEqual( self.assertEqual(cnt.op_count, 9)
dict(counters["graph_break"]),
{"torch.func.vmap: out_dims is not an int or tuple variable.": 2},
)
def test_vmap_new_tensor_in_body(self): def test_vmap_new_tensor_in_body(self):
def fn(x): def fn(x):

View File

@ -1541,7 +1541,7 @@ utils_device.CURRENT_DEVICE == None""",
args = [torch.randn(10), 4096, np.int64(8)] args = [torch.randn(10), 4096, np.int64(8)]
correct = fn(*args) correct = fn(*args)
cnts = torch._dynamo.testing.CompileCounter() cnts = torch._dynamo.testing.CompileCounter()
opt_fn = torch._dynamo.optimize(cnts, dynamic=True)(fn) opt_fn = torch._dynamo.optimize(cnts, dynamic=True, nopython=True)(fn)
self.assertTrue(same(opt_fn(*args), correct)) self.assertTrue(same(opt_fn(*args), correct))
self.assertTrue(same(opt_fn(*args), correct)) self.assertTrue(same(opt_fn(*args), correct))
self.assertEqual(cnts.frame_count, 1) self.assertEqual(cnts.frame_count, 1)

View File

@ -814,9 +814,8 @@ class MockModule(torch.nn.Module):
class ReproTests(torch._dynamo.test_case.TestCase): class ReproTests(torch._dynamo.test_case.TestCase):
def test_do_paste_mask(self): def test_do_paste_mask(self):
torch._dynamo.utils.counters.clear() torch._dynamo.utils.counters.clear()
opt__do_paste_mask = torch._dynamo.optimize( cnt = torch._dynamo.testing.CompileCounter()
torch._dynamo.testing.CompileCounter() opt__do_paste_mask = torch.compile(_do_paste_mask, backend=cnt)
)(_do_paste_mask)
opt__do_paste_mask( opt__do_paste_mask(
torch.randn(1, 1, 28, 28), torch.randn(1, 1, 28, 28),
torch.tensor([[0.0, 1, 2, 4]]) * 1, torch.tensor([[0.0, 1, 2, 4]]) * 1,
@ -852,12 +851,9 @@ class ReproTests(torch._dynamo.test_case.TestCase):
640, 640,
False, False,
) )
# (dynamic shapes, static shapes)
self.assertGreaterEqual(torch._dynamo.utils.counters["frames"]["ok"], 3) self.assertIn(cnt.frame_count, (5, 7))
self.assertEqual( self.assertIn(cnt.op_count, (106, 127))
torch._dynamo.utils.counters["frames"]["total"],
torch._dynamo.utils.counters["frames"]["ok"] + 1,
)
def test_convert_boxes_to_pooler_format(self): def test_convert_boxes_to_pooler_format(self):
boxes1 = [ boxes1 = [
@ -2451,77 +2447,6 @@ class ReproTests(torch._dynamo.test_case.TestCase):
self.assertEqual(f(x, x), opt_f(x, x)) self.assertEqual(f(x, x), opt_f(x, x))
self.assertEqual(f(x, y), opt_f(x, y)) self.assertEqual(f(x, y), opt_f(x, y))
def test_reformer_remove_unused_args(self):
# This test case is very interesting. First, let's describe
# the bug this is testing for. The bug we fixed is twofold:
#
# - We prune GraphArgs that aren't used in the output graph.
# However, sometimes it is possible for those GraphArgs to be
# utilized in shape guards (you could imagine this happening if
# dynamo poked some shape variables without recording them in the
# graph.) If we prune those GraphArgs, we get a
# "s1 not in ..." error as we can no longer codegen the
# requested guards.
#
# - But in practice, Dynamo usually traces size accesses into the
# graph, preventing the GraphArg from getting pruned. So how
# come we were running into this in practice with hf_Reformer?
# The answer is checkpointing!
#
# This brings us to the following test case. Here's what it does:
#
# 1. It traces some operations, and then checkpoints before inlining
# the function call to g
#
# 2. g traces some more operations (triggering the shape guard
# to be created), but then it graph breaks
#
# 3. Because you can't graph break in an inlining function, we roll
# back to the outer checkpoint ("undoing" the operation that
# induced the shape guard) and then immediately generate a
# subgraph at that point.
#
# If we failed to checkpoint the ShapeEnv, it can still have guards
# from the aborted speculation, which we will then still attempt to
# codegen.
#
# There's an additional nuance: suppose x is used but y is not.
# If you create a guard like y == x * 2, you will accidentally avoid
# the "s1 not in ..." error, as y will get substituted with x * 2,
# but x is still a GraphArg (it's used) and you don't end up with
# the error. This is why we must show y + y == x, not vice versa.
# Similarly, it is also why we must not do a simple guard like x == y
#
# Can we actually demonstrate that checkpointing the ShapeEnv is
# necessary? It's not so easy to induce this case. Dynamo is very
# eager about adding locals to GraphArgs; any local that is in scope,
# even if it isn't used, is added to GraphArgs (see also
# https://github.com/pytorch/torchdynamo/issues/1925 ). So long
# as Dynamo eagerly guards in this way, we have an invariant that
# all locals are guaranteed to show up in GraphArgs before the
# inlining function call, in which case we will always have enough
# information to codegen our guards so long as we don't prune the
# unused GraphArgs away (and indeed, the direct fix for this bug
# was to make sure we use original GraphArgs). Non locals,
# conversely, typically are static, and so won't have guards allocated
# for them. That being said, there may still be a way to trigger
# this error.
def g(x, y):
r = torch.cat((y, y)) + x
print("foo")
return r
def f(x, y):
x = x * 3
return g(x, y)
opt_f = torch._dynamo.optimize("aot_eager")(f)
x = torch.randn(4)
y = torch.randn(2)
self.assertEqual(f(x, y), opt_f(x, y))
def test_swin_base_tensor_attr(self): def test_swin_base_tensor_attr(self):
class Foo(torch.nn.Module): class Foo(torch.nn.Module):
def __init__(self): def __init__(self):

View File

@ -271,7 +271,7 @@ class SubclassTests(torch._dynamo.test_case.TestCase):
kwargs = {} kwargs = {}
return super().__torch_function__(func, types, args, kwargs) return super().__torch_function__(func, types, args, kwargs)
@torch.compile(backend="eager", fullgraph=True) @torch.compile(backend="eager")
def fn(x): def fn(x):
return x.sigmoid() return x.sigmoid()
@ -819,13 +819,6 @@ class TestNestedTensor(torch._dynamo.test_case.TestCase):
nt3, _ = self._get_jagged_tensor(((2, 3, 4), 3), None) nt3, _ = self._get_jagged_tensor(((2, 3, 4), 3), None)
self._check_recompiles(binary, (nt1, nt2), (nt1, nt3), True) self._check_recompiles(binary, (nt1, nt2), (nt1, nt3), True)
def test_binary_recompiles_due_to_duck_sizing(self):
# Even though the input is unused, we still guard due to duck sizing
nt1, offsets = self._get_jagged_tensor(((2, 3, 4), 3), None)
nt2, _ = self._get_jagged_tensor(((2, 3, 4), 3), offsets)
nt3, _ = self._get_jagged_tensor(((2, 3, 4), 3), None)
self._check_recompiles(lambda nt1, nt2: nt1.sin(), (nt1, nt2), (nt1, nt3), True)
# TODO: cannot parametrize this test class with device for some reason # TODO: cannot parametrize this test class with device for some reason
def _test_autograd(self, backend): def _test_autograd(self, backend):
a = torch.randn(2, 3, requires_grad=True, dtype=torch.float64) a = torch.randn(2, 3, requires_grad=True, dtype=torch.float64)

View File

@ -477,8 +477,8 @@ class SubGraphTests(torch._dynamo.test_case.TestCase):
opt_fn(v1, a, b, c) opt_fn(v1, a, b, c)
# checking here we don't create 2^n graphs # checking here we don't create 2^n graphs
self.assertEqual(cnt.frame_count, 12) self.assertEqual(cnt.frame_count, 7)
self.assertEqual(cnt.op_count, 16) self.assertEqual(cnt.op_count, 10)
def test_resume_with_no_grad1(self): def test_resume_with_no_grad1(self):
def fn(a, b): def fn(a, b):

View File

@ -4,7 +4,7 @@ import hypothesis.strategies as st
from hypothesis import given from hypothesis import given
import numpy as np import numpy as np
import torch import torch
from torch.testing._internal.common_utils import TestCase, run_tests from torch.testing._internal.common_utils import TestCase, run_tests, skipIfTorchDynamo
import torch.testing._internal.hypothesis_utils as hu import torch.testing._internal.hypothesis_utils as hu
hu.assert_deadline_disabled() hu.assert_deadline_disabled()
@ -56,6 +56,7 @@ class PruningOpTest(TestCase):
self.assertEqual(pt_compressed_indices_map.dtype, indices_type) self.assertEqual(pt_compressed_indices_map.dtype, indices_type)
@skipIfTorchDynamo()
@given( @given(
embedding_rows=st.integers(1, 100), embedding_rows=st.integers(1, 100),
embedding_dims=st.integers(1, 100), embedding_dims=st.integers(1, 100),
@ -67,6 +68,7 @@ class PruningOpTest(TestCase):
self._test_rowwise_prune_op(embedding_rows, embedding_dims, torch.int, weights_dtype) self._test_rowwise_prune_op(embedding_rows, embedding_dims, torch.int, weights_dtype)
@skipIfTorchDynamo()
@given( @given(
embedding_rows=st.integers(1, 100), embedding_rows=st.integers(1, 100),
embedding_dims=st.integers(1, 100), embedding_dims=st.integers(1, 100),

View File

@ -2530,6 +2530,7 @@ class TestSparseCSR(TestCase):
run_test(4, 5, 4, 10, False) run_test(4, 5, 4, 10, False)
run_test(4, 4, 4, 16, True) run_test(4, 4, 4, 16, True)
@skipIfTorchDynamo()
@onlyCPU @onlyCPU
@dtypes(torch.float32, torch.float64, torch.bfloat16) @dtypes(torch.float32, torch.float64, torch.bfloat16)
@precisionOverride({torch.bfloat16: 0.01}) @precisionOverride({torch.bfloat16: 0.01})
@ -2894,6 +2895,7 @@ class TestSparseCSR(TestCase):
run_test(shape, max(shape), index_dtype) run_test(shape, max(shape), index_dtype)
run_test(shape, shape[0] * shape[1], index_dtype) run_test(shape, shape[0] * shape[1], index_dtype)
@skipIfTorchDynamo()
@skipMeta @skipMeta
@dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16)) @dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16))
@all_sparse_compressed_layouts() @all_sparse_compressed_layouts()

View File

@ -76,8 +76,6 @@ class PyCodegen:
self.clear_tos() self.clear_tos()
return return
self.tx.output.guards.update(value.guards)
assert isinstance(value, VariableTracker) assert isinstance(value, VariableTracker)
output = self._output output = self._output
graph_outputs = self.graph_outputs graph_outputs = self.graph_outputs

View File

@ -586,7 +586,7 @@ class GuardBuilder(GuardBuilderBase):
f"{id(torch._dynamo.eval_frame.guarded_backend_cache.current_backend)}" f"{id(torch._dynamo.eval_frame.guarded_backend_cache.current_backend)}"
) )
code = [ code = [
f"___skip_backend_check() or ___current_backend() == ___lookup_backend({backend_id})" f"(___skip_backend_check() or ___current_backend() == ___lookup_backend({backend_id}))"
] ]
self._produce_guard_code(guard, code) self._produce_guard_code(guard, code)
@ -1366,3 +1366,19 @@ def make_dupe_guard(obj_source, dupe_source):
# However, this should always be a sound guard to add here. # However, this should always be a sound guard to add here.
return functools.partial(GuardBuilder.DUPLICATE_INPUT, source_b=dupe_source) return functools.partial(GuardBuilder.DUPLICATE_INPUT, source_b=dupe_source)
return None return None
def install_guard(*guards, skip=0):
"""
Add dynamo guards to the current tracing context.
Args:
guards: guard(s) to add
skip: number of stack frames to ignore for debug stack trace
"""
from torch._guards import TracingContext
add = TracingContext.get().guards_context.dynamo_guards.add
for guard in guards:
assert isinstance(guard, Guard)
add(guard, skip=skip + 1)

View File

@ -62,7 +62,7 @@ from .exc import (
unimplemented, unimplemented,
unimplemented_with_warning, unimplemented_with_warning,
) )
from .guards import GuardBuilder from .guards import GuardBuilder, install_guard
from .mutation_guard import is_dynamic_nn_module from .mutation_guard import is_dynamic_nn_module
from .side_effects import SideEffects from .side_effects import SideEffects
from .source import ( from .source import (
@ -561,7 +561,11 @@ class OutputGraph(Checkpointable[OutputGraphState]):
# FX deepcopy doesn't work for a partially created graph, so just remove new nodes # FX deepcopy doesn't work for a partially created graph, so just remove new nodes
removed_nodes = 0 removed_nodes = 0
for node in reversed(list(self.graph.nodes)): for node in reversed(list(self.graph.nodes)):
if node.meta["creation_timestamp"] > self.timestamp: if (
node.meta["creation_timestamp"] > self.timestamp
# placeholders here may have been lazily added by existing objects
and node.op != "placeholder"
):
# Erasing node alone does not remove the meta information # Erasing node alone does not remove the meta information
# So, remove the help tensor explicitly # So, remove the help tensor explicitly
if "example_value" in node.meta: if "example_value" in node.meta:
@ -670,7 +674,6 @@ class OutputGraph(Checkpointable[OutputGraphState]):
return variables.UnspecializedNNModuleVariable(target, **options) return variables.UnspecializedNNModuleVariable(target, **options)
options = dict(options) options = dict(options)
options["guards"] = set(options.get("guards", []))
assert "source" in options assert "source" in options
source = options["source"] source = options["source"]
assert not isinstance(source, ParamBufferSource) assert not isinstance(source, ParamBufferSource)
@ -692,10 +695,10 @@ class OutputGraph(Checkpointable[OutputGraphState]):
tracer = self.root_tracer tracer = self.root_tracer
if not is_constant_source(source): if not is_constant_source(source):
options["guards"].add(source.make_guard(GuardBuilder.TENSOR_MATCH)) install_guard(source.make_guard(GuardBuilder.TENSOR_MATCH))
if get_static_address_type(target) == "guarded": if get_static_address_type(target) == "guarded":
options["guards"].add(source.make_guard(GuardBuilder.DATA_PTR_MATCH)) install_guard(source.make_guard(GuardBuilder.DATA_PTR_MATCH))
def wrap_name(module_key): def wrap_name(module_key):
assert self.param_name_to_source is not None assert self.param_name_to_source is not None
@ -711,7 +714,7 @@ class OutputGraph(Checkpointable[OutputGraphState]):
elif isinstance(target, torch.nn.Module): elif isinstance(target, torch.nn.Module):
assert isinstance(target, torch.nn.Module) assert isinstance(target, torch.nn.Module)
options["guards"].add(source.make_guard(GuardBuilder.NN_MODULE)) install_guard(source.make_guard(GuardBuilder.NN_MODULE))
def wrap_name(module_key): def wrap_name(module_key):
return NNModuleVariable(type(target), module_key, **options) return NNModuleVariable(type(target), module_key, **options)
@ -1005,9 +1008,6 @@ class OutputGraph(Checkpointable[OutputGraphState]):
assert isinstance(rv, list) assert isinstance(rv, list)
assert isinstance(root, FakeRootModule) assert isinstance(root, FakeRootModule)
for output in rv:
self.guards.update(output.guards)
self.create_node( self.create_node(
"output", "output",
"output", "output",

View File

@ -54,7 +54,7 @@ from .codegen import PyCodegen
from .current_scope_id import current_scope_id from .current_scope_id import current_scope_id
from .exc import ArgsMismatchError, BackendCompilerFailed, unimplemented, Unsupported from .exc import ArgsMismatchError, BackendCompilerFailed, unimplemented, Unsupported
from .funcname_cache import get_funcname from .funcname_cache import get_funcname
from .guards import GuardBuilder from .guards import GuardBuilder, install_guard
from .output_graph import GraphCompileReason, OutputGraph, OutputGraphState from .output_graph import GraphCompileReason, OutputGraph, OutputGraphState
from .replay_record import DummyModule, ExecutionRecorder from .replay_record import DummyModule, ExecutionRecorder
from .resume_execution import ContinueExecutionCache, ReenterWith from .resume_execution import ContinueExecutionCache, ReenterWith
@ -323,13 +323,11 @@ def _detect_and_normalize_assert_statement(
def generic_jump(truth_fn: typing.Callable[[object], bool], push: bool): def generic_jump(truth_fn: typing.Callable[[object], bool], push: bool):
def inner(self: "InstructionTranslatorBase", inst: Instruction): def inner(self: "InstructionTranslatorBase", inst: Instruction):
value: VariableTracker = self.pop() value: VariableTracker = self.pop()
self.output.guards.update(value.guards)
if ( if (
config.rewrite_assert_with_torch_assert config.rewrite_assert_with_torch_assert
and _detect_and_normalize_assert_statement(self, truth_fn, push) and _detect_and_normalize_assert_statement(self, truth_fn, push)
): ):
error_msg: VariableTracker = self.pop() error_msg: VariableTracker = self.pop()
self.output.guards.update(error_msg.guards)
# Skip over things like `assert True` # Skip over things like `assert True`
if value.is_python_constant() and bool(value.as_python_constant()): if value.is_python_constant() and bool(value.as_python_constant()):
self.jump(inst) self.jump(inst)
@ -419,7 +417,6 @@ def generic_jump(truth_fn: typing.Callable[[object], bool], push: bool):
if isinstance(result, ConstantVariable) and isinstance( if isinstance(result, ConstantVariable) and isinstance(
result.value, (bool, int) result.value, (bool, int)
): ):
self.output.guards.update(result.guards)
if truth_fn(result.value): if truth_fn(result.value):
push and self.push(value) push and self.push(value)
self.jump(inst) self.jump(inst)
@ -686,9 +683,7 @@ class InstructionTranslatorBase(Checkpointable[InstructionTranslatorGraphState])
""" """
A call to some user defined function by inlining it. A call to some user defined function by inlining it.
""" """
result = InliningInstructionTranslator.inline_call(self, fn, args, kwargs) return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
self.output.guards.update(fn.guards)
return result
def get_line_of_code_header(self, lineno=None): def get_line_of_code_header(self, lineno=None):
if lineno is None: if lineno is None:
@ -1139,7 +1134,6 @@ class InstructionTranslatorBase(Checkpointable[InstructionTranslatorGraphState])
def FOR_ITER(self, inst): def FOR_ITER(self, inst):
it = self.pop().realize() it = self.pop().realize()
if isinstance(it, (variables.ListIteratorVariable, variables.IteratorVariable)): if isinstance(it, (variables.ListIteratorVariable, variables.IteratorVariable)):
self.output.guards.update(it.guards)
try: try:
val, next_iter = it.next_variables(self) val, next_iter = it.next_variables(self)
self.push(next_iter) self.push(next_iter)
@ -1233,8 +1227,6 @@ class InstructionTranslatorBase(Checkpointable[InstructionTranslatorGraphState])
if sys.version_info >= (3, 11): if sys.version_info >= (3, 11):
null = self.pop() null = self.pop()
assert isinstance(null, NullVariable) assert isinstance(null, NullVariable)
self.output.guards.update(argsvars.guards)
self.output.guards.update(kwargsvars.guards)
if ( if (
isinstance(fn, GetAttrVariable) isinstance(fn, GetAttrVariable)
@ -1327,12 +1319,8 @@ class InstructionTranslatorBase(Checkpointable[InstructionTranslatorGraphState])
), f"Mutating module attribute {inst.argval} during export." ), f"Mutating module attribute {inst.argval} during export."
try: try:
self.output.guards.update( BuiltinVariable(setattr).call_function(
BuiltinVariable(setattr) self, [obj, ConstantVariable.create(inst.argval), val], {}
.call_function(
self, [obj, ConstantVariable.create(inst.argval), val], {}
)
.guards
) )
return return
except Unsupported as e: except Unsupported as e:
@ -1355,10 +1343,8 @@ class InstructionTranslatorBase(Checkpointable[InstructionTranslatorGraphState])
def DELETE_ATTR(self, inst): def DELETE_ATTR(self, inst):
obj = self.pop() obj = self.pop()
self.output.guards.update( BuiltinVariable(delattr).call_function(
BuiltinVariable(delattr) self, [obj, ConstantVariable.create(inst.argval)], {}
.call_function(self, [obj, ConstantVariable.create(inst.argval)], {})
.guards
) )
def create_call_resume_at(self, offset): def create_call_resume_at(self, offset):
@ -1375,8 +1361,6 @@ class InstructionTranslatorBase(Checkpointable[InstructionTranslatorGraphState])
def STORE_SUBSCR(self, inst): def STORE_SUBSCR(self, inst):
val, obj, key = self.popn(3) val, obj, key = self.popn(3)
result = obj.call_method(self, "__setitem__", [key, val], {}) result = obj.call_method(self, "__setitem__", [key, val], {})
# no result is pushed, so need to lift the guards to global
self.output.guards.update(result.guards)
def BUILD_TUPLE(self, inst): def BUILD_TUPLE(self, inst):
items = self.popn(inst.argval) items = self.popn(inst.argval)
@ -1511,7 +1495,6 @@ class InstructionTranslatorBase(Checkpointable[InstructionTranslatorGraphState])
obj, obj,
ListVariable( ListVariable(
obj.items + [v], obj.items + [v],
regen_guards=False,
**VariableTracker.propagate([obj, v]), **VariableTracker.propagate([obj, v]),
), ),
) )
@ -1559,7 +1542,6 @@ class InstructionTranslatorBase(Checkpointable[InstructionTranslatorGraphState])
def UNPACK_SEQUENCE(self, inst): def UNPACK_SEQUENCE(self, inst):
seq = self.pop() seq = self.pop()
if isinstance(seq, (BaseListVariable, SetVariable)): if isinstance(seq, (BaseListVariable, SetVariable)):
self.output.guards.update(seq.guards)
val = seq.unpack_var_sequence(self) val = seq.unpack_var_sequence(self)
elif seq.is_python_constant() and isinstance(seq, ConstantVariable): elif seq.is_python_constant() and isinstance(seq, ConstantVariable):
val = seq.unpack_var_sequence(self) val = seq.unpack_var_sequence(self)
@ -1874,8 +1856,6 @@ class InstructionTranslatorBase(Checkpointable[InstructionTranslatorGraphState])
if isinstance(ctx, GenericContextWrappingVariable): if isinstance(ctx, GenericContextWrappingVariable):
self.generic_context_manager_depth += 1 self.generic_context_manager_depth += 1
self.output.guards.update(ctx.guards)
exit = WithExitFunctionVariable( exit = WithExitFunctionVariable(
ctx, ctx,
inst.target, inst.target,
@ -1961,9 +1941,7 @@ class InstructionTranslatorBase(Checkpointable[InstructionTranslatorGraphState])
) )
def store_global_weakref(self, name, value): def store_global_weakref(self, name, value):
self.output.guards.add( install_guard(GlobalWeakRefSource(name).make_guard(GuardBuilder.WEAKREF_ALIVE))
GlobalWeakRefSource(name).make_guard(GuardBuilder.WEAKREF_ALIVE)
)
if name not in self.output.global_scope: if name not in self.output.global_scope:
self.output.install_global(name, weakref.ref(value)) self.output.install_global(name, weakref.ref(value))
@ -2148,67 +2126,26 @@ class InstructionTranslator(InstructionTranslatorBase):
vars.extend(cells_and_freevars) vars.extend(cells_and_freevars)
cells_and_freevars_set = set(cells_and_freevars) cells_and_freevars_set = set(cells_and_freevars)
self.symbolic_locals = collections.OrderedDict( self.symbolic_locals = {
( k: variables.LazyVariableTracker.create(
k, f_locals[k],
VariableBuilder( source=LocalSource(k, cell_or_freevar=k in cells_and_freevars_set),
self,
LocalSource(k, cell_or_freevar=k in cells_and_freevars_set),
)(f_locals[k]),
) )
for k in vars for k in vars
if k in f_locals if k in f_locals
) }
if export: if export:
# export gets super confused if we never realize unused inputs # export gets confused if we never realize unused inputs
# in export mode just eagerly realize everything # in export mode just eagerly realize everything
self.symbolic_locals = VariableTracker.apply( self.symbolic_locals = VariableTracker.apply(
lambda x: x.realize(), self.symbolic_locals lambda x: x.realize(), self.symbolic_locals
) )
self.init_local_index_guards_hack()
self._freevars_ids = dict() self._freevars_ids = dict()
for name in self.code_options["co_freevars"]: for name in self.code_options["co_freevars"]:
if name in f_locals: if name in f_locals:
self._freevars_ids[name] = id(f_locals[name]) self._freevars_ids[name] = id(f_locals[name])
def init_local_index_guards_hack(self):
# symbolic_locals contains the mapping from original f_locals to the
# Variable objects. During the Variable building phase, each object also
# has its associated guards. At the end, we will accumulate these
# guards.
#
# One way of handling these guards is to just accumulate all of them
# right now. However, many f_locals might not be used in the frame and
# thus can unnecessarily increase guard execution overhead. Therefore,
# we selectively update output.guards as we run the Python Bytecode
# instruction by instruction.
#
# An exception here is list/dict variables. Guards related to these
# variables have indexed access, like Tensor_match on args[0], and if
# args is not used in this frame, we will miss a LIST_LENGTH check like
# len(args) == 2. Missing the LIST_LENGTH check causes problem for the
# next invocation when args is not a list, and args[0] is a runtime
# error. Therefore, we recursively add guards for list/dict variable here.
for val in self.symbolic_locals.values():
if isinstance(
val, (ListIteratorVariable, BaseListVariable, ConstDictVariable)
):
local_guards = VariableTracker.propagate(val)["guards"]
index_guards = [
guard
for guard in local_guards
if guard.create_fn
in (
GuardBuilder.LIST_LENGTH,
GuardBuilder.DICT_KEYS,
GuardBuilder.ODICT_KEYS,
GuardBuilder.TUPLE_ITERATOR_LEN,
)
]
self.output.guards.update(index_guards)
def run(self): def run(self):
super().run() super().run()
@ -2661,7 +2598,6 @@ class InliningGeneratorInstructionTranslator(InliningInstructionTranslator):
if isinstance( if isinstance(
tos, (variables.ListIteratorVariable, variables.IteratorVariable) tos, (variables.ListIteratorVariable, variables.IteratorVariable)
): ):
self.output.guards.update(tos.guards)
try: try:
val, next_iter = tos.next_variables(self) val, next_iter = tos.next_variables(self)
self.push(val) self.push(val)

View File

@ -1,12 +1,12 @@
import collections import collections
from enum import Enum from enum import Enum
from typing import Any, Callable, Dict, List, Optional, Set from typing import Any, Callable, Dict, List
from .. import variables from .. import variables
from ..current_scope_id import current_scope_id from ..current_scope_id import current_scope_id
from ..exc import unimplemented from ..exc import unimplemented
from ..source import AttrSource, Source from ..source import AttrSource, Source
from ..utils import dict_values, identity, istype, odict_values from ..utils import identity, istype
class MutableLocalSource(Enum): class MutableLocalSource(Enum):
@ -154,21 +154,8 @@ class VariableTracker(metaclass=VariableTrackerMeta):
@staticmethod @staticmethod
def propagate(*vars: List[List["VariableTracker"]]): def propagate(*vars: List[List["VariableTracker"]]):
"""Combine the guards from many VariableTracker into **kwargs for a new instance""" # TODO(jansel): delete this function
guards = set() return {}
def visit(var):
if type(var) in (list, tuple, dict_values, odict_values):
for i in var:
visit(i)
else:
assert isinstance(var, VariableTracker), typestr(var)
guards.update(var.guards)
visit(vars)
return {
"guards": guards,
}
def clone(self, **kwargs): def clone(self, **kwargs):
"""Shallow copy with some (optional) changes""" """Shallow copy with some (optional) changes"""
@ -246,22 +233,8 @@ class VariableTracker(metaclass=VariableTrackerMeta):
cache[idx] = (result, value) cache[idx] = (result, value)
return result return result
def add_guard(self, guard):
return self.clone(guards=set.union(self.guards, {guard}))
def add_guards(self, guards):
if guards is None:
return self
assert isinstance(guards, set)
return self.clone(guards=set.union(self.guards, guards))
def add_options(self, options, *more): def add_options(self, options, *more):
if more: return self
return self.add_options(options).add_options(*more)
if isinstance(options, VariableTracker):
return self.add_guards(options.guards)
assert isinstance(options, dict)
return self.add_guards(options.get("guards", set()))
def __str__(self): def __str__(self):
return f"{self.__class__.__name__}()" return f"{self.__class__.__name__}()"
@ -283,13 +256,6 @@ class VariableTracker(metaclass=VariableTrackerMeta):
except NotImplementedError: except NotImplementedError:
return False return False
def can_make_guard(self):
try:
self.make_guard(None)
return True
except NotImplementedError:
return False
def make_guard(self, fn): def make_guard(self, fn):
if self.source: if self.source:
return self.source.make_guard(fn) return self.source.make_guard(fn)
@ -380,6 +346,10 @@ class VariableTracker(metaclass=VariableTrackerMeta):
"""Used by LazyVariableTracker to build the real VariableTracker""" """Used by LazyVariableTracker to build the real VariableTracker"""
return self return self
def recursive_realize(self):
"""Realize all objects under this"""
return VariableTracker.apply(lambda x: x.realize(), self)
def unwrap(self) -> "VariableTracker": def unwrap(self) -> "VariableTracker":
"""Used by LazyVariableTracker to return the real VariableTracker if it already exists""" """Used by LazyVariableTracker to return the real VariableTracker if it already exists"""
return self return self
@ -391,14 +361,12 @@ class VariableTracker(metaclass=VariableTrackerMeta):
def __init__( def __init__(
self, self,
*, *,
guards: Optional[Set] = None,
source: Source = None, source: Source = None,
mutable_local: MutableLocal = None, mutable_local: MutableLocal = None,
user_code_variable_name: str = None, user_code_variable_name: str = None,
parents_tracker: ParentsTracker = None, parents_tracker: ParentsTracker = None,
): ):
super().__init__() super().__init__()
self.guards = guards or set()
self.source = source self.source = source
self.mutable_local = mutable_local self.mutable_local = mutable_local
self.user_code_variable_name = user_code_variable_name self.user_code_variable_name = user_code_variable_name

View File

@ -44,7 +44,7 @@ from ..allowed_functions import (
from ..device_interface import device_interfaces from ..device_interface import device_interfaces
from ..exc import InternalTorchDynamoError, unimplemented from ..exc import InternalTorchDynamoError, unimplemented
from ..guards import GuardBuilder, make_dupe_guard from ..guards import GuardBuilder, install_guard, make_dupe_guard
from ..side_effects import SideEffects from ..side_effects import SideEffects
from ..source import ( from ..source import (
AttrSource, AttrSource,
@ -198,6 +198,7 @@ class GraphArg:
def erase(self): def erase(self):
self._example = None self._example = None
self.example_strong_ref = None
def __eq__(self, other): def __eq__(self, other):
return self.source.name() == other.source.name() return self.source.name() == other.source.name()
@ -231,9 +232,7 @@ class VariableBuilder:
side_effect_result = self.tx.output.side_effects[value] side_effect_result = self.tx.output.side_effects[value]
dup_guard = make_dupe_guard(self.source, side_effect_result.source) dup_guard = make_dupe_guard(self.source, side_effect_result.source)
if dup_guard: if dup_guard:
side_effect_result = side_effect_result.add_guards( self.install_guards(dup_guard)
self.make_guards(dup_guard)
)
return side_effect_result return side_effect_result
vt = self._wrap(value).clone(**self.options()) vt = self._wrap(value).clone(**self.options())
if self._can_lift_attrs_to_inputs(vt): if self._can_lift_attrs_to_inputs(vt):
@ -272,14 +271,15 @@ class VariableBuilder:
def options(self): def options(self):
return {"source": self.get_source()} return {"source": self.get_source()}
def make_guards(self, *guards): def install_guards(self, *guards):
source = self.get_source() source = self.get_source()
if ( if (
isinstance(source, ConstantSource) isinstance(source, ConstantSource)
or source.guard_source() == GuardSource.CONSTANT or source.guard_source() == GuardSource.CONSTANT
): ):
return None return None
return {source.make_guard(guard) for guard in guards} install_guard(*[source.make_guard(guard) for guard in guards], skip=1)
return {}
@classmethod @classmethod
@functools.lru_cache(None) @functools.lru_cache(None)
@ -330,7 +330,7 @@ class VariableBuilder:
lambda self, value: LambdaVariable( lambda self, value: LambdaVariable(
InspectSignatureVariable.create, InspectSignatureVariable.create,
source=self.source, source=self.source,
guards=self.make_guards(GuardBuilder.FUNCTION_MATCH), **self.install_guards(GuardBuilder.FUNCTION_MATCH),
), ),
), ),
(comptime, lambda self, value: ComptimeVariable()), (comptime, lambda self, value: ComptimeVariable()),
@ -339,7 +339,7 @@ class VariableBuilder:
lambda self, value: LambdaVariable( lambda self, value: LambdaVariable(
_dataclasses_fields_lambda, _dataclasses_fields_lambda,
source=self.source, source=self.source,
guards=self.make_guards(GuardBuilder.FUNCTION_MATCH), **self.install_guards(GuardBuilder.FUNCTION_MATCH),
), ),
), ),
( (
@ -347,7 +347,7 @@ class VariableBuilder:
lambda self, value: TorchVariable( lambda self, value: TorchVariable(
value, value,
source=self.source, source=self.source,
guards=self.make_guards(GuardBuilder.FUNCTION_MATCH), **self.install_guards(GuardBuilder.FUNCTION_MATCH),
), ),
), ),
] ]
@ -375,8 +375,6 @@ class VariableBuilder:
class Autotuner: class Autotuner:
pass pass
make_guards = self.make_guards
# Handle exact type() match # Handle exact type() match
type_dispatch = self._type_dispatch().get(type(value)) type_dispatch = self._type_dispatch().get(type(value))
if type_dispatch is not None: if type_dispatch is not None:
@ -400,13 +398,13 @@ class VariableBuilder:
return self.wrap_listlike(value) return self.wrap_listlike(value)
elif value is torch.utils._pytree.SUPPORTED_NODES: elif value is torch.utils._pytree.SUPPORTED_NODES:
# For SUPPORTED_NODES, we guard on the dictionary version (PEP509)
# under the assumption that the values themselves don't change.
self.install_guards(GuardBuilder.DICT_VERSION)
result = { result = {
k: UserDefinedObjectVariable( k: UserDefinedObjectVariable(
value[k], value[k],
source=GetItemSource(self.get_source(), k), source=GetItemSource(self.get_source(), k),
# For SUPPORTED_NODES, we guard on the dictionary version (PEP509)
# under the assumption that the values themselves don't change.
guards=self.make_guards(GuardBuilder.DICT_VERSION),
) )
for k in value.keys() for k in value.keys()
} }
@ -429,9 +427,9 @@ class VariableBuilder:
# Why is this OK for (specialized) nnmodules? We set up a setattr hook # Why is this OK for (specialized) nnmodules? We set up a setattr hook
# to check for module property mutations, which does a reasonable, # to check for module property mutations, which does a reasonable,
# but not completely secure job ensuring a property wasn't changed. # but not completely secure job ensuring a property wasn't changed.
guards = self.make_guards(GuardBuilder.BOOL_FALSE) self.install_guards(GuardBuilder.BOOL_FALSE)
else: else:
guards = self.make_guards(GuardBuilder.DICT_KEYS) self.install_guards(GuardBuilder.DICT_KEYS)
# store key variables in global location for reconstruction # store key variables in global location for reconstruction
for key in value.keys(): for key in value.keys():
@ -448,7 +446,7 @@ class VariableBuilder:
k: LazyVariableTracker.create( k: LazyVariableTracker.create(
value[k], value[k],
source=GetItemSource(self.get_source(), index_source(k)), source=GetItemSource(self.get_source(), index_source(k)),
).add_guards(guards) )
for k in value.keys() for k in value.keys()
} }
@ -457,10 +455,9 @@ class VariableBuilder:
result, result,
type(value), type(value),
self._wrap(value.default_factory), self._wrap(value.default_factory),
guards=guards,
) )
else: else:
result = ConstDictVariable(result, type(value), guards=guards) result = ConstDictVariable(result, type(value))
return self.tx.output.side_effects.track_dict(self.source, value, result) return self.tx.output.side_effects.track_dict(self.source, value, result)
elif isinstance(value, torch.nn.Module): elif isinstance(value, torch.nn.Module):
@ -472,23 +469,14 @@ class VariableBuilder:
): ):
# For frozenset, we can guard by object ID instead of value # For frozenset, we can guard by object ID instead of value
# equality, this allows us to handle non-literal values # equality, this allows us to handle non-literal values
return ConstantVariable.create( self.install_guards(GuardBuilder.ID_MATCH)
value=value, return ConstantVariable.create(value=value, source=self.source)
source=self.source,
guards=make_guards(GuardBuilder.ID_MATCH),
)
elif isinstance(value, enum.Enum): elif isinstance(value, enum.Enum):
return EnumVariable( self.install_guards(GuardBuilder.ID_MATCH)
value=value, return EnumVariable(value=value, source=self.source)
source=self.source,
guards=make_guards(GuardBuilder.ID_MATCH),
)
elif is_builtin_callable(value): elif is_builtin_callable(value):
return BuiltinVariable( self.install_guards(GuardBuilder.BUILTIN_MATCH)
value, return BuiltinVariable(value, source=self.source)
source=self.source,
guards=make_guards(GuardBuilder.BUILTIN_MATCH),
)
elif is_utils_checkpoint(value): elif is_utils_checkpoint(value):
return build_checkpoint_variable(source=self.source) return build_checkpoint_variable(source=self.source)
elif isinstance(value, functools.partial): elif isinstance(value, functools.partial):
@ -509,52 +497,50 @@ class VariableBuilder:
self.tx, GetItemSource(keywords_source, k) self.tx, GetItemSource(keywords_source, k)
)(v) )(v)
guards = { install_guard(
self.get_source().make_guard(GuardBuilder.TYPE_MATCH), self.get_source().make_guard(GuardBuilder.TYPE_MATCH),
keywords_source.make_guard(GuardBuilder.DICT_KEYS), keywords_source.make_guard(GuardBuilder.DICT_KEYS),
args_source.make_guard(GuardBuilder.LIST_LENGTH), args_source.make_guard(GuardBuilder.LIST_LENGTH),
}
return FunctoolsPartialVariable(
func_obj, args, keywords, original=value, guards=guards
) )
return FunctoolsPartialVariable(func_obj, args, keywords, original=value)
elif is_typing(value): elif is_typing(value):
# typing.List, typing.Mapping, etc. # typing.List, typing.Mapping, etc.
self.install_guards(GuardBuilder.ID_MATCH)
return TypingVariable( return TypingVariable(
value, value,
source=self.source, source=self.source,
guards=make_guards(GuardBuilder.ID_MATCH),
) )
elif np is not None and isinstance(value, np.generic): elif np is not None and isinstance(value, np.generic):
# numpy array scalars: convert to 0D arrays # numpy array scalars: convert to 0D arrays
return self.wrap_numpy_ndarray(np.asarray(value)) return self.wrap_numpy_ndarray(np.asarray(value))
elif is_numpy(value): elif is_numpy(value):
assert np assert np
return NumpyVariable( self.install_guards(
value, GuardBuilder.FUNCTION_MATCH
source=self.source, if callable(value)
guards=make_guards( else GuardBuilder.TYPE_MATCH
GuardBuilder.FUNCTION_MATCH
if callable(value)
else GuardBuilder.TYPE_MATCH
),
) )
return NumpyVariable(value, source=self.source)
# NB: These can't be put in type_dispatch, they have to run later # NB: These can't be put in type_dispatch, they have to run later
elif CollectiveFunctionRewriteVariable.can_rewrite(value): elif CollectiveFunctionRewriteVariable.can_rewrite(value):
self.install_guards(GuardBuilder.FUNCTION_MATCH)
return CollectiveFunctionRewriteVariable.create( return CollectiveFunctionRewriteVariable.create(
self.tx, self.tx,
value, value,
source=self.source, source=self.source,
guards=make_guards(GuardBuilder.FUNCTION_MATCH),
) )
elif istype(value, torch.autograd.function.FunctionMeta): elif istype(value, torch.autograd.function.FunctionMeta):
self.install_guards(GuardBuilder.FUNCTION_MATCH)
return AutogradFunctionVariable( return AutogradFunctionVariable(
value, value,
source=self.source, source=self.source,
guards=make_guards(GuardBuilder.FUNCTION_MATCH),
) )
elif isinstance(value, torch.autograd.function.FunctionCtx): elif isinstance(value, torch.autograd.function.FunctionCtx):
saved_tensors_source = AttrSource(self.source, "saved_tensors") saved_tensors_source = AttrSource(self.source, "saved_tensors")
install_guard(
self.source.make_guard(GuardBuilder.TYPE_MATCH),
saved_tensors_source.make_guard(GuardBuilder.LIST_LENGTH),
)
saved_tensors = [ saved_tensors = [
VariableBuilder(self.tx, GetItemSource(saved_tensors_source, n))(v) VariableBuilder(self.tx, GetItemSource(saved_tensors_source, n))(v)
for n, v in enumerate(value.saved_tensors) for n, v in enumerate(value.saved_tensors)
@ -565,8 +551,6 @@ class VariableBuilder:
AutogradFunctionContextVariable( AutogradFunctionContextVariable(
value, value,
source=self.source, source=self.source,
guards=make_guards(GuardBuilder.TYPE_MATCH)
| {saved_tensors_source.make_guard(GuardBuilder.LIST_LENGTH)},
saved_tensors=SavedTensorBox(saved_tensors), saved_tensors=SavedTensorBox(saved_tensors),
), ),
) )
@ -579,53 +563,43 @@ class VariableBuilder:
and value == getattr(value.__self__, "apply", None) and value == getattr(value.__self__, "apply", None)
): ):
# handle aliased autograd function `apply` calls # handle aliased autograd function `apply` calls
self.install_guards(GuardBuilder.FUNCTION_MATCH)
return GetAttrVariable( return GetAttrVariable(
AutogradFunctionVariable( AutogradFunctionVariable(value.__self__, source=self.source),
value.__self__,
source=self.source,
guards=make_guards(GuardBuilder.FUNCTION_MATCH),
),
"apply", "apply",
) )
elif np and isinstance(value, np.number): elif np and isinstance(value, np.number):
return self.wrap_unspecialized_primitive(value) return self.wrap_unspecialized_primitive(value)
elif DataClassVariable.is_matching_object(value): elif DataClassVariable.is_matching_object(value):
return DataClassVariable.wrap(self, value).add_guards( self.install_guards(GuardBuilder.TYPE_MATCH)
make_guards(GuardBuilder.TYPE_MATCH) return DataClassVariable.wrap(self, value)
)
elif HFPretrainedConfigVariable.is_matching_object(value): elif HFPretrainedConfigVariable.is_matching_object(value):
return HFPretrainedConfigVariable( self.install_guards(GuardBuilder.TYPE_MATCH)
value, guards=make_guards(GuardBuilder.TYPE_MATCH) return HFPretrainedConfigVariable(value)
)
elif isinstance(value, HigherOrderOperator): elif isinstance(value, HigherOrderOperator):
return TorchHigherOrderOperatorVariable.make( self.install_guards(GuardBuilder.TYPE_MATCH, GuardBuilder.NAME_MATCH)
value, return TorchHigherOrderOperatorVariable.make(value, source=self.source)
source=self.source,
guards=self.make_guards(
GuardBuilder.TYPE_MATCH, GuardBuilder.NAME_MATCH
),
)
elif type(value).__name__ == "builtin_function_or_method" and isinstance( elif type(value).__name__ == "builtin_function_or_method" and isinstance(
value.__self__, torch_special_class_types value.__self__, torch_special_class_types
): ):
self.install_guards(GuardBuilder.FUNCTION_MATCH)
return TorchVariable( return TorchVariable(
value, value,
guards=make_guards(GuardBuilder.FUNCTION_MATCH),
) )
elif isinstance(value, _StreamBase): elif isinstance(value, _StreamBase):
self.install_guards(GuardBuilder.ID_MATCH)
return StreamVariable( return StreamVariable(
None, None,
value, value,
value.device.type, value.device.type,
source=self.source, source=self.source,
guards=make_guards(GuardBuilder.ID_MATCH),
) )
elif isinstance(value, _EventBase): elif isinstance(value, _EventBase):
self.install_guards(GuardBuilder.ID_MATCH)
return EventVariable( return EventVariable(
None, None,
value, value,
source=self.source, source=self.source,
guards=make_guards(GuardBuilder.ID_MATCH),
) )
elif ( elif (
isinstance(value, torch._C._TensorMeta) isinstance(value, torch._C._TensorMeta)
@ -636,55 +610,36 @@ class VariableBuilder:
istype(value, contextlib.nullcontext) istype(value, contextlib.nullcontext)
and inspect.getattr_static(value, "enter_result", None) is None and inspect.getattr_static(value, "enter_result", None) is None
): ):
return NullContextVariable( # TODO(jansel): I think this can be TYPE_MATCH
source=self.source, self.install_guards(GuardBuilder.FUNCTION_MATCH)
guards=make_guards(GuardBuilder.FUNCTION_MATCH), return NullContextVariable(source=self.source)
)
elif KeyedJaggedTensorVariable.is_matching_object(value): elif KeyedJaggedTensorVariable.is_matching_object(value):
result = KeyedJaggedTensorVariable( self.install_guards(GuardBuilder.TYPE_MATCH)
value, result = KeyedJaggedTensorVariable(value, source=self.source)
source=self.source,
guards=self.make_guards(GuardBuilder.TYPE_MATCH),
)
# TODO: this doing it manually is bad # TODO: this doing it manually is bad
return self.tx.output.side_effects.track_object_existing( return self.tx.output.side_effects.track_object_existing(
self.source, value, result self.source, value, result
) )
elif isinstance(value, torch.optim.Optimizer): elif isinstance(value, torch.optim.Optimizer):
return OptimizerVariable( self.install_guards(GuardBuilder.TYPE_MATCH)
value, return OptimizerVariable(value, source=self.source)
source=self.source,
guards=self.make_guards(GuardBuilder.TYPE_MATCH),
)
elif ProcessGroupVariable.is_process_group(value): elif ProcessGroupVariable.is_process_group(value):
return ProcessGroupVariable( self.install_guards(GuardBuilder.ID_MATCH)
value, return ProcessGroupVariable(value, source=self.source)
source=self.source,
guards=self.make_guards(GuardBuilder.ID_MATCH),
)
elif DeviceMeshVariable.is_device_mesh(value): elif DeviceMeshVariable.is_device_mesh(value):
# TODO: see if we need to add custom guard instead # TODO: see if we need to add custom guard instead of a simple ID_MATCH
# of a simple ID_MATCH self.install_guards(GuardBuilder.ID_MATCH)
return DeviceMeshVariable( return DeviceMeshVariable(value, source=self.source)
value,
source=self.source,
guards=self.make_guards(GuardBuilder.ID_MATCH),
)
elif PlacementClassVariable.is_placement_type(value): elif PlacementClassVariable.is_placement_type(value):
# TODO: see if we need to add custom guard instead # TODO: see if we need to add custom guard instead of a simple ID_MATCH
# of a simple ID_MATCH self.install_guards(GuardBuilder.ID_MATCH)
return PlacementClassVariable( return PlacementClassVariable(value, source=self.source)
value,
source=self.source,
guards=make_guards(GuardBuilder.ID_MATCH),
)
elif PlacementVariable.is_placement(value): elif PlacementVariable.is_placement(value):
# TODO: see if we need to add custom guard instead # TODO: see if we need to add custom guard instead of a simple ID_MATCH
# of a simple ID_MATCH self.install_guards(GuardBuilder.ID_MATCH)
return PlacementVariable( return PlacementVariable(
value, value,
source=self.source, source=self.source,
guards=make_guards(GuardBuilder.ID_MATCH),
) )
elif isinstance(value, torch.SymBool): elif isinstance(value, torch.SymBool):
# Note: the idea here is to re-use the infra we've built for SymInt by simulating the # Note: the idea here is to re-use the infra we've built for SymInt by simulating the
@ -727,12 +682,12 @@ class VariableBuilder:
new_symint == 1, new_symint == 1,
) )
elif isinstance(value, (JITFunction, Autotuner)): elif isinstance(value, (JITFunction, Autotuner)):
self.install_guards(GuardBuilder.ID_MATCH)
return TritonKernelVariable( return TritonKernelVariable(
value, value,
None, # No kernel idx provided None, # No kernel idx provided
None, # No grid provided None, # No grid provided
source=self.source, source=self.source,
guards=make_guards(GuardBuilder.ID_MATCH),
) )
elif trace_rules.lookup(value) is not None: elif trace_rules.lookup(value) is not None:
return trace_rules.lookup(value).create_with_source( return trace_rules.lookup(value).create_with_source(
@ -741,10 +696,10 @@ class VariableBuilder:
elif is_allowed(value): elif is_allowed(value):
if is_user_defined_allowed(value): if is_user_defined_allowed(value):
self.tx.output.has_user_defined_allowed_in_graph = True self.tx.output.has_user_defined_allowed_in_graph = True
self.install_guards(GuardBuilder.FUNCTION_MATCH)
return TorchVariable( return TorchVariable(
value, value,
source=self.source, source=self.source,
guards=make_guards(GuardBuilder.FUNCTION_MATCH),
) )
elif ( elif (
istype(value, (type, types.FunctionType)) istype(value, (type, types.FunctionType))
@ -752,17 +707,17 @@ class VariableBuilder:
and not inspect.getattr_static(value, "_torchdynamo_inline", False) and not inspect.getattr_static(value, "_torchdynamo_inline", False)
and not inspect.getattr_static(value, "__script_if_tracing_wrapper", False) and not inspect.getattr_static(value, "__script_if_tracing_wrapper", False)
): ):
self.install_guards(GuardBuilder.FUNCTION_MATCH)
return SkipFilesVariable( return SkipFilesVariable(
value, value,
skipfiles.check_verbose(value, allow_torch=True).reason, skipfiles.check_verbose(value, allow_torch=True).reason,
source=self.source, source=self.source,
guards=make_guards(GuardBuilder.FUNCTION_MATCH),
) )
elif istype(value, (types.FunctionType, torch.jit.ScriptFunction)): elif istype(value, (types.FunctionType, torch.jit.ScriptFunction)):
self.install_guards(GuardBuilder.FUNCTION_MATCH)
return UserFunctionVariable( return UserFunctionVariable(
value, value,
source=self.source, source=self.source,
guards=make_guards(GuardBuilder.FUNCTION_MATCH),
) )
elif isinstance(value, types.MethodType) and isinstance( elif isinstance(value, types.MethodType) and isinstance(
value.__self__, torch.nn.Module value.__self__, torch.nn.Module
@ -784,40 +739,33 @@ class VariableBuilder:
assert self_obj and isinstance( assert self_obj and isinstance(
self_obj, VariableTracker self_obj, VariableTracker
), "Failed to produce a valid self obj" ), "Failed to produce a valid self obj"
self.install_guards(GuardBuilder.FUNCTION_MATCH)
return UserMethodVariable( return UserMethodVariable(
value.__func__, value.__func__,
self_obj, self_obj,
source=self.source, source=self.source,
guards=make_guards(GuardBuilder.FUNCTION_MATCH),
) )
elif istype(value, (types.ModuleType, replay_record.DummyModule)): elif istype(value, (types.ModuleType, replay_record.DummyModule)):
self.install_guards(GuardBuilder.FUNCTION_MATCH)
return PythonModuleVariable( return PythonModuleVariable(
value, value,
source=self.source, source=self.source,
guards=make_guards(GuardBuilder.PYMODULE_MATCH),
) )
elif isinstance(value, types.GetSetDescriptorType): elif isinstance(value, types.GetSetDescriptorType):
return GetSetDescriptorVariable( self.install_guards(GuardBuilder.FUNCTION_MATCH)
value, guards=self.make_guards(GuardBuilder.FUNCTION_MATCH) return GetSetDescriptorVariable(value)
)
elif isinstance(value, types.MethodWrapperType): elif isinstance(value, types.MethodWrapperType):
return MethodWrapperVariable( self.install_guards(GuardBuilder.FUNCTION_MATCH)
value, return MethodWrapperVariable(value, source=self.source)
source=self.source,
guards=self.make_guards(GuardBuilder.FUNCTION_MATCH),
)
elif issubclass(type(value), type): elif issubclass(type(value), type):
self.install_guards(GuardBuilder.FUNCTION_MATCH)
return UserDefinedClassVariable( return UserDefinedClassVariable(
value, value,
source=self.source, source=self.source,
guards=make_guards(GuardBuilder.FUNCTION_MATCH),
) )
else: else:
result = UserDefinedObjectVariable( self.install_guards(GuardBuilder.TYPE_MATCH)
value, result = UserDefinedObjectVariable(value, source=self.source)
source=self.source,
guards=self.make_guards(GuardBuilder.TYPE_MATCH),
)
if not SideEffects.cls_supports_mutation_side_effects(type(value)): if not SideEffects.cls_supports_mutation_side_effects(type(value)):
# don't allow STORE_ATTR mutation with custom __setattr__ # don't allow STORE_ATTR mutation with custom __setattr__
return result return result
@ -857,36 +805,32 @@ class VariableBuilder:
def wrap_listlike(self, value: Union[tuple, list, odict_values, NamedTuple]): def wrap_listlike(self, value: Union[tuple, list, odict_values, NamedTuple]):
# One can index a tensor with a list/tuple. Therefore, we need to # One can index a tensor with a list/tuple. Therefore, we need to
# have a stricter match. # have a stricter match.
guards = self.make_guards(GuardBuilder.LIST_LENGTH) self.install_guards(GuardBuilder.LIST_LENGTH)
for item in value: for item in value:
if item is value: if item is value:
unimplemented("list elements are pointing to the list itself") unimplemented("list elements are pointing to the list itself")
output = [ output = [
VariableBuilder(self.tx, GetItemSource(self.get_source(), i))( VariableBuilder(self.tx, GetItemSource(self.get_source(), i))(item)
item
).add_guards(guards)
for i, item in enumerate(value) for i, item in enumerate(value)
] ]
result = BaseListVariable.cls_for_instance(value)( result = BaseListVariable.cls_for_instance(value)(
output, mutable_local=MutableLocal(), guards=guards output, mutable_local=MutableLocal()
) )
if istype(value, list): if istype(value, list):
return self.tx.output.side_effects.track_list(self.source, value, result) return self.tx.output.side_effects.track_list(self.source, value, result)
return result return result
def wrap_tuple_iterator(self, value: tuple_iterator): def wrap_tuple_iterator(self, value: tuple_iterator):
guards = self.make_guards(GuardBuilder.TUPLE_ITERATOR_LEN) self.install_guards(GuardBuilder.TUPLE_ITERATOR_LEN)
output = [ output = [
VariableBuilder(self.tx, TupleIteratorGetItemSource(self.get_source(), i))( VariableBuilder(self.tx, TupleIteratorGetItemSource(self.get_source(), i))(
tuple_iterator_getitem(value, i) tuple_iterator_getitem(value, i)
).add_guards(guards) )
for i in range(tuple_iterator_len(value)) for i in range(tuple_iterator_len(value))
] ]
return TupleIteratorVariable( return TupleIteratorVariable(output, mutable_local=MutableLocal())
output, mutable_local=MutableLocal(), guards=guards
)
def wrap_slice_range(self, value: Union[slice, range]): def wrap_slice_range(self, value: Union[slice, range]):
items = [ items = [
@ -896,21 +840,20 @@ class VariableBuilder:
for k in ("start", "stop", "step") for k in ("start", "stop", "step")
] ]
if isinstance(value, slice): if isinstance(value, slice):
return SliceVariable( self.install_guards(GuardBuilder.TYPE_MATCH)
items, guards=self.make_guards(GuardBuilder.TYPE_MATCH) return SliceVariable(items)
)
else: else:
return RangeVariable( # TODO(jansel): I think this can be TYPE_MATCH
items, guards=self.make_guards(GuardBuilder.EQUALS_MATCH) self.install_guards(GuardBuilder.EQUALS_MATCH)
) return RangeVariable(items)
def wrap_module(self, value: torch.nn.Module): def wrap_module(self, value: torch.nn.Module):
from ..eval_frame import OptimizedModule from ..eval_frame import OptimizedModule
if istype(value, OptimizedModule): if istype(value, OptimizedModule):
guards = self.make_guards(GuardBuilder.TYPE_MATCH) self.install_guards(GuardBuilder.TYPE_MATCH)
self.source = AttrSource(self.source, "_orig_mod") self.source = AttrSource(self.source, "_orig_mod")
return self.wrap_module(value._orig_mod).add_guards(guards) return self.wrap_module(value._orig_mod)
if ( if (
isinstance(value, (torch.nn.RNN, torch.nn.GRU, torch.nn.LSTM)) isinstance(value, (torch.nn.RNN, torch.nn.GRU, torch.nn.LSTM))
@ -919,9 +862,8 @@ class VariableBuilder:
unimplemented("TorchDynamo purposely graph breaks on RNN, GRU, LSTMs") unimplemented("TorchDynamo purposely graph breaks on RNN, GRU, LSTMs")
if mutation_guard.is_dynamic_nn_module(value): if mutation_guard.is_dynamic_nn_module(value):
# created dynamically, don't specialize on it # created dynamically, don't specialize on it
result = UnspecializedNNModuleVariable( self.install_guards(GuardBuilder.TYPE_MATCH)
value, guards=self.make_guards(GuardBuilder.TYPE_MATCH) result = UnspecializedNNModuleVariable(value)
)
if not SideEffects.cls_supports_mutation_side_effects(type(value)): if not SideEffects.cls_supports_mutation_side_effects(type(value)):
# don't allow STORE_ATTR mutation with custom __setattr__ # don't allow STORE_ATTR mutation with custom __setattr__
return result return result
@ -931,9 +873,8 @@ class VariableBuilder:
elif issubclass( elif issubclass(
value.__class__, torch.nn.parallel.distributed.DistributedDataParallel value.__class__, torch.nn.parallel.distributed.DistributedDataParallel
): ):
return UnspecializedNNModuleVariable( self.install_guards(GuardBuilder.TYPE_MATCH)
value, guards=self.make_guards(GuardBuilder.TYPE_MATCH) return UnspecializedNNModuleVariable(value)
)
elif getattr(value, "_is_fsdp_managed_module", False): elif getattr(value, "_is_fsdp_managed_module", False):
# See note [Dynamo treats FSDP wrapped modules as UnspecializedNNModule] # See note [Dynamo treats FSDP wrapped modules as UnspecializedNNModule]
# in fully_sharded_data_parallel.py for more information # in fully_sharded_data_parallel.py for more information
@ -960,11 +901,8 @@ class VariableBuilder:
# #
# ID_MATCH is required to disambiguate cases as simple as a unit test that constructs 2 models and wraps # ID_MATCH is required to disambiguate cases as simple as a unit test that constructs 2 models and wraps
# them differently with different FSDP configs. (test_dynamo_distributed.py -k test_fsdp_aot_eager) # them differently with different FSDP configs. (test_dynamo_distributed.py -k test_fsdp_aot_eager)
return FSDPManagedNNModuleVariable( self.install_guards(GuardBuilder.TYPE_MATCH, GuardBuilder.ID_MATCH)
value, return FSDPManagedNNModuleVariable(value, source=self.get_source())
guards=self.make_guards(GuardBuilder.TYPE_MATCH, GuardBuilder.ID_MATCH),
source=self.get_source(),
)
else: else:
return self.tx.output.register_attr_or_module( return self.tx.output.register_attr_or_module(
value, value,
@ -976,12 +914,12 @@ class VariableBuilder:
def wrap_literal(self, value): def wrap_literal(self, value):
unspec = not config.specialize_int unspec = not config.specialize_int
if unspec and type(value) is torch.Size: if unspec and type(value) is torch.Size:
self.install_guards(GuardBuilder.LIST_LENGTH)
return SizeVariable( return SizeVariable(
[ [
VariableBuilder(self.tx, GetItemSource(self.get_source(), i))(v) VariableBuilder(self.tx, GetItemSource(self.get_source(), i))(v)
for i, v in enumerate(value) for i, v in enumerate(value)
], ]
guards=self.make_guards(GuardBuilder.LIST_LENGTH),
) )
elif unspec and type(value) is int: elif unspec and type(value) is int:
# unspecializing int by default, but still # unspecializing int by default, but still
@ -995,17 +933,13 @@ class VariableBuilder:
# NN modules on the fly) # NN modules on the fly)
or self.source.guard_source().is_nn_module() or self.source.guard_source().is_nn_module()
): ):
return ConstantVariable.create( self.install_guards(GuardBuilder.CONSTANT_MATCH)
value=value, return ConstantVariable.create(value=value)
guards=self.make_guards(GuardBuilder.CONSTANT_MATCH),
)
else: else:
return self.wrap_unspecialized_primitive(value) return self.wrap_unspecialized_primitive(value)
else: else:
return ConstantVariable.create( self.install_guards(GuardBuilder.CONSTANT_MATCH)
value=value, return ConstantVariable.create(value=value)
guards=self.make_guards(GuardBuilder.CONSTANT_MATCH),
)
def assert_not_wrapped_by_this_graph(self, value: torch.Tensor): def assert_not_wrapped_by_this_graph(self, value: torch.Tensor):
if is_fake(value) and maybe_get_fake_mode(value) is self.tx.fake_mode: if is_fake(value) and maybe_get_fake_mode(value) is self.tx.fake_mode:
@ -1027,11 +961,7 @@ class VariableBuilder:
) and not source.guard_source().is_fsdp_module(): ) and not source.guard_source().is_fsdp_module():
self.assert_not_wrapped_by_this_graph(value) self.assert_not_wrapped_by_this_graph(value)
return self.tx.output.register_attr_or_module( return self.tx.output.register_attr_or_module(
value, value, self.name, source=source
self.name,
source=source,
# Guards are done inside register_attr_or_module
# guards=self.make_guards(GuardBuilder.TENSOR_MATCH),
) )
if is_constant_source(source): if is_constant_source(source):
@ -1099,20 +1029,7 @@ class VariableBuilder:
options["torch_function_fn"] = build_torch_function_fn( options["torch_function_fn"] = build_torch_function_fn(
self.tx, value, self.source self.tx, value, self.source
) )
options["guards"] = self.make_guards(GuardBuilder.TYPE_MATCH) self.install_guards(GuardBuilder.TYPE_MATCH)
else:
options["guards"] = set()
options["guards"].update(
self.make_guards(
functools.partial(
GuardBuilder.TENSOR_MATCH,
value=value
if isinstance(source, NumpyTensorSource)
else TensorWeakRef(value),
)
)
)
if ( if (
isinstance(value, torch.Tensor) isinstance(value, torch.Tensor)
@ -1130,6 +1047,16 @@ class VariableBuilder:
source=source, source=source,
**options, **options,
) )
self.install_guards(
functools.partial(
GuardBuilder.TENSOR_MATCH,
value=value
if isinstance(source, NumpyTensorSource)
else TensorWeakRef(value),
)
)
self.tx.output.input_source_to_var[source] = tensor_variable self.tx.output.input_source_to_var[source] = tensor_variable
assert "tensor_dict" not in tensor_proxy.node.meta assert "tensor_dict" not in tensor_proxy.node.meta
tensor_proxy.node.meta["tensor_dict"] = value.__dict__.copy() tensor_proxy.node.meta["tensor_dict"] = value.__dict__.copy()
@ -1172,11 +1099,11 @@ class VariableBuilder:
# a tensor. It's a little annoying to make a VT to throw out, but there's so many side effects here # a tensor. It's a little annoying to make a VT to throw out, but there's so many side effects here
# that there's not another great way to do this atm. # that there's not another great way to do this atm.
# This creates the right graphargs, as well as registration for guards in tensor names and shape env. # This creates the right graphargs, as well as registration for guards in tensor names and shape env.
tensor_vt = VariableBuilder(self.tx, source)(tensor_value) VariableBuilder(self.tx, source)(tensor_value).recursive_realize()
proxy = self.tx.output.root_tracer.create_graph_input( proxy = self.tx.output.root_tracer.create_graph_input(
re.sub(r"[^a-zA-Z0-9]+", "_", self.name), type(tensor_value), source=source re.sub(r"[^a-zA-Z0-9]+", "_", self.name), type(tensor_value), source=source
) )
options = {"source": source, "guards": tensor_vt.guards} options = {"source": source}
numpy_ndarray_variable = wrap_fx_proxy_cls( numpy_ndarray_variable = wrap_fx_proxy_cls(
target_cls=NumpyNdarrayVariable, target_cls=NumpyNdarrayVariable,
tx=self.tx, tx=self.tx,
@ -1229,10 +1156,8 @@ class VariableBuilder:
# If specialize_int is False, also return # If specialize_int is False, also return
# a constant (but this should have been handled # a constant (but this should have been handled
# in the caller, TBH) # in the caller, TBH)
return ConstantVariable.create( self.install_guards(GuardBuilder.CONSTANT_MATCH)
value=value, return ConstantVariable.create(value=value)
guards=self.make_guards(GuardBuilder.CONSTANT_MATCH),
)
name = self.source.name() name = self.source.name()
if name not in self.tx.output.frame_state: if name not in self.tx.output.frame_state:
@ -1264,10 +1189,8 @@ class VariableBuilder:
else: # assume_static_by_default else: # assume_static_by_default
# TODO: dynamic_dim = DimDynamic.STATIC should work but # TODO: dynamic_dim = DimDynamic.STATIC should work but
# for some reason it doesn't # for some reason it doesn't
return ConstantVariable.create( self.install_guards(GuardBuilder.CONSTANT_MATCH)
value=value, return ConstantVariable.create(value=value)
guards=self.make_guards(GuardBuilder.CONSTANT_MATCH),
)
wrapped_value = shape_env.create_unspecified_symint_and_symbol( wrapped_value = shape_env.create_unspecified_symint_and_symbol(
value, value,
@ -1281,11 +1204,8 @@ class VariableBuilder:
else: else:
wrapped_value = torch.tensor(value) wrapped_value = torch.tensor(value)
if not isinstance(self.get_source(), RandomValueSource): if not isinstance(self.get_source(), RandomValueSource):
guards = {self.get_source().make_guard(GuardBuilder.TYPE_MATCH)} install_guard(self.get_source().make_guard(GuardBuilder.TYPE_MATCH))
options = {"guards": guards} options = {"source": self.get_source()}
else:
options = {}
options.update({"source": self.get_source()})
if isinstance(wrapped_value, torch.Tensor): if isinstance(wrapped_value, torch.Tensor):
options.update({"raw_value": value}) options.update({"raw_value": value})
@ -1352,16 +1272,19 @@ def _dataclasses_fields_lambda(obj):
def wrap_fx_proxy(tx, proxy, example_value=None, subclass_type=None, **options): def wrap_fx_proxy(tx, proxy, example_value=None, subclass_type=None, **options):
return wrap_fx_proxy_cls( kwargs = {
target_cls=TensorVariable "tx": tx,
if not subclass_type "proxy": proxy,
else TensorWithTFOverrideVariable, "example_value": example_value,
tx=tx, "subclass_type": subclass_type,
proxy=proxy,
example_value=example_value,
subclass_type=subclass_type,
**options, **options,
) }
if subclass_type is None:
return wrap_fx_proxy_cls(target_cls=TensorVariable, **kwargs)
else:
result = wrap_fx_proxy_cls(target_cls=TensorWithTFOverrideVariable, **kwargs)
result.install_global(tx)
return result
# Note: Unfortunate split due to some gross classes existing that subclass TensorVariable # Note: Unfortunate split due to some gross classes existing that subclass TensorVariable

View File

@ -19,7 +19,7 @@ from ..exc import (
UserError, UserError,
UserErrorType, UserErrorType,
) )
from ..guards import GuardBuilder from ..guards import GuardBuilder, install_guard
from ..replay_record import DummyModule from ..replay_record import DummyModule
from ..source import AttrSource, GetItemSource, is_constant_source, TypeSource from ..source import AttrSource, GetItemSource, is_constant_source, TypeSource
from ..utils import ( from ..utils import (
@ -339,7 +339,6 @@ class BuiltinVariable(VariableTracker):
a, a,
ListVariable( ListVariable(
list(a.items) + list(b.unpack_var_sequence(tx)), list(a.items) + list(b.unpack_var_sequence(tx)),
regen_guards=False,
**options, **options,
), ),
) )
@ -826,23 +825,22 @@ class BuiltinVariable(VariableTracker):
mutable_local=MutableLocal(), mutable_local=MutableLocal(),
) )
elif obj.has_unpack_var_sequence(tx): elif obj.has_unpack_var_sequence(tx):
guards = set()
if obj.source and not is_constant_source(obj.source): if obj.source and not is_constant_source(obj.source):
if isinstance(obj, TupleIteratorVariable): if isinstance(obj, TupleIteratorVariable):
guards.add(obj.source.make_guard(GuardBuilder.TUPLE_ITERATOR_LEN)) install_guard(
obj.source.make_guard(GuardBuilder.TUPLE_ITERATOR_LEN)
)
else: else:
guards.add(obj.source.make_guard(GuardBuilder.LIST_LENGTH)) install_guard(obj.source.make_guard(GuardBuilder.LIST_LENGTH))
if cls is SetVariable: if cls is SetVariable:
return cls( return cls(
list(obj.unpack_var_sequence(tx)), list(obj.unpack_var_sequence(tx)),
mutable_local=MutableLocal(), mutable_local=MutableLocal(),
guards=guards,
).add_options(self, obj) ).add_options(self, obj)
return cls( return cls(
list(obj.unpack_var_sequence(tx)), list(obj.unpack_var_sequence(tx)),
mutable_local=MutableLocal(), mutable_local=MutableLocal(),
guards=guards,
).add_options(self, obj) ).add_options(self, obj)
call_iter = _call_iter_tuple_list call_iter = _call_iter_tuple_list
@ -1060,7 +1058,6 @@ class BuiltinVariable(VariableTracker):
from .builder import SourcelessBuilder, VariableBuilder from .builder import SourcelessBuilder, VariableBuilder
options = VariableTracker.propagate(self, obj, name_var) options = VariableTracker.propagate(self, obj, name_var)
guards = options["guards"]
name = name_var.as_python_constant() name = name_var.as_python_constant()
if not name_var.is_python_constant(): if not name_var.is_python_constant():
@ -1075,10 +1072,9 @@ class BuiltinVariable(VariableTracker):
if default is not None: if default is not None:
hasattr_var = self.call_hasattr(tx, obj, name_var) hasattr_var = self.call_hasattr(tx, obj, name_var)
guards.update(hasattr_var.guards)
assert hasattr_var.as_python_constant() in (True, False) assert hasattr_var.as_python_constant() in (True, False)
if not hasattr_var.as_python_constant(): if not hasattr_var.as_python_constant():
return default.add_guards(guards) return default
if obj.source: if obj.source:
source = AttrSource(obj.source, name) source = AttrSource(obj.source, name)
@ -1152,14 +1148,14 @@ class BuiltinVariable(VariableTracker):
elif ConstantVariable.is_literal(member): elif ConstantVariable.is_literal(member):
return ConstantVariable.create(member, **options) return ConstantVariable.create(member, **options)
else: else:
return VariableBuilder(tx, source)(member).add_guards(guards) return VariableBuilder(tx, source)(member)
elif isinstance(obj, (PythonModuleVariable, DummyModule)): elif isinstance(obj, (PythonModuleVariable, DummyModule)):
member = obj.value.__dict__[name] member = obj.value.__dict__[name]
if config.replay_record_enabled: if config.replay_record_enabled:
tx.exec_recorder.record_module_access(obj.value, name, member) tx.exec_recorder.record_module_access(obj.value, name, member)
return VariableBuilder(tx, source)(member).add_guards(guards) return VariableBuilder(tx, source)(member)
elif istype(obj, UserFunctionVariable) and name in ("__name__", "__module__"): elif istype(obj, UserFunctionVariable) and name in ("__name__", "__module__"):
return ConstantVariable.create( return ConstantVariable.create(
getattr(obj.fn, name), **VariableTracker.propagate(obj) getattr(obj.fn, name), **VariableTracker.propagate(obj)

View File

@ -6,7 +6,7 @@ from torch._dynamo.source import GetItemSource
from .. import variables from .. import variables
from ..exc import unimplemented, UserError, UserErrorType from ..exc import unimplemented, UserError, UserErrorType
from ..guards import GuardBuilder from ..guards import GuardBuilder, install_guard
from ..utils import np from ..utils import np
from .base import typestr, VariableTracker from .base import typestr, VariableTracker
@ -41,21 +41,15 @@ class ConstantVariable(VariableTracker):
items = [] items = []
for i, x in enumerate(value): for i, x in enumerate(value):
item_source = GetItemSource(source, i) if source else None item_source = GetItemSource(source, i) if source else None
guards = ( if item_source:
{item_source.make_guard(GuardBuilder.CONSTANT_MATCH)} install_guard(item_source.make_guard(GuardBuilder.CONSTANT_MATCH))
if item_source
else None
)
items.append( items.append(
ConstantVariable.create( ConstantVariable.create(
x, x,
source=item_source, source=item_source,
guards=guards,
) )
) )
return variables.BaseListVariable.cls_for(type(value))( return variables.BaseListVariable.cls_for(type(value))(items, **kwargs)
items, regen_guards=True, **kwargs
)
return ConstantVariable(value, **kwargs) return ConstantVariable(value, **kwargs)

View File

@ -9,7 +9,7 @@ from .. import variables
from ..bytecode_transformation import create_call_function, create_instruction from ..bytecode_transformation import create_call_function, create_instruction
from ..device_interface import get_interface_for_device from ..device_interface import get_interface_for_device
from ..exc import unimplemented, Unsupported from ..exc import unimplemented, Unsupported
from ..guards import GuardBuilder from ..guards import GuardBuilder, install_guard
from ..source import AttrSource, GlobalStateSource from ..source import AttrSource, GlobalStateSource
from .base import VariableTracker from .base import VariableTracker
from .functions import ( from .functions import (
@ -161,7 +161,7 @@ class GenericContextWrappingVariable(ContextWrappingVariable):
class GradModeVariable(ContextWrappingVariable): class GradModeVariable(ContextWrappingVariable):
"""represents torch.{no_grad,enable_grad,set_grad_mode}()""" """represents torch.{no_grad,enable_grad,set_grad_mode}()"""
_guards_singleton = {Guard(GlobalStateSource(), GuardBuilder.GRAD_MODE)} _guards_singleton = Guard(GlobalStateSource(), GuardBuilder.GRAD_MODE)
@staticmethod @staticmethod
def create(tx, target_value, initialized=True, **kwargs): def create(tx, target_value, initialized=True, **kwargs):
@ -179,8 +179,8 @@ class GradModeVariable(ContextWrappingVariable):
super().__init__( super().__init__(
target_values=target_values, initial_values=initial_values, **kwargs target_values=target_values, initial_values=initial_values, **kwargs
) )
self.guards = self.guards | self._guards_singleton
self.initialized = initialized self.initialized = initialized
install_guard(self._guards_singleton)
def enter(self, tx): def enter(self, tx):
if not self.initialized: if not self.initialized:
@ -263,7 +263,7 @@ class InferenceModeVariable(ContextWrappingVariable):
class TorchFunctionDisableVariable(ContextWrappingVariable): class TorchFunctionDisableVariable(ContextWrappingVariable):
"""represents whether torch function overrides are enabled or not""" """represents whether torch function overrides are enabled or not"""
_guards_singleton = {Guard(GlobalStateSource(), GuardBuilder.TORCH_FUNCTION_STATE)} _guards_singleton = Guard(GlobalStateSource(), GuardBuilder.TORCH_FUNCTION_STATE)
@staticmethod @staticmethod
def create(tx, **kwargs): def create(tx, **kwargs):
@ -281,7 +281,7 @@ class TorchFunctionDisableVariable(ContextWrappingVariable):
super().__init__( super().__init__(
target_values=target_values, initial_values=initial_values, **kwargs target_values=target_values, initial_values=initial_values, **kwargs
) )
self.guards = self.guards | self._guards_singleton install_guard(self._guards_singleton)
def enter(self, tx): def enter(self, tx):
return variables.ConstantVariable.create( return variables.ConstantVariable.create(
@ -296,9 +296,9 @@ class TorchFunctionDisableVariable(ContextWrappingVariable):
class DeterministicAlgorithmsVariable(ContextWrappingVariable): class DeterministicAlgorithmsVariable(ContextWrappingVariable):
"""represents torch.{are_deterministic_algorithms_enabled,use_deterministic_algorithms}()""" """represents torch.{are_deterministic_algorithms_enabled,use_deterministic_algorithms}()"""
_guards_singleton = { _guards_singleton = Guard(
Guard(GlobalStateSource(), GuardBuilder.DETERMINISTIC_ALGORITHMS) GlobalStateSource(), GuardBuilder.DETERMINISTIC_ALGORITHMS
} )
@staticmethod @staticmethod
def create(tx, target_value, **kwargs): def create(tx, target_value, **kwargs):
@ -315,7 +315,7 @@ class DeterministicAlgorithmsVariable(ContextWrappingVariable):
super().__init__( super().__init__(
target_values=target_values, initial_values=initial_values, **kwargs target_values=target_values, initial_values=initial_values, **kwargs
) )
self.guards = self.guards | self._guards_singleton install_guard(self._guards_singleton)
def enter(self, tx): def enter(self, tx):
return variables.ConstantVariable.create( return variables.ConstantVariable.create(

View File

@ -13,7 +13,7 @@ from ..bytecode_transformation import create_call_function, create_instruction
from ..eval_frame import skip_code from ..eval_frame import skip_code
from ..exc import unimplemented from ..exc import unimplemented
from ..guards import GuardBuilder, make_dupe_guard from ..guards import GuardBuilder, install_guard, make_dupe_guard
from ..source import AttrSource, GetItemSource, GlobalWeakRefSource from ..source import AttrSource, GetItemSource, GlobalWeakRefSource
from ..utils import global_key_name, istensor, iter_contains from ..utils import global_key_name, istensor, iter_contains
from .base import MutableLocal, VariableTracker from .base import MutableLocal, VariableTracker
@ -24,10 +24,8 @@ from .tensor import TensorVariable
class ConstDictVariable(VariableTracker): class ConstDictVariable(VariableTracker):
def __init__(self, items, user_cls, **kwargs): def __init__(self, items, user_cls, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
# All the keys are constants # All the keys are constants
assert not any(isinstance(x, VariableTracker) for x in items) assert not any(isinstance(x, VariableTracker) for x in items)
self.guards.update(VariableTracker.propagate(items.values())["guards"])
self.items = items self.items = items
self.user_cls = user_cls self.user_cls = user_cls
@ -298,7 +296,6 @@ class SetVariable(VariableTracker):
def __init__( def __init__(
self, self,
items: List[VariableTracker], items: List[VariableTracker],
regen_guards=True,
**kwargs, **kwargs,
): ):
super().__init__(**kwargs) super().__init__(**kwargs)
@ -309,10 +306,6 @@ class SetVariable(VariableTracker):
self.items = [] self.items = []
self._add(items) self._add(items)
# Sometimes, we know that we have passed in the guards from the items in the set
if regen_guards:
self.guards.update(VariableTracker.propagate(items)["guards"])
def as_proxy(self): def as_proxy(self):
return [x.as_proxy() for x in self.items] return [x.as_proxy() for x in self.items]
@ -378,9 +371,7 @@ class SetVariable(VariableTracker):
e.vt.source, set_element.vt.source e.vt.source, set_element.vt.source
) )
if alias_guard: if alias_guard:
e.vt = e.vt.add_guards( install_guard(e.vt.source.make_guard(alias_guard))
{e.vt.source.make_guard(alias_guard)}
)
return self.items return self.items
@ -401,7 +392,6 @@ class SetVariable(VariableTracker):
result = SetVariable( result = SetVariable(
self._add(item), self._add(item),
mutable_local=self.mutable_local, mutable_local=self.mutable_local,
regen_guards=False,
**options, **options,
) )
tx.replace_all(self, result) tx.replace_all(self, result)
@ -413,7 +403,7 @@ class SetVariable(VariableTracker):
result = items.pop() result = items.pop()
tx.replace_all( tx.replace_all(
self, self,
SetVariable(items, regen_guards=False, **options), SetVariable(items, **options),
) )
return result return result
elif name == "__len__": elif name == "__len__":
@ -797,46 +787,40 @@ class PythonSysModulesVariable(VariableTracker):
def _contains_helper(self, tx, key: VariableTracker): def _contains_helper(self, tx, key: VariableTracker):
k = ConstDictVariable.get_key(key) k = ConstDictVariable.get_key(key)
has_key = k in sys.modules has_key = k in sys.modules
guard = self.make_guard( install_guard(
functools.partial(GuardBuilder.DICT_CONTAINS, key=k, invert=not has_key) self.make_guard(
functools.partial(GuardBuilder.DICT_CONTAINS, key=k, invert=not has_key)
)
) )
guards = {*self.guards, guard} return k, has_key
return k, has_key, guards
def call_contains(self, tx, key: VariableTracker): def call_contains(self, tx, key: VariableTracker):
k, has_key, guards = self._contains_helper(tx, key) k, has_key = self._contains_helper(tx, key)
return ConstantVariable.create( return ConstantVariable.create(value=has_key)
value=has_key,
guards=guards,
)
def call_get( def call_get(
self, tx, key: VariableTracker, default: Optional[VariableTracker] = None self, tx, key: VariableTracker, default: Optional[VariableTracker] = None
): ):
from .builder import VariableBuilder from .builder import VariableBuilder
k, has_key, guards = self._contains_helper(tx, key) k, has_key = self._contains_helper(tx, key)
if has_key: if has_key:
return VariableBuilder( return VariableBuilder(
tx, tx,
GetItemSource(self.source, k), GetItemSource(self.source, k),
)( )(sys.modules[k])
sys.modules[k]
).add_guards(guards)
if default is not None: if default is not None:
return default.add_guards(guards) return default
return ConstantVariable.create(value=None, guards=guards) return ConstantVariable.create(value=None)
def call_getitem(self, tx, key: VariableTracker): def call_getitem(self, tx, key: VariableTracker):
from .builder import VariableBuilder from .builder import VariableBuilder
k, has_key, guards = self._contains_helper(tx, key) k, has_key = self._contains_helper(tx, key)
return VariableBuilder( return VariableBuilder(
tx, tx,
GetItemSource(self.source, k), GetItemSource(self.source, k),
)( )(sys.modules[k])
sys.modules[k]
).add_guards(guards)

View File

@ -614,12 +614,6 @@ class FunctoolsPartialVariable(VariableTracker):
self.keywords = keywords self.keywords = keywords
self.original = original self.original = original
self.guards.update(VariableTracker.propagate(func)["guards"])
for arg in args:
self.guards.update(VariableTracker.propagate(arg)["guards"])
for val in keywords.values():
self.guards.update(VariableTracker.propagate(val)["guards"])
def call_function( def call_function(
self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]" self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
) -> "VariableTracker": ) -> "VariableTracker":

View File

@ -25,7 +25,6 @@ from ..exc import (
UserError, UserError,
UserErrorType, UserErrorType,
) )
from ..guards import GuardBuilder
from ..source import FSDPNNModuleSource, GetItemSource, NNModuleSource from ..source import FSDPNNModuleSource, GetItemSource, NNModuleSource
from ..utils import proxy_args_kwargs from ..utils import proxy_args_kwargs
from .dicts import ConstDictVariable from .dicts import ConstDictVariable
@ -100,9 +99,6 @@ def validate_args_and_maybe_create_graph_inputs(
assert isinstance(a, VariableTracker) assert isinstance(a, VariableTracker)
if isinstance(a, ConstantVariable): if isinstance(a, ConstantVariable):
# Ensures that we recompile when the constant value changes
a.add_guard(GuardBuilder.CONSTANT_MATCH)
if manually_set_subgraph_inputs: if manually_set_subgraph_inputs:
# This arg is not used in the body of the higher order op. # This arg is not used in the body of the higher order op.
# Currently, this new input is added to make the calls # Currently, this new input is added to make the calls
@ -194,6 +190,11 @@ def speculate_subgraph(
) )
try: try:
f, sub_args, sub_kwargs = VariableTracker.apply(
# ensure guards on args get installed in parent subgraph
lambda x: x.realize(),
(f, sub_args, sub_kwargs),
)
with tx.output.subtracer(source_target, tracer) as subtracer: with tx.output.subtracer(source_target, tracer) as subtracer:
args = validate_args_and_maybe_create_graph_inputs( args = validate_args_and_maybe_create_graph_inputs(
sub_args, subtracer, tx, manually_set_subgraph_inputs sub_args, subtracer, tx, manually_set_subgraph_inputs
@ -247,7 +248,6 @@ def speculate_subgraph(
"HigherOrderOperator body's output must consist of tensors only" "HigherOrderOperator body's output must consist of tensors only"
) )
tx.output.guards.update(output.guards)
# The output proxies might not belong to this SubgraphTracer # The output proxies might not belong to this SubgraphTracer
# (if they are free variables that were never lifted) # (if they are free variables that were never lifted)
# so lift them here. # so lift them here.
@ -411,7 +411,6 @@ class CondHigherOrderVariable(TorchHigherOrderOperatorVariable):
f"item but got {str(type(args[0]))} " f"item but got {str(type(args[0]))} "
f"with original python type {str(args[0].python_type())}.", f"with original python type {str(args[0].python_type())}.",
) )
tx.output.guards.update(args[0].guards)
# operands # operands
if not isinstance(args[3], (ListVariable, TupleVariable)): if not isinstance(args[3], (ListVariable, TupleVariable)):
@ -1116,6 +1115,7 @@ class AutogradFunctionMethodHigherOrderVariable(TorchHigherOrderOperatorVariable
else: else:
fn = TorchVariable(self.value) fn = TorchVariable(self.value)
checkpoint = tx.copy_graphstate() checkpoint = tx.copy_graphstate()
# TODO(jansel): BUG!!! we aren't copying on the line below, so the post-pre check below is pointless
pre_guards = tx.output.guards pre_guards = tx.output.guards
graph_checkpoint = tx.output.graph graph_checkpoint = tx.output.graph

View File

@ -23,7 +23,6 @@ class LazyCache:
self.vt.parents_tracker.add(parents_tracker) self.vt.parents_tracker.add(parents_tracker)
del self.value del self.value
del self.source del self.source
tx.output.guards.update(self.vt.guards)
class LazyVariableTracker(VariableTracker): class LazyVariableTracker(VariableTracker):
@ -79,8 +78,6 @@ class LazyVariableTracker(VariableTracker):
return getattr(self.realize(), item) return getattr(self.realize(), item)
# most methods are auto-generated below, these are the ones we want to exclude # most methods are auto-generated below, these are the ones we want to exclude
add_guards = VariableTracker.add_guards
add_guard = VariableTracker.add_guard
add_options = VariableTracker.add_options add_options = VariableTracker.add_options
apply = VariableTracker.apply apply = VariableTracker.apply
copy = VariableTracker.copy copy = VariableTracker.copy

View File

@ -48,16 +48,11 @@ class BaseListVariable(VariableTracker):
def __init__( def __init__(
self, self,
items: List[VariableTracker], items: List[VariableTracker],
regen_guards=True,
**kwargs, **kwargs,
): ):
super().__init__(**kwargs) super().__init__(**kwargs)
assert isinstance(items, list) assert isinstance(items, list)
assert all(isinstance(x, VariableTracker) for x in items) assert all(isinstance(x, VariableTracker) for x in items)
# Sometimes, we know that we have passed in the guards from the items in the list
if regen_guards:
self.guards.update(VariableTracker.propagate(items)["guards"])
self.items: List[VariableTracker] = items self.items: List[VariableTracker] = items
def _as_proxy(self): def _as_proxy(self):
@ -246,7 +241,6 @@ class CommonListMethodsVariable(BaseListVariable):
self, self,
type(self)( type(self)(
self.items + [arg], self.items + [arg],
regen_guards=False,
**options, **options,
), ),
) )
@ -263,7 +257,6 @@ class CommonListMethodsVariable(BaseListVariable):
self, self,
type(self)( type(self)(
list(self.items) + list(arg.unpack_var_sequence(tx)), list(self.items) + list(arg.unpack_var_sequence(tx)),
regen_guards=False,
**options, **options,
), ),
) )
@ -274,7 +267,7 @@ class CommonListMethodsVariable(BaseListVariable):
items.insert(idx.as_python_constant(), value) items.insert(idx.as_python_constant(), value)
return tx.replace_all( return tx.replace_all(
self, self,
type(self)(items, regen_guards=False, **options), type(self)(items, **options),
) )
elif name == "pop" and self.mutable_local: elif name == "pop" and self.mutable_local:
assert not kwargs assert not kwargs
@ -282,14 +275,14 @@ class CommonListMethodsVariable(BaseListVariable):
result = items.pop(*[a.as_python_constant() for a in args]) result = items.pop(*[a.as_python_constant() for a in args])
tx.replace_all( tx.replace_all(
self, self,
type(self)(items, regen_guards=False, **options), type(self)(items, **options),
) )
return result return result
elif name == "clear" and self.mutable_local: elif name == "clear" and self.mutable_local:
assert not kwargs and not args assert not kwargs and not args
return tx.replace_all( return tx.replace_all(
self, self,
type(self)([], regen_guards=False, **options), type(self)([], **options),
) )
elif ( elif (
name == "__setitem__" name == "__setitem__"
@ -304,16 +297,14 @@ class CommonListMethodsVariable(BaseListVariable):
items[key.as_python_constant()] = list(value.items) items[key.as_python_constant()] = list(value.items)
else: else:
items[key.as_python_constant()] = value items[key.as_python_constant()] = value
result = ListVariable(items, regen_guards=False, **options) result = ListVariable(items, **options)
return tx.replace_all(self, result) return tx.replace_all(self, result)
elif name == "copy": elif name == "copy":
# List copy() doesn't have args and kwargs # List copy() doesn't have args and kwargs
assert not kwargs assert not kwargs
assert not args assert not args
items = list(self.items) items = list(self.items)
return type(self)( return type(self)(items, mutable_local=MutableLocal(), **options)
items, regen_guards=False, mutable_local=MutableLocal(), **options
)
else: else:
return super().call_method(tx, name, args, kwargs) return super().call_method(tx, name, args, kwargs)
@ -351,7 +342,7 @@ class ListVariable(CommonListMethodsVariable):
items[key.as_python_constant()] = value.unpack_var_sequence(tx) items[key.as_python_constant()] = value.unpack_var_sequence(tx)
else: else:
items[key.as_python_constant()] = value items[key.as_python_constant()] = value
result = ListVariable(items, regen_guards=False, **options) result = ListVariable(items, **options)
return tx.replace_all(self, result) return tx.replace_all(self, result)
else: else:
return super().call_method(tx, name, args, kwargs) return super().call_method(tx, name, args, kwargs)
@ -396,7 +387,7 @@ class DequeVariable(CommonListMethodsVariable):
) )
items = list(self.items) items = list(self.items)
items[key.as_python_constant()] = value items[key.as_python_constant()] = value
result = DequeVariable(items, regen_guards=False, **options) result = DequeVariable(items, **options)
return tx.replace_all(self, result) return tx.replace_all(self, result)
elif name == "extendleft" and self.mutable_local: elif name == "extendleft" and self.mutable_local:
assert not kwargs assert not kwargs
@ -405,7 +396,6 @@ class DequeVariable(CommonListMethodsVariable):
self, self,
DequeVariable( DequeVariable(
list(arg.unpack_var_sequence(tx)) + list(self.items), list(arg.unpack_var_sequence(tx)) + list(self.items),
regen_guards=False,
**options, **options,
), ),
) )
@ -416,7 +406,7 @@ class DequeVariable(CommonListMethodsVariable):
result = items.popleft() result = items.popleft()
tx.replace_all( tx.replace_all(
self, self,
DequeVariable(list(items), regen_guards=False, **options), DequeVariable(list(items), **options),
) )
return result return result
elif name == "appendleft" and self.mutable_local: elif name == "appendleft" and self.mutable_local:
@ -425,7 +415,6 @@ class DequeVariable(CommonListMethodsVariable):
self, self,
DequeVariable( DequeVariable(
[args[0]] + list(self.items), [args[0]] + list(self.items),
regen_guards=False,
**options, **options,
), ),
) )

View File

@ -13,7 +13,7 @@ import torch._numpy as tnp
from .. import config, polyfill, variables from .. import config, polyfill, variables
from ..bytecode_transformation import create_call_function, create_instruction from ..bytecode_transformation import create_call_function, create_instruction
from ..exc import unimplemented from ..exc import unimplemented
from ..guards import GuardBuilder from ..guards import GuardBuilder, install_guard
from ..source import AttrSource, GetItemSource, ODictGetItemSource, TypeSource from ..source import AttrSource, GetItemSource, ODictGetItemSource, TypeSource
from ..utils import ( from ..utils import (
check_constant_args, check_constant_args,
@ -97,9 +97,8 @@ class SuperVariable(VariableTracker):
return GetAttrVariable(self, name, **options) return GetAttrVariable(self, name, **options)
if source: if source:
options["source"] = source options["source"] = source
return variables.ConstantVariable.create(value, **options).add_guard( install_guard(source.make_guard(GuardBuilder.CONSTANT_MATCH))
source.make_guard(GuardBuilder.CONSTANT_MATCH) return variables.ConstantVariable.create(value, **options)
)
return variables.ConstantVariable.create(value, **options) return variables.ConstantVariable.create(value, **options)
def call_method( def call_method(

View File

@ -10,7 +10,7 @@ import torch.nn
from .. import skipfiles, variables from .. import skipfiles, variables
from ..allowed_functions import is_allowed from ..allowed_functions import is_allowed
from ..exc import unimplemented, UnspecializeRestartAnalysis, Unsupported from ..exc import unimplemented, UnspecializeRestartAnalysis, Unsupported
from ..guards import GuardBuilder from ..guards import GuardBuilder, install_guard
from ..mutation_guard import GenerationTracker from ..mutation_guard import GenerationTracker
from ..source import ( from ..source import (
AttrSource, AttrSource,
@ -127,11 +127,12 @@ class NNModuleVariable(VariableTracker):
options = VariableTracker.propagate(self) options = VariableTracker.propagate(self)
mod = tx.output.get_submodule(self.module_key) mod = tx.output.get_submodule(self.module_key)
result = hasattr(mod, name) result = hasattr(mod, name)
return variables.ConstantVariable.create(result, **options).add_guard( install_guard(
NNModuleSource(AttrSource(self.source, name)).make_guard( NNModuleSource(AttrSource(self.source, name)).make_guard(
GuardBuilder.HASATTR GuardBuilder.HASATTR
) )
) )
return variables.ConstantVariable.create(result, **options)
def is_training(self, tx): def is_training(self, tx):
mod = tx.output.get_submodule(self.module_key) mod = tx.output.get_submodule(self.module_key)
@ -167,7 +168,6 @@ class NNModuleVariable(VariableTracker):
from .builder import VariableBuilder from .builder import VariableBuilder
options = VariableTracker.propagate(self) options = VariableTracker.propagate(self)
guards = options.get("guards", set())
if self.source: if self.source:
source = AttrSource(self.source, name) source = AttrSource(self.source, name)
@ -220,13 +220,12 @@ class NNModuleVariable(VariableTracker):
if istype(subobj, property): if istype(subobj, property):
return variables.UserFunctionVariable( return variables.UserFunctionVariable(
subobj.fget, subobj.fget,
guards=guards,
source=source, source=source,
).call_function(tx, [(self)], {}) ).call_function(tx, [(self)], {})
elif istype(subobj, classmethod): elif istype(subobj, classmethod):
return variables.UserMethodVariable( return variables.UserMethodVariable(
subobj.__func__, subobj.__func__,
variables.UserDefinedObjectVariable(type(base), guards=guards), variables.UserDefinedObjectVariable(type(base)),
**options, **options,
) )
elif istype(subobj, staticmethod): elif istype(subobj, staticmethod):
@ -616,7 +615,7 @@ class NNModuleVariable(VariableTracker):
): ):
# Inline the function # Inline the function
fn = getattr(module, name).__func__ fn = getattr(module, name).__func__
fn_source = AttrSource(self.source, "__func__") fn_source = AttrSource(AttrSource(self.source, name), "__func__")
options["source"] = fn_source options["source"] = fn_source
return tx.inline_user_function_return( return tx.inline_user_function_return(
variables.UserFunctionVariable(fn, **options), variables.UserFunctionVariable(fn, **options),
@ -759,7 +758,7 @@ class UnspecializedNNModuleVariable(UserDefinedObjectVariable):
assert not args or kwargs assert not args or kwargs
if tx.output.side_effects.has_pending_mutation(self): if tx.output.side_effects.has_pending_mutation(self):
unimplemented("Module.parameters() with pending mutation") unimplemented("Module.parameters() with pending mutation")
options["guards"].add( install_guard(
self.source.make_guard(GuardBuilder.NN_MODULE_PARAM_NAMES) self.source.make_guard(GuardBuilder.NN_MODULE_PARAM_NAMES)
) )
items = [] items = []

View File

@ -4,7 +4,7 @@ from typing import Dict, List
import torch import torch
from ..decorators import mark_static_address from ..decorators import mark_static_address
from ..guards import GuardBuilder from ..guards import GuardBuilder, install_guard
from ..source import AttrSource, GetItemSource, GlobalWeakRefSource from ..source import AttrSource, GetItemSource, GlobalWeakRefSource
from ..utils import global_key_name from ..utils import global_key_name
@ -126,13 +126,12 @@ class OptimizerVariable(UserDefinedObjectVariable):
# state guards take a long time to generate # state guards take a long time to generate
# so we manually generate them here # so we manually generate them here
guards = set()
state_source = AttrSource(self.source, "state") state_source = AttrSource(self.source, "state")
guards.add(state_source.make_guard(GuardBuilder.DICT_KEYS)) install_guard(state_source.make_guard(GuardBuilder.DICT_KEYS))
for p, value in self.value.state.items(): for p, value in self.value.state.items():
tx.store_global_weakref(global_key_name(p), p) tx.store_global_weakref(global_key_name(p), p)
p_state_source = GetItemSource(state_source, self.tensor_to_source[p]) p_state_source = GetItemSource(state_source, self.tensor_to_source[p])
guards.add(p_state_source.make_guard(GuardBuilder.DICT_KEYS)) install_guard(p_state_source.make_guard(GuardBuilder.DICT_KEYS))
for k, v in value.items(): for k, v in value.items():
if ( if (
isinstance(v, torch.Tensor) isinstance(v, torch.Tensor)
@ -141,7 +140,7 @@ class OptimizerVariable(UserDefinedObjectVariable):
): ):
self.tensor_to_source[v] = GetItemSource(p_state_source, k) self.tensor_to_source[v] = GetItemSource(p_state_source, k)
elif v is None or isinstance(v, (bool, int, float, str)): elif v is None or isinstance(v, (bool, int, float, str)):
guards.add( install_guard(
GetItemSource(p_state_source, k).make_guard( GetItemSource(p_state_source, k).make_guard(
GuardBuilder.CONSTANT_MATCH GuardBuilder.CONSTANT_MATCH
) )
@ -149,12 +148,10 @@ class OptimizerVariable(UserDefinedObjectVariable):
else: else:
raise GuardInstallException() raise GuardInstallException()
tx.output.guards.update(guards) # this next line has the side effect of installing guards
VariableBuilder(tx, AttrSource(self.source, "param_groups"))(
group_guards = VariableBuilder(tx, AttrSource(self.source, "param_groups"))(
self.value.param_groups self.value.param_groups
) ).recursive_realize()
tx.output.guards.update(group_guards.guards)
def wrap_tensor(self, tx, tensor_value): def wrap_tensor(self, tx, tensor_value):
"""Wrap state tensor in a TensorVariable""" """Wrap state tensor in a TensorVariable"""

View File

@ -1,4 +1,5 @@
import functools import functools
import inspect import inspect
import operator import operator
import types import types
@ -31,7 +32,7 @@ from .. import config, variables
from .._trace_wrapped_higher_order_op import trace_wrapped from .._trace_wrapped_higher_order_op import trace_wrapped
from ..exc import unimplemented, UserError, UserErrorType from ..exc import unimplemented, UserError, UserErrorType
from ..guards import GuardBuilder from ..guards import GuardBuilder, install_guard
from ..source import AttrSource from ..source import AttrSource
from ..utils import ( from ..utils import (
fqn, fqn,
@ -206,12 +207,8 @@ class TensorVariable(VariableTracker):
from .builder import VariableBuilder from .builder import VariableBuilder
attr_source = AttrSource(self.source, name) attr_source = AttrSource(self.source, name)
has_attr_guard = attr_source.make_guard(GuardBuilder.HASATTR) install_guard(attr_source.make_guard(GuardBuilder.HASATTR))
return ( return VariableBuilder(tx, attr_source)(real_value).add_options(self)
VariableBuilder(tx, attr_source)(real_value)
.add_options(self)
.add_guard(has_attr_guard)
)
def var_getattr(self, tx, name): def var_getattr(self, tx, name):
from . import ConstantVariable, TorchVariable from . import ConstantVariable, TorchVariable
@ -254,7 +251,7 @@ class TensorVariable(VariableTracker):
# In some cases, a <tensor>.<attr> guard can be evaluated first, and break if # In some cases, a <tensor>.<attr> guard can be evaluated first, and break if
# <tensor> is later changed to another type # <tensor> is later changed to another type
if result is not None and self.source is not None: if result is not None and self.source is not None:
result = result.add_guard(self.make_guard(GuardBuilder.TYPE_MATCH)) install_guard(self.make_guard(GuardBuilder.TYPE_MATCH))
# It's hard to get inplace view (metadata mutation) on graph input work properly across # It's hard to get inplace view (metadata mutation) on graph input work properly across
# dynamo/aot/inductor, just fall back. # dynamo/aot/inductor, just fall back.
@ -607,7 +604,6 @@ class TensorVariable(VariableTracker):
unimplemented( unimplemented(
"boolean masking setitem backwards requires dynamic shapes" "boolean masking setitem backwards requires dynamic shapes"
) )
tx.output.guards.update(options["guards"])
tx.output.create_proxy( tx.output.create_proxy(
"call_function", "call_function",
operator.setitem, operator.setitem,

View File

@ -8,6 +8,7 @@ import types
from typing import Dict, List from typing import Dict, List
from torch._streambase import _StreamBase from torch._streambase import _StreamBase
from ..guards import install_guard
try: try:
import numpy as np import numpy as np
@ -159,10 +160,10 @@ class TorchCtxManagerClassVariable(VariableTracker):
@classmethod @classmethod
def create_with_source(cls, value, source): def create_with_source(cls, value, source):
install_guard(source.make_guard(GuardBuilder.FUNCTION_MATCH))
return TorchCtxManagerClassVariable( return TorchCtxManagerClassVariable(
value, value,
source=source, source=source,
guards={source.make_guard(GuardBuilder.FUNCTION_MATCH)},
) )
def __init__(self, value, **kwargs): def __init__(self, value, **kwargs):
@ -259,6 +260,9 @@ class TorchVariable(VariableTracker):
except RuntimeError as e: except RuntimeError as e:
assert "No such operator" in str(e), str(e) assert "No such operator" in str(e), str(e)
self_should_be_none = None self_should_be_none = None
except AssertionError as e:
assert "Unknown attribute" in str(e), str(e)
self_should_be_none = None
# assert "_ntuple.<locals>.parse" not in str(value) # assert "_ntuple.<locals>.parse" not in str(value)
@ -425,18 +429,18 @@ class TorchVariable(VariableTracker):
return self._call_ntuple(tx, args, kwargs, options) return self._call_ntuple(tx, args, kwargs, options)
elif self.value is torch.is_grad_enabled: elif self.value is torch.is_grad_enabled:
assert not (args or kwargs) assert not (args or kwargs)
return ConstantVariable.create( install_guard(GradModeVariable._guards_singleton)
torch.is_grad_enabled(), **options return ConstantVariable.create(torch.is_grad_enabled(), **options)
).add_guards(GradModeVariable._guards_singleton)
elif self.value is torch.use_deterministic_algorithms and len(args) == 1: elif self.value is torch.use_deterministic_algorithms and len(args) == 1:
return DeterministicAlgorithmsVariable.create( return DeterministicAlgorithmsVariable.create(
tx, args[0].as_python_constant(), **options tx, args[0].as_python_constant(), **options
) )
elif self.value is torch.are_deterministic_algorithms_enabled: elif self.value is torch.are_deterministic_algorithms_enabled:
assert not (args or kwargs) assert not (args or kwargs)
install_guard(DeterministicAlgorithmsVariable._guards_singleton)
return ConstantVariable.create( return ConstantVariable.create(
torch.are_deterministic_algorithms_enabled(), **options torch.are_deterministic_algorithms_enabled(), **options
).add_guards(DeterministicAlgorithmsVariable._guards_singleton) )
elif self.value is torch.autograd.graph.disable_saved_tensors_hooks: elif self.value is torch.autograd.graph.disable_saved_tensors_hooks:
assert len(args) == 1 assert len(args) == 1
return DisabledSavedTensorsHooksVariable.create( return DisabledSavedTensorsHooksVariable.create(
@ -444,9 +448,8 @@ class TorchVariable(VariableTracker):
) )
elif self.value is torch._C._is_torch_function_enabled: elif self.value is torch._C._is_torch_function_enabled:
assert not (args or kwargs) assert not (args or kwargs)
return ConstantVariable.create( install_guard(TorchFunctionDisableVariable._guards_singleton)
tx.output.torch_function_enabled, **options return ConstantVariable.create(tx.output.torch_function_enabled, **options)
).add_guards(TorchFunctionDisableVariable._guards_singleton)
elif self.value in ( elif self.value in (
torch.overrides.has_torch_function_variadic, torch.overrides.has_torch_function_variadic,
torch.overrides.has_torch_function_unary, torch.overrides.has_torch_function_unary,

View File

@ -5,6 +5,7 @@ import torch.utils._pytree as pytree
from torch.overrides import _get_overloaded_args, get_default_nowrap_functions from torch.overrides import _get_overloaded_args, get_default_nowrap_functions
from ..exc import unimplemented from ..exc import unimplemented
from ..guards import GuardBuilder, install_guard
from ..source import AttrSource, GlobalSource from ..source import AttrSource, GlobalSource
from ..utils import is_tensor_base_attr_getter from ..utils import is_tensor_base_attr_getter
from .base import VariableTracker from .base import VariableTracker
@ -133,14 +134,15 @@ class TensorWithTFOverrideVariable(TensorVariable):
kwargs.pop("class_type") is torch.Tensor kwargs.pop("class_type") is torch.Tensor
), "invalid class type in TensorWithTFOverrideVariable.from_tensor_var" ), "invalid class type in TensorWithTFOverrideVariable.from_tensor_var"
var = cls(torch_function_fn=torch_function_fn, class_type=class_type, **kwargs) var = cls(torch_function_fn=torch_function_fn, class_type=class_type, **kwargs)
var.install_global(tx)
return var
def install_global(self, tx):
# stash the subclass type to rewrap an output tensor if needed # stash the subclass type to rewrap an output tensor if needed
# this is needed because the actual type needs to be available # this is needed because the actual type needs to be available
# each time the compiled artifact is run and outputs a wrapped tensor. # each time the compiled artifact is run and outputs a wrapped tensor.
if var.global_mangled_class_name() not in tx.output.global_scope: if self.global_mangled_class_name() not in tx.output.global_scope:
tx.output.install_global(var.global_mangled_class_name(), class_type) tx.output.install_global(self.global_mangled_class_name(), self.class_type)
return var
def python_type(self): def python_type(self):
return self.class_type return self.class_type
@ -157,7 +159,7 @@ class TensorWithTFOverrideVariable(TensorVariable):
# [Note: __torch_function__] We currently only support attributes that are defined on # [Note: __torch_function__] We currently only support attributes that are defined on
# base tensors, custom attribute accesses will graph break. # base tensors, custom attribute accesses will graph break.
import torch import torch
from .builder import SourcelessBuilder, VariableBuilder from .builder import SourcelessBuilder
if name in banned_attrs or not hasattr(torch.Tensor, name): if name in banned_attrs or not hasattr(torch.Tensor, name):
unimplemented( unimplemented(
@ -172,15 +174,12 @@ class TensorWithTFOverrideVariable(TensorVariable):
if tx.output.torch_function_enabled: if tx.output.torch_function_enabled:
if self.source: if self.source:
get_fn = VariableBuilder( install_guard(
tx, AttrSource(AttrSource(self.source, "__class__"), name).make_guard(
source=AttrSource( GuardBuilder.FUNCTION_MATCH
AttrSource(AttrSource(self.source, "__class__"), name), )
"__get__", )
), get_fn = SourcelessBuilder()(tx, getattr(torch.Tensor, name).__get__)
)(inspect.getattr_static(self.python_type(), name).__get__)
else:
get_fn = SourcelessBuilder()(tx, getattr(torch.Tensor, name).__get__)
return self.call_torch_function( return self.call_torch_function(
tx, tx,

View File

@ -17,7 +17,7 @@ from torch._guards import TracingContext
from .. import variables from .. import variables
from ..allowed_functions import is_allowed from ..allowed_functions import is_allowed
from ..exc import unimplemented from ..exc import unimplemented
from ..guards import GuardBuilder from ..guards import GuardBuilder, install_guard
from ..source import AttrSource, ODictGetItemSource, RandomValueSource from ..source import AttrSource, ODictGetItemSource, RandomValueSource
from ..utils import ( from ..utils import (
all_hook_names, all_hook_names,
@ -266,9 +266,10 @@ class UserDefinedObjectVariable(UserDefinedVariable):
assert not (args or kwargs) assert not (args or kwargs)
keys = list(self.value.keys()) keys = list(self.value.keys())
assert all(map(ConstantVariable.is_literal, keys)) assert all(map(ConstantVariable.is_literal, keys))
install_guard(self.source.make_guard(GuardBuilder.ODICT_KEYS))
return TupleVariable( return TupleVariable(
[ConstantVariable.create(k, **options) for k in keys], **options [ConstantVariable.create(k, **options) for k in keys], **options
).add_guard(self.source.make_guard(GuardBuilder.ODICT_KEYS)) )
if ( if (
method in (collections.OrderedDict.__contains__, dict.__contains__) method in (collections.OrderedDict.__contains__, dict.__contains__)
@ -278,9 +279,10 @@ class UserDefinedObjectVariable(UserDefinedVariable):
in (collections.OrderedDict.keys, dict.keys) in (collections.OrderedDict.keys, dict.keys)
): ):
assert not kwargs assert not kwargs
install_guard(self.source.make_guard(GuardBuilder.ODICT_KEYS))
return ConstantVariable.create( return ConstantVariable.create(
args[0].as_python_constant() in self.value, **options args[0].as_python_constant() in self.value, **options
).add_guard(self.source.make_guard(GuardBuilder.ODICT_KEYS)) )
if ( if (
method is collections.OrderedDict.items method is collections.OrderedDict.items
@ -376,20 +378,15 @@ class UserDefinedObjectVariable(UserDefinedVariable):
) )
): ):
options = VariableTracker.propagate(self, args, kwargs.values()) options = VariableTracker.propagate(self, args, kwargs.values())
options.setdefault("guards", set())
if self.source: if self.source:
options["guards"].add( install_guard(
AttrSource(self.source, "func").make_guard(GuardBuilder.ID_MATCH) AttrSource(self.source, "func").make_guard(GuardBuilder.ID_MATCH),
)
options["guards"].add(
AttrSource(self.source, "args").make_guard( AttrSource(self.source, "args").make_guard(
GuardBuilder.CONSTANT_MATCH GuardBuilder.CONSTANT_MATCH
) ),
)
options["guards"].add(
AttrSource(self.source, "keywords").make_guard( AttrSource(self.source, "keywords").make_guard(
GuardBuilder.CONSTANT_MATCH GuardBuilder.CONSTANT_MATCH
) ),
) )
partial_args = [ partial_args = [
@ -410,7 +407,7 @@ class UserDefinedObjectVariable(UserDefinedVariable):
tx, partial_args, partial_kwargs tx, partial_args, partial_kwargs
) )
elif callable(self.value): elif callable(self.value):
self.add_guard(self.source.make_guard(GuardBuilder.FUNCTION_MATCH)) install_guard(self.source.make_guard(GuardBuilder.FUNCTION_MATCH))
return self.call_method(tx, "__call__", args, kwargs) return self.call_method(tx, "__call__", args, kwargs)
return super().call_function(tx, args, kwargs) return super().call_function(tx, args, kwargs)
@ -578,7 +575,7 @@ class UserDefinedObjectVariable(UserDefinedVariable):
pass pass
options = VariableTracker.propagate(self) options = VariableTracker.propagate(self)
if self.source: if self.source:
options["guards"].add( install_guard(
AttrSource(self.source, name).make_guard(GuardBuilder.HASATTR) AttrSource(self.source, name).make_guard(GuardBuilder.HASATTR)
) )
if self._check_for_getattribute() or self._check_for_getattr(): if self._check_for_getattribute() or self._check_for_getattr():

View File

@ -241,7 +241,13 @@ class Guard:
return output return output
def create(self, builder: GuardBuilderBase): def create(self, builder: GuardBuilderBase):
return self.create_fn(builder, self) try:
return self.create_fn(builder, self)
except Exception:
log.error("Error while creating guard:\n%s", str(self).rstrip())
if self.stack:
log.error("Created at:\n%s", "".join(self.stack.format()[-4:]).rstrip())
raise
def is_nn_module(self): def is_nn_module(self):
return self.source.is_nn_module() return self.source.is_nn_module()