mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[dynamo] Eagerly install guards (#111415)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/111415 Approved by: https://github.com/voznesenskym ghstack dependencies: #111306
This commit is contained in:
committed by
PyTorch MergeBot
parent
2964682490
commit
9664190952
@ -3292,7 +3292,9 @@ class GraphModule(torch.nn.Module):
|
||||
cos = l_x_.cos(); l_x_ = None
|
||||
return pytree.tree_unflatten([cos], self._out_spec)
|
||||
"""
|
||||
true_guard_code = ["cast_symbool_to_symint_guardless(L['pred']) == 1"]
|
||||
true_guard_code = [
|
||||
"cast_symbool_to_symint_guardless(L['pred']) == 1",
|
||||
]
|
||||
false_guard_code = [
|
||||
"Ne(cast_symbool_to_symint_guardless(L['pred']), 1)",
|
||||
"-9223372036854775808 <= cast_symbool_to_symint_guardless(L['pred'])",
|
||||
|
@ -297,25 +297,25 @@ class HigherOrderOpTests(torch._dynamo.test_case.TestCase):
|
||||
actual_graph,
|
||||
"""\
|
||||
class GraphModule(torch.nn.Module):
|
||||
def forward(self, L_x_ : torch.Tensor, L_y_ : torch.Tensor, L_z_ : torch.Tensor):
|
||||
l_x_ = L_x_
|
||||
l_y_ = L_y_
|
||||
l_z_ = L_z_
|
||||
def forward(self, L_d_x_ : torch.Tensor, L_d_y_0_ : torch.Tensor, L_d_y_1_2_ : torch.Tensor):
|
||||
l_d_x_ = L_d_x_
|
||||
l_d_y_0_ = L_d_y_0_
|
||||
l_d_y_1_2_ = L_d_y_1_2_
|
||||
|
||||
wrap_body_0 = self.wrap_body_0
|
||||
wrap = torch._higher_order_ops.wrap.wrap(wrap_body_0, l_x_, l_y_, l_z_); wrap_body_0 = l_x_ = l_y_ = l_z_ = None
|
||||
wrap = torch._higher_order_ops.wrap.wrap(wrap_body_0, l_d_x_, l_d_y_0_, l_d_y_1_2_); wrap_body_0 = l_d_x_ = l_d_y_0_ = l_d_y_1_2_ = None
|
||||
getitem = wrap[0]; wrap = None
|
||||
return (getitem,)
|
||||
|
||||
class GraphModule(torch.nn.Module):
|
||||
def forward(self, l_x_, l_y_, l_z_):
|
||||
sin = l_x_.sin(); l_x_ = None
|
||||
cos = l_y_.cos(); l_y_ = None
|
||||
def forward(self, l_d_x_, l_d_y_0_, l_d_y_1_2_):
|
||||
sin = l_d_x_.sin(); l_d_x_ = None
|
||||
cos = l_d_y_0_.cos(); l_d_y_0_ = None
|
||||
add = sin + cos; sin = cos = None
|
||||
sin_1 = l_z_.sin(); l_z_ = None
|
||||
sin_1 = l_d_y_1_2_.sin(); l_d_y_1_2_ = None
|
||||
sub = add - sin_1; add = sin_1 = None
|
||||
return (sub,)
|
||||
""",
|
||||
""", # NOQA: B950
|
||||
)
|
||||
|
||||
def test_wrap_pytree_args_with_symint_constant(self):
|
||||
@ -3005,9 +3005,9 @@ class GraphModule(torch.nn.Module):
|
||||
actual,
|
||||
"""\
|
||||
class GraphModule(torch.nn.Module):
|
||||
def forward(self, L_x_ : torch.Tensor, L_y_ : torch.Tensor):
|
||||
l_x_ = L_x_
|
||||
def forward(self, L_y_ : torch.Tensor, L_x_ : torch.Tensor):
|
||||
child = L_y_
|
||||
l_x_ = L_x_
|
||||
|
||||
_check_randomness_arg = torch._functorch.vmap._check_randomness_arg('error')
|
||||
_check_randomness_arg_1 = torch._functorch.vmap._check_randomness_arg('error')
|
||||
@ -3269,16 +3269,14 @@ class GraphModule(torch.nn.Module):
|
||||
return torch.func.vmap(torch.sum, in_dims)(x)
|
||||
|
||||
x = torch.randn(3, 3, 3, 3)
|
||||
opt = torch.compile(wrapper_fn, backend="eager", fullgraph=False, dynamic=True)
|
||||
cnt = CompileCounter()
|
||||
opt = torch.compile(wrapper_fn, backend=cnt, fullgraph=False, dynamic=True)
|
||||
expected = wrapper_fn(x, 0), wrapper_fn(x, 1), wrapper_fn(x, 2)
|
||||
# Third invocation of `opt` makes `in_dims` as SymInt.
|
||||
actual = opt(x, 0), opt(x, 1), opt(x, 2)
|
||||
self.assertEqual(expected, actual)
|
||||
self.assertEqual(len(counters["graph_break"]), 1)
|
||||
self.assertEqual(
|
||||
dict(counters["graph_break"]),
|
||||
{"torch.func.vmap: in_dims is not an int or tuple variable.": 2},
|
||||
)
|
||||
self.assertEqual(cnt.frame_count, 3)
|
||||
self.assertEqual(cnt.op_count, 9)
|
||||
|
||||
def test_vmap_multiple_invocation_out_dims(self):
|
||||
counters.clear()
|
||||
@ -3287,16 +3285,14 @@ class GraphModule(torch.nn.Module):
|
||||
return torch.func.vmap(lambda x: torch.sum(x, 0), out_dims=out_dims)(x)
|
||||
|
||||
x = torch.randn(3, 3, 3, 3)
|
||||
opt = torch.compile(wrapper_fn, backend="eager", fullgraph=False, dynamic=True)
|
||||
cnt = CompileCounter()
|
||||
opt = torch.compile(wrapper_fn, backend=cnt, fullgraph=False, dynamic=True)
|
||||
expected = wrapper_fn(x, 0), wrapper_fn(x, 1), wrapper_fn(x, 2)
|
||||
# Third invocation of `opt` makes `in_dims` as SymInt.
|
||||
actual = opt(x, 0), opt(x, 1), opt(x, 2)
|
||||
self.assertEqual(expected, actual)
|
||||
self.assertEqual(len(counters["graph_break"]), 1)
|
||||
self.assertEqual(
|
||||
dict(counters["graph_break"]),
|
||||
{"torch.func.vmap: out_dims is not an int or tuple variable.": 2},
|
||||
)
|
||||
self.assertEqual(cnt.frame_count, 3)
|
||||
self.assertEqual(cnt.op_count, 9)
|
||||
|
||||
def test_vmap_new_tensor_in_body(self):
|
||||
def fn(x):
|
||||
|
@ -1541,7 +1541,7 @@ utils_device.CURRENT_DEVICE == None""",
|
||||
args = [torch.randn(10), 4096, np.int64(8)]
|
||||
correct = fn(*args)
|
||||
cnts = torch._dynamo.testing.CompileCounter()
|
||||
opt_fn = torch._dynamo.optimize(cnts, dynamic=True)(fn)
|
||||
opt_fn = torch._dynamo.optimize(cnts, dynamic=True, nopython=True)(fn)
|
||||
self.assertTrue(same(opt_fn(*args), correct))
|
||||
self.assertTrue(same(opt_fn(*args), correct))
|
||||
self.assertEqual(cnts.frame_count, 1)
|
||||
|
@ -814,9 +814,8 @@ class MockModule(torch.nn.Module):
|
||||
class ReproTests(torch._dynamo.test_case.TestCase):
|
||||
def test_do_paste_mask(self):
|
||||
torch._dynamo.utils.counters.clear()
|
||||
opt__do_paste_mask = torch._dynamo.optimize(
|
||||
torch._dynamo.testing.CompileCounter()
|
||||
)(_do_paste_mask)
|
||||
cnt = torch._dynamo.testing.CompileCounter()
|
||||
opt__do_paste_mask = torch.compile(_do_paste_mask, backend=cnt)
|
||||
opt__do_paste_mask(
|
||||
torch.randn(1, 1, 28, 28),
|
||||
torch.tensor([[0.0, 1, 2, 4]]) * 1,
|
||||
@ -852,12 +851,9 @@ class ReproTests(torch._dynamo.test_case.TestCase):
|
||||
640,
|
||||
False,
|
||||
)
|
||||
|
||||
self.assertGreaterEqual(torch._dynamo.utils.counters["frames"]["ok"], 3)
|
||||
self.assertEqual(
|
||||
torch._dynamo.utils.counters["frames"]["total"],
|
||||
torch._dynamo.utils.counters["frames"]["ok"] + 1,
|
||||
)
|
||||
# (dynamic shapes, static shapes)
|
||||
self.assertIn(cnt.frame_count, (5, 7))
|
||||
self.assertIn(cnt.op_count, (106, 127))
|
||||
|
||||
def test_convert_boxes_to_pooler_format(self):
|
||||
boxes1 = [
|
||||
@ -2451,77 +2447,6 @@ class ReproTests(torch._dynamo.test_case.TestCase):
|
||||
self.assertEqual(f(x, x), opt_f(x, x))
|
||||
self.assertEqual(f(x, y), opt_f(x, y))
|
||||
|
||||
def test_reformer_remove_unused_args(self):
|
||||
# This test case is very interesting. First, let's describe
|
||||
# the bug this is testing for. The bug we fixed is twofold:
|
||||
#
|
||||
# - We prune GraphArgs that aren't used in the output graph.
|
||||
# However, sometimes it is possible for those GraphArgs to be
|
||||
# utilized in shape guards (you could imagine this happening if
|
||||
# dynamo poked some shape variables without recording them in the
|
||||
# graph.) If we prune those GraphArgs, we get a
|
||||
# "s1 not in ..." error as we can no longer codegen the
|
||||
# requested guards.
|
||||
#
|
||||
# - But in practice, Dynamo usually traces size accesses into the
|
||||
# graph, preventing the GraphArg from getting pruned. So how
|
||||
# come we were running into this in practice with hf_Reformer?
|
||||
# The answer is checkpointing!
|
||||
#
|
||||
# This brings us to the following test case. Here's what it does:
|
||||
#
|
||||
# 1. It traces some operations, and then checkpoints before inlining
|
||||
# the function call to g
|
||||
#
|
||||
# 2. g traces some more operations (triggering the shape guard
|
||||
# to be created), but then it graph breaks
|
||||
#
|
||||
# 3. Because you can't graph break in an inlining function, we roll
|
||||
# back to the outer checkpoint ("undoing" the operation that
|
||||
# induced the shape guard) and then immediately generate a
|
||||
# subgraph at that point.
|
||||
#
|
||||
# If we failed to checkpoint the ShapeEnv, it can still have guards
|
||||
# from the aborted speculation, which we will then still attempt to
|
||||
# codegen.
|
||||
#
|
||||
# There's an additional nuance: suppose x is used but y is not.
|
||||
# If you create a guard like y == x * 2, you will accidentally avoid
|
||||
# the "s1 not in ..." error, as y will get substituted with x * 2,
|
||||
# but x is still a GraphArg (it's used) and you don't end up with
|
||||
# the error. This is why we must show y + y == x, not vice versa.
|
||||
# Similarly, it is also why we must not do a simple guard like x == y
|
||||
#
|
||||
# Can we actually demonstrate that checkpointing the ShapeEnv is
|
||||
# necessary? It's not so easy to induce this case. Dynamo is very
|
||||
# eager about adding locals to GraphArgs; any local that is in scope,
|
||||
# even if it isn't used, is added to GraphArgs (see also
|
||||
# https://github.com/pytorch/torchdynamo/issues/1925 ). So long
|
||||
# as Dynamo eagerly guards in this way, we have an invariant that
|
||||
# all locals are guaranteed to show up in GraphArgs before the
|
||||
# inlining function call, in which case we will always have enough
|
||||
# information to codegen our guards so long as we don't prune the
|
||||
# unused GraphArgs away (and indeed, the direct fix for this bug
|
||||
# was to make sure we use original GraphArgs). Non locals,
|
||||
# conversely, typically are static, and so won't have guards allocated
|
||||
# for them. That being said, there may still be a way to trigger
|
||||
# this error.
|
||||
|
||||
def g(x, y):
|
||||
r = torch.cat((y, y)) + x
|
||||
print("foo")
|
||||
return r
|
||||
|
||||
def f(x, y):
|
||||
x = x * 3
|
||||
return g(x, y)
|
||||
|
||||
opt_f = torch._dynamo.optimize("aot_eager")(f)
|
||||
|
||||
x = torch.randn(4)
|
||||
y = torch.randn(2)
|
||||
self.assertEqual(f(x, y), opt_f(x, y))
|
||||
|
||||
def test_swin_base_tensor_attr(self):
|
||||
class Foo(torch.nn.Module):
|
||||
def __init__(self):
|
||||
|
@ -271,7 +271,7 @@ class SubclassTests(torch._dynamo.test_case.TestCase):
|
||||
kwargs = {}
|
||||
return super().__torch_function__(func, types, args, kwargs)
|
||||
|
||||
@torch.compile(backend="eager", fullgraph=True)
|
||||
@torch.compile(backend="eager")
|
||||
def fn(x):
|
||||
return x.sigmoid()
|
||||
|
||||
@ -819,13 +819,6 @@ class TestNestedTensor(torch._dynamo.test_case.TestCase):
|
||||
nt3, _ = self._get_jagged_tensor(((2, 3, 4), 3), None)
|
||||
self._check_recompiles(binary, (nt1, nt2), (nt1, nt3), True)
|
||||
|
||||
def test_binary_recompiles_due_to_duck_sizing(self):
|
||||
# Even though the input is unused, we still guard due to duck sizing
|
||||
nt1, offsets = self._get_jagged_tensor(((2, 3, 4), 3), None)
|
||||
nt2, _ = self._get_jagged_tensor(((2, 3, 4), 3), offsets)
|
||||
nt3, _ = self._get_jagged_tensor(((2, 3, 4), 3), None)
|
||||
self._check_recompiles(lambda nt1, nt2: nt1.sin(), (nt1, nt2), (nt1, nt3), True)
|
||||
|
||||
# TODO: cannot parametrize this test class with device for some reason
|
||||
def _test_autograd(self, backend):
|
||||
a = torch.randn(2, 3, requires_grad=True, dtype=torch.float64)
|
||||
|
@ -477,8 +477,8 @@ class SubGraphTests(torch._dynamo.test_case.TestCase):
|
||||
opt_fn(v1, a, b, c)
|
||||
|
||||
# checking here we don't create 2^n graphs
|
||||
self.assertEqual(cnt.frame_count, 12)
|
||||
self.assertEqual(cnt.op_count, 16)
|
||||
self.assertEqual(cnt.frame_count, 7)
|
||||
self.assertEqual(cnt.op_count, 10)
|
||||
|
||||
def test_resume_with_no_grad1(self):
|
||||
def fn(a, b):
|
||||
|
@ -4,7 +4,7 @@ import hypothesis.strategies as st
|
||||
from hypothesis import given
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.testing._internal.common_utils import TestCase, run_tests
|
||||
from torch.testing._internal.common_utils import TestCase, run_tests, skipIfTorchDynamo
|
||||
import torch.testing._internal.hypothesis_utils as hu
|
||||
hu.assert_deadline_disabled()
|
||||
|
||||
@ -56,6 +56,7 @@ class PruningOpTest(TestCase):
|
||||
self.assertEqual(pt_compressed_indices_map.dtype, indices_type)
|
||||
|
||||
|
||||
@skipIfTorchDynamo()
|
||||
@given(
|
||||
embedding_rows=st.integers(1, 100),
|
||||
embedding_dims=st.integers(1, 100),
|
||||
@ -67,6 +68,7 @@ class PruningOpTest(TestCase):
|
||||
self._test_rowwise_prune_op(embedding_rows, embedding_dims, torch.int, weights_dtype)
|
||||
|
||||
|
||||
@skipIfTorchDynamo()
|
||||
@given(
|
||||
embedding_rows=st.integers(1, 100),
|
||||
embedding_dims=st.integers(1, 100),
|
||||
|
@ -2530,6 +2530,7 @@ class TestSparseCSR(TestCase):
|
||||
run_test(4, 5, 4, 10, False)
|
||||
run_test(4, 4, 4, 16, True)
|
||||
|
||||
@skipIfTorchDynamo()
|
||||
@onlyCPU
|
||||
@dtypes(torch.float32, torch.float64, torch.bfloat16)
|
||||
@precisionOverride({torch.bfloat16: 0.01})
|
||||
@ -2894,6 +2895,7 @@ class TestSparseCSR(TestCase):
|
||||
run_test(shape, max(shape), index_dtype)
|
||||
run_test(shape, shape[0] * shape[1], index_dtype)
|
||||
|
||||
@skipIfTorchDynamo()
|
||||
@skipMeta
|
||||
@dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16))
|
||||
@all_sparse_compressed_layouts()
|
||||
|
@ -76,8 +76,6 @@ class PyCodegen:
|
||||
self.clear_tos()
|
||||
return
|
||||
|
||||
self.tx.output.guards.update(value.guards)
|
||||
|
||||
assert isinstance(value, VariableTracker)
|
||||
output = self._output
|
||||
graph_outputs = self.graph_outputs
|
||||
|
@ -586,7 +586,7 @@ class GuardBuilder(GuardBuilderBase):
|
||||
f"{id(torch._dynamo.eval_frame.guarded_backend_cache.current_backend)}"
|
||||
)
|
||||
code = [
|
||||
f"___skip_backend_check() or ___current_backend() == ___lookup_backend({backend_id})"
|
||||
f"(___skip_backend_check() or ___current_backend() == ___lookup_backend({backend_id}))"
|
||||
]
|
||||
self._produce_guard_code(guard, code)
|
||||
|
||||
@ -1366,3 +1366,19 @@ def make_dupe_guard(obj_source, dupe_source):
|
||||
# However, this should always be a sound guard to add here.
|
||||
return functools.partial(GuardBuilder.DUPLICATE_INPUT, source_b=dupe_source)
|
||||
return None
|
||||
|
||||
|
||||
def install_guard(*guards, skip=0):
|
||||
"""
|
||||
Add dynamo guards to the current tracing context.
|
||||
|
||||
Args:
|
||||
guards: guard(s) to add
|
||||
skip: number of stack frames to ignore for debug stack trace
|
||||
"""
|
||||
from torch._guards import TracingContext
|
||||
|
||||
add = TracingContext.get().guards_context.dynamo_guards.add
|
||||
for guard in guards:
|
||||
assert isinstance(guard, Guard)
|
||||
add(guard, skip=skip + 1)
|
||||
|
@ -62,7 +62,7 @@ from .exc import (
|
||||
unimplemented,
|
||||
unimplemented_with_warning,
|
||||
)
|
||||
from .guards import GuardBuilder
|
||||
from .guards import GuardBuilder, install_guard
|
||||
from .mutation_guard import is_dynamic_nn_module
|
||||
from .side_effects import SideEffects
|
||||
from .source import (
|
||||
@ -561,7 +561,11 @@ class OutputGraph(Checkpointable[OutputGraphState]):
|
||||
# FX deepcopy doesn't work for a partially created graph, so just remove new nodes
|
||||
removed_nodes = 0
|
||||
for node in reversed(list(self.graph.nodes)):
|
||||
if node.meta["creation_timestamp"] > self.timestamp:
|
||||
if (
|
||||
node.meta["creation_timestamp"] > self.timestamp
|
||||
# placeholders here may have been lazily added by existing objects
|
||||
and node.op != "placeholder"
|
||||
):
|
||||
# Erasing node alone does not remove the meta information
|
||||
# So, remove the help tensor explicitly
|
||||
if "example_value" in node.meta:
|
||||
@ -670,7 +674,6 @@ class OutputGraph(Checkpointable[OutputGraphState]):
|
||||
return variables.UnspecializedNNModuleVariable(target, **options)
|
||||
|
||||
options = dict(options)
|
||||
options["guards"] = set(options.get("guards", []))
|
||||
assert "source" in options
|
||||
source = options["source"]
|
||||
assert not isinstance(source, ParamBufferSource)
|
||||
@ -692,10 +695,10 @@ class OutputGraph(Checkpointable[OutputGraphState]):
|
||||
tracer = self.root_tracer
|
||||
|
||||
if not is_constant_source(source):
|
||||
options["guards"].add(source.make_guard(GuardBuilder.TENSOR_MATCH))
|
||||
install_guard(source.make_guard(GuardBuilder.TENSOR_MATCH))
|
||||
|
||||
if get_static_address_type(target) == "guarded":
|
||||
options["guards"].add(source.make_guard(GuardBuilder.DATA_PTR_MATCH))
|
||||
install_guard(source.make_guard(GuardBuilder.DATA_PTR_MATCH))
|
||||
|
||||
def wrap_name(module_key):
|
||||
assert self.param_name_to_source is not None
|
||||
@ -711,7 +714,7 @@ class OutputGraph(Checkpointable[OutputGraphState]):
|
||||
elif isinstance(target, torch.nn.Module):
|
||||
assert isinstance(target, torch.nn.Module)
|
||||
|
||||
options["guards"].add(source.make_guard(GuardBuilder.NN_MODULE))
|
||||
install_guard(source.make_guard(GuardBuilder.NN_MODULE))
|
||||
|
||||
def wrap_name(module_key):
|
||||
return NNModuleVariable(type(target), module_key, **options)
|
||||
@ -1005,9 +1008,6 @@ class OutputGraph(Checkpointable[OutputGraphState]):
|
||||
|
||||
assert isinstance(rv, list)
|
||||
assert isinstance(root, FakeRootModule)
|
||||
for output in rv:
|
||||
self.guards.update(output.guards)
|
||||
|
||||
self.create_node(
|
||||
"output",
|
||||
"output",
|
||||
|
@ -54,7 +54,7 @@ from .codegen import PyCodegen
|
||||
from .current_scope_id import current_scope_id
|
||||
from .exc import ArgsMismatchError, BackendCompilerFailed, unimplemented, Unsupported
|
||||
from .funcname_cache import get_funcname
|
||||
from .guards import GuardBuilder
|
||||
from .guards import GuardBuilder, install_guard
|
||||
from .output_graph import GraphCompileReason, OutputGraph, OutputGraphState
|
||||
from .replay_record import DummyModule, ExecutionRecorder
|
||||
from .resume_execution import ContinueExecutionCache, ReenterWith
|
||||
@ -323,13 +323,11 @@ def _detect_and_normalize_assert_statement(
|
||||
def generic_jump(truth_fn: typing.Callable[[object], bool], push: bool):
|
||||
def inner(self: "InstructionTranslatorBase", inst: Instruction):
|
||||
value: VariableTracker = self.pop()
|
||||
self.output.guards.update(value.guards)
|
||||
if (
|
||||
config.rewrite_assert_with_torch_assert
|
||||
and _detect_and_normalize_assert_statement(self, truth_fn, push)
|
||||
):
|
||||
error_msg: VariableTracker = self.pop()
|
||||
self.output.guards.update(error_msg.guards)
|
||||
# Skip over things like `assert True`
|
||||
if value.is_python_constant() and bool(value.as_python_constant()):
|
||||
self.jump(inst)
|
||||
@ -419,7 +417,6 @@ def generic_jump(truth_fn: typing.Callable[[object], bool], push: bool):
|
||||
if isinstance(result, ConstantVariable) and isinstance(
|
||||
result.value, (bool, int)
|
||||
):
|
||||
self.output.guards.update(result.guards)
|
||||
if truth_fn(result.value):
|
||||
push and self.push(value)
|
||||
self.jump(inst)
|
||||
@ -686,9 +683,7 @@ class InstructionTranslatorBase(Checkpointable[InstructionTranslatorGraphState])
|
||||
"""
|
||||
A call to some user defined function by inlining it.
|
||||
"""
|
||||
result = InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
|
||||
self.output.guards.update(fn.guards)
|
||||
return result
|
||||
return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
|
||||
|
||||
def get_line_of_code_header(self, lineno=None):
|
||||
if lineno is None:
|
||||
@ -1139,7 +1134,6 @@ class InstructionTranslatorBase(Checkpointable[InstructionTranslatorGraphState])
|
||||
def FOR_ITER(self, inst):
|
||||
it = self.pop().realize()
|
||||
if isinstance(it, (variables.ListIteratorVariable, variables.IteratorVariable)):
|
||||
self.output.guards.update(it.guards)
|
||||
try:
|
||||
val, next_iter = it.next_variables(self)
|
||||
self.push(next_iter)
|
||||
@ -1233,8 +1227,6 @@ class InstructionTranslatorBase(Checkpointable[InstructionTranslatorGraphState])
|
||||
if sys.version_info >= (3, 11):
|
||||
null = self.pop()
|
||||
assert isinstance(null, NullVariable)
|
||||
self.output.guards.update(argsvars.guards)
|
||||
self.output.guards.update(kwargsvars.guards)
|
||||
|
||||
if (
|
||||
isinstance(fn, GetAttrVariable)
|
||||
@ -1327,13 +1319,9 @@ class InstructionTranslatorBase(Checkpointable[InstructionTranslatorGraphState])
|
||||
), f"Mutating module attribute {inst.argval} during export."
|
||||
|
||||
try:
|
||||
self.output.guards.update(
|
||||
BuiltinVariable(setattr)
|
||||
.call_function(
|
||||
BuiltinVariable(setattr).call_function(
|
||||
self, [obj, ConstantVariable.create(inst.argval), val], {}
|
||||
)
|
||||
.guards
|
||||
)
|
||||
return
|
||||
except Unsupported as e:
|
||||
if not self.should_compile_partial_graph():
|
||||
@ -1355,10 +1343,8 @@ class InstructionTranslatorBase(Checkpointable[InstructionTranslatorGraphState])
|
||||
|
||||
def DELETE_ATTR(self, inst):
|
||||
obj = self.pop()
|
||||
self.output.guards.update(
|
||||
BuiltinVariable(delattr)
|
||||
.call_function(self, [obj, ConstantVariable.create(inst.argval)], {})
|
||||
.guards
|
||||
BuiltinVariable(delattr).call_function(
|
||||
self, [obj, ConstantVariable.create(inst.argval)], {}
|
||||
)
|
||||
|
||||
def create_call_resume_at(self, offset):
|
||||
@ -1375,8 +1361,6 @@ class InstructionTranslatorBase(Checkpointable[InstructionTranslatorGraphState])
|
||||
def STORE_SUBSCR(self, inst):
|
||||
val, obj, key = self.popn(3)
|
||||
result = obj.call_method(self, "__setitem__", [key, val], {})
|
||||
# no result is pushed, so need to lift the guards to global
|
||||
self.output.guards.update(result.guards)
|
||||
|
||||
def BUILD_TUPLE(self, inst):
|
||||
items = self.popn(inst.argval)
|
||||
@ -1511,7 +1495,6 @@ class InstructionTranslatorBase(Checkpointable[InstructionTranslatorGraphState])
|
||||
obj,
|
||||
ListVariable(
|
||||
obj.items + [v],
|
||||
regen_guards=False,
|
||||
**VariableTracker.propagate([obj, v]),
|
||||
),
|
||||
)
|
||||
@ -1559,7 +1542,6 @@ class InstructionTranslatorBase(Checkpointable[InstructionTranslatorGraphState])
|
||||
def UNPACK_SEQUENCE(self, inst):
|
||||
seq = self.pop()
|
||||
if isinstance(seq, (BaseListVariable, SetVariable)):
|
||||
self.output.guards.update(seq.guards)
|
||||
val = seq.unpack_var_sequence(self)
|
||||
elif seq.is_python_constant() and isinstance(seq, ConstantVariable):
|
||||
val = seq.unpack_var_sequence(self)
|
||||
@ -1874,8 +1856,6 @@ class InstructionTranslatorBase(Checkpointable[InstructionTranslatorGraphState])
|
||||
if isinstance(ctx, GenericContextWrappingVariable):
|
||||
self.generic_context_manager_depth += 1
|
||||
|
||||
self.output.guards.update(ctx.guards)
|
||||
|
||||
exit = WithExitFunctionVariable(
|
||||
ctx,
|
||||
inst.target,
|
||||
@ -1961,9 +1941,7 @@ class InstructionTranslatorBase(Checkpointable[InstructionTranslatorGraphState])
|
||||
)
|
||||
|
||||
def store_global_weakref(self, name, value):
|
||||
self.output.guards.add(
|
||||
GlobalWeakRefSource(name).make_guard(GuardBuilder.WEAKREF_ALIVE)
|
||||
)
|
||||
install_guard(GlobalWeakRefSource(name).make_guard(GuardBuilder.WEAKREF_ALIVE))
|
||||
if name not in self.output.global_scope:
|
||||
self.output.install_global(name, weakref.ref(value))
|
||||
|
||||
@ -2148,67 +2126,26 @@ class InstructionTranslator(InstructionTranslatorBase):
|
||||
vars.extend(cells_and_freevars)
|
||||
cells_and_freevars_set = set(cells_and_freevars)
|
||||
|
||||
self.symbolic_locals = collections.OrderedDict(
|
||||
(
|
||||
k,
|
||||
VariableBuilder(
|
||||
self,
|
||||
LocalSource(k, cell_or_freevar=k in cells_and_freevars_set),
|
||||
)(f_locals[k]),
|
||||
self.symbolic_locals = {
|
||||
k: variables.LazyVariableTracker.create(
|
||||
f_locals[k],
|
||||
source=LocalSource(k, cell_or_freevar=k in cells_and_freevars_set),
|
||||
)
|
||||
for k in vars
|
||||
if k in f_locals
|
||||
)
|
||||
}
|
||||
if export:
|
||||
# export gets super confused if we never realize unused inputs
|
||||
# export gets confused if we never realize unused inputs
|
||||
# in export mode just eagerly realize everything
|
||||
self.symbolic_locals = VariableTracker.apply(
|
||||
lambda x: x.realize(), self.symbolic_locals
|
||||
)
|
||||
|
||||
self.init_local_index_guards_hack()
|
||||
|
||||
self._freevars_ids = dict()
|
||||
for name in self.code_options["co_freevars"]:
|
||||
if name in f_locals:
|
||||
self._freevars_ids[name] = id(f_locals[name])
|
||||
|
||||
def init_local_index_guards_hack(self):
|
||||
# symbolic_locals contains the mapping from original f_locals to the
|
||||
# Variable objects. During the Variable building phase, each object also
|
||||
# has its associated guards. At the end, we will accumulate these
|
||||
# guards.
|
||||
#
|
||||
# One way of handling these guards is to just accumulate all of them
|
||||
# right now. However, many f_locals might not be used in the frame and
|
||||
# thus can unnecessarily increase guard execution overhead. Therefore,
|
||||
# we selectively update output.guards as we run the Python Bytecode
|
||||
# instruction by instruction.
|
||||
#
|
||||
# An exception here is list/dict variables. Guards related to these
|
||||
# variables have indexed access, like Tensor_match on args[0], and if
|
||||
# args is not used in this frame, we will miss a LIST_LENGTH check like
|
||||
# len(args) == 2. Missing the LIST_LENGTH check causes problem for the
|
||||
# next invocation when args is not a list, and args[0] is a runtime
|
||||
# error. Therefore, we recursively add guards for list/dict variable here.
|
||||
for val in self.symbolic_locals.values():
|
||||
if isinstance(
|
||||
val, (ListIteratorVariable, BaseListVariable, ConstDictVariable)
|
||||
):
|
||||
local_guards = VariableTracker.propagate(val)["guards"]
|
||||
index_guards = [
|
||||
guard
|
||||
for guard in local_guards
|
||||
if guard.create_fn
|
||||
in (
|
||||
GuardBuilder.LIST_LENGTH,
|
||||
GuardBuilder.DICT_KEYS,
|
||||
GuardBuilder.ODICT_KEYS,
|
||||
GuardBuilder.TUPLE_ITERATOR_LEN,
|
||||
)
|
||||
]
|
||||
self.output.guards.update(index_guards)
|
||||
|
||||
def run(self):
|
||||
super().run()
|
||||
|
||||
@ -2661,7 +2598,6 @@ class InliningGeneratorInstructionTranslator(InliningInstructionTranslator):
|
||||
if isinstance(
|
||||
tos, (variables.ListIteratorVariable, variables.IteratorVariable)
|
||||
):
|
||||
self.output.guards.update(tos.guards)
|
||||
try:
|
||||
val, next_iter = tos.next_variables(self)
|
||||
self.push(val)
|
||||
|
@ -1,12 +1,12 @@
|
||||
import collections
|
||||
from enum import Enum
|
||||
from typing import Any, Callable, Dict, List, Optional, Set
|
||||
from typing import Any, Callable, Dict, List
|
||||
|
||||
from .. import variables
|
||||
from ..current_scope_id import current_scope_id
|
||||
from ..exc import unimplemented
|
||||
from ..source import AttrSource, Source
|
||||
from ..utils import dict_values, identity, istype, odict_values
|
||||
from ..utils import identity, istype
|
||||
|
||||
|
||||
class MutableLocalSource(Enum):
|
||||
@ -154,21 +154,8 @@ class VariableTracker(metaclass=VariableTrackerMeta):
|
||||
|
||||
@staticmethod
|
||||
def propagate(*vars: List[List["VariableTracker"]]):
|
||||
"""Combine the guards from many VariableTracker into **kwargs for a new instance"""
|
||||
guards = set()
|
||||
|
||||
def visit(var):
|
||||
if type(var) in (list, tuple, dict_values, odict_values):
|
||||
for i in var:
|
||||
visit(i)
|
||||
else:
|
||||
assert isinstance(var, VariableTracker), typestr(var)
|
||||
guards.update(var.guards)
|
||||
|
||||
visit(vars)
|
||||
return {
|
||||
"guards": guards,
|
||||
}
|
||||
# TODO(jansel): delete this function
|
||||
return {}
|
||||
|
||||
def clone(self, **kwargs):
|
||||
"""Shallow copy with some (optional) changes"""
|
||||
@ -246,22 +233,8 @@ class VariableTracker(metaclass=VariableTrackerMeta):
|
||||
cache[idx] = (result, value)
|
||||
return result
|
||||
|
||||
def add_guard(self, guard):
|
||||
return self.clone(guards=set.union(self.guards, {guard}))
|
||||
|
||||
def add_guards(self, guards):
|
||||
if guards is None:
|
||||
return self
|
||||
assert isinstance(guards, set)
|
||||
return self.clone(guards=set.union(self.guards, guards))
|
||||
|
||||
def add_options(self, options, *more):
|
||||
if more:
|
||||
return self.add_options(options).add_options(*more)
|
||||
if isinstance(options, VariableTracker):
|
||||
return self.add_guards(options.guards)
|
||||
assert isinstance(options, dict)
|
||||
return self.add_guards(options.get("guards", set()))
|
||||
return self
|
||||
|
||||
def __str__(self):
|
||||
return f"{self.__class__.__name__}()"
|
||||
@ -283,13 +256,6 @@ class VariableTracker(metaclass=VariableTrackerMeta):
|
||||
except NotImplementedError:
|
||||
return False
|
||||
|
||||
def can_make_guard(self):
|
||||
try:
|
||||
self.make_guard(None)
|
||||
return True
|
||||
except NotImplementedError:
|
||||
return False
|
||||
|
||||
def make_guard(self, fn):
|
||||
if self.source:
|
||||
return self.source.make_guard(fn)
|
||||
@ -380,6 +346,10 @@ class VariableTracker(metaclass=VariableTrackerMeta):
|
||||
"""Used by LazyVariableTracker to build the real VariableTracker"""
|
||||
return self
|
||||
|
||||
def recursive_realize(self):
|
||||
"""Realize all objects under this"""
|
||||
return VariableTracker.apply(lambda x: x.realize(), self)
|
||||
|
||||
def unwrap(self) -> "VariableTracker":
|
||||
"""Used by LazyVariableTracker to return the real VariableTracker if it already exists"""
|
||||
return self
|
||||
@ -391,14 +361,12 @@ class VariableTracker(metaclass=VariableTrackerMeta):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
guards: Optional[Set] = None,
|
||||
source: Source = None,
|
||||
mutable_local: MutableLocal = None,
|
||||
user_code_variable_name: str = None,
|
||||
parents_tracker: ParentsTracker = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.guards = guards or set()
|
||||
self.source = source
|
||||
self.mutable_local = mutable_local
|
||||
self.user_code_variable_name = user_code_variable_name
|
||||
|
@ -44,7 +44,7 @@ from ..allowed_functions import (
|
||||
|
||||
from ..device_interface import device_interfaces
|
||||
from ..exc import InternalTorchDynamoError, unimplemented
|
||||
from ..guards import GuardBuilder, make_dupe_guard
|
||||
from ..guards import GuardBuilder, install_guard, make_dupe_guard
|
||||
from ..side_effects import SideEffects
|
||||
from ..source import (
|
||||
AttrSource,
|
||||
@ -198,6 +198,7 @@ class GraphArg:
|
||||
|
||||
def erase(self):
|
||||
self._example = None
|
||||
self.example_strong_ref = None
|
||||
|
||||
def __eq__(self, other):
|
||||
return self.source.name() == other.source.name()
|
||||
@ -231,9 +232,7 @@ class VariableBuilder:
|
||||
side_effect_result = self.tx.output.side_effects[value]
|
||||
dup_guard = make_dupe_guard(self.source, side_effect_result.source)
|
||||
if dup_guard:
|
||||
side_effect_result = side_effect_result.add_guards(
|
||||
self.make_guards(dup_guard)
|
||||
)
|
||||
self.install_guards(dup_guard)
|
||||
return side_effect_result
|
||||
vt = self._wrap(value).clone(**self.options())
|
||||
if self._can_lift_attrs_to_inputs(vt):
|
||||
@ -272,14 +271,15 @@ class VariableBuilder:
|
||||
def options(self):
|
||||
return {"source": self.get_source()}
|
||||
|
||||
def make_guards(self, *guards):
|
||||
def install_guards(self, *guards):
|
||||
source = self.get_source()
|
||||
if (
|
||||
isinstance(source, ConstantSource)
|
||||
or source.guard_source() == GuardSource.CONSTANT
|
||||
):
|
||||
return None
|
||||
return {source.make_guard(guard) for guard in guards}
|
||||
install_guard(*[source.make_guard(guard) for guard in guards], skip=1)
|
||||
return {}
|
||||
|
||||
@classmethod
|
||||
@functools.lru_cache(None)
|
||||
@ -330,7 +330,7 @@ class VariableBuilder:
|
||||
lambda self, value: LambdaVariable(
|
||||
InspectSignatureVariable.create,
|
||||
source=self.source,
|
||||
guards=self.make_guards(GuardBuilder.FUNCTION_MATCH),
|
||||
**self.install_guards(GuardBuilder.FUNCTION_MATCH),
|
||||
),
|
||||
),
|
||||
(comptime, lambda self, value: ComptimeVariable()),
|
||||
@ -339,7 +339,7 @@ class VariableBuilder:
|
||||
lambda self, value: LambdaVariable(
|
||||
_dataclasses_fields_lambda,
|
||||
source=self.source,
|
||||
guards=self.make_guards(GuardBuilder.FUNCTION_MATCH),
|
||||
**self.install_guards(GuardBuilder.FUNCTION_MATCH),
|
||||
),
|
||||
),
|
||||
(
|
||||
@ -347,7 +347,7 @@ class VariableBuilder:
|
||||
lambda self, value: TorchVariable(
|
||||
value,
|
||||
source=self.source,
|
||||
guards=self.make_guards(GuardBuilder.FUNCTION_MATCH),
|
||||
**self.install_guards(GuardBuilder.FUNCTION_MATCH),
|
||||
),
|
||||
),
|
||||
]
|
||||
@ -375,8 +375,6 @@ class VariableBuilder:
|
||||
class Autotuner:
|
||||
pass
|
||||
|
||||
make_guards = self.make_guards
|
||||
|
||||
# Handle exact type() match
|
||||
type_dispatch = self._type_dispatch().get(type(value))
|
||||
if type_dispatch is not None:
|
||||
@ -400,13 +398,13 @@ class VariableBuilder:
|
||||
return self.wrap_listlike(value)
|
||||
|
||||
elif value is torch.utils._pytree.SUPPORTED_NODES:
|
||||
# For SUPPORTED_NODES, we guard on the dictionary version (PEP509)
|
||||
# under the assumption that the values themselves don't change.
|
||||
self.install_guards(GuardBuilder.DICT_VERSION)
|
||||
result = {
|
||||
k: UserDefinedObjectVariable(
|
||||
value[k],
|
||||
source=GetItemSource(self.get_source(), k),
|
||||
# For SUPPORTED_NODES, we guard on the dictionary version (PEP509)
|
||||
# under the assumption that the values themselves don't change.
|
||||
guards=self.make_guards(GuardBuilder.DICT_VERSION),
|
||||
)
|
||||
for k in value.keys()
|
||||
}
|
||||
@ -429,9 +427,9 @@ class VariableBuilder:
|
||||
# Why is this OK for (specialized) nnmodules? We set up a setattr hook
|
||||
# to check for module property mutations, which does a reasonable,
|
||||
# but not completely secure job ensuring a property wasn't changed.
|
||||
guards = self.make_guards(GuardBuilder.BOOL_FALSE)
|
||||
self.install_guards(GuardBuilder.BOOL_FALSE)
|
||||
else:
|
||||
guards = self.make_guards(GuardBuilder.DICT_KEYS)
|
||||
self.install_guards(GuardBuilder.DICT_KEYS)
|
||||
|
||||
# store key variables in global location for reconstruction
|
||||
for key in value.keys():
|
||||
@ -448,7 +446,7 @@ class VariableBuilder:
|
||||
k: LazyVariableTracker.create(
|
||||
value[k],
|
||||
source=GetItemSource(self.get_source(), index_source(k)),
|
||||
).add_guards(guards)
|
||||
)
|
||||
for k in value.keys()
|
||||
}
|
||||
|
||||
@ -457,10 +455,9 @@ class VariableBuilder:
|
||||
result,
|
||||
type(value),
|
||||
self._wrap(value.default_factory),
|
||||
guards=guards,
|
||||
)
|
||||
else:
|
||||
result = ConstDictVariable(result, type(value), guards=guards)
|
||||
result = ConstDictVariable(result, type(value))
|
||||
|
||||
return self.tx.output.side_effects.track_dict(self.source, value, result)
|
||||
elif isinstance(value, torch.nn.Module):
|
||||
@ -472,23 +469,14 @@ class VariableBuilder:
|
||||
):
|
||||
# For frozenset, we can guard by object ID instead of value
|
||||
# equality, this allows us to handle non-literal values
|
||||
return ConstantVariable.create(
|
||||
value=value,
|
||||
source=self.source,
|
||||
guards=make_guards(GuardBuilder.ID_MATCH),
|
||||
)
|
||||
self.install_guards(GuardBuilder.ID_MATCH)
|
||||
return ConstantVariable.create(value=value, source=self.source)
|
||||
elif isinstance(value, enum.Enum):
|
||||
return EnumVariable(
|
||||
value=value,
|
||||
source=self.source,
|
||||
guards=make_guards(GuardBuilder.ID_MATCH),
|
||||
)
|
||||
self.install_guards(GuardBuilder.ID_MATCH)
|
||||
return EnumVariable(value=value, source=self.source)
|
||||
elif is_builtin_callable(value):
|
||||
return BuiltinVariable(
|
||||
value,
|
||||
source=self.source,
|
||||
guards=make_guards(GuardBuilder.BUILTIN_MATCH),
|
||||
)
|
||||
self.install_guards(GuardBuilder.BUILTIN_MATCH)
|
||||
return BuiltinVariable(value, source=self.source)
|
||||
elif is_utils_checkpoint(value):
|
||||
return build_checkpoint_variable(source=self.source)
|
||||
elif isinstance(value, functools.partial):
|
||||
@ -509,52 +497,50 @@ class VariableBuilder:
|
||||
self.tx, GetItemSource(keywords_source, k)
|
||||
)(v)
|
||||
|
||||
guards = {
|
||||
install_guard(
|
||||
self.get_source().make_guard(GuardBuilder.TYPE_MATCH),
|
||||
keywords_source.make_guard(GuardBuilder.DICT_KEYS),
|
||||
args_source.make_guard(GuardBuilder.LIST_LENGTH),
|
||||
}
|
||||
|
||||
return FunctoolsPartialVariable(
|
||||
func_obj, args, keywords, original=value, guards=guards
|
||||
)
|
||||
return FunctoolsPartialVariable(func_obj, args, keywords, original=value)
|
||||
elif is_typing(value):
|
||||
# typing.List, typing.Mapping, etc.
|
||||
self.install_guards(GuardBuilder.ID_MATCH)
|
||||
return TypingVariable(
|
||||
value,
|
||||
source=self.source,
|
||||
guards=make_guards(GuardBuilder.ID_MATCH),
|
||||
)
|
||||
elif np is not None and isinstance(value, np.generic):
|
||||
# numpy array scalars: convert to 0D arrays
|
||||
return self.wrap_numpy_ndarray(np.asarray(value))
|
||||
elif is_numpy(value):
|
||||
assert np
|
||||
return NumpyVariable(
|
||||
value,
|
||||
source=self.source,
|
||||
guards=make_guards(
|
||||
self.install_guards(
|
||||
GuardBuilder.FUNCTION_MATCH
|
||||
if callable(value)
|
||||
else GuardBuilder.TYPE_MATCH
|
||||
),
|
||||
)
|
||||
return NumpyVariable(value, source=self.source)
|
||||
# NB: These can't be put in type_dispatch, they have to run later
|
||||
elif CollectiveFunctionRewriteVariable.can_rewrite(value):
|
||||
self.install_guards(GuardBuilder.FUNCTION_MATCH)
|
||||
return CollectiveFunctionRewriteVariable.create(
|
||||
self.tx,
|
||||
value,
|
||||
source=self.source,
|
||||
guards=make_guards(GuardBuilder.FUNCTION_MATCH),
|
||||
)
|
||||
elif istype(value, torch.autograd.function.FunctionMeta):
|
||||
self.install_guards(GuardBuilder.FUNCTION_MATCH)
|
||||
return AutogradFunctionVariable(
|
||||
value,
|
||||
source=self.source,
|
||||
guards=make_guards(GuardBuilder.FUNCTION_MATCH),
|
||||
)
|
||||
elif isinstance(value, torch.autograd.function.FunctionCtx):
|
||||
saved_tensors_source = AttrSource(self.source, "saved_tensors")
|
||||
install_guard(
|
||||
self.source.make_guard(GuardBuilder.TYPE_MATCH),
|
||||
saved_tensors_source.make_guard(GuardBuilder.LIST_LENGTH),
|
||||
)
|
||||
saved_tensors = [
|
||||
VariableBuilder(self.tx, GetItemSource(saved_tensors_source, n))(v)
|
||||
for n, v in enumerate(value.saved_tensors)
|
||||
@ -565,8 +551,6 @@ class VariableBuilder:
|
||||
AutogradFunctionContextVariable(
|
||||
value,
|
||||
source=self.source,
|
||||
guards=make_guards(GuardBuilder.TYPE_MATCH)
|
||||
| {saved_tensors_source.make_guard(GuardBuilder.LIST_LENGTH)},
|
||||
saved_tensors=SavedTensorBox(saved_tensors),
|
||||
),
|
||||
)
|
||||
@ -579,53 +563,43 @@ class VariableBuilder:
|
||||
and value == getattr(value.__self__, "apply", None)
|
||||
):
|
||||
# handle aliased autograd function `apply` calls
|
||||
self.install_guards(GuardBuilder.FUNCTION_MATCH)
|
||||
return GetAttrVariable(
|
||||
AutogradFunctionVariable(
|
||||
value.__self__,
|
||||
source=self.source,
|
||||
guards=make_guards(GuardBuilder.FUNCTION_MATCH),
|
||||
),
|
||||
AutogradFunctionVariable(value.__self__, source=self.source),
|
||||
"apply",
|
||||
)
|
||||
elif np and isinstance(value, np.number):
|
||||
return self.wrap_unspecialized_primitive(value)
|
||||
elif DataClassVariable.is_matching_object(value):
|
||||
return DataClassVariable.wrap(self, value).add_guards(
|
||||
make_guards(GuardBuilder.TYPE_MATCH)
|
||||
)
|
||||
self.install_guards(GuardBuilder.TYPE_MATCH)
|
||||
return DataClassVariable.wrap(self, value)
|
||||
elif HFPretrainedConfigVariable.is_matching_object(value):
|
||||
return HFPretrainedConfigVariable(
|
||||
value, guards=make_guards(GuardBuilder.TYPE_MATCH)
|
||||
)
|
||||
self.install_guards(GuardBuilder.TYPE_MATCH)
|
||||
return HFPretrainedConfigVariable(value)
|
||||
elif isinstance(value, HigherOrderOperator):
|
||||
return TorchHigherOrderOperatorVariable.make(
|
||||
value,
|
||||
source=self.source,
|
||||
guards=self.make_guards(
|
||||
GuardBuilder.TYPE_MATCH, GuardBuilder.NAME_MATCH
|
||||
),
|
||||
)
|
||||
self.install_guards(GuardBuilder.TYPE_MATCH, GuardBuilder.NAME_MATCH)
|
||||
return TorchHigherOrderOperatorVariable.make(value, source=self.source)
|
||||
elif type(value).__name__ == "builtin_function_or_method" and isinstance(
|
||||
value.__self__, torch_special_class_types
|
||||
):
|
||||
self.install_guards(GuardBuilder.FUNCTION_MATCH)
|
||||
return TorchVariable(
|
||||
value,
|
||||
guards=make_guards(GuardBuilder.FUNCTION_MATCH),
|
||||
)
|
||||
elif isinstance(value, _StreamBase):
|
||||
self.install_guards(GuardBuilder.ID_MATCH)
|
||||
return StreamVariable(
|
||||
None,
|
||||
value,
|
||||
value.device.type,
|
||||
source=self.source,
|
||||
guards=make_guards(GuardBuilder.ID_MATCH),
|
||||
)
|
||||
elif isinstance(value, _EventBase):
|
||||
self.install_guards(GuardBuilder.ID_MATCH)
|
||||
return EventVariable(
|
||||
None,
|
||||
value,
|
||||
source=self.source,
|
||||
guards=make_guards(GuardBuilder.ID_MATCH),
|
||||
)
|
||||
elif (
|
||||
isinstance(value, torch._C._TensorMeta)
|
||||
@ -636,55 +610,36 @@ class VariableBuilder:
|
||||
istype(value, contextlib.nullcontext)
|
||||
and inspect.getattr_static(value, "enter_result", None) is None
|
||||
):
|
||||
return NullContextVariable(
|
||||
source=self.source,
|
||||
guards=make_guards(GuardBuilder.FUNCTION_MATCH),
|
||||
)
|
||||
# TODO(jansel): I think this can be TYPE_MATCH
|
||||
self.install_guards(GuardBuilder.FUNCTION_MATCH)
|
||||
return NullContextVariable(source=self.source)
|
||||
elif KeyedJaggedTensorVariable.is_matching_object(value):
|
||||
result = KeyedJaggedTensorVariable(
|
||||
value,
|
||||
source=self.source,
|
||||
guards=self.make_guards(GuardBuilder.TYPE_MATCH),
|
||||
)
|
||||
self.install_guards(GuardBuilder.TYPE_MATCH)
|
||||
result = KeyedJaggedTensorVariable(value, source=self.source)
|
||||
# TODO: this doing it manually is bad
|
||||
return self.tx.output.side_effects.track_object_existing(
|
||||
self.source, value, result
|
||||
)
|
||||
elif isinstance(value, torch.optim.Optimizer):
|
||||
return OptimizerVariable(
|
||||
value,
|
||||
source=self.source,
|
||||
guards=self.make_guards(GuardBuilder.TYPE_MATCH),
|
||||
)
|
||||
self.install_guards(GuardBuilder.TYPE_MATCH)
|
||||
return OptimizerVariable(value, source=self.source)
|
||||
elif ProcessGroupVariable.is_process_group(value):
|
||||
return ProcessGroupVariable(
|
||||
value,
|
||||
source=self.source,
|
||||
guards=self.make_guards(GuardBuilder.ID_MATCH),
|
||||
)
|
||||
self.install_guards(GuardBuilder.ID_MATCH)
|
||||
return ProcessGroupVariable(value, source=self.source)
|
||||
elif DeviceMeshVariable.is_device_mesh(value):
|
||||
# TODO: see if we need to add custom guard instead
|
||||
# of a simple ID_MATCH
|
||||
return DeviceMeshVariable(
|
||||
value,
|
||||
source=self.source,
|
||||
guards=self.make_guards(GuardBuilder.ID_MATCH),
|
||||
)
|
||||
# TODO: see if we need to add custom guard instead of a simple ID_MATCH
|
||||
self.install_guards(GuardBuilder.ID_MATCH)
|
||||
return DeviceMeshVariable(value, source=self.source)
|
||||
elif PlacementClassVariable.is_placement_type(value):
|
||||
# TODO: see if we need to add custom guard instead
|
||||
# of a simple ID_MATCH
|
||||
return PlacementClassVariable(
|
||||
value,
|
||||
source=self.source,
|
||||
guards=make_guards(GuardBuilder.ID_MATCH),
|
||||
)
|
||||
# TODO: see if we need to add custom guard instead of a simple ID_MATCH
|
||||
self.install_guards(GuardBuilder.ID_MATCH)
|
||||
return PlacementClassVariable(value, source=self.source)
|
||||
elif PlacementVariable.is_placement(value):
|
||||
# TODO: see if we need to add custom guard instead
|
||||
# of a simple ID_MATCH
|
||||
# TODO: see if we need to add custom guard instead of a simple ID_MATCH
|
||||
self.install_guards(GuardBuilder.ID_MATCH)
|
||||
return PlacementVariable(
|
||||
value,
|
||||
source=self.source,
|
||||
guards=make_guards(GuardBuilder.ID_MATCH),
|
||||
)
|
||||
elif isinstance(value, torch.SymBool):
|
||||
# Note: the idea here is to re-use the infra we've built for SymInt by simulating the
|
||||
@ -727,12 +682,12 @@ class VariableBuilder:
|
||||
new_symint == 1,
|
||||
)
|
||||
elif isinstance(value, (JITFunction, Autotuner)):
|
||||
self.install_guards(GuardBuilder.ID_MATCH)
|
||||
return TritonKernelVariable(
|
||||
value,
|
||||
None, # No kernel idx provided
|
||||
None, # No grid provided
|
||||
source=self.source,
|
||||
guards=make_guards(GuardBuilder.ID_MATCH),
|
||||
)
|
||||
elif trace_rules.lookup(value) is not None:
|
||||
return trace_rules.lookup(value).create_with_source(
|
||||
@ -741,10 +696,10 @@ class VariableBuilder:
|
||||
elif is_allowed(value):
|
||||
if is_user_defined_allowed(value):
|
||||
self.tx.output.has_user_defined_allowed_in_graph = True
|
||||
self.install_guards(GuardBuilder.FUNCTION_MATCH)
|
||||
return TorchVariable(
|
||||
value,
|
||||
source=self.source,
|
||||
guards=make_guards(GuardBuilder.FUNCTION_MATCH),
|
||||
)
|
||||
elif (
|
||||
istype(value, (type, types.FunctionType))
|
||||
@ -752,17 +707,17 @@ class VariableBuilder:
|
||||
and not inspect.getattr_static(value, "_torchdynamo_inline", False)
|
||||
and not inspect.getattr_static(value, "__script_if_tracing_wrapper", False)
|
||||
):
|
||||
self.install_guards(GuardBuilder.FUNCTION_MATCH)
|
||||
return SkipFilesVariable(
|
||||
value,
|
||||
skipfiles.check_verbose(value, allow_torch=True).reason,
|
||||
source=self.source,
|
||||
guards=make_guards(GuardBuilder.FUNCTION_MATCH),
|
||||
)
|
||||
elif istype(value, (types.FunctionType, torch.jit.ScriptFunction)):
|
||||
self.install_guards(GuardBuilder.FUNCTION_MATCH)
|
||||
return UserFunctionVariable(
|
||||
value,
|
||||
source=self.source,
|
||||
guards=make_guards(GuardBuilder.FUNCTION_MATCH),
|
||||
)
|
||||
elif isinstance(value, types.MethodType) and isinstance(
|
||||
value.__self__, torch.nn.Module
|
||||
@ -784,40 +739,33 @@ class VariableBuilder:
|
||||
assert self_obj and isinstance(
|
||||
self_obj, VariableTracker
|
||||
), "Failed to produce a valid self obj"
|
||||
self.install_guards(GuardBuilder.FUNCTION_MATCH)
|
||||
return UserMethodVariable(
|
||||
value.__func__,
|
||||
self_obj,
|
||||
source=self.source,
|
||||
guards=make_guards(GuardBuilder.FUNCTION_MATCH),
|
||||
)
|
||||
elif istype(value, (types.ModuleType, replay_record.DummyModule)):
|
||||
self.install_guards(GuardBuilder.FUNCTION_MATCH)
|
||||
return PythonModuleVariable(
|
||||
value,
|
||||
source=self.source,
|
||||
guards=make_guards(GuardBuilder.PYMODULE_MATCH),
|
||||
)
|
||||
elif isinstance(value, types.GetSetDescriptorType):
|
||||
return GetSetDescriptorVariable(
|
||||
value, guards=self.make_guards(GuardBuilder.FUNCTION_MATCH)
|
||||
)
|
||||
self.install_guards(GuardBuilder.FUNCTION_MATCH)
|
||||
return GetSetDescriptorVariable(value)
|
||||
elif isinstance(value, types.MethodWrapperType):
|
||||
return MethodWrapperVariable(
|
||||
value,
|
||||
source=self.source,
|
||||
guards=self.make_guards(GuardBuilder.FUNCTION_MATCH),
|
||||
)
|
||||
self.install_guards(GuardBuilder.FUNCTION_MATCH)
|
||||
return MethodWrapperVariable(value, source=self.source)
|
||||
elif issubclass(type(value), type):
|
||||
self.install_guards(GuardBuilder.FUNCTION_MATCH)
|
||||
return UserDefinedClassVariable(
|
||||
value,
|
||||
source=self.source,
|
||||
guards=make_guards(GuardBuilder.FUNCTION_MATCH),
|
||||
)
|
||||
else:
|
||||
result = UserDefinedObjectVariable(
|
||||
value,
|
||||
source=self.source,
|
||||
guards=self.make_guards(GuardBuilder.TYPE_MATCH),
|
||||
)
|
||||
self.install_guards(GuardBuilder.TYPE_MATCH)
|
||||
result = UserDefinedObjectVariable(value, source=self.source)
|
||||
if not SideEffects.cls_supports_mutation_side_effects(type(value)):
|
||||
# don't allow STORE_ATTR mutation with custom __setattr__
|
||||
return result
|
||||
@ -857,36 +805,32 @@ class VariableBuilder:
|
||||
def wrap_listlike(self, value: Union[tuple, list, odict_values, NamedTuple]):
|
||||
# One can index a tensor with a list/tuple. Therefore, we need to
|
||||
# have a stricter match.
|
||||
guards = self.make_guards(GuardBuilder.LIST_LENGTH)
|
||||
self.install_guards(GuardBuilder.LIST_LENGTH)
|
||||
|
||||
for item in value:
|
||||
if item is value:
|
||||
unimplemented("list elements are pointing to the list itself")
|
||||
|
||||
output = [
|
||||
VariableBuilder(self.tx, GetItemSource(self.get_source(), i))(
|
||||
item
|
||||
).add_guards(guards)
|
||||
VariableBuilder(self.tx, GetItemSource(self.get_source(), i))(item)
|
||||
for i, item in enumerate(value)
|
||||
]
|
||||
result = BaseListVariable.cls_for_instance(value)(
|
||||
output, mutable_local=MutableLocal(), guards=guards
|
||||
output, mutable_local=MutableLocal()
|
||||
)
|
||||
if istype(value, list):
|
||||
return self.tx.output.side_effects.track_list(self.source, value, result)
|
||||
return result
|
||||
|
||||
def wrap_tuple_iterator(self, value: tuple_iterator):
|
||||
guards = self.make_guards(GuardBuilder.TUPLE_ITERATOR_LEN)
|
||||
self.install_guards(GuardBuilder.TUPLE_ITERATOR_LEN)
|
||||
output = [
|
||||
VariableBuilder(self.tx, TupleIteratorGetItemSource(self.get_source(), i))(
|
||||
tuple_iterator_getitem(value, i)
|
||||
).add_guards(guards)
|
||||
)
|
||||
for i in range(tuple_iterator_len(value))
|
||||
]
|
||||
return TupleIteratorVariable(
|
||||
output, mutable_local=MutableLocal(), guards=guards
|
||||
)
|
||||
return TupleIteratorVariable(output, mutable_local=MutableLocal())
|
||||
|
||||
def wrap_slice_range(self, value: Union[slice, range]):
|
||||
items = [
|
||||
@ -896,21 +840,20 @@ class VariableBuilder:
|
||||
for k in ("start", "stop", "step")
|
||||
]
|
||||
if isinstance(value, slice):
|
||||
return SliceVariable(
|
||||
items, guards=self.make_guards(GuardBuilder.TYPE_MATCH)
|
||||
)
|
||||
self.install_guards(GuardBuilder.TYPE_MATCH)
|
||||
return SliceVariable(items)
|
||||
else:
|
||||
return RangeVariable(
|
||||
items, guards=self.make_guards(GuardBuilder.EQUALS_MATCH)
|
||||
)
|
||||
# TODO(jansel): I think this can be TYPE_MATCH
|
||||
self.install_guards(GuardBuilder.EQUALS_MATCH)
|
||||
return RangeVariable(items)
|
||||
|
||||
def wrap_module(self, value: torch.nn.Module):
|
||||
from ..eval_frame import OptimizedModule
|
||||
|
||||
if istype(value, OptimizedModule):
|
||||
guards = self.make_guards(GuardBuilder.TYPE_MATCH)
|
||||
self.install_guards(GuardBuilder.TYPE_MATCH)
|
||||
self.source = AttrSource(self.source, "_orig_mod")
|
||||
return self.wrap_module(value._orig_mod).add_guards(guards)
|
||||
return self.wrap_module(value._orig_mod)
|
||||
|
||||
if (
|
||||
isinstance(value, (torch.nn.RNN, torch.nn.GRU, torch.nn.LSTM))
|
||||
@ -919,9 +862,8 @@ class VariableBuilder:
|
||||
unimplemented("TorchDynamo purposely graph breaks on RNN, GRU, LSTMs")
|
||||
if mutation_guard.is_dynamic_nn_module(value):
|
||||
# created dynamically, don't specialize on it
|
||||
result = UnspecializedNNModuleVariable(
|
||||
value, guards=self.make_guards(GuardBuilder.TYPE_MATCH)
|
||||
)
|
||||
self.install_guards(GuardBuilder.TYPE_MATCH)
|
||||
result = UnspecializedNNModuleVariable(value)
|
||||
if not SideEffects.cls_supports_mutation_side_effects(type(value)):
|
||||
# don't allow STORE_ATTR mutation with custom __setattr__
|
||||
return result
|
||||
@ -931,9 +873,8 @@ class VariableBuilder:
|
||||
elif issubclass(
|
||||
value.__class__, torch.nn.parallel.distributed.DistributedDataParallel
|
||||
):
|
||||
return UnspecializedNNModuleVariable(
|
||||
value, guards=self.make_guards(GuardBuilder.TYPE_MATCH)
|
||||
)
|
||||
self.install_guards(GuardBuilder.TYPE_MATCH)
|
||||
return UnspecializedNNModuleVariable(value)
|
||||
elif getattr(value, "_is_fsdp_managed_module", False):
|
||||
# See note [Dynamo treats FSDP wrapped modules as UnspecializedNNModule]
|
||||
# in fully_sharded_data_parallel.py for more information
|
||||
@ -960,11 +901,8 @@ class VariableBuilder:
|
||||
#
|
||||
# ID_MATCH is required to disambiguate cases as simple as a unit test that constructs 2 models and wraps
|
||||
# them differently with different FSDP configs. (test_dynamo_distributed.py -k test_fsdp_aot_eager)
|
||||
return FSDPManagedNNModuleVariable(
|
||||
value,
|
||||
guards=self.make_guards(GuardBuilder.TYPE_MATCH, GuardBuilder.ID_MATCH),
|
||||
source=self.get_source(),
|
||||
)
|
||||
self.install_guards(GuardBuilder.TYPE_MATCH, GuardBuilder.ID_MATCH)
|
||||
return FSDPManagedNNModuleVariable(value, source=self.get_source())
|
||||
else:
|
||||
return self.tx.output.register_attr_or_module(
|
||||
value,
|
||||
@ -976,12 +914,12 @@ class VariableBuilder:
|
||||
def wrap_literal(self, value):
|
||||
unspec = not config.specialize_int
|
||||
if unspec and type(value) is torch.Size:
|
||||
self.install_guards(GuardBuilder.LIST_LENGTH)
|
||||
return SizeVariable(
|
||||
[
|
||||
VariableBuilder(self.tx, GetItemSource(self.get_source(), i))(v)
|
||||
for i, v in enumerate(value)
|
||||
],
|
||||
guards=self.make_guards(GuardBuilder.LIST_LENGTH),
|
||||
]
|
||||
)
|
||||
elif unspec and type(value) is int:
|
||||
# unspecializing int by default, but still
|
||||
@ -995,17 +933,13 @@ class VariableBuilder:
|
||||
# NN modules on the fly)
|
||||
or self.source.guard_source().is_nn_module()
|
||||
):
|
||||
return ConstantVariable.create(
|
||||
value=value,
|
||||
guards=self.make_guards(GuardBuilder.CONSTANT_MATCH),
|
||||
)
|
||||
self.install_guards(GuardBuilder.CONSTANT_MATCH)
|
||||
return ConstantVariable.create(value=value)
|
||||
else:
|
||||
return self.wrap_unspecialized_primitive(value)
|
||||
else:
|
||||
return ConstantVariable.create(
|
||||
value=value,
|
||||
guards=self.make_guards(GuardBuilder.CONSTANT_MATCH),
|
||||
)
|
||||
self.install_guards(GuardBuilder.CONSTANT_MATCH)
|
||||
return ConstantVariable.create(value=value)
|
||||
|
||||
def assert_not_wrapped_by_this_graph(self, value: torch.Tensor):
|
||||
if is_fake(value) and maybe_get_fake_mode(value) is self.tx.fake_mode:
|
||||
@ -1027,11 +961,7 @@ class VariableBuilder:
|
||||
) and not source.guard_source().is_fsdp_module():
|
||||
self.assert_not_wrapped_by_this_graph(value)
|
||||
return self.tx.output.register_attr_or_module(
|
||||
value,
|
||||
self.name,
|
||||
source=source,
|
||||
# Guards are done inside register_attr_or_module
|
||||
# guards=self.make_guards(GuardBuilder.TENSOR_MATCH),
|
||||
value, self.name, source=source
|
||||
)
|
||||
|
||||
if is_constant_source(source):
|
||||
@ -1099,20 +1029,7 @@ class VariableBuilder:
|
||||
options["torch_function_fn"] = build_torch_function_fn(
|
||||
self.tx, value, self.source
|
||||
)
|
||||
options["guards"] = self.make_guards(GuardBuilder.TYPE_MATCH)
|
||||
else:
|
||||
options["guards"] = set()
|
||||
|
||||
options["guards"].update(
|
||||
self.make_guards(
|
||||
functools.partial(
|
||||
GuardBuilder.TENSOR_MATCH,
|
||||
value=value
|
||||
if isinstance(source, NumpyTensorSource)
|
||||
else TensorWeakRef(value),
|
||||
)
|
||||
)
|
||||
)
|
||||
self.install_guards(GuardBuilder.TYPE_MATCH)
|
||||
|
||||
if (
|
||||
isinstance(value, torch.Tensor)
|
||||
@ -1130,6 +1047,16 @@ class VariableBuilder:
|
||||
source=source,
|
||||
**options,
|
||||
)
|
||||
|
||||
self.install_guards(
|
||||
functools.partial(
|
||||
GuardBuilder.TENSOR_MATCH,
|
||||
value=value
|
||||
if isinstance(source, NumpyTensorSource)
|
||||
else TensorWeakRef(value),
|
||||
)
|
||||
)
|
||||
|
||||
self.tx.output.input_source_to_var[source] = tensor_variable
|
||||
assert "tensor_dict" not in tensor_proxy.node.meta
|
||||
tensor_proxy.node.meta["tensor_dict"] = value.__dict__.copy()
|
||||
@ -1172,11 +1099,11 @@ class VariableBuilder:
|
||||
# a tensor. It's a little annoying to make a VT to throw out, but there's so many side effects here
|
||||
# that there's not another great way to do this atm.
|
||||
# This creates the right graphargs, as well as registration for guards in tensor names and shape env.
|
||||
tensor_vt = VariableBuilder(self.tx, source)(tensor_value)
|
||||
VariableBuilder(self.tx, source)(tensor_value).recursive_realize()
|
||||
proxy = self.tx.output.root_tracer.create_graph_input(
|
||||
re.sub(r"[^a-zA-Z0-9]+", "_", self.name), type(tensor_value), source=source
|
||||
)
|
||||
options = {"source": source, "guards": tensor_vt.guards}
|
||||
options = {"source": source}
|
||||
numpy_ndarray_variable = wrap_fx_proxy_cls(
|
||||
target_cls=NumpyNdarrayVariable,
|
||||
tx=self.tx,
|
||||
@ -1229,10 +1156,8 @@ class VariableBuilder:
|
||||
# If specialize_int is False, also return
|
||||
# a constant (but this should have been handled
|
||||
# in the caller, TBH)
|
||||
return ConstantVariable.create(
|
||||
value=value,
|
||||
guards=self.make_guards(GuardBuilder.CONSTANT_MATCH),
|
||||
)
|
||||
self.install_guards(GuardBuilder.CONSTANT_MATCH)
|
||||
return ConstantVariable.create(value=value)
|
||||
|
||||
name = self.source.name()
|
||||
if name not in self.tx.output.frame_state:
|
||||
@ -1264,10 +1189,8 @@ class VariableBuilder:
|
||||
else: # assume_static_by_default
|
||||
# TODO: dynamic_dim = DimDynamic.STATIC should work but
|
||||
# for some reason it doesn't
|
||||
return ConstantVariable.create(
|
||||
value=value,
|
||||
guards=self.make_guards(GuardBuilder.CONSTANT_MATCH),
|
||||
)
|
||||
self.install_guards(GuardBuilder.CONSTANT_MATCH)
|
||||
return ConstantVariable.create(value=value)
|
||||
|
||||
wrapped_value = shape_env.create_unspecified_symint_and_symbol(
|
||||
value,
|
||||
@ -1281,11 +1204,8 @@ class VariableBuilder:
|
||||
else:
|
||||
wrapped_value = torch.tensor(value)
|
||||
if not isinstance(self.get_source(), RandomValueSource):
|
||||
guards = {self.get_source().make_guard(GuardBuilder.TYPE_MATCH)}
|
||||
options = {"guards": guards}
|
||||
else:
|
||||
options = {}
|
||||
options.update({"source": self.get_source()})
|
||||
install_guard(self.get_source().make_guard(GuardBuilder.TYPE_MATCH))
|
||||
options = {"source": self.get_source()}
|
||||
if isinstance(wrapped_value, torch.Tensor):
|
||||
options.update({"raw_value": value})
|
||||
|
||||
@ -1352,16 +1272,19 @@ def _dataclasses_fields_lambda(obj):
|
||||
|
||||
|
||||
def wrap_fx_proxy(tx, proxy, example_value=None, subclass_type=None, **options):
|
||||
return wrap_fx_proxy_cls(
|
||||
target_cls=TensorVariable
|
||||
if not subclass_type
|
||||
else TensorWithTFOverrideVariable,
|
||||
tx=tx,
|
||||
proxy=proxy,
|
||||
example_value=example_value,
|
||||
subclass_type=subclass_type,
|
||||
kwargs = {
|
||||
"tx": tx,
|
||||
"proxy": proxy,
|
||||
"example_value": example_value,
|
||||
"subclass_type": subclass_type,
|
||||
**options,
|
||||
)
|
||||
}
|
||||
if subclass_type is None:
|
||||
return wrap_fx_proxy_cls(target_cls=TensorVariable, **kwargs)
|
||||
else:
|
||||
result = wrap_fx_proxy_cls(target_cls=TensorWithTFOverrideVariable, **kwargs)
|
||||
result.install_global(tx)
|
||||
return result
|
||||
|
||||
|
||||
# Note: Unfortunate split due to some gross classes existing that subclass TensorVariable
|
||||
|
@ -19,7 +19,7 @@ from ..exc import (
|
||||
UserError,
|
||||
UserErrorType,
|
||||
)
|
||||
from ..guards import GuardBuilder
|
||||
from ..guards import GuardBuilder, install_guard
|
||||
from ..replay_record import DummyModule
|
||||
from ..source import AttrSource, GetItemSource, is_constant_source, TypeSource
|
||||
from ..utils import (
|
||||
@ -339,7 +339,6 @@ class BuiltinVariable(VariableTracker):
|
||||
a,
|
||||
ListVariable(
|
||||
list(a.items) + list(b.unpack_var_sequence(tx)),
|
||||
regen_guards=False,
|
||||
**options,
|
||||
),
|
||||
)
|
||||
@ -826,23 +825,22 @@ class BuiltinVariable(VariableTracker):
|
||||
mutable_local=MutableLocal(),
|
||||
)
|
||||
elif obj.has_unpack_var_sequence(tx):
|
||||
guards = set()
|
||||
if obj.source and not is_constant_source(obj.source):
|
||||
if isinstance(obj, TupleIteratorVariable):
|
||||
guards.add(obj.source.make_guard(GuardBuilder.TUPLE_ITERATOR_LEN))
|
||||
install_guard(
|
||||
obj.source.make_guard(GuardBuilder.TUPLE_ITERATOR_LEN)
|
||||
)
|
||||
else:
|
||||
guards.add(obj.source.make_guard(GuardBuilder.LIST_LENGTH))
|
||||
install_guard(obj.source.make_guard(GuardBuilder.LIST_LENGTH))
|
||||
if cls is SetVariable:
|
||||
return cls(
|
||||
list(obj.unpack_var_sequence(tx)),
|
||||
mutable_local=MutableLocal(),
|
||||
guards=guards,
|
||||
).add_options(self, obj)
|
||||
|
||||
return cls(
|
||||
list(obj.unpack_var_sequence(tx)),
|
||||
mutable_local=MutableLocal(),
|
||||
guards=guards,
|
||||
).add_options(self, obj)
|
||||
|
||||
call_iter = _call_iter_tuple_list
|
||||
@ -1060,7 +1058,6 @@ class BuiltinVariable(VariableTracker):
|
||||
from .builder import SourcelessBuilder, VariableBuilder
|
||||
|
||||
options = VariableTracker.propagate(self, obj, name_var)
|
||||
guards = options["guards"]
|
||||
name = name_var.as_python_constant()
|
||||
|
||||
if not name_var.is_python_constant():
|
||||
@ -1075,10 +1072,9 @@ class BuiltinVariable(VariableTracker):
|
||||
|
||||
if default is not None:
|
||||
hasattr_var = self.call_hasattr(tx, obj, name_var)
|
||||
guards.update(hasattr_var.guards)
|
||||
assert hasattr_var.as_python_constant() in (True, False)
|
||||
if not hasattr_var.as_python_constant():
|
||||
return default.add_guards(guards)
|
||||
return default
|
||||
|
||||
if obj.source:
|
||||
source = AttrSource(obj.source, name)
|
||||
@ -1152,14 +1148,14 @@ class BuiltinVariable(VariableTracker):
|
||||
elif ConstantVariable.is_literal(member):
|
||||
return ConstantVariable.create(member, **options)
|
||||
else:
|
||||
return VariableBuilder(tx, source)(member).add_guards(guards)
|
||||
return VariableBuilder(tx, source)(member)
|
||||
elif isinstance(obj, (PythonModuleVariable, DummyModule)):
|
||||
member = obj.value.__dict__[name]
|
||||
|
||||
if config.replay_record_enabled:
|
||||
tx.exec_recorder.record_module_access(obj.value, name, member)
|
||||
|
||||
return VariableBuilder(tx, source)(member).add_guards(guards)
|
||||
return VariableBuilder(tx, source)(member)
|
||||
elif istype(obj, UserFunctionVariable) and name in ("__name__", "__module__"):
|
||||
return ConstantVariable.create(
|
||||
getattr(obj.fn, name), **VariableTracker.propagate(obj)
|
||||
|
@ -6,7 +6,7 @@ from torch._dynamo.source import GetItemSource
|
||||
|
||||
from .. import variables
|
||||
from ..exc import unimplemented, UserError, UserErrorType
|
||||
from ..guards import GuardBuilder
|
||||
from ..guards import GuardBuilder, install_guard
|
||||
from ..utils import np
|
||||
from .base import typestr, VariableTracker
|
||||
|
||||
@ -41,21 +41,15 @@ class ConstantVariable(VariableTracker):
|
||||
items = []
|
||||
for i, x in enumerate(value):
|
||||
item_source = GetItemSource(source, i) if source else None
|
||||
guards = (
|
||||
{item_source.make_guard(GuardBuilder.CONSTANT_MATCH)}
|
||||
if item_source
|
||||
else None
|
||||
)
|
||||
if item_source:
|
||||
install_guard(item_source.make_guard(GuardBuilder.CONSTANT_MATCH))
|
||||
items.append(
|
||||
ConstantVariable.create(
|
||||
x,
|
||||
source=item_source,
|
||||
guards=guards,
|
||||
)
|
||||
)
|
||||
return variables.BaseListVariable.cls_for(type(value))(
|
||||
items, regen_guards=True, **kwargs
|
||||
)
|
||||
return variables.BaseListVariable.cls_for(type(value))(items, **kwargs)
|
||||
|
||||
return ConstantVariable(value, **kwargs)
|
||||
|
||||
|
@ -9,7 +9,7 @@ from .. import variables
|
||||
from ..bytecode_transformation import create_call_function, create_instruction
|
||||
from ..device_interface import get_interface_for_device
|
||||
from ..exc import unimplemented, Unsupported
|
||||
from ..guards import GuardBuilder
|
||||
from ..guards import GuardBuilder, install_guard
|
||||
from ..source import AttrSource, GlobalStateSource
|
||||
from .base import VariableTracker
|
||||
from .functions import (
|
||||
@ -161,7 +161,7 @@ class GenericContextWrappingVariable(ContextWrappingVariable):
|
||||
class GradModeVariable(ContextWrappingVariable):
|
||||
"""represents torch.{no_grad,enable_grad,set_grad_mode}()"""
|
||||
|
||||
_guards_singleton = {Guard(GlobalStateSource(), GuardBuilder.GRAD_MODE)}
|
||||
_guards_singleton = Guard(GlobalStateSource(), GuardBuilder.GRAD_MODE)
|
||||
|
||||
@staticmethod
|
||||
def create(tx, target_value, initialized=True, **kwargs):
|
||||
@ -179,8 +179,8 @@ class GradModeVariable(ContextWrappingVariable):
|
||||
super().__init__(
|
||||
target_values=target_values, initial_values=initial_values, **kwargs
|
||||
)
|
||||
self.guards = self.guards | self._guards_singleton
|
||||
self.initialized = initialized
|
||||
install_guard(self._guards_singleton)
|
||||
|
||||
def enter(self, tx):
|
||||
if not self.initialized:
|
||||
@ -263,7 +263,7 @@ class InferenceModeVariable(ContextWrappingVariable):
|
||||
class TorchFunctionDisableVariable(ContextWrappingVariable):
|
||||
"""represents whether torch function overrides are enabled or not"""
|
||||
|
||||
_guards_singleton = {Guard(GlobalStateSource(), GuardBuilder.TORCH_FUNCTION_STATE)}
|
||||
_guards_singleton = Guard(GlobalStateSource(), GuardBuilder.TORCH_FUNCTION_STATE)
|
||||
|
||||
@staticmethod
|
||||
def create(tx, **kwargs):
|
||||
@ -281,7 +281,7 @@ class TorchFunctionDisableVariable(ContextWrappingVariable):
|
||||
super().__init__(
|
||||
target_values=target_values, initial_values=initial_values, **kwargs
|
||||
)
|
||||
self.guards = self.guards | self._guards_singleton
|
||||
install_guard(self._guards_singleton)
|
||||
|
||||
def enter(self, tx):
|
||||
return variables.ConstantVariable.create(
|
||||
@ -296,9 +296,9 @@ class TorchFunctionDisableVariable(ContextWrappingVariable):
|
||||
class DeterministicAlgorithmsVariable(ContextWrappingVariable):
|
||||
"""represents torch.{are_deterministic_algorithms_enabled,use_deterministic_algorithms}()"""
|
||||
|
||||
_guards_singleton = {
|
||||
Guard(GlobalStateSource(), GuardBuilder.DETERMINISTIC_ALGORITHMS)
|
||||
}
|
||||
_guards_singleton = Guard(
|
||||
GlobalStateSource(), GuardBuilder.DETERMINISTIC_ALGORITHMS
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def create(tx, target_value, **kwargs):
|
||||
@ -315,7 +315,7 @@ class DeterministicAlgorithmsVariable(ContextWrappingVariable):
|
||||
super().__init__(
|
||||
target_values=target_values, initial_values=initial_values, **kwargs
|
||||
)
|
||||
self.guards = self.guards | self._guards_singleton
|
||||
install_guard(self._guards_singleton)
|
||||
|
||||
def enter(self, tx):
|
||||
return variables.ConstantVariable.create(
|
||||
|
@ -13,7 +13,7 @@ from ..bytecode_transformation import create_call_function, create_instruction
|
||||
from ..eval_frame import skip_code
|
||||
|
||||
from ..exc import unimplemented
|
||||
from ..guards import GuardBuilder, make_dupe_guard
|
||||
from ..guards import GuardBuilder, install_guard, make_dupe_guard
|
||||
from ..source import AttrSource, GetItemSource, GlobalWeakRefSource
|
||||
from ..utils import global_key_name, istensor, iter_contains
|
||||
from .base import MutableLocal, VariableTracker
|
||||
@ -24,10 +24,8 @@ from .tensor import TensorVariable
|
||||
class ConstDictVariable(VariableTracker):
|
||||
def __init__(self, items, user_cls, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
# All the keys are constants
|
||||
assert not any(isinstance(x, VariableTracker) for x in items)
|
||||
self.guards.update(VariableTracker.propagate(items.values())["guards"])
|
||||
self.items = items
|
||||
self.user_cls = user_cls
|
||||
|
||||
@ -298,7 +296,6 @@ class SetVariable(VariableTracker):
|
||||
def __init__(
|
||||
self,
|
||||
items: List[VariableTracker],
|
||||
regen_guards=True,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
@ -309,10 +306,6 @@ class SetVariable(VariableTracker):
|
||||
self.items = []
|
||||
self._add(items)
|
||||
|
||||
# Sometimes, we know that we have passed in the guards from the items in the set
|
||||
if regen_guards:
|
||||
self.guards.update(VariableTracker.propagate(items)["guards"])
|
||||
|
||||
def as_proxy(self):
|
||||
return [x.as_proxy() for x in self.items]
|
||||
|
||||
@ -378,9 +371,7 @@ class SetVariable(VariableTracker):
|
||||
e.vt.source, set_element.vt.source
|
||||
)
|
||||
if alias_guard:
|
||||
e.vt = e.vt.add_guards(
|
||||
{e.vt.source.make_guard(alias_guard)}
|
||||
)
|
||||
install_guard(e.vt.source.make_guard(alias_guard))
|
||||
|
||||
return self.items
|
||||
|
||||
@ -401,7 +392,6 @@ class SetVariable(VariableTracker):
|
||||
result = SetVariable(
|
||||
self._add(item),
|
||||
mutable_local=self.mutable_local,
|
||||
regen_guards=False,
|
||||
**options,
|
||||
)
|
||||
tx.replace_all(self, result)
|
||||
@ -413,7 +403,7 @@ class SetVariable(VariableTracker):
|
||||
result = items.pop()
|
||||
tx.replace_all(
|
||||
self,
|
||||
SetVariable(items, regen_guards=False, **options),
|
||||
SetVariable(items, **options),
|
||||
)
|
||||
return result
|
||||
elif name == "__len__":
|
||||
@ -797,46 +787,40 @@ class PythonSysModulesVariable(VariableTracker):
|
||||
def _contains_helper(self, tx, key: VariableTracker):
|
||||
k = ConstDictVariable.get_key(key)
|
||||
has_key = k in sys.modules
|
||||
guard = self.make_guard(
|
||||
install_guard(
|
||||
self.make_guard(
|
||||
functools.partial(GuardBuilder.DICT_CONTAINS, key=k, invert=not has_key)
|
||||
)
|
||||
guards = {*self.guards, guard}
|
||||
return k, has_key, guards
|
||||
)
|
||||
return k, has_key
|
||||
|
||||
def call_contains(self, tx, key: VariableTracker):
|
||||
k, has_key, guards = self._contains_helper(tx, key)
|
||||
return ConstantVariable.create(
|
||||
value=has_key,
|
||||
guards=guards,
|
||||
)
|
||||
k, has_key = self._contains_helper(tx, key)
|
||||
return ConstantVariable.create(value=has_key)
|
||||
|
||||
def call_get(
|
||||
self, tx, key: VariableTracker, default: Optional[VariableTracker] = None
|
||||
):
|
||||
from .builder import VariableBuilder
|
||||
|
||||
k, has_key, guards = self._contains_helper(tx, key)
|
||||
k, has_key = self._contains_helper(tx, key)
|
||||
|
||||
if has_key:
|
||||
return VariableBuilder(
|
||||
tx,
|
||||
GetItemSource(self.source, k),
|
||||
)(
|
||||
sys.modules[k]
|
||||
).add_guards(guards)
|
||||
)(sys.modules[k])
|
||||
|
||||
if default is not None:
|
||||
return default.add_guards(guards)
|
||||
return default
|
||||
|
||||
return ConstantVariable.create(value=None, guards=guards)
|
||||
return ConstantVariable.create(value=None)
|
||||
|
||||
def call_getitem(self, tx, key: VariableTracker):
|
||||
from .builder import VariableBuilder
|
||||
|
||||
k, has_key, guards = self._contains_helper(tx, key)
|
||||
k, has_key = self._contains_helper(tx, key)
|
||||
return VariableBuilder(
|
||||
tx,
|
||||
GetItemSource(self.source, k),
|
||||
)(
|
||||
sys.modules[k]
|
||||
).add_guards(guards)
|
||||
)(sys.modules[k])
|
||||
|
@ -614,12 +614,6 @@ class FunctoolsPartialVariable(VariableTracker):
|
||||
self.keywords = keywords
|
||||
self.original = original
|
||||
|
||||
self.guards.update(VariableTracker.propagate(func)["guards"])
|
||||
for arg in args:
|
||||
self.guards.update(VariableTracker.propagate(arg)["guards"])
|
||||
for val in keywords.values():
|
||||
self.guards.update(VariableTracker.propagate(val)["guards"])
|
||||
|
||||
def call_function(
|
||||
self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
|
||||
) -> "VariableTracker":
|
||||
|
@ -25,7 +25,6 @@ from ..exc import (
|
||||
UserError,
|
||||
UserErrorType,
|
||||
)
|
||||
from ..guards import GuardBuilder
|
||||
from ..source import FSDPNNModuleSource, GetItemSource, NNModuleSource
|
||||
from ..utils import proxy_args_kwargs
|
||||
from .dicts import ConstDictVariable
|
||||
@ -100,9 +99,6 @@ def validate_args_and_maybe_create_graph_inputs(
|
||||
assert isinstance(a, VariableTracker)
|
||||
|
||||
if isinstance(a, ConstantVariable):
|
||||
# Ensures that we recompile when the constant value changes
|
||||
a.add_guard(GuardBuilder.CONSTANT_MATCH)
|
||||
|
||||
if manually_set_subgraph_inputs:
|
||||
# This arg is not used in the body of the higher order op.
|
||||
# Currently, this new input is added to make the calls
|
||||
@ -194,6 +190,11 @@ def speculate_subgraph(
|
||||
)
|
||||
|
||||
try:
|
||||
f, sub_args, sub_kwargs = VariableTracker.apply(
|
||||
# ensure guards on args get installed in parent subgraph
|
||||
lambda x: x.realize(),
|
||||
(f, sub_args, sub_kwargs),
|
||||
)
|
||||
with tx.output.subtracer(source_target, tracer) as subtracer:
|
||||
args = validate_args_and_maybe_create_graph_inputs(
|
||||
sub_args, subtracer, tx, manually_set_subgraph_inputs
|
||||
@ -247,7 +248,6 @@ def speculate_subgraph(
|
||||
"HigherOrderOperator body's output must consist of tensors only"
|
||||
)
|
||||
|
||||
tx.output.guards.update(output.guards)
|
||||
# The output proxies might not belong to this SubgraphTracer
|
||||
# (if they are free variables that were never lifted)
|
||||
# so lift them here.
|
||||
@ -411,7 +411,6 @@ class CondHigherOrderVariable(TorchHigherOrderOperatorVariable):
|
||||
f"item but got {str(type(args[0]))} "
|
||||
f"with original python type {str(args[0].python_type())}.",
|
||||
)
|
||||
tx.output.guards.update(args[0].guards)
|
||||
|
||||
# operands
|
||||
if not isinstance(args[3], (ListVariable, TupleVariable)):
|
||||
@ -1116,6 +1115,7 @@ class AutogradFunctionMethodHigherOrderVariable(TorchHigherOrderOperatorVariable
|
||||
else:
|
||||
fn = TorchVariable(self.value)
|
||||
checkpoint = tx.copy_graphstate()
|
||||
# TODO(jansel): BUG!!! we aren't copying on the line below, so the post-pre check below is pointless
|
||||
pre_guards = tx.output.guards
|
||||
graph_checkpoint = tx.output.graph
|
||||
|
||||
|
@ -23,7 +23,6 @@ class LazyCache:
|
||||
self.vt.parents_tracker.add(parents_tracker)
|
||||
del self.value
|
||||
del self.source
|
||||
tx.output.guards.update(self.vt.guards)
|
||||
|
||||
|
||||
class LazyVariableTracker(VariableTracker):
|
||||
@ -79,8 +78,6 @@ class LazyVariableTracker(VariableTracker):
|
||||
return getattr(self.realize(), item)
|
||||
|
||||
# most methods are auto-generated below, these are the ones we want to exclude
|
||||
add_guards = VariableTracker.add_guards
|
||||
add_guard = VariableTracker.add_guard
|
||||
add_options = VariableTracker.add_options
|
||||
apply = VariableTracker.apply
|
||||
copy = VariableTracker.copy
|
||||
|
@ -48,16 +48,11 @@ class BaseListVariable(VariableTracker):
|
||||
def __init__(
|
||||
self,
|
||||
items: List[VariableTracker],
|
||||
regen_guards=True,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
assert isinstance(items, list)
|
||||
assert all(isinstance(x, VariableTracker) for x in items)
|
||||
# Sometimes, we know that we have passed in the guards from the items in the list
|
||||
if regen_guards:
|
||||
self.guards.update(VariableTracker.propagate(items)["guards"])
|
||||
|
||||
self.items: List[VariableTracker] = items
|
||||
|
||||
def _as_proxy(self):
|
||||
@ -246,7 +241,6 @@ class CommonListMethodsVariable(BaseListVariable):
|
||||
self,
|
||||
type(self)(
|
||||
self.items + [arg],
|
||||
regen_guards=False,
|
||||
**options,
|
||||
),
|
||||
)
|
||||
@ -263,7 +257,6 @@ class CommonListMethodsVariable(BaseListVariable):
|
||||
self,
|
||||
type(self)(
|
||||
list(self.items) + list(arg.unpack_var_sequence(tx)),
|
||||
regen_guards=False,
|
||||
**options,
|
||||
),
|
||||
)
|
||||
@ -274,7 +267,7 @@ class CommonListMethodsVariable(BaseListVariable):
|
||||
items.insert(idx.as_python_constant(), value)
|
||||
return tx.replace_all(
|
||||
self,
|
||||
type(self)(items, regen_guards=False, **options),
|
||||
type(self)(items, **options),
|
||||
)
|
||||
elif name == "pop" and self.mutable_local:
|
||||
assert not kwargs
|
||||
@ -282,14 +275,14 @@ class CommonListMethodsVariable(BaseListVariable):
|
||||
result = items.pop(*[a.as_python_constant() for a in args])
|
||||
tx.replace_all(
|
||||
self,
|
||||
type(self)(items, regen_guards=False, **options),
|
||||
type(self)(items, **options),
|
||||
)
|
||||
return result
|
||||
elif name == "clear" and self.mutable_local:
|
||||
assert not kwargs and not args
|
||||
return tx.replace_all(
|
||||
self,
|
||||
type(self)([], regen_guards=False, **options),
|
||||
type(self)([], **options),
|
||||
)
|
||||
elif (
|
||||
name == "__setitem__"
|
||||
@ -304,16 +297,14 @@ class CommonListMethodsVariable(BaseListVariable):
|
||||
items[key.as_python_constant()] = list(value.items)
|
||||
else:
|
||||
items[key.as_python_constant()] = value
|
||||
result = ListVariable(items, regen_guards=False, **options)
|
||||
result = ListVariable(items, **options)
|
||||
return tx.replace_all(self, result)
|
||||
elif name == "copy":
|
||||
# List copy() doesn't have args and kwargs
|
||||
assert not kwargs
|
||||
assert not args
|
||||
items = list(self.items)
|
||||
return type(self)(
|
||||
items, regen_guards=False, mutable_local=MutableLocal(), **options
|
||||
)
|
||||
return type(self)(items, mutable_local=MutableLocal(), **options)
|
||||
else:
|
||||
return super().call_method(tx, name, args, kwargs)
|
||||
|
||||
@ -351,7 +342,7 @@ class ListVariable(CommonListMethodsVariable):
|
||||
items[key.as_python_constant()] = value.unpack_var_sequence(tx)
|
||||
else:
|
||||
items[key.as_python_constant()] = value
|
||||
result = ListVariable(items, regen_guards=False, **options)
|
||||
result = ListVariable(items, **options)
|
||||
return tx.replace_all(self, result)
|
||||
else:
|
||||
return super().call_method(tx, name, args, kwargs)
|
||||
@ -396,7 +387,7 @@ class DequeVariable(CommonListMethodsVariable):
|
||||
)
|
||||
items = list(self.items)
|
||||
items[key.as_python_constant()] = value
|
||||
result = DequeVariable(items, regen_guards=False, **options)
|
||||
result = DequeVariable(items, **options)
|
||||
return tx.replace_all(self, result)
|
||||
elif name == "extendleft" and self.mutable_local:
|
||||
assert not kwargs
|
||||
@ -405,7 +396,6 @@ class DequeVariable(CommonListMethodsVariable):
|
||||
self,
|
||||
DequeVariable(
|
||||
list(arg.unpack_var_sequence(tx)) + list(self.items),
|
||||
regen_guards=False,
|
||||
**options,
|
||||
),
|
||||
)
|
||||
@ -416,7 +406,7 @@ class DequeVariable(CommonListMethodsVariable):
|
||||
result = items.popleft()
|
||||
tx.replace_all(
|
||||
self,
|
||||
DequeVariable(list(items), regen_guards=False, **options),
|
||||
DequeVariable(list(items), **options),
|
||||
)
|
||||
return result
|
||||
elif name == "appendleft" and self.mutable_local:
|
||||
@ -425,7 +415,6 @@ class DequeVariable(CommonListMethodsVariable):
|
||||
self,
|
||||
DequeVariable(
|
||||
[args[0]] + list(self.items),
|
||||
regen_guards=False,
|
||||
**options,
|
||||
),
|
||||
)
|
||||
|
@ -13,7 +13,7 @@ import torch._numpy as tnp
|
||||
from .. import config, polyfill, variables
|
||||
from ..bytecode_transformation import create_call_function, create_instruction
|
||||
from ..exc import unimplemented
|
||||
from ..guards import GuardBuilder
|
||||
from ..guards import GuardBuilder, install_guard
|
||||
from ..source import AttrSource, GetItemSource, ODictGetItemSource, TypeSource
|
||||
from ..utils import (
|
||||
check_constant_args,
|
||||
@ -97,9 +97,8 @@ class SuperVariable(VariableTracker):
|
||||
return GetAttrVariable(self, name, **options)
|
||||
if source:
|
||||
options["source"] = source
|
||||
return variables.ConstantVariable.create(value, **options).add_guard(
|
||||
source.make_guard(GuardBuilder.CONSTANT_MATCH)
|
||||
)
|
||||
install_guard(source.make_guard(GuardBuilder.CONSTANT_MATCH))
|
||||
return variables.ConstantVariable.create(value, **options)
|
||||
return variables.ConstantVariable.create(value, **options)
|
||||
|
||||
def call_method(
|
||||
|
@ -10,7 +10,7 @@ import torch.nn
|
||||
from .. import skipfiles, variables
|
||||
from ..allowed_functions import is_allowed
|
||||
from ..exc import unimplemented, UnspecializeRestartAnalysis, Unsupported
|
||||
from ..guards import GuardBuilder
|
||||
from ..guards import GuardBuilder, install_guard
|
||||
from ..mutation_guard import GenerationTracker
|
||||
from ..source import (
|
||||
AttrSource,
|
||||
@ -127,11 +127,12 @@ class NNModuleVariable(VariableTracker):
|
||||
options = VariableTracker.propagate(self)
|
||||
mod = tx.output.get_submodule(self.module_key)
|
||||
result = hasattr(mod, name)
|
||||
return variables.ConstantVariable.create(result, **options).add_guard(
|
||||
install_guard(
|
||||
NNModuleSource(AttrSource(self.source, name)).make_guard(
|
||||
GuardBuilder.HASATTR
|
||||
)
|
||||
)
|
||||
return variables.ConstantVariable.create(result, **options)
|
||||
|
||||
def is_training(self, tx):
|
||||
mod = tx.output.get_submodule(self.module_key)
|
||||
@ -167,7 +168,6 @@ class NNModuleVariable(VariableTracker):
|
||||
from .builder import VariableBuilder
|
||||
|
||||
options = VariableTracker.propagate(self)
|
||||
guards = options.get("guards", set())
|
||||
|
||||
if self.source:
|
||||
source = AttrSource(self.source, name)
|
||||
@ -220,13 +220,12 @@ class NNModuleVariable(VariableTracker):
|
||||
if istype(subobj, property):
|
||||
return variables.UserFunctionVariable(
|
||||
subobj.fget,
|
||||
guards=guards,
|
||||
source=source,
|
||||
).call_function(tx, [(self)], {})
|
||||
elif istype(subobj, classmethod):
|
||||
return variables.UserMethodVariable(
|
||||
subobj.__func__,
|
||||
variables.UserDefinedObjectVariable(type(base), guards=guards),
|
||||
variables.UserDefinedObjectVariable(type(base)),
|
||||
**options,
|
||||
)
|
||||
elif istype(subobj, staticmethod):
|
||||
@ -616,7 +615,7 @@ class NNModuleVariable(VariableTracker):
|
||||
):
|
||||
# Inline the function
|
||||
fn = getattr(module, name).__func__
|
||||
fn_source = AttrSource(self.source, "__func__")
|
||||
fn_source = AttrSource(AttrSource(self.source, name), "__func__")
|
||||
options["source"] = fn_source
|
||||
return tx.inline_user_function_return(
|
||||
variables.UserFunctionVariable(fn, **options),
|
||||
@ -759,7 +758,7 @@ class UnspecializedNNModuleVariable(UserDefinedObjectVariable):
|
||||
assert not args or kwargs
|
||||
if tx.output.side_effects.has_pending_mutation(self):
|
||||
unimplemented("Module.parameters() with pending mutation")
|
||||
options["guards"].add(
|
||||
install_guard(
|
||||
self.source.make_guard(GuardBuilder.NN_MODULE_PARAM_NAMES)
|
||||
)
|
||||
items = []
|
||||
|
@ -4,7 +4,7 @@ from typing import Dict, List
|
||||
import torch
|
||||
from ..decorators import mark_static_address
|
||||
|
||||
from ..guards import GuardBuilder
|
||||
from ..guards import GuardBuilder, install_guard
|
||||
from ..source import AttrSource, GetItemSource, GlobalWeakRefSource
|
||||
from ..utils import global_key_name
|
||||
|
||||
@ -126,13 +126,12 @@ class OptimizerVariable(UserDefinedObjectVariable):
|
||||
|
||||
# state guards take a long time to generate
|
||||
# so we manually generate them here
|
||||
guards = set()
|
||||
state_source = AttrSource(self.source, "state")
|
||||
guards.add(state_source.make_guard(GuardBuilder.DICT_KEYS))
|
||||
install_guard(state_source.make_guard(GuardBuilder.DICT_KEYS))
|
||||
for p, value in self.value.state.items():
|
||||
tx.store_global_weakref(global_key_name(p), p)
|
||||
p_state_source = GetItemSource(state_source, self.tensor_to_source[p])
|
||||
guards.add(p_state_source.make_guard(GuardBuilder.DICT_KEYS))
|
||||
install_guard(p_state_source.make_guard(GuardBuilder.DICT_KEYS))
|
||||
for k, v in value.items():
|
||||
if (
|
||||
isinstance(v, torch.Tensor)
|
||||
@ -141,7 +140,7 @@ class OptimizerVariable(UserDefinedObjectVariable):
|
||||
):
|
||||
self.tensor_to_source[v] = GetItemSource(p_state_source, k)
|
||||
elif v is None or isinstance(v, (bool, int, float, str)):
|
||||
guards.add(
|
||||
install_guard(
|
||||
GetItemSource(p_state_source, k).make_guard(
|
||||
GuardBuilder.CONSTANT_MATCH
|
||||
)
|
||||
@ -149,12 +148,10 @@ class OptimizerVariable(UserDefinedObjectVariable):
|
||||
else:
|
||||
raise GuardInstallException()
|
||||
|
||||
tx.output.guards.update(guards)
|
||||
|
||||
group_guards = VariableBuilder(tx, AttrSource(self.source, "param_groups"))(
|
||||
# this next line has the side effect of installing guards
|
||||
VariableBuilder(tx, AttrSource(self.source, "param_groups"))(
|
||||
self.value.param_groups
|
||||
)
|
||||
tx.output.guards.update(group_guards.guards)
|
||||
).recursive_realize()
|
||||
|
||||
def wrap_tensor(self, tx, tensor_value):
|
||||
"""Wrap state tensor in a TensorVariable"""
|
||||
|
@ -1,4 +1,5 @@
|
||||
import functools
|
||||
|
||||
import inspect
|
||||
import operator
|
||||
import types
|
||||
@ -31,7 +32,7 @@ from .. import config, variables
|
||||
from .._trace_wrapped_higher_order_op import trace_wrapped
|
||||
|
||||
from ..exc import unimplemented, UserError, UserErrorType
|
||||
from ..guards import GuardBuilder
|
||||
from ..guards import GuardBuilder, install_guard
|
||||
from ..source import AttrSource
|
||||
from ..utils import (
|
||||
fqn,
|
||||
@ -206,12 +207,8 @@ class TensorVariable(VariableTracker):
|
||||
from .builder import VariableBuilder
|
||||
|
||||
attr_source = AttrSource(self.source, name)
|
||||
has_attr_guard = attr_source.make_guard(GuardBuilder.HASATTR)
|
||||
return (
|
||||
VariableBuilder(tx, attr_source)(real_value)
|
||||
.add_options(self)
|
||||
.add_guard(has_attr_guard)
|
||||
)
|
||||
install_guard(attr_source.make_guard(GuardBuilder.HASATTR))
|
||||
return VariableBuilder(tx, attr_source)(real_value).add_options(self)
|
||||
|
||||
def var_getattr(self, tx, name):
|
||||
from . import ConstantVariable, TorchVariable
|
||||
@ -254,7 +251,7 @@ class TensorVariable(VariableTracker):
|
||||
# In some cases, a <tensor>.<attr> guard can be evaluated first, and break if
|
||||
# <tensor> is later changed to another type
|
||||
if result is not None and self.source is not None:
|
||||
result = result.add_guard(self.make_guard(GuardBuilder.TYPE_MATCH))
|
||||
install_guard(self.make_guard(GuardBuilder.TYPE_MATCH))
|
||||
|
||||
# It's hard to get inplace view (metadata mutation) on graph input work properly across
|
||||
# dynamo/aot/inductor, just fall back.
|
||||
@ -607,7 +604,6 @@ class TensorVariable(VariableTracker):
|
||||
unimplemented(
|
||||
"boolean masking setitem backwards requires dynamic shapes"
|
||||
)
|
||||
tx.output.guards.update(options["guards"])
|
||||
tx.output.create_proxy(
|
||||
"call_function",
|
||||
operator.setitem,
|
||||
|
@ -8,6 +8,7 @@ import types
|
||||
from typing import Dict, List
|
||||
|
||||
from torch._streambase import _StreamBase
|
||||
from ..guards import install_guard
|
||||
|
||||
try:
|
||||
import numpy as np
|
||||
@ -159,10 +160,10 @@ class TorchCtxManagerClassVariable(VariableTracker):
|
||||
|
||||
@classmethod
|
||||
def create_with_source(cls, value, source):
|
||||
install_guard(source.make_guard(GuardBuilder.FUNCTION_MATCH))
|
||||
return TorchCtxManagerClassVariable(
|
||||
value,
|
||||
source=source,
|
||||
guards={source.make_guard(GuardBuilder.FUNCTION_MATCH)},
|
||||
)
|
||||
|
||||
def __init__(self, value, **kwargs):
|
||||
@ -259,6 +260,9 @@ class TorchVariable(VariableTracker):
|
||||
except RuntimeError as e:
|
||||
assert "No such operator" in str(e), str(e)
|
||||
self_should_be_none = None
|
||||
except AssertionError as e:
|
||||
assert "Unknown attribute" in str(e), str(e)
|
||||
self_should_be_none = None
|
||||
|
||||
# assert "_ntuple.<locals>.parse" not in str(value)
|
||||
|
||||
@ -425,18 +429,18 @@ class TorchVariable(VariableTracker):
|
||||
return self._call_ntuple(tx, args, kwargs, options)
|
||||
elif self.value is torch.is_grad_enabled:
|
||||
assert not (args or kwargs)
|
||||
return ConstantVariable.create(
|
||||
torch.is_grad_enabled(), **options
|
||||
).add_guards(GradModeVariable._guards_singleton)
|
||||
install_guard(GradModeVariable._guards_singleton)
|
||||
return ConstantVariable.create(torch.is_grad_enabled(), **options)
|
||||
elif self.value is torch.use_deterministic_algorithms and len(args) == 1:
|
||||
return DeterministicAlgorithmsVariable.create(
|
||||
tx, args[0].as_python_constant(), **options
|
||||
)
|
||||
elif self.value is torch.are_deterministic_algorithms_enabled:
|
||||
assert not (args or kwargs)
|
||||
install_guard(DeterministicAlgorithmsVariable._guards_singleton)
|
||||
return ConstantVariable.create(
|
||||
torch.are_deterministic_algorithms_enabled(), **options
|
||||
).add_guards(DeterministicAlgorithmsVariable._guards_singleton)
|
||||
)
|
||||
elif self.value is torch.autograd.graph.disable_saved_tensors_hooks:
|
||||
assert len(args) == 1
|
||||
return DisabledSavedTensorsHooksVariable.create(
|
||||
@ -444,9 +448,8 @@ class TorchVariable(VariableTracker):
|
||||
)
|
||||
elif self.value is torch._C._is_torch_function_enabled:
|
||||
assert not (args or kwargs)
|
||||
return ConstantVariable.create(
|
||||
tx.output.torch_function_enabled, **options
|
||||
).add_guards(TorchFunctionDisableVariable._guards_singleton)
|
||||
install_guard(TorchFunctionDisableVariable._guards_singleton)
|
||||
return ConstantVariable.create(tx.output.torch_function_enabled, **options)
|
||||
elif self.value in (
|
||||
torch.overrides.has_torch_function_variadic,
|
||||
torch.overrides.has_torch_function_unary,
|
||||
|
@ -5,6 +5,7 @@ import torch.utils._pytree as pytree
|
||||
|
||||
from torch.overrides import _get_overloaded_args, get_default_nowrap_functions
|
||||
from ..exc import unimplemented
|
||||
from ..guards import GuardBuilder, install_guard
|
||||
from ..source import AttrSource, GlobalSource
|
||||
from ..utils import is_tensor_base_attr_getter
|
||||
from .base import VariableTracker
|
||||
@ -133,14 +134,15 @@ class TensorWithTFOverrideVariable(TensorVariable):
|
||||
kwargs.pop("class_type") is torch.Tensor
|
||||
), "invalid class type in TensorWithTFOverrideVariable.from_tensor_var"
|
||||
var = cls(torch_function_fn=torch_function_fn, class_type=class_type, **kwargs)
|
||||
var.install_global(tx)
|
||||
return var
|
||||
|
||||
def install_global(self, tx):
|
||||
# stash the subclass type to rewrap an output tensor if needed
|
||||
# this is needed because the actual type needs to be available
|
||||
# each time the compiled artifact is run and outputs a wrapped tensor.
|
||||
if var.global_mangled_class_name() not in tx.output.global_scope:
|
||||
tx.output.install_global(var.global_mangled_class_name(), class_type)
|
||||
|
||||
return var
|
||||
if self.global_mangled_class_name() not in tx.output.global_scope:
|
||||
tx.output.install_global(self.global_mangled_class_name(), self.class_type)
|
||||
|
||||
def python_type(self):
|
||||
return self.class_type
|
||||
@ -157,7 +159,7 @@ class TensorWithTFOverrideVariable(TensorVariable):
|
||||
# [Note: __torch_function__] We currently only support attributes that are defined on
|
||||
# base tensors, custom attribute accesses will graph break.
|
||||
import torch
|
||||
from .builder import SourcelessBuilder, VariableBuilder
|
||||
from .builder import SourcelessBuilder
|
||||
|
||||
if name in banned_attrs or not hasattr(torch.Tensor, name):
|
||||
unimplemented(
|
||||
@ -172,14 +174,11 @@ class TensorWithTFOverrideVariable(TensorVariable):
|
||||
|
||||
if tx.output.torch_function_enabled:
|
||||
if self.source:
|
||||
get_fn = VariableBuilder(
|
||||
tx,
|
||||
source=AttrSource(
|
||||
AttrSource(AttrSource(self.source, "__class__"), name),
|
||||
"__get__",
|
||||
),
|
||||
)(inspect.getattr_static(self.python_type(), name).__get__)
|
||||
else:
|
||||
install_guard(
|
||||
AttrSource(AttrSource(self.source, "__class__"), name).make_guard(
|
||||
GuardBuilder.FUNCTION_MATCH
|
||||
)
|
||||
)
|
||||
get_fn = SourcelessBuilder()(tx, getattr(torch.Tensor, name).__get__)
|
||||
|
||||
return self.call_torch_function(
|
||||
|
@ -17,7 +17,7 @@ from torch._guards import TracingContext
|
||||
from .. import variables
|
||||
from ..allowed_functions import is_allowed
|
||||
from ..exc import unimplemented
|
||||
from ..guards import GuardBuilder
|
||||
from ..guards import GuardBuilder, install_guard
|
||||
from ..source import AttrSource, ODictGetItemSource, RandomValueSource
|
||||
from ..utils import (
|
||||
all_hook_names,
|
||||
@ -266,9 +266,10 @@ class UserDefinedObjectVariable(UserDefinedVariable):
|
||||
assert not (args or kwargs)
|
||||
keys = list(self.value.keys())
|
||||
assert all(map(ConstantVariable.is_literal, keys))
|
||||
install_guard(self.source.make_guard(GuardBuilder.ODICT_KEYS))
|
||||
return TupleVariable(
|
||||
[ConstantVariable.create(k, **options) for k in keys], **options
|
||||
).add_guard(self.source.make_guard(GuardBuilder.ODICT_KEYS))
|
||||
)
|
||||
|
||||
if (
|
||||
method in (collections.OrderedDict.__contains__, dict.__contains__)
|
||||
@ -278,9 +279,10 @@ class UserDefinedObjectVariable(UserDefinedVariable):
|
||||
in (collections.OrderedDict.keys, dict.keys)
|
||||
):
|
||||
assert not kwargs
|
||||
install_guard(self.source.make_guard(GuardBuilder.ODICT_KEYS))
|
||||
return ConstantVariable.create(
|
||||
args[0].as_python_constant() in self.value, **options
|
||||
).add_guard(self.source.make_guard(GuardBuilder.ODICT_KEYS))
|
||||
)
|
||||
|
||||
if (
|
||||
method is collections.OrderedDict.items
|
||||
@ -376,20 +378,15 @@ class UserDefinedObjectVariable(UserDefinedVariable):
|
||||
)
|
||||
):
|
||||
options = VariableTracker.propagate(self, args, kwargs.values())
|
||||
options.setdefault("guards", set())
|
||||
if self.source:
|
||||
options["guards"].add(
|
||||
AttrSource(self.source, "func").make_guard(GuardBuilder.ID_MATCH)
|
||||
)
|
||||
options["guards"].add(
|
||||
install_guard(
|
||||
AttrSource(self.source, "func").make_guard(GuardBuilder.ID_MATCH),
|
||||
AttrSource(self.source, "args").make_guard(
|
||||
GuardBuilder.CONSTANT_MATCH
|
||||
)
|
||||
)
|
||||
options["guards"].add(
|
||||
),
|
||||
AttrSource(self.source, "keywords").make_guard(
|
||||
GuardBuilder.CONSTANT_MATCH
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
partial_args = [
|
||||
@ -410,7 +407,7 @@ class UserDefinedObjectVariable(UserDefinedVariable):
|
||||
tx, partial_args, partial_kwargs
|
||||
)
|
||||
elif callable(self.value):
|
||||
self.add_guard(self.source.make_guard(GuardBuilder.FUNCTION_MATCH))
|
||||
install_guard(self.source.make_guard(GuardBuilder.FUNCTION_MATCH))
|
||||
return self.call_method(tx, "__call__", args, kwargs)
|
||||
|
||||
return super().call_function(tx, args, kwargs)
|
||||
@ -578,7 +575,7 @@ class UserDefinedObjectVariable(UserDefinedVariable):
|
||||
pass
|
||||
options = VariableTracker.propagate(self)
|
||||
if self.source:
|
||||
options["guards"].add(
|
||||
install_guard(
|
||||
AttrSource(self.source, name).make_guard(GuardBuilder.HASATTR)
|
||||
)
|
||||
if self._check_for_getattribute() or self._check_for_getattr():
|
||||
|
@ -241,7 +241,13 @@ class Guard:
|
||||
return output
|
||||
|
||||
def create(self, builder: GuardBuilderBase):
|
||||
try:
|
||||
return self.create_fn(builder, self)
|
||||
except Exception:
|
||||
log.error("Error while creating guard:\n%s", str(self).rstrip())
|
||||
if self.stack:
|
||||
log.error("Created at:\n%s", "".join(self.stack.format()[-4:]).rstrip())
|
||||
raise
|
||||
|
||||
def is_nn_module(self):
|
||||
return self.source.is_nn_module()
|
||||
|
Reference in New Issue
Block a user