[dynamo] Eagerly install guards (#111415)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/111415
Approved by: https://github.com/voznesenskym
ghstack dependencies: #111306
This commit is contained in:
Jason Ansel
2023-11-07 08:12:57 -08:00
committed by PyTorch MergeBot
parent 2964682490
commit 9664190952
30 changed files with 333 additions and 622 deletions

View File

@ -3292,7 +3292,9 @@ class GraphModule(torch.nn.Module):
cos = l_x_.cos(); l_x_ = None
return pytree.tree_unflatten([cos], self._out_spec)
"""
true_guard_code = ["cast_symbool_to_symint_guardless(L['pred']) == 1"]
true_guard_code = [
"cast_symbool_to_symint_guardless(L['pred']) == 1",
]
false_guard_code = [
"Ne(cast_symbool_to_symint_guardless(L['pred']), 1)",
"-9223372036854775808 <= cast_symbool_to_symint_guardless(L['pred'])",

View File

@ -297,25 +297,25 @@ class HigherOrderOpTests(torch._dynamo.test_case.TestCase):
actual_graph,
"""\
class GraphModule(torch.nn.Module):
def forward(self, L_x_ : torch.Tensor, L_y_ : torch.Tensor, L_z_ : torch.Tensor):
l_x_ = L_x_
l_y_ = L_y_
l_z_ = L_z_
def forward(self, L_d_x_ : torch.Tensor, L_d_y_0_ : torch.Tensor, L_d_y_1_2_ : torch.Tensor):
l_d_x_ = L_d_x_
l_d_y_0_ = L_d_y_0_
l_d_y_1_2_ = L_d_y_1_2_
wrap_body_0 = self.wrap_body_0
wrap = torch._higher_order_ops.wrap.wrap(wrap_body_0, l_x_, l_y_, l_z_); wrap_body_0 = l_x_ = l_y_ = l_z_ = None
wrap = torch._higher_order_ops.wrap.wrap(wrap_body_0, l_d_x_, l_d_y_0_, l_d_y_1_2_); wrap_body_0 = l_d_x_ = l_d_y_0_ = l_d_y_1_2_ = None
getitem = wrap[0]; wrap = None
return (getitem,)
class GraphModule(torch.nn.Module):
def forward(self, l_x_, l_y_, l_z_):
sin = l_x_.sin(); l_x_ = None
cos = l_y_.cos(); l_y_ = None
def forward(self, l_d_x_, l_d_y_0_, l_d_y_1_2_):
sin = l_d_x_.sin(); l_d_x_ = None
cos = l_d_y_0_.cos(); l_d_y_0_ = None
add = sin + cos; sin = cos = None
sin_1 = l_z_.sin(); l_z_ = None
sin_1 = l_d_y_1_2_.sin(); l_d_y_1_2_ = None
sub = add - sin_1; add = sin_1 = None
return (sub,)
""",
""", # NOQA: B950
)
def test_wrap_pytree_args_with_symint_constant(self):
@ -3005,9 +3005,9 @@ class GraphModule(torch.nn.Module):
actual,
"""\
class GraphModule(torch.nn.Module):
def forward(self, L_x_ : torch.Tensor, L_y_ : torch.Tensor):
l_x_ = L_x_
def forward(self, L_y_ : torch.Tensor, L_x_ : torch.Tensor):
child = L_y_
l_x_ = L_x_
_check_randomness_arg = torch._functorch.vmap._check_randomness_arg('error')
_check_randomness_arg_1 = torch._functorch.vmap._check_randomness_arg('error')
@ -3269,16 +3269,14 @@ class GraphModule(torch.nn.Module):
return torch.func.vmap(torch.sum, in_dims)(x)
x = torch.randn(3, 3, 3, 3)
opt = torch.compile(wrapper_fn, backend="eager", fullgraph=False, dynamic=True)
cnt = CompileCounter()
opt = torch.compile(wrapper_fn, backend=cnt, fullgraph=False, dynamic=True)
expected = wrapper_fn(x, 0), wrapper_fn(x, 1), wrapper_fn(x, 2)
# Third invocation of `opt` makes `in_dims` as SymInt.
actual = opt(x, 0), opt(x, 1), opt(x, 2)
self.assertEqual(expected, actual)
self.assertEqual(len(counters["graph_break"]), 1)
self.assertEqual(
dict(counters["graph_break"]),
{"torch.func.vmap: in_dims is not an int or tuple variable.": 2},
)
self.assertEqual(cnt.frame_count, 3)
self.assertEqual(cnt.op_count, 9)
def test_vmap_multiple_invocation_out_dims(self):
counters.clear()
@ -3287,16 +3285,14 @@ class GraphModule(torch.nn.Module):
return torch.func.vmap(lambda x: torch.sum(x, 0), out_dims=out_dims)(x)
x = torch.randn(3, 3, 3, 3)
opt = torch.compile(wrapper_fn, backend="eager", fullgraph=False, dynamic=True)
cnt = CompileCounter()
opt = torch.compile(wrapper_fn, backend=cnt, fullgraph=False, dynamic=True)
expected = wrapper_fn(x, 0), wrapper_fn(x, 1), wrapper_fn(x, 2)
# Third invocation of `opt` makes `in_dims` as SymInt.
actual = opt(x, 0), opt(x, 1), opt(x, 2)
self.assertEqual(expected, actual)
self.assertEqual(len(counters["graph_break"]), 1)
self.assertEqual(
dict(counters["graph_break"]),
{"torch.func.vmap: out_dims is not an int or tuple variable.": 2},
)
self.assertEqual(cnt.frame_count, 3)
self.assertEqual(cnt.op_count, 9)
def test_vmap_new_tensor_in_body(self):
def fn(x):

View File

@ -1541,7 +1541,7 @@ utils_device.CURRENT_DEVICE == None""",
args = [torch.randn(10), 4096, np.int64(8)]
correct = fn(*args)
cnts = torch._dynamo.testing.CompileCounter()
opt_fn = torch._dynamo.optimize(cnts, dynamic=True)(fn)
opt_fn = torch._dynamo.optimize(cnts, dynamic=True, nopython=True)(fn)
self.assertTrue(same(opt_fn(*args), correct))
self.assertTrue(same(opt_fn(*args), correct))
self.assertEqual(cnts.frame_count, 1)

View File

@ -814,9 +814,8 @@ class MockModule(torch.nn.Module):
class ReproTests(torch._dynamo.test_case.TestCase):
def test_do_paste_mask(self):
torch._dynamo.utils.counters.clear()
opt__do_paste_mask = torch._dynamo.optimize(
torch._dynamo.testing.CompileCounter()
)(_do_paste_mask)
cnt = torch._dynamo.testing.CompileCounter()
opt__do_paste_mask = torch.compile(_do_paste_mask, backend=cnt)
opt__do_paste_mask(
torch.randn(1, 1, 28, 28),
torch.tensor([[0.0, 1, 2, 4]]) * 1,
@ -852,12 +851,9 @@ class ReproTests(torch._dynamo.test_case.TestCase):
640,
False,
)
self.assertGreaterEqual(torch._dynamo.utils.counters["frames"]["ok"], 3)
self.assertEqual(
torch._dynamo.utils.counters["frames"]["total"],
torch._dynamo.utils.counters["frames"]["ok"] + 1,
)
# (dynamic shapes, static shapes)
self.assertIn(cnt.frame_count, (5, 7))
self.assertIn(cnt.op_count, (106, 127))
def test_convert_boxes_to_pooler_format(self):
boxes1 = [
@ -2451,77 +2447,6 @@ class ReproTests(torch._dynamo.test_case.TestCase):
self.assertEqual(f(x, x), opt_f(x, x))
self.assertEqual(f(x, y), opt_f(x, y))
def test_reformer_remove_unused_args(self):
# This test case is very interesting. First, let's describe
# the bug this is testing for. The bug we fixed is twofold:
#
# - We prune GraphArgs that aren't used in the output graph.
# However, sometimes it is possible for those GraphArgs to be
# utilized in shape guards (you could imagine this happening if
# dynamo poked some shape variables without recording them in the
# graph.) If we prune those GraphArgs, we get a
# "s1 not in ..." error as we can no longer codegen the
# requested guards.
#
# - But in practice, Dynamo usually traces size accesses into the
# graph, preventing the GraphArg from getting pruned. So how
# come we were running into this in practice with hf_Reformer?
# The answer is checkpointing!
#
# This brings us to the following test case. Here's what it does:
#
# 1. It traces some operations, and then checkpoints before inlining
# the function call to g
#
# 2. g traces some more operations (triggering the shape guard
# to be created), but then it graph breaks
#
# 3. Because you can't graph break in an inlining function, we roll
# back to the outer checkpoint ("undoing" the operation that
# induced the shape guard) and then immediately generate a
# subgraph at that point.
#
# If we failed to checkpoint the ShapeEnv, it can still have guards
# from the aborted speculation, which we will then still attempt to
# codegen.
#
# There's an additional nuance: suppose x is used but y is not.
# If you create a guard like y == x * 2, you will accidentally avoid
# the "s1 not in ..." error, as y will get substituted with x * 2,
# but x is still a GraphArg (it's used) and you don't end up with
# the error. This is why we must show y + y == x, not vice versa.
# Similarly, it is also why we must not do a simple guard like x == y
#
# Can we actually demonstrate that checkpointing the ShapeEnv is
# necessary? It's not so easy to induce this case. Dynamo is very
# eager about adding locals to GraphArgs; any local that is in scope,
# even if it isn't used, is added to GraphArgs (see also
# https://github.com/pytorch/torchdynamo/issues/1925 ). So long
# as Dynamo eagerly guards in this way, we have an invariant that
# all locals are guaranteed to show up in GraphArgs before the
# inlining function call, in which case we will always have enough
# information to codegen our guards so long as we don't prune the
# unused GraphArgs away (and indeed, the direct fix for this bug
# was to make sure we use original GraphArgs). Non locals,
# conversely, typically are static, and so won't have guards allocated
# for them. That being said, there may still be a way to trigger
# this error.
def g(x, y):
r = torch.cat((y, y)) + x
print("foo")
return r
def f(x, y):
x = x * 3
return g(x, y)
opt_f = torch._dynamo.optimize("aot_eager")(f)
x = torch.randn(4)
y = torch.randn(2)
self.assertEqual(f(x, y), opt_f(x, y))
def test_swin_base_tensor_attr(self):
class Foo(torch.nn.Module):
def __init__(self):

View File

@ -271,7 +271,7 @@ class SubclassTests(torch._dynamo.test_case.TestCase):
kwargs = {}
return super().__torch_function__(func, types, args, kwargs)
@torch.compile(backend="eager", fullgraph=True)
@torch.compile(backend="eager")
def fn(x):
return x.sigmoid()
@ -819,13 +819,6 @@ class TestNestedTensor(torch._dynamo.test_case.TestCase):
nt3, _ = self._get_jagged_tensor(((2, 3, 4), 3), None)
self._check_recompiles(binary, (nt1, nt2), (nt1, nt3), True)
def test_binary_recompiles_due_to_duck_sizing(self):
# Even though the input is unused, we still guard due to duck sizing
nt1, offsets = self._get_jagged_tensor(((2, 3, 4), 3), None)
nt2, _ = self._get_jagged_tensor(((2, 3, 4), 3), offsets)
nt3, _ = self._get_jagged_tensor(((2, 3, 4), 3), None)
self._check_recompiles(lambda nt1, nt2: nt1.sin(), (nt1, nt2), (nt1, nt3), True)
# TODO: cannot parametrize this test class with device for some reason
def _test_autograd(self, backend):
a = torch.randn(2, 3, requires_grad=True, dtype=torch.float64)

View File

@ -477,8 +477,8 @@ class SubGraphTests(torch._dynamo.test_case.TestCase):
opt_fn(v1, a, b, c)
# checking here we don't create 2^n graphs
self.assertEqual(cnt.frame_count, 12)
self.assertEqual(cnt.op_count, 16)
self.assertEqual(cnt.frame_count, 7)
self.assertEqual(cnt.op_count, 10)
def test_resume_with_no_grad1(self):
def fn(a, b):

View File

@ -4,7 +4,7 @@ import hypothesis.strategies as st
from hypothesis import given
import numpy as np
import torch
from torch.testing._internal.common_utils import TestCase, run_tests
from torch.testing._internal.common_utils import TestCase, run_tests, skipIfTorchDynamo
import torch.testing._internal.hypothesis_utils as hu
hu.assert_deadline_disabled()
@ -56,6 +56,7 @@ class PruningOpTest(TestCase):
self.assertEqual(pt_compressed_indices_map.dtype, indices_type)
@skipIfTorchDynamo()
@given(
embedding_rows=st.integers(1, 100),
embedding_dims=st.integers(1, 100),
@ -67,6 +68,7 @@ class PruningOpTest(TestCase):
self._test_rowwise_prune_op(embedding_rows, embedding_dims, torch.int, weights_dtype)
@skipIfTorchDynamo()
@given(
embedding_rows=st.integers(1, 100),
embedding_dims=st.integers(1, 100),

View File

@ -2530,6 +2530,7 @@ class TestSparseCSR(TestCase):
run_test(4, 5, 4, 10, False)
run_test(4, 4, 4, 16, True)
@skipIfTorchDynamo()
@onlyCPU
@dtypes(torch.float32, torch.float64, torch.bfloat16)
@precisionOverride({torch.bfloat16: 0.01})
@ -2894,6 +2895,7 @@ class TestSparseCSR(TestCase):
run_test(shape, max(shape), index_dtype)
run_test(shape, shape[0] * shape[1], index_dtype)
@skipIfTorchDynamo()
@skipMeta
@dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16))
@all_sparse_compressed_layouts()

View File

@ -76,8 +76,6 @@ class PyCodegen:
self.clear_tos()
return
self.tx.output.guards.update(value.guards)
assert isinstance(value, VariableTracker)
output = self._output
graph_outputs = self.graph_outputs

View File

@ -586,7 +586,7 @@ class GuardBuilder(GuardBuilderBase):
f"{id(torch._dynamo.eval_frame.guarded_backend_cache.current_backend)}"
)
code = [
f"___skip_backend_check() or ___current_backend() == ___lookup_backend({backend_id})"
f"(___skip_backend_check() or ___current_backend() == ___lookup_backend({backend_id}))"
]
self._produce_guard_code(guard, code)
@ -1366,3 +1366,19 @@ def make_dupe_guard(obj_source, dupe_source):
# However, this should always be a sound guard to add here.
return functools.partial(GuardBuilder.DUPLICATE_INPUT, source_b=dupe_source)
return None
def install_guard(*guards, skip=0):
"""
Add dynamo guards to the current tracing context.
Args:
guards: guard(s) to add
skip: number of stack frames to ignore for debug stack trace
"""
from torch._guards import TracingContext
add = TracingContext.get().guards_context.dynamo_guards.add
for guard in guards:
assert isinstance(guard, Guard)
add(guard, skip=skip + 1)

View File

@ -62,7 +62,7 @@ from .exc import (
unimplemented,
unimplemented_with_warning,
)
from .guards import GuardBuilder
from .guards import GuardBuilder, install_guard
from .mutation_guard import is_dynamic_nn_module
from .side_effects import SideEffects
from .source import (
@ -561,7 +561,11 @@ class OutputGraph(Checkpointable[OutputGraphState]):
# FX deepcopy doesn't work for a partially created graph, so just remove new nodes
removed_nodes = 0
for node in reversed(list(self.graph.nodes)):
if node.meta["creation_timestamp"] > self.timestamp:
if (
node.meta["creation_timestamp"] > self.timestamp
# placeholders here may have been lazily added by existing objects
and node.op != "placeholder"
):
# Erasing node alone does not remove the meta information
# So, remove the help tensor explicitly
if "example_value" in node.meta:
@ -670,7 +674,6 @@ class OutputGraph(Checkpointable[OutputGraphState]):
return variables.UnspecializedNNModuleVariable(target, **options)
options = dict(options)
options["guards"] = set(options.get("guards", []))
assert "source" in options
source = options["source"]
assert not isinstance(source, ParamBufferSource)
@ -692,10 +695,10 @@ class OutputGraph(Checkpointable[OutputGraphState]):
tracer = self.root_tracer
if not is_constant_source(source):
options["guards"].add(source.make_guard(GuardBuilder.TENSOR_MATCH))
install_guard(source.make_guard(GuardBuilder.TENSOR_MATCH))
if get_static_address_type(target) == "guarded":
options["guards"].add(source.make_guard(GuardBuilder.DATA_PTR_MATCH))
install_guard(source.make_guard(GuardBuilder.DATA_PTR_MATCH))
def wrap_name(module_key):
assert self.param_name_to_source is not None
@ -711,7 +714,7 @@ class OutputGraph(Checkpointable[OutputGraphState]):
elif isinstance(target, torch.nn.Module):
assert isinstance(target, torch.nn.Module)
options["guards"].add(source.make_guard(GuardBuilder.NN_MODULE))
install_guard(source.make_guard(GuardBuilder.NN_MODULE))
def wrap_name(module_key):
return NNModuleVariable(type(target), module_key, **options)
@ -1005,9 +1008,6 @@ class OutputGraph(Checkpointable[OutputGraphState]):
assert isinstance(rv, list)
assert isinstance(root, FakeRootModule)
for output in rv:
self.guards.update(output.guards)
self.create_node(
"output",
"output",

View File

@ -54,7 +54,7 @@ from .codegen import PyCodegen
from .current_scope_id import current_scope_id
from .exc import ArgsMismatchError, BackendCompilerFailed, unimplemented, Unsupported
from .funcname_cache import get_funcname
from .guards import GuardBuilder
from .guards import GuardBuilder, install_guard
from .output_graph import GraphCompileReason, OutputGraph, OutputGraphState
from .replay_record import DummyModule, ExecutionRecorder
from .resume_execution import ContinueExecutionCache, ReenterWith
@ -323,13 +323,11 @@ def _detect_and_normalize_assert_statement(
def generic_jump(truth_fn: typing.Callable[[object], bool], push: bool):
def inner(self: "InstructionTranslatorBase", inst: Instruction):
value: VariableTracker = self.pop()
self.output.guards.update(value.guards)
if (
config.rewrite_assert_with_torch_assert
and _detect_and_normalize_assert_statement(self, truth_fn, push)
):
error_msg: VariableTracker = self.pop()
self.output.guards.update(error_msg.guards)
# Skip over things like `assert True`
if value.is_python_constant() and bool(value.as_python_constant()):
self.jump(inst)
@ -419,7 +417,6 @@ def generic_jump(truth_fn: typing.Callable[[object], bool], push: bool):
if isinstance(result, ConstantVariable) and isinstance(
result.value, (bool, int)
):
self.output.guards.update(result.guards)
if truth_fn(result.value):
push and self.push(value)
self.jump(inst)
@ -686,9 +683,7 @@ class InstructionTranslatorBase(Checkpointable[InstructionTranslatorGraphState])
"""
A call to some user defined function by inlining it.
"""
result = InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
self.output.guards.update(fn.guards)
return result
return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
def get_line_of_code_header(self, lineno=None):
if lineno is None:
@ -1139,7 +1134,6 @@ class InstructionTranslatorBase(Checkpointable[InstructionTranslatorGraphState])
def FOR_ITER(self, inst):
it = self.pop().realize()
if isinstance(it, (variables.ListIteratorVariable, variables.IteratorVariable)):
self.output.guards.update(it.guards)
try:
val, next_iter = it.next_variables(self)
self.push(next_iter)
@ -1233,8 +1227,6 @@ class InstructionTranslatorBase(Checkpointable[InstructionTranslatorGraphState])
if sys.version_info >= (3, 11):
null = self.pop()
assert isinstance(null, NullVariable)
self.output.guards.update(argsvars.guards)
self.output.guards.update(kwargsvars.guards)
if (
isinstance(fn, GetAttrVariable)
@ -1327,13 +1319,9 @@ class InstructionTranslatorBase(Checkpointable[InstructionTranslatorGraphState])
), f"Mutating module attribute {inst.argval} during export."
try:
self.output.guards.update(
BuiltinVariable(setattr)
.call_function(
BuiltinVariable(setattr).call_function(
self, [obj, ConstantVariable.create(inst.argval), val], {}
)
.guards
)
return
except Unsupported as e:
if not self.should_compile_partial_graph():
@ -1355,10 +1343,8 @@ class InstructionTranslatorBase(Checkpointable[InstructionTranslatorGraphState])
def DELETE_ATTR(self, inst):
obj = self.pop()
self.output.guards.update(
BuiltinVariable(delattr)
.call_function(self, [obj, ConstantVariable.create(inst.argval)], {})
.guards
BuiltinVariable(delattr).call_function(
self, [obj, ConstantVariable.create(inst.argval)], {}
)
def create_call_resume_at(self, offset):
@ -1375,8 +1361,6 @@ class InstructionTranslatorBase(Checkpointable[InstructionTranslatorGraphState])
def STORE_SUBSCR(self, inst):
val, obj, key = self.popn(3)
result = obj.call_method(self, "__setitem__", [key, val], {})
# no result is pushed, so need to lift the guards to global
self.output.guards.update(result.guards)
def BUILD_TUPLE(self, inst):
items = self.popn(inst.argval)
@ -1511,7 +1495,6 @@ class InstructionTranslatorBase(Checkpointable[InstructionTranslatorGraphState])
obj,
ListVariable(
obj.items + [v],
regen_guards=False,
**VariableTracker.propagate([obj, v]),
),
)
@ -1559,7 +1542,6 @@ class InstructionTranslatorBase(Checkpointable[InstructionTranslatorGraphState])
def UNPACK_SEQUENCE(self, inst):
seq = self.pop()
if isinstance(seq, (BaseListVariable, SetVariable)):
self.output.guards.update(seq.guards)
val = seq.unpack_var_sequence(self)
elif seq.is_python_constant() and isinstance(seq, ConstantVariable):
val = seq.unpack_var_sequence(self)
@ -1874,8 +1856,6 @@ class InstructionTranslatorBase(Checkpointable[InstructionTranslatorGraphState])
if isinstance(ctx, GenericContextWrappingVariable):
self.generic_context_manager_depth += 1
self.output.guards.update(ctx.guards)
exit = WithExitFunctionVariable(
ctx,
inst.target,
@ -1961,9 +1941,7 @@ class InstructionTranslatorBase(Checkpointable[InstructionTranslatorGraphState])
)
def store_global_weakref(self, name, value):
self.output.guards.add(
GlobalWeakRefSource(name).make_guard(GuardBuilder.WEAKREF_ALIVE)
)
install_guard(GlobalWeakRefSource(name).make_guard(GuardBuilder.WEAKREF_ALIVE))
if name not in self.output.global_scope:
self.output.install_global(name, weakref.ref(value))
@ -2148,67 +2126,26 @@ class InstructionTranslator(InstructionTranslatorBase):
vars.extend(cells_and_freevars)
cells_and_freevars_set = set(cells_and_freevars)
self.symbolic_locals = collections.OrderedDict(
(
k,
VariableBuilder(
self,
LocalSource(k, cell_or_freevar=k in cells_and_freevars_set),
)(f_locals[k]),
self.symbolic_locals = {
k: variables.LazyVariableTracker.create(
f_locals[k],
source=LocalSource(k, cell_or_freevar=k in cells_and_freevars_set),
)
for k in vars
if k in f_locals
)
}
if export:
# export gets super confused if we never realize unused inputs
# export gets confused if we never realize unused inputs
# in export mode just eagerly realize everything
self.symbolic_locals = VariableTracker.apply(
lambda x: x.realize(), self.symbolic_locals
)
self.init_local_index_guards_hack()
self._freevars_ids = dict()
for name in self.code_options["co_freevars"]:
if name in f_locals:
self._freevars_ids[name] = id(f_locals[name])
def init_local_index_guards_hack(self):
# symbolic_locals contains the mapping from original f_locals to the
# Variable objects. During the Variable building phase, each object also
# has its associated guards. At the end, we will accumulate these
# guards.
#
# One way of handling these guards is to just accumulate all of them
# right now. However, many f_locals might not be used in the frame and
# thus can unnecessarily increase guard execution overhead. Therefore,
# we selectively update output.guards as we run the Python Bytecode
# instruction by instruction.
#
# An exception here is list/dict variables. Guards related to these
# variables have indexed access, like Tensor_match on args[0], and if
# args is not used in this frame, we will miss a LIST_LENGTH check like
# len(args) == 2. Missing the LIST_LENGTH check causes problem for the
# next invocation when args is not a list, and args[0] is a runtime
# error. Therefore, we recursively add guards for list/dict variable here.
for val in self.symbolic_locals.values():
if isinstance(
val, (ListIteratorVariable, BaseListVariable, ConstDictVariable)
):
local_guards = VariableTracker.propagate(val)["guards"]
index_guards = [
guard
for guard in local_guards
if guard.create_fn
in (
GuardBuilder.LIST_LENGTH,
GuardBuilder.DICT_KEYS,
GuardBuilder.ODICT_KEYS,
GuardBuilder.TUPLE_ITERATOR_LEN,
)
]
self.output.guards.update(index_guards)
def run(self):
super().run()
@ -2661,7 +2598,6 @@ class InliningGeneratorInstructionTranslator(InliningInstructionTranslator):
if isinstance(
tos, (variables.ListIteratorVariable, variables.IteratorVariable)
):
self.output.guards.update(tos.guards)
try:
val, next_iter = tos.next_variables(self)
self.push(val)

View File

@ -1,12 +1,12 @@
import collections
from enum import Enum
from typing import Any, Callable, Dict, List, Optional, Set
from typing import Any, Callable, Dict, List
from .. import variables
from ..current_scope_id import current_scope_id
from ..exc import unimplemented
from ..source import AttrSource, Source
from ..utils import dict_values, identity, istype, odict_values
from ..utils import identity, istype
class MutableLocalSource(Enum):
@ -154,21 +154,8 @@ class VariableTracker(metaclass=VariableTrackerMeta):
@staticmethod
def propagate(*vars: List[List["VariableTracker"]]):
"""Combine the guards from many VariableTracker into **kwargs for a new instance"""
guards = set()
def visit(var):
if type(var) in (list, tuple, dict_values, odict_values):
for i in var:
visit(i)
else:
assert isinstance(var, VariableTracker), typestr(var)
guards.update(var.guards)
visit(vars)
return {
"guards": guards,
}
# TODO(jansel): delete this function
return {}
def clone(self, **kwargs):
"""Shallow copy with some (optional) changes"""
@ -246,22 +233,8 @@ class VariableTracker(metaclass=VariableTrackerMeta):
cache[idx] = (result, value)
return result
def add_guard(self, guard):
return self.clone(guards=set.union(self.guards, {guard}))
def add_guards(self, guards):
if guards is None:
return self
assert isinstance(guards, set)
return self.clone(guards=set.union(self.guards, guards))
def add_options(self, options, *more):
if more:
return self.add_options(options).add_options(*more)
if isinstance(options, VariableTracker):
return self.add_guards(options.guards)
assert isinstance(options, dict)
return self.add_guards(options.get("guards", set()))
return self
def __str__(self):
return f"{self.__class__.__name__}()"
@ -283,13 +256,6 @@ class VariableTracker(metaclass=VariableTrackerMeta):
except NotImplementedError:
return False
def can_make_guard(self):
try:
self.make_guard(None)
return True
except NotImplementedError:
return False
def make_guard(self, fn):
if self.source:
return self.source.make_guard(fn)
@ -380,6 +346,10 @@ class VariableTracker(metaclass=VariableTrackerMeta):
"""Used by LazyVariableTracker to build the real VariableTracker"""
return self
def recursive_realize(self):
"""Realize all objects under this"""
return VariableTracker.apply(lambda x: x.realize(), self)
def unwrap(self) -> "VariableTracker":
"""Used by LazyVariableTracker to return the real VariableTracker if it already exists"""
return self
@ -391,14 +361,12 @@ class VariableTracker(metaclass=VariableTrackerMeta):
def __init__(
self,
*,
guards: Optional[Set] = None,
source: Source = None,
mutable_local: MutableLocal = None,
user_code_variable_name: str = None,
parents_tracker: ParentsTracker = None,
):
super().__init__()
self.guards = guards or set()
self.source = source
self.mutable_local = mutable_local
self.user_code_variable_name = user_code_variable_name

View File

@ -44,7 +44,7 @@ from ..allowed_functions import (
from ..device_interface import device_interfaces
from ..exc import InternalTorchDynamoError, unimplemented
from ..guards import GuardBuilder, make_dupe_guard
from ..guards import GuardBuilder, install_guard, make_dupe_guard
from ..side_effects import SideEffects
from ..source import (
AttrSource,
@ -198,6 +198,7 @@ class GraphArg:
def erase(self):
self._example = None
self.example_strong_ref = None
def __eq__(self, other):
return self.source.name() == other.source.name()
@ -231,9 +232,7 @@ class VariableBuilder:
side_effect_result = self.tx.output.side_effects[value]
dup_guard = make_dupe_guard(self.source, side_effect_result.source)
if dup_guard:
side_effect_result = side_effect_result.add_guards(
self.make_guards(dup_guard)
)
self.install_guards(dup_guard)
return side_effect_result
vt = self._wrap(value).clone(**self.options())
if self._can_lift_attrs_to_inputs(vt):
@ -272,14 +271,15 @@ class VariableBuilder:
def options(self):
return {"source": self.get_source()}
def make_guards(self, *guards):
def install_guards(self, *guards):
source = self.get_source()
if (
isinstance(source, ConstantSource)
or source.guard_source() == GuardSource.CONSTANT
):
return None
return {source.make_guard(guard) for guard in guards}
install_guard(*[source.make_guard(guard) for guard in guards], skip=1)
return {}
@classmethod
@functools.lru_cache(None)
@ -330,7 +330,7 @@ class VariableBuilder:
lambda self, value: LambdaVariable(
InspectSignatureVariable.create,
source=self.source,
guards=self.make_guards(GuardBuilder.FUNCTION_MATCH),
**self.install_guards(GuardBuilder.FUNCTION_MATCH),
),
),
(comptime, lambda self, value: ComptimeVariable()),
@ -339,7 +339,7 @@ class VariableBuilder:
lambda self, value: LambdaVariable(
_dataclasses_fields_lambda,
source=self.source,
guards=self.make_guards(GuardBuilder.FUNCTION_MATCH),
**self.install_guards(GuardBuilder.FUNCTION_MATCH),
),
),
(
@ -347,7 +347,7 @@ class VariableBuilder:
lambda self, value: TorchVariable(
value,
source=self.source,
guards=self.make_guards(GuardBuilder.FUNCTION_MATCH),
**self.install_guards(GuardBuilder.FUNCTION_MATCH),
),
),
]
@ -375,8 +375,6 @@ class VariableBuilder:
class Autotuner:
pass
make_guards = self.make_guards
# Handle exact type() match
type_dispatch = self._type_dispatch().get(type(value))
if type_dispatch is not None:
@ -400,13 +398,13 @@ class VariableBuilder:
return self.wrap_listlike(value)
elif value is torch.utils._pytree.SUPPORTED_NODES:
# For SUPPORTED_NODES, we guard on the dictionary version (PEP509)
# under the assumption that the values themselves don't change.
self.install_guards(GuardBuilder.DICT_VERSION)
result = {
k: UserDefinedObjectVariable(
value[k],
source=GetItemSource(self.get_source(), k),
# For SUPPORTED_NODES, we guard on the dictionary version (PEP509)
# under the assumption that the values themselves don't change.
guards=self.make_guards(GuardBuilder.DICT_VERSION),
)
for k in value.keys()
}
@ -429,9 +427,9 @@ class VariableBuilder:
# Why is this OK for (specialized) nnmodules? We set up a setattr hook
# to check for module property mutations, which does a reasonable,
# but not completely secure job ensuring a property wasn't changed.
guards = self.make_guards(GuardBuilder.BOOL_FALSE)
self.install_guards(GuardBuilder.BOOL_FALSE)
else:
guards = self.make_guards(GuardBuilder.DICT_KEYS)
self.install_guards(GuardBuilder.DICT_KEYS)
# store key variables in global location for reconstruction
for key in value.keys():
@ -448,7 +446,7 @@ class VariableBuilder:
k: LazyVariableTracker.create(
value[k],
source=GetItemSource(self.get_source(), index_source(k)),
).add_guards(guards)
)
for k in value.keys()
}
@ -457,10 +455,9 @@ class VariableBuilder:
result,
type(value),
self._wrap(value.default_factory),
guards=guards,
)
else:
result = ConstDictVariable(result, type(value), guards=guards)
result = ConstDictVariable(result, type(value))
return self.tx.output.side_effects.track_dict(self.source, value, result)
elif isinstance(value, torch.nn.Module):
@ -472,23 +469,14 @@ class VariableBuilder:
):
# For frozenset, we can guard by object ID instead of value
# equality, this allows us to handle non-literal values
return ConstantVariable.create(
value=value,
source=self.source,
guards=make_guards(GuardBuilder.ID_MATCH),
)
self.install_guards(GuardBuilder.ID_MATCH)
return ConstantVariable.create(value=value, source=self.source)
elif isinstance(value, enum.Enum):
return EnumVariable(
value=value,
source=self.source,
guards=make_guards(GuardBuilder.ID_MATCH),
)
self.install_guards(GuardBuilder.ID_MATCH)
return EnumVariable(value=value, source=self.source)
elif is_builtin_callable(value):
return BuiltinVariable(
value,
source=self.source,
guards=make_guards(GuardBuilder.BUILTIN_MATCH),
)
self.install_guards(GuardBuilder.BUILTIN_MATCH)
return BuiltinVariable(value, source=self.source)
elif is_utils_checkpoint(value):
return build_checkpoint_variable(source=self.source)
elif isinstance(value, functools.partial):
@ -509,52 +497,50 @@ class VariableBuilder:
self.tx, GetItemSource(keywords_source, k)
)(v)
guards = {
install_guard(
self.get_source().make_guard(GuardBuilder.TYPE_MATCH),
keywords_source.make_guard(GuardBuilder.DICT_KEYS),
args_source.make_guard(GuardBuilder.LIST_LENGTH),
}
return FunctoolsPartialVariable(
func_obj, args, keywords, original=value, guards=guards
)
return FunctoolsPartialVariable(func_obj, args, keywords, original=value)
elif is_typing(value):
# typing.List, typing.Mapping, etc.
self.install_guards(GuardBuilder.ID_MATCH)
return TypingVariable(
value,
source=self.source,
guards=make_guards(GuardBuilder.ID_MATCH),
)
elif np is not None and isinstance(value, np.generic):
# numpy array scalars: convert to 0D arrays
return self.wrap_numpy_ndarray(np.asarray(value))
elif is_numpy(value):
assert np
return NumpyVariable(
value,
source=self.source,
guards=make_guards(
self.install_guards(
GuardBuilder.FUNCTION_MATCH
if callable(value)
else GuardBuilder.TYPE_MATCH
),
)
return NumpyVariable(value, source=self.source)
# NB: These can't be put in type_dispatch, they have to run later
elif CollectiveFunctionRewriteVariable.can_rewrite(value):
self.install_guards(GuardBuilder.FUNCTION_MATCH)
return CollectiveFunctionRewriteVariable.create(
self.tx,
value,
source=self.source,
guards=make_guards(GuardBuilder.FUNCTION_MATCH),
)
elif istype(value, torch.autograd.function.FunctionMeta):
self.install_guards(GuardBuilder.FUNCTION_MATCH)
return AutogradFunctionVariable(
value,
source=self.source,
guards=make_guards(GuardBuilder.FUNCTION_MATCH),
)
elif isinstance(value, torch.autograd.function.FunctionCtx):
saved_tensors_source = AttrSource(self.source, "saved_tensors")
install_guard(
self.source.make_guard(GuardBuilder.TYPE_MATCH),
saved_tensors_source.make_guard(GuardBuilder.LIST_LENGTH),
)
saved_tensors = [
VariableBuilder(self.tx, GetItemSource(saved_tensors_source, n))(v)
for n, v in enumerate(value.saved_tensors)
@ -565,8 +551,6 @@ class VariableBuilder:
AutogradFunctionContextVariable(
value,
source=self.source,
guards=make_guards(GuardBuilder.TYPE_MATCH)
| {saved_tensors_source.make_guard(GuardBuilder.LIST_LENGTH)},
saved_tensors=SavedTensorBox(saved_tensors),
),
)
@ -579,53 +563,43 @@ class VariableBuilder:
and value == getattr(value.__self__, "apply", None)
):
# handle aliased autograd function `apply` calls
self.install_guards(GuardBuilder.FUNCTION_MATCH)
return GetAttrVariable(
AutogradFunctionVariable(
value.__self__,
source=self.source,
guards=make_guards(GuardBuilder.FUNCTION_MATCH),
),
AutogradFunctionVariable(value.__self__, source=self.source),
"apply",
)
elif np and isinstance(value, np.number):
return self.wrap_unspecialized_primitive(value)
elif DataClassVariable.is_matching_object(value):
return DataClassVariable.wrap(self, value).add_guards(
make_guards(GuardBuilder.TYPE_MATCH)
)
self.install_guards(GuardBuilder.TYPE_MATCH)
return DataClassVariable.wrap(self, value)
elif HFPretrainedConfigVariable.is_matching_object(value):
return HFPretrainedConfigVariable(
value, guards=make_guards(GuardBuilder.TYPE_MATCH)
)
self.install_guards(GuardBuilder.TYPE_MATCH)
return HFPretrainedConfigVariable(value)
elif isinstance(value, HigherOrderOperator):
return TorchHigherOrderOperatorVariable.make(
value,
source=self.source,
guards=self.make_guards(
GuardBuilder.TYPE_MATCH, GuardBuilder.NAME_MATCH
),
)
self.install_guards(GuardBuilder.TYPE_MATCH, GuardBuilder.NAME_MATCH)
return TorchHigherOrderOperatorVariable.make(value, source=self.source)
elif type(value).__name__ == "builtin_function_or_method" and isinstance(
value.__self__, torch_special_class_types
):
self.install_guards(GuardBuilder.FUNCTION_MATCH)
return TorchVariable(
value,
guards=make_guards(GuardBuilder.FUNCTION_MATCH),
)
elif isinstance(value, _StreamBase):
self.install_guards(GuardBuilder.ID_MATCH)
return StreamVariable(
None,
value,
value.device.type,
source=self.source,
guards=make_guards(GuardBuilder.ID_MATCH),
)
elif isinstance(value, _EventBase):
self.install_guards(GuardBuilder.ID_MATCH)
return EventVariable(
None,
value,
source=self.source,
guards=make_guards(GuardBuilder.ID_MATCH),
)
elif (
isinstance(value, torch._C._TensorMeta)
@ -636,55 +610,36 @@ class VariableBuilder:
istype(value, contextlib.nullcontext)
and inspect.getattr_static(value, "enter_result", None) is None
):
return NullContextVariable(
source=self.source,
guards=make_guards(GuardBuilder.FUNCTION_MATCH),
)
# TODO(jansel): I think this can be TYPE_MATCH
self.install_guards(GuardBuilder.FUNCTION_MATCH)
return NullContextVariable(source=self.source)
elif KeyedJaggedTensorVariable.is_matching_object(value):
result = KeyedJaggedTensorVariable(
value,
source=self.source,
guards=self.make_guards(GuardBuilder.TYPE_MATCH),
)
self.install_guards(GuardBuilder.TYPE_MATCH)
result = KeyedJaggedTensorVariable(value, source=self.source)
# TODO: this doing it manually is bad
return self.tx.output.side_effects.track_object_existing(
self.source, value, result
)
elif isinstance(value, torch.optim.Optimizer):
return OptimizerVariable(
value,
source=self.source,
guards=self.make_guards(GuardBuilder.TYPE_MATCH),
)
self.install_guards(GuardBuilder.TYPE_MATCH)
return OptimizerVariable(value, source=self.source)
elif ProcessGroupVariable.is_process_group(value):
return ProcessGroupVariable(
value,
source=self.source,
guards=self.make_guards(GuardBuilder.ID_MATCH),
)
self.install_guards(GuardBuilder.ID_MATCH)
return ProcessGroupVariable(value, source=self.source)
elif DeviceMeshVariable.is_device_mesh(value):
# TODO: see if we need to add custom guard instead
# of a simple ID_MATCH
return DeviceMeshVariable(
value,
source=self.source,
guards=self.make_guards(GuardBuilder.ID_MATCH),
)
# TODO: see if we need to add custom guard instead of a simple ID_MATCH
self.install_guards(GuardBuilder.ID_MATCH)
return DeviceMeshVariable(value, source=self.source)
elif PlacementClassVariable.is_placement_type(value):
# TODO: see if we need to add custom guard instead
# of a simple ID_MATCH
return PlacementClassVariable(
value,
source=self.source,
guards=make_guards(GuardBuilder.ID_MATCH),
)
# TODO: see if we need to add custom guard instead of a simple ID_MATCH
self.install_guards(GuardBuilder.ID_MATCH)
return PlacementClassVariable(value, source=self.source)
elif PlacementVariable.is_placement(value):
# TODO: see if we need to add custom guard instead
# of a simple ID_MATCH
# TODO: see if we need to add custom guard instead of a simple ID_MATCH
self.install_guards(GuardBuilder.ID_MATCH)
return PlacementVariable(
value,
source=self.source,
guards=make_guards(GuardBuilder.ID_MATCH),
)
elif isinstance(value, torch.SymBool):
# Note: the idea here is to re-use the infra we've built for SymInt by simulating the
@ -727,12 +682,12 @@ class VariableBuilder:
new_symint == 1,
)
elif isinstance(value, (JITFunction, Autotuner)):
self.install_guards(GuardBuilder.ID_MATCH)
return TritonKernelVariable(
value,
None, # No kernel idx provided
None, # No grid provided
source=self.source,
guards=make_guards(GuardBuilder.ID_MATCH),
)
elif trace_rules.lookup(value) is not None:
return trace_rules.lookup(value).create_with_source(
@ -741,10 +696,10 @@ class VariableBuilder:
elif is_allowed(value):
if is_user_defined_allowed(value):
self.tx.output.has_user_defined_allowed_in_graph = True
self.install_guards(GuardBuilder.FUNCTION_MATCH)
return TorchVariable(
value,
source=self.source,
guards=make_guards(GuardBuilder.FUNCTION_MATCH),
)
elif (
istype(value, (type, types.FunctionType))
@ -752,17 +707,17 @@ class VariableBuilder:
and not inspect.getattr_static(value, "_torchdynamo_inline", False)
and not inspect.getattr_static(value, "__script_if_tracing_wrapper", False)
):
self.install_guards(GuardBuilder.FUNCTION_MATCH)
return SkipFilesVariable(
value,
skipfiles.check_verbose(value, allow_torch=True).reason,
source=self.source,
guards=make_guards(GuardBuilder.FUNCTION_MATCH),
)
elif istype(value, (types.FunctionType, torch.jit.ScriptFunction)):
self.install_guards(GuardBuilder.FUNCTION_MATCH)
return UserFunctionVariable(
value,
source=self.source,
guards=make_guards(GuardBuilder.FUNCTION_MATCH),
)
elif isinstance(value, types.MethodType) and isinstance(
value.__self__, torch.nn.Module
@ -784,40 +739,33 @@ class VariableBuilder:
assert self_obj and isinstance(
self_obj, VariableTracker
), "Failed to produce a valid self obj"
self.install_guards(GuardBuilder.FUNCTION_MATCH)
return UserMethodVariable(
value.__func__,
self_obj,
source=self.source,
guards=make_guards(GuardBuilder.FUNCTION_MATCH),
)
elif istype(value, (types.ModuleType, replay_record.DummyModule)):
self.install_guards(GuardBuilder.FUNCTION_MATCH)
return PythonModuleVariable(
value,
source=self.source,
guards=make_guards(GuardBuilder.PYMODULE_MATCH),
)
elif isinstance(value, types.GetSetDescriptorType):
return GetSetDescriptorVariable(
value, guards=self.make_guards(GuardBuilder.FUNCTION_MATCH)
)
self.install_guards(GuardBuilder.FUNCTION_MATCH)
return GetSetDescriptorVariable(value)
elif isinstance(value, types.MethodWrapperType):
return MethodWrapperVariable(
value,
source=self.source,
guards=self.make_guards(GuardBuilder.FUNCTION_MATCH),
)
self.install_guards(GuardBuilder.FUNCTION_MATCH)
return MethodWrapperVariable(value, source=self.source)
elif issubclass(type(value), type):
self.install_guards(GuardBuilder.FUNCTION_MATCH)
return UserDefinedClassVariable(
value,
source=self.source,
guards=make_guards(GuardBuilder.FUNCTION_MATCH),
)
else:
result = UserDefinedObjectVariable(
value,
source=self.source,
guards=self.make_guards(GuardBuilder.TYPE_MATCH),
)
self.install_guards(GuardBuilder.TYPE_MATCH)
result = UserDefinedObjectVariable(value, source=self.source)
if not SideEffects.cls_supports_mutation_side_effects(type(value)):
# don't allow STORE_ATTR mutation with custom __setattr__
return result
@ -857,36 +805,32 @@ class VariableBuilder:
def wrap_listlike(self, value: Union[tuple, list, odict_values, NamedTuple]):
# One can index a tensor with a list/tuple. Therefore, we need to
# have a stricter match.
guards = self.make_guards(GuardBuilder.LIST_LENGTH)
self.install_guards(GuardBuilder.LIST_LENGTH)
for item in value:
if item is value:
unimplemented("list elements are pointing to the list itself")
output = [
VariableBuilder(self.tx, GetItemSource(self.get_source(), i))(
item
).add_guards(guards)
VariableBuilder(self.tx, GetItemSource(self.get_source(), i))(item)
for i, item in enumerate(value)
]
result = BaseListVariable.cls_for_instance(value)(
output, mutable_local=MutableLocal(), guards=guards
output, mutable_local=MutableLocal()
)
if istype(value, list):
return self.tx.output.side_effects.track_list(self.source, value, result)
return result
def wrap_tuple_iterator(self, value: tuple_iterator):
guards = self.make_guards(GuardBuilder.TUPLE_ITERATOR_LEN)
self.install_guards(GuardBuilder.TUPLE_ITERATOR_LEN)
output = [
VariableBuilder(self.tx, TupleIteratorGetItemSource(self.get_source(), i))(
tuple_iterator_getitem(value, i)
).add_guards(guards)
)
for i in range(tuple_iterator_len(value))
]
return TupleIteratorVariable(
output, mutable_local=MutableLocal(), guards=guards
)
return TupleIteratorVariable(output, mutable_local=MutableLocal())
def wrap_slice_range(self, value: Union[slice, range]):
items = [
@ -896,21 +840,20 @@ class VariableBuilder:
for k in ("start", "stop", "step")
]
if isinstance(value, slice):
return SliceVariable(
items, guards=self.make_guards(GuardBuilder.TYPE_MATCH)
)
self.install_guards(GuardBuilder.TYPE_MATCH)
return SliceVariable(items)
else:
return RangeVariable(
items, guards=self.make_guards(GuardBuilder.EQUALS_MATCH)
)
# TODO(jansel): I think this can be TYPE_MATCH
self.install_guards(GuardBuilder.EQUALS_MATCH)
return RangeVariable(items)
def wrap_module(self, value: torch.nn.Module):
from ..eval_frame import OptimizedModule
if istype(value, OptimizedModule):
guards = self.make_guards(GuardBuilder.TYPE_MATCH)
self.install_guards(GuardBuilder.TYPE_MATCH)
self.source = AttrSource(self.source, "_orig_mod")
return self.wrap_module(value._orig_mod).add_guards(guards)
return self.wrap_module(value._orig_mod)
if (
isinstance(value, (torch.nn.RNN, torch.nn.GRU, torch.nn.LSTM))
@ -919,9 +862,8 @@ class VariableBuilder:
unimplemented("TorchDynamo purposely graph breaks on RNN, GRU, LSTMs")
if mutation_guard.is_dynamic_nn_module(value):
# created dynamically, don't specialize on it
result = UnspecializedNNModuleVariable(
value, guards=self.make_guards(GuardBuilder.TYPE_MATCH)
)
self.install_guards(GuardBuilder.TYPE_MATCH)
result = UnspecializedNNModuleVariable(value)
if not SideEffects.cls_supports_mutation_side_effects(type(value)):
# don't allow STORE_ATTR mutation with custom __setattr__
return result
@ -931,9 +873,8 @@ class VariableBuilder:
elif issubclass(
value.__class__, torch.nn.parallel.distributed.DistributedDataParallel
):
return UnspecializedNNModuleVariable(
value, guards=self.make_guards(GuardBuilder.TYPE_MATCH)
)
self.install_guards(GuardBuilder.TYPE_MATCH)
return UnspecializedNNModuleVariable(value)
elif getattr(value, "_is_fsdp_managed_module", False):
# See note [Dynamo treats FSDP wrapped modules as UnspecializedNNModule]
# in fully_sharded_data_parallel.py for more information
@ -960,11 +901,8 @@ class VariableBuilder:
#
# ID_MATCH is required to disambiguate cases as simple as a unit test that constructs 2 models and wraps
# them differently with different FSDP configs. (test_dynamo_distributed.py -k test_fsdp_aot_eager)
return FSDPManagedNNModuleVariable(
value,
guards=self.make_guards(GuardBuilder.TYPE_MATCH, GuardBuilder.ID_MATCH),
source=self.get_source(),
)
self.install_guards(GuardBuilder.TYPE_MATCH, GuardBuilder.ID_MATCH)
return FSDPManagedNNModuleVariable(value, source=self.get_source())
else:
return self.tx.output.register_attr_or_module(
value,
@ -976,12 +914,12 @@ class VariableBuilder:
def wrap_literal(self, value):
unspec = not config.specialize_int
if unspec and type(value) is torch.Size:
self.install_guards(GuardBuilder.LIST_LENGTH)
return SizeVariable(
[
VariableBuilder(self.tx, GetItemSource(self.get_source(), i))(v)
for i, v in enumerate(value)
],
guards=self.make_guards(GuardBuilder.LIST_LENGTH),
]
)
elif unspec and type(value) is int:
# unspecializing int by default, but still
@ -995,17 +933,13 @@ class VariableBuilder:
# NN modules on the fly)
or self.source.guard_source().is_nn_module()
):
return ConstantVariable.create(
value=value,
guards=self.make_guards(GuardBuilder.CONSTANT_MATCH),
)
self.install_guards(GuardBuilder.CONSTANT_MATCH)
return ConstantVariable.create(value=value)
else:
return self.wrap_unspecialized_primitive(value)
else:
return ConstantVariable.create(
value=value,
guards=self.make_guards(GuardBuilder.CONSTANT_MATCH),
)
self.install_guards(GuardBuilder.CONSTANT_MATCH)
return ConstantVariable.create(value=value)
def assert_not_wrapped_by_this_graph(self, value: torch.Tensor):
if is_fake(value) and maybe_get_fake_mode(value) is self.tx.fake_mode:
@ -1027,11 +961,7 @@ class VariableBuilder:
) and not source.guard_source().is_fsdp_module():
self.assert_not_wrapped_by_this_graph(value)
return self.tx.output.register_attr_or_module(
value,
self.name,
source=source,
# Guards are done inside register_attr_or_module
# guards=self.make_guards(GuardBuilder.TENSOR_MATCH),
value, self.name, source=source
)
if is_constant_source(source):
@ -1099,20 +1029,7 @@ class VariableBuilder:
options["torch_function_fn"] = build_torch_function_fn(
self.tx, value, self.source
)
options["guards"] = self.make_guards(GuardBuilder.TYPE_MATCH)
else:
options["guards"] = set()
options["guards"].update(
self.make_guards(
functools.partial(
GuardBuilder.TENSOR_MATCH,
value=value
if isinstance(source, NumpyTensorSource)
else TensorWeakRef(value),
)
)
)
self.install_guards(GuardBuilder.TYPE_MATCH)
if (
isinstance(value, torch.Tensor)
@ -1130,6 +1047,16 @@ class VariableBuilder:
source=source,
**options,
)
self.install_guards(
functools.partial(
GuardBuilder.TENSOR_MATCH,
value=value
if isinstance(source, NumpyTensorSource)
else TensorWeakRef(value),
)
)
self.tx.output.input_source_to_var[source] = tensor_variable
assert "tensor_dict" not in tensor_proxy.node.meta
tensor_proxy.node.meta["tensor_dict"] = value.__dict__.copy()
@ -1172,11 +1099,11 @@ class VariableBuilder:
# a tensor. It's a little annoying to make a VT to throw out, but there's so many side effects here
# that there's not another great way to do this atm.
# This creates the right graphargs, as well as registration for guards in tensor names and shape env.
tensor_vt = VariableBuilder(self.tx, source)(tensor_value)
VariableBuilder(self.tx, source)(tensor_value).recursive_realize()
proxy = self.tx.output.root_tracer.create_graph_input(
re.sub(r"[^a-zA-Z0-9]+", "_", self.name), type(tensor_value), source=source
)
options = {"source": source, "guards": tensor_vt.guards}
options = {"source": source}
numpy_ndarray_variable = wrap_fx_proxy_cls(
target_cls=NumpyNdarrayVariable,
tx=self.tx,
@ -1229,10 +1156,8 @@ class VariableBuilder:
# If specialize_int is False, also return
# a constant (but this should have been handled
# in the caller, TBH)
return ConstantVariable.create(
value=value,
guards=self.make_guards(GuardBuilder.CONSTANT_MATCH),
)
self.install_guards(GuardBuilder.CONSTANT_MATCH)
return ConstantVariable.create(value=value)
name = self.source.name()
if name not in self.tx.output.frame_state:
@ -1264,10 +1189,8 @@ class VariableBuilder:
else: # assume_static_by_default
# TODO: dynamic_dim = DimDynamic.STATIC should work but
# for some reason it doesn't
return ConstantVariable.create(
value=value,
guards=self.make_guards(GuardBuilder.CONSTANT_MATCH),
)
self.install_guards(GuardBuilder.CONSTANT_MATCH)
return ConstantVariable.create(value=value)
wrapped_value = shape_env.create_unspecified_symint_and_symbol(
value,
@ -1281,11 +1204,8 @@ class VariableBuilder:
else:
wrapped_value = torch.tensor(value)
if not isinstance(self.get_source(), RandomValueSource):
guards = {self.get_source().make_guard(GuardBuilder.TYPE_MATCH)}
options = {"guards": guards}
else:
options = {}
options.update({"source": self.get_source()})
install_guard(self.get_source().make_guard(GuardBuilder.TYPE_MATCH))
options = {"source": self.get_source()}
if isinstance(wrapped_value, torch.Tensor):
options.update({"raw_value": value})
@ -1352,16 +1272,19 @@ def _dataclasses_fields_lambda(obj):
def wrap_fx_proxy(tx, proxy, example_value=None, subclass_type=None, **options):
return wrap_fx_proxy_cls(
target_cls=TensorVariable
if not subclass_type
else TensorWithTFOverrideVariable,
tx=tx,
proxy=proxy,
example_value=example_value,
subclass_type=subclass_type,
kwargs = {
"tx": tx,
"proxy": proxy,
"example_value": example_value,
"subclass_type": subclass_type,
**options,
)
}
if subclass_type is None:
return wrap_fx_proxy_cls(target_cls=TensorVariable, **kwargs)
else:
result = wrap_fx_proxy_cls(target_cls=TensorWithTFOverrideVariable, **kwargs)
result.install_global(tx)
return result
# Note: Unfortunate split due to some gross classes existing that subclass TensorVariable

View File

@ -19,7 +19,7 @@ from ..exc import (
UserError,
UserErrorType,
)
from ..guards import GuardBuilder
from ..guards import GuardBuilder, install_guard
from ..replay_record import DummyModule
from ..source import AttrSource, GetItemSource, is_constant_source, TypeSource
from ..utils import (
@ -339,7 +339,6 @@ class BuiltinVariable(VariableTracker):
a,
ListVariable(
list(a.items) + list(b.unpack_var_sequence(tx)),
regen_guards=False,
**options,
),
)
@ -826,23 +825,22 @@ class BuiltinVariable(VariableTracker):
mutable_local=MutableLocal(),
)
elif obj.has_unpack_var_sequence(tx):
guards = set()
if obj.source and not is_constant_source(obj.source):
if isinstance(obj, TupleIteratorVariable):
guards.add(obj.source.make_guard(GuardBuilder.TUPLE_ITERATOR_LEN))
install_guard(
obj.source.make_guard(GuardBuilder.TUPLE_ITERATOR_LEN)
)
else:
guards.add(obj.source.make_guard(GuardBuilder.LIST_LENGTH))
install_guard(obj.source.make_guard(GuardBuilder.LIST_LENGTH))
if cls is SetVariable:
return cls(
list(obj.unpack_var_sequence(tx)),
mutable_local=MutableLocal(),
guards=guards,
).add_options(self, obj)
return cls(
list(obj.unpack_var_sequence(tx)),
mutable_local=MutableLocal(),
guards=guards,
).add_options(self, obj)
call_iter = _call_iter_tuple_list
@ -1060,7 +1058,6 @@ class BuiltinVariable(VariableTracker):
from .builder import SourcelessBuilder, VariableBuilder
options = VariableTracker.propagate(self, obj, name_var)
guards = options["guards"]
name = name_var.as_python_constant()
if not name_var.is_python_constant():
@ -1075,10 +1072,9 @@ class BuiltinVariable(VariableTracker):
if default is not None:
hasattr_var = self.call_hasattr(tx, obj, name_var)
guards.update(hasattr_var.guards)
assert hasattr_var.as_python_constant() in (True, False)
if not hasattr_var.as_python_constant():
return default.add_guards(guards)
return default
if obj.source:
source = AttrSource(obj.source, name)
@ -1152,14 +1148,14 @@ class BuiltinVariable(VariableTracker):
elif ConstantVariable.is_literal(member):
return ConstantVariable.create(member, **options)
else:
return VariableBuilder(tx, source)(member).add_guards(guards)
return VariableBuilder(tx, source)(member)
elif isinstance(obj, (PythonModuleVariable, DummyModule)):
member = obj.value.__dict__[name]
if config.replay_record_enabled:
tx.exec_recorder.record_module_access(obj.value, name, member)
return VariableBuilder(tx, source)(member).add_guards(guards)
return VariableBuilder(tx, source)(member)
elif istype(obj, UserFunctionVariable) and name in ("__name__", "__module__"):
return ConstantVariable.create(
getattr(obj.fn, name), **VariableTracker.propagate(obj)

View File

@ -6,7 +6,7 @@ from torch._dynamo.source import GetItemSource
from .. import variables
from ..exc import unimplemented, UserError, UserErrorType
from ..guards import GuardBuilder
from ..guards import GuardBuilder, install_guard
from ..utils import np
from .base import typestr, VariableTracker
@ -41,21 +41,15 @@ class ConstantVariable(VariableTracker):
items = []
for i, x in enumerate(value):
item_source = GetItemSource(source, i) if source else None
guards = (
{item_source.make_guard(GuardBuilder.CONSTANT_MATCH)}
if item_source
else None
)
if item_source:
install_guard(item_source.make_guard(GuardBuilder.CONSTANT_MATCH))
items.append(
ConstantVariable.create(
x,
source=item_source,
guards=guards,
)
)
return variables.BaseListVariable.cls_for(type(value))(
items, regen_guards=True, **kwargs
)
return variables.BaseListVariable.cls_for(type(value))(items, **kwargs)
return ConstantVariable(value, **kwargs)

View File

@ -9,7 +9,7 @@ from .. import variables
from ..bytecode_transformation import create_call_function, create_instruction
from ..device_interface import get_interface_for_device
from ..exc import unimplemented, Unsupported
from ..guards import GuardBuilder
from ..guards import GuardBuilder, install_guard
from ..source import AttrSource, GlobalStateSource
from .base import VariableTracker
from .functions import (
@ -161,7 +161,7 @@ class GenericContextWrappingVariable(ContextWrappingVariable):
class GradModeVariable(ContextWrappingVariable):
"""represents torch.{no_grad,enable_grad,set_grad_mode}()"""
_guards_singleton = {Guard(GlobalStateSource(), GuardBuilder.GRAD_MODE)}
_guards_singleton = Guard(GlobalStateSource(), GuardBuilder.GRAD_MODE)
@staticmethod
def create(tx, target_value, initialized=True, **kwargs):
@ -179,8 +179,8 @@ class GradModeVariable(ContextWrappingVariable):
super().__init__(
target_values=target_values, initial_values=initial_values, **kwargs
)
self.guards = self.guards | self._guards_singleton
self.initialized = initialized
install_guard(self._guards_singleton)
def enter(self, tx):
if not self.initialized:
@ -263,7 +263,7 @@ class InferenceModeVariable(ContextWrappingVariable):
class TorchFunctionDisableVariable(ContextWrappingVariable):
"""represents whether torch function overrides are enabled or not"""
_guards_singleton = {Guard(GlobalStateSource(), GuardBuilder.TORCH_FUNCTION_STATE)}
_guards_singleton = Guard(GlobalStateSource(), GuardBuilder.TORCH_FUNCTION_STATE)
@staticmethod
def create(tx, **kwargs):
@ -281,7 +281,7 @@ class TorchFunctionDisableVariable(ContextWrappingVariable):
super().__init__(
target_values=target_values, initial_values=initial_values, **kwargs
)
self.guards = self.guards | self._guards_singleton
install_guard(self._guards_singleton)
def enter(self, tx):
return variables.ConstantVariable.create(
@ -296,9 +296,9 @@ class TorchFunctionDisableVariable(ContextWrappingVariable):
class DeterministicAlgorithmsVariable(ContextWrappingVariable):
"""represents torch.{are_deterministic_algorithms_enabled,use_deterministic_algorithms}()"""
_guards_singleton = {
Guard(GlobalStateSource(), GuardBuilder.DETERMINISTIC_ALGORITHMS)
}
_guards_singleton = Guard(
GlobalStateSource(), GuardBuilder.DETERMINISTIC_ALGORITHMS
)
@staticmethod
def create(tx, target_value, **kwargs):
@ -315,7 +315,7 @@ class DeterministicAlgorithmsVariable(ContextWrappingVariable):
super().__init__(
target_values=target_values, initial_values=initial_values, **kwargs
)
self.guards = self.guards | self._guards_singleton
install_guard(self._guards_singleton)
def enter(self, tx):
return variables.ConstantVariable.create(

View File

@ -13,7 +13,7 @@ from ..bytecode_transformation import create_call_function, create_instruction
from ..eval_frame import skip_code
from ..exc import unimplemented
from ..guards import GuardBuilder, make_dupe_guard
from ..guards import GuardBuilder, install_guard, make_dupe_guard
from ..source import AttrSource, GetItemSource, GlobalWeakRefSource
from ..utils import global_key_name, istensor, iter_contains
from .base import MutableLocal, VariableTracker
@ -24,10 +24,8 @@ from .tensor import TensorVariable
class ConstDictVariable(VariableTracker):
def __init__(self, items, user_cls, **kwargs):
super().__init__(**kwargs)
# All the keys are constants
assert not any(isinstance(x, VariableTracker) for x in items)
self.guards.update(VariableTracker.propagate(items.values())["guards"])
self.items = items
self.user_cls = user_cls
@ -298,7 +296,6 @@ class SetVariable(VariableTracker):
def __init__(
self,
items: List[VariableTracker],
regen_guards=True,
**kwargs,
):
super().__init__(**kwargs)
@ -309,10 +306,6 @@ class SetVariable(VariableTracker):
self.items = []
self._add(items)
# Sometimes, we know that we have passed in the guards from the items in the set
if regen_guards:
self.guards.update(VariableTracker.propagate(items)["guards"])
def as_proxy(self):
return [x.as_proxy() for x in self.items]
@ -378,9 +371,7 @@ class SetVariable(VariableTracker):
e.vt.source, set_element.vt.source
)
if alias_guard:
e.vt = e.vt.add_guards(
{e.vt.source.make_guard(alias_guard)}
)
install_guard(e.vt.source.make_guard(alias_guard))
return self.items
@ -401,7 +392,6 @@ class SetVariable(VariableTracker):
result = SetVariable(
self._add(item),
mutable_local=self.mutable_local,
regen_guards=False,
**options,
)
tx.replace_all(self, result)
@ -413,7 +403,7 @@ class SetVariable(VariableTracker):
result = items.pop()
tx.replace_all(
self,
SetVariable(items, regen_guards=False, **options),
SetVariable(items, **options),
)
return result
elif name == "__len__":
@ -797,46 +787,40 @@ class PythonSysModulesVariable(VariableTracker):
def _contains_helper(self, tx, key: VariableTracker):
k = ConstDictVariable.get_key(key)
has_key = k in sys.modules
guard = self.make_guard(
install_guard(
self.make_guard(
functools.partial(GuardBuilder.DICT_CONTAINS, key=k, invert=not has_key)
)
guards = {*self.guards, guard}
return k, has_key, guards
)
return k, has_key
def call_contains(self, tx, key: VariableTracker):
k, has_key, guards = self._contains_helper(tx, key)
return ConstantVariable.create(
value=has_key,
guards=guards,
)
k, has_key = self._contains_helper(tx, key)
return ConstantVariable.create(value=has_key)
def call_get(
self, tx, key: VariableTracker, default: Optional[VariableTracker] = None
):
from .builder import VariableBuilder
k, has_key, guards = self._contains_helper(tx, key)
k, has_key = self._contains_helper(tx, key)
if has_key:
return VariableBuilder(
tx,
GetItemSource(self.source, k),
)(
sys.modules[k]
).add_guards(guards)
)(sys.modules[k])
if default is not None:
return default.add_guards(guards)
return default
return ConstantVariable.create(value=None, guards=guards)
return ConstantVariable.create(value=None)
def call_getitem(self, tx, key: VariableTracker):
from .builder import VariableBuilder
k, has_key, guards = self._contains_helper(tx, key)
k, has_key = self._contains_helper(tx, key)
return VariableBuilder(
tx,
GetItemSource(self.source, k),
)(
sys.modules[k]
).add_guards(guards)
)(sys.modules[k])

View File

@ -614,12 +614,6 @@ class FunctoolsPartialVariable(VariableTracker):
self.keywords = keywords
self.original = original
self.guards.update(VariableTracker.propagate(func)["guards"])
for arg in args:
self.guards.update(VariableTracker.propagate(arg)["guards"])
for val in keywords.values():
self.guards.update(VariableTracker.propagate(val)["guards"])
def call_function(
self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
) -> "VariableTracker":

View File

@ -25,7 +25,6 @@ from ..exc import (
UserError,
UserErrorType,
)
from ..guards import GuardBuilder
from ..source import FSDPNNModuleSource, GetItemSource, NNModuleSource
from ..utils import proxy_args_kwargs
from .dicts import ConstDictVariable
@ -100,9 +99,6 @@ def validate_args_and_maybe_create_graph_inputs(
assert isinstance(a, VariableTracker)
if isinstance(a, ConstantVariable):
# Ensures that we recompile when the constant value changes
a.add_guard(GuardBuilder.CONSTANT_MATCH)
if manually_set_subgraph_inputs:
# This arg is not used in the body of the higher order op.
# Currently, this new input is added to make the calls
@ -194,6 +190,11 @@ def speculate_subgraph(
)
try:
f, sub_args, sub_kwargs = VariableTracker.apply(
# ensure guards on args get installed in parent subgraph
lambda x: x.realize(),
(f, sub_args, sub_kwargs),
)
with tx.output.subtracer(source_target, tracer) as subtracer:
args = validate_args_and_maybe_create_graph_inputs(
sub_args, subtracer, tx, manually_set_subgraph_inputs
@ -247,7 +248,6 @@ def speculate_subgraph(
"HigherOrderOperator body's output must consist of tensors only"
)
tx.output.guards.update(output.guards)
# The output proxies might not belong to this SubgraphTracer
# (if they are free variables that were never lifted)
# so lift them here.
@ -411,7 +411,6 @@ class CondHigherOrderVariable(TorchHigherOrderOperatorVariable):
f"item but got {str(type(args[0]))} "
f"with original python type {str(args[0].python_type())}.",
)
tx.output.guards.update(args[0].guards)
# operands
if not isinstance(args[3], (ListVariable, TupleVariable)):
@ -1116,6 +1115,7 @@ class AutogradFunctionMethodHigherOrderVariable(TorchHigherOrderOperatorVariable
else:
fn = TorchVariable(self.value)
checkpoint = tx.copy_graphstate()
# TODO(jansel): BUG!!! we aren't copying on the line below, so the post-pre check below is pointless
pre_guards = tx.output.guards
graph_checkpoint = tx.output.graph

View File

@ -23,7 +23,6 @@ class LazyCache:
self.vt.parents_tracker.add(parents_tracker)
del self.value
del self.source
tx.output.guards.update(self.vt.guards)
class LazyVariableTracker(VariableTracker):
@ -79,8 +78,6 @@ class LazyVariableTracker(VariableTracker):
return getattr(self.realize(), item)
# most methods are auto-generated below, these are the ones we want to exclude
add_guards = VariableTracker.add_guards
add_guard = VariableTracker.add_guard
add_options = VariableTracker.add_options
apply = VariableTracker.apply
copy = VariableTracker.copy

View File

@ -48,16 +48,11 @@ class BaseListVariable(VariableTracker):
def __init__(
self,
items: List[VariableTracker],
regen_guards=True,
**kwargs,
):
super().__init__(**kwargs)
assert isinstance(items, list)
assert all(isinstance(x, VariableTracker) for x in items)
# Sometimes, we know that we have passed in the guards from the items in the list
if regen_guards:
self.guards.update(VariableTracker.propagate(items)["guards"])
self.items: List[VariableTracker] = items
def _as_proxy(self):
@ -246,7 +241,6 @@ class CommonListMethodsVariable(BaseListVariable):
self,
type(self)(
self.items + [arg],
regen_guards=False,
**options,
),
)
@ -263,7 +257,6 @@ class CommonListMethodsVariable(BaseListVariable):
self,
type(self)(
list(self.items) + list(arg.unpack_var_sequence(tx)),
regen_guards=False,
**options,
),
)
@ -274,7 +267,7 @@ class CommonListMethodsVariable(BaseListVariable):
items.insert(idx.as_python_constant(), value)
return tx.replace_all(
self,
type(self)(items, regen_guards=False, **options),
type(self)(items, **options),
)
elif name == "pop" and self.mutable_local:
assert not kwargs
@ -282,14 +275,14 @@ class CommonListMethodsVariable(BaseListVariable):
result = items.pop(*[a.as_python_constant() for a in args])
tx.replace_all(
self,
type(self)(items, regen_guards=False, **options),
type(self)(items, **options),
)
return result
elif name == "clear" and self.mutable_local:
assert not kwargs and not args
return tx.replace_all(
self,
type(self)([], regen_guards=False, **options),
type(self)([], **options),
)
elif (
name == "__setitem__"
@ -304,16 +297,14 @@ class CommonListMethodsVariable(BaseListVariable):
items[key.as_python_constant()] = list(value.items)
else:
items[key.as_python_constant()] = value
result = ListVariable(items, regen_guards=False, **options)
result = ListVariable(items, **options)
return tx.replace_all(self, result)
elif name == "copy":
# List copy() doesn't have args and kwargs
assert not kwargs
assert not args
items = list(self.items)
return type(self)(
items, regen_guards=False, mutable_local=MutableLocal(), **options
)
return type(self)(items, mutable_local=MutableLocal(), **options)
else:
return super().call_method(tx, name, args, kwargs)
@ -351,7 +342,7 @@ class ListVariable(CommonListMethodsVariable):
items[key.as_python_constant()] = value.unpack_var_sequence(tx)
else:
items[key.as_python_constant()] = value
result = ListVariable(items, regen_guards=False, **options)
result = ListVariable(items, **options)
return tx.replace_all(self, result)
else:
return super().call_method(tx, name, args, kwargs)
@ -396,7 +387,7 @@ class DequeVariable(CommonListMethodsVariable):
)
items = list(self.items)
items[key.as_python_constant()] = value
result = DequeVariable(items, regen_guards=False, **options)
result = DequeVariable(items, **options)
return tx.replace_all(self, result)
elif name == "extendleft" and self.mutable_local:
assert not kwargs
@ -405,7 +396,6 @@ class DequeVariable(CommonListMethodsVariable):
self,
DequeVariable(
list(arg.unpack_var_sequence(tx)) + list(self.items),
regen_guards=False,
**options,
),
)
@ -416,7 +406,7 @@ class DequeVariable(CommonListMethodsVariable):
result = items.popleft()
tx.replace_all(
self,
DequeVariable(list(items), regen_guards=False, **options),
DequeVariable(list(items), **options),
)
return result
elif name == "appendleft" and self.mutable_local:
@ -425,7 +415,6 @@ class DequeVariable(CommonListMethodsVariable):
self,
DequeVariable(
[args[0]] + list(self.items),
regen_guards=False,
**options,
),
)

View File

@ -13,7 +13,7 @@ import torch._numpy as tnp
from .. import config, polyfill, variables
from ..bytecode_transformation import create_call_function, create_instruction
from ..exc import unimplemented
from ..guards import GuardBuilder
from ..guards import GuardBuilder, install_guard
from ..source import AttrSource, GetItemSource, ODictGetItemSource, TypeSource
from ..utils import (
check_constant_args,
@ -97,9 +97,8 @@ class SuperVariable(VariableTracker):
return GetAttrVariable(self, name, **options)
if source:
options["source"] = source
return variables.ConstantVariable.create(value, **options).add_guard(
source.make_guard(GuardBuilder.CONSTANT_MATCH)
)
install_guard(source.make_guard(GuardBuilder.CONSTANT_MATCH))
return variables.ConstantVariable.create(value, **options)
return variables.ConstantVariable.create(value, **options)
def call_method(

View File

@ -10,7 +10,7 @@ import torch.nn
from .. import skipfiles, variables
from ..allowed_functions import is_allowed
from ..exc import unimplemented, UnspecializeRestartAnalysis, Unsupported
from ..guards import GuardBuilder
from ..guards import GuardBuilder, install_guard
from ..mutation_guard import GenerationTracker
from ..source import (
AttrSource,
@ -127,11 +127,12 @@ class NNModuleVariable(VariableTracker):
options = VariableTracker.propagate(self)
mod = tx.output.get_submodule(self.module_key)
result = hasattr(mod, name)
return variables.ConstantVariable.create(result, **options).add_guard(
install_guard(
NNModuleSource(AttrSource(self.source, name)).make_guard(
GuardBuilder.HASATTR
)
)
return variables.ConstantVariable.create(result, **options)
def is_training(self, tx):
mod = tx.output.get_submodule(self.module_key)
@ -167,7 +168,6 @@ class NNModuleVariable(VariableTracker):
from .builder import VariableBuilder
options = VariableTracker.propagate(self)
guards = options.get("guards", set())
if self.source:
source = AttrSource(self.source, name)
@ -220,13 +220,12 @@ class NNModuleVariable(VariableTracker):
if istype(subobj, property):
return variables.UserFunctionVariable(
subobj.fget,
guards=guards,
source=source,
).call_function(tx, [(self)], {})
elif istype(subobj, classmethod):
return variables.UserMethodVariable(
subobj.__func__,
variables.UserDefinedObjectVariable(type(base), guards=guards),
variables.UserDefinedObjectVariable(type(base)),
**options,
)
elif istype(subobj, staticmethod):
@ -616,7 +615,7 @@ class NNModuleVariable(VariableTracker):
):
# Inline the function
fn = getattr(module, name).__func__
fn_source = AttrSource(self.source, "__func__")
fn_source = AttrSource(AttrSource(self.source, name), "__func__")
options["source"] = fn_source
return tx.inline_user_function_return(
variables.UserFunctionVariable(fn, **options),
@ -759,7 +758,7 @@ class UnspecializedNNModuleVariable(UserDefinedObjectVariable):
assert not args or kwargs
if tx.output.side_effects.has_pending_mutation(self):
unimplemented("Module.parameters() with pending mutation")
options["guards"].add(
install_guard(
self.source.make_guard(GuardBuilder.NN_MODULE_PARAM_NAMES)
)
items = []

View File

@ -4,7 +4,7 @@ from typing import Dict, List
import torch
from ..decorators import mark_static_address
from ..guards import GuardBuilder
from ..guards import GuardBuilder, install_guard
from ..source import AttrSource, GetItemSource, GlobalWeakRefSource
from ..utils import global_key_name
@ -126,13 +126,12 @@ class OptimizerVariable(UserDefinedObjectVariable):
# state guards take a long time to generate
# so we manually generate them here
guards = set()
state_source = AttrSource(self.source, "state")
guards.add(state_source.make_guard(GuardBuilder.DICT_KEYS))
install_guard(state_source.make_guard(GuardBuilder.DICT_KEYS))
for p, value in self.value.state.items():
tx.store_global_weakref(global_key_name(p), p)
p_state_source = GetItemSource(state_source, self.tensor_to_source[p])
guards.add(p_state_source.make_guard(GuardBuilder.DICT_KEYS))
install_guard(p_state_source.make_guard(GuardBuilder.DICT_KEYS))
for k, v in value.items():
if (
isinstance(v, torch.Tensor)
@ -141,7 +140,7 @@ class OptimizerVariable(UserDefinedObjectVariable):
):
self.tensor_to_source[v] = GetItemSource(p_state_source, k)
elif v is None or isinstance(v, (bool, int, float, str)):
guards.add(
install_guard(
GetItemSource(p_state_source, k).make_guard(
GuardBuilder.CONSTANT_MATCH
)
@ -149,12 +148,10 @@ class OptimizerVariable(UserDefinedObjectVariable):
else:
raise GuardInstallException()
tx.output.guards.update(guards)
group_guards = VariableBuilder(tx, AttrSource(self.source, "param_groups"))(
# this next line has the side effect of installing guards
VariableBuilder(tx, AttrSource(self.source, "param_groups"))(
self.value.param_groups
)
tx.output.guards.update(group_guards.guards)
).recursive_realize()
def wrap_tensor(self, tx, tensor_value):
"""Wrap state tensor in a TensorVariable"""

View File

@ -1,4 +1,5 @@
import functools
import inspect
import operator
import types
@ -31,7 +32,7 @@ from .. import config, variables
from .._trace_wrapped_higher_order_op import trace_wrapped
from ..exc import unimplemented, UserError, UserErrorType
from ..guards import GuardBuilder
from ..guards import GuardBuilder, install_guard
from ..source import AttrSource
from ..utils import (
fqn,
@ -206,12 +207,8 @@ class TensorVariable(VariableTracker):
from .builder import VariableBuilder
attr_source = AttrSource(self.source, name)
has_attr_guard = attr_source.make_guard(GuardBuilder.HASATTR)
return (
VariableBuilder(tx, attr_source)(real_value)
.add_options(self)
.add_guard(has_attr_guard)
)
install_guard(attr_source.make_guard(GuardBuilder.HASATTR))
return VariableBuilder(tx, attr_source)(real_value).add_options(self)
def var_getattr(self, tx, name):
from . import ConstantVariable, TorchVariable
@ -254,7 +251,7 @@ class TensorVariable(VariableTracker):
# In some cases, a <tensor>.<attr> guard can be evaluated first, and break if
# <tensor> is later changed to another type
if result is not None and self.source is not None:
result = result.add_guard(self.make_guard(GuardBuilder.TYPE_MATCH))
install_guard(self.make_guard(GuardBuilder.TYPE_MATCH))
# It's hard to get inplace view (metadata mutation) on graph input work properly across
# dynamo/aot/inductor, just fall back.
@ -607,7 +604,6 @@ class TensorVariable(VariableTracker):
unimplemented(
"boolean masking setitem backwards requires dynamic shapes"
)
tx.output.guards.update(options["guards"])
tx.output.create_proxy(
"call_function",
operator.setitem,

View File

@ -8,6 +8,7 @@ import types
from typing import Dict, List
from torch._streambase import _StreamBase
from ..guards import install_guard
try:
import numpy as np
@ -159,10 +160,10 @@ class TorchCtxManagerClassVariable(VariableTracker):
@classmethod
def create_with_source(cls, value, source):
install_guard(source.make_guard(GuardBuilder.FUNCTION_MATCH))
return TorchCtxManagerClassVariable(
value,
source=source,
guards={source.make_guard(GuardBuilder.FUNCTION_MATCH)},
)
def __init__(self, value, **kwargs):
@ -259,6 +260,9 @@ class TorchVariable(VariableTracker):
except RuntimeError as e:
assert "No such operator" in str(e), str(e)
self_should_be_none = None
except AssertionError as e:
assert "Unknown attribute" in str(e), str(e)
self_should_be_none = None
# assert "_ntuple.<locals>.parse" not in str(value)
@ -425,18 +429,18 @@ class TorchVariable(VariableTracker):
return self._call_ntuple(tx, args, kwargs, options)
elif self.value is torch.is_grad_enabled:
assert not (args or kwargs)
return ConstantVariable.create(
torch.is_grad_enabled(), **options
).add_guards(GradModeVariable._guards_singleton)
install_guard(GradModeVariable._guards_singleton)
return ConstantVariable.create(torch.is_grad_enabled(), **options)
elif self.value is torch.use_deterministic_algorithms and len(args) == 1:
return DeterministicAlgorithmsVariable.create(
tx, args[0].as_python_constant(), **options
)
elif self.value is torch.are_deterministic_algorithms_enabled:
assert not (args or kwargs)
install_guard(DeterministicAlgorithmsVariable._guards_singleton)
return ConstantVariable.create(
torch.are_deterministic_algorithms_enabled(), **options
).add_guards(DeterministicAlgorithmsVariable._guards_singleton)
)
elif self.value is torch.autograd.graph.disable_saved_tensors_hooks:
assert len(args) == 1
return DisabledSavedTensorsHooksVariable.create(
@ -444,9 +448,8 @@ class TorchVariable(VariableTracker):
)
elif self.value is torch._C._is_torch_function_enabled:
assert not (args or kwargs)
return ConstantVariable.create(
tx.output.torch_function_enabled, **options
).add_guards(TorchFunctionDisableVariable._guards_singleton)
install_guard(TorchFunctionDisableVariable._guards_singleton)
return ConstantVariable.create(tx.output.torch_function_enabled, **options)
elif self.value in (
torch.overrides.has_torch_function_variadic,
torch.overrides.has_torch_function_unary,

View File

@ -5,6 +5,7 @@ import torch.utils._pytree as pytree
from torch.overrides import _get_overloaded_args, get_default_nowrap_functions
from ..exc import unimplemented
from ..guards import GuardBuilder, install_guard
from ..source import AttrSource, GlobalSource
from ..utils import is_tensor_base_attr_getter
from .base import VariableTracker
@ -133,14 +134,15 @@ class TensorWithTFOverrideVariable(TensorVariable):
kwargs.pop("class_type") is torch.Tensor
), "invalid class type in TensorWithTFOverrideVariable.from_tensor_var"
var = cls(torch_function_fn=torch_function_fn, class_type=class_type, **kwargs)
var.install_global(tx)
return var
def install_global(self, tx):
# stash the subclass type to rewrap an output tensor if needed
# this is needed because the actual type needs to be available
# each time the compiled artifact is run and outputs a wrapped tensor.
if var.global_mangled_class_name() not in tx.output.global_scope:
tx.output.install_global(var.global_mangled_class_name(), class_type)
return var
if self.global_mangled_class_name() not in tx.output.global_scope:
tx.output.install_global(self.global_mangled_class_name(), self.class_type)
def python_type(self):
return self.class_type
@ -157,7 +159,7 @@ class TensorWithTFOverrideVariable(TensorVariable):
# [Note: __torch_function__] We currently only support attributes that are defined on
# base tensors, custom attribute accesses will graph break.
import torch
from .builder import SourcelessBuilder, VariableBuilder
from .builder import SourcelessBuilder
if name in banned_attrs or not hasattr(torch.Tensor, name):
unimplemented(
@ -172,14 +174,11 @@ class TensorWithTFOverrideVariable(TensorVariable):
if tx.output.torch_function_enabled:
if self.source:
get_fn = VariableBuilder(
tx,
source=AttrSource(
AttrSource(AttrSource(self.source, "__class__"), name),
"__get__",
),
)(inspect.getattr_static(self.python_type(), name).__get__)
else:
install_guard(
AttrSource(AttrSource(self.source, "__class__"), name).make_guard(
GuardBuilder.FUNCTION_MATCH
)
)
get_fn = SourcelessBuilder()(tx, getattr(torch.Tensor, name).__get__)
return self.call_torch_function(

View File

@ -17,7 +17,7 @@ from torch._guards import TracingContext
from .. import variables
from ..allowed_functions import is_allowed
from ..exc import unimplemented
from ..guards import GuardBuilder
from ..guards import GuardBuilder, install_guard
from ..source import AttrSource, ODictGetItemSource, RandomValueSource
from ..utils import (
all_hook_names,
@ -266,9 +266,10 @@ class UserDefinedObjectVariable(UserDefinedVariable):
assert not (args or kwargs)
keys = list(self.value.keys())
assert all(map(ConstantVariable.is_literal, keys))
install_guard(self.source.make_guard(GuardBuilder.ODICT_KEYS))
return TupleVariable(
[ConstantVariable.create(k, **options) for k in keys], **options
).add_guard(self.source.make_guard(GuardBuilder.ODICT_KEYS))
)
if (
method in (collections.OrderedDict.__contains__, dict.__contains__)
@ -278,9 +279,10 @@ class UserDefinedObjectVariable(UserDefinedVariable):
in (collections.OrderedDict.keys, dict.keys)
):
assert not kwargs
install_guard(self.source.make_guard(GuardBuilder.ODICT_KEYS))
return ConstantVariable.create(
args[0].as_python_constant() in self.value, **options
).add_guard(self.source.make_guard(GuardBuilder.ODICT_KEYS))
)
if (
method is collections.OrderedDict.items
@ -376,20 +378,15 @@ class UserDefinedObjectVariable(UserDefinedVariable):
)
):
options = VariableTracker.propagate(self, args, kwargs.values())
options.setdefault("guards", set())
if self.source:
options["guards"].add(
AttrSource(self.source, "func").make_guard(GuardBuilder.ID_MATCH)
)
options["guards"].add(
install_guard(
AttrSource(self.source, "func").make_guard(GuardBuilder.ID_MATCH),
AttrSource(self.source, "args").make_guard(
GuardBuilder.CONSTANT_MATCH
)
)
options["guards"].add(
),
AttrSource(self.source, "keywords").make_guard(
GuardBuilder.CONSTANT_MATCH
)
),
)
partial_args = [
@ -410,7 +407,7 @@ class UserDefinedObjectVariable(UserDefinedVariable):
tx, partial_args, partial_kwargs
)
elif callable(self.value):
self.add_guard(self.source.make_guard(GuardBuilder.FUNCTION_MATCH))
install_guard(self.source.make_guard(GuardBuilder.FUNCTION_MATCH))
return self.call_method(tx, "__call__", args, kwargs)
return super().call_function(tx, args, kwargs)
@ -578,7 +575,7 @@ class UserDefinedObjectVariable(UserDefinedVariable):
pass
options = VariableTracker.propagate(self)
if self.source:
options["guards"].add(
install_guard(
AttrSource(self.source, name).make_guard(GuardBuilder.HASATTR)
)
if self._check_for_getattribute() or self._check_for_getattr():

View File

@ -241,7 +241,13 @@ class Guard:
return output
def create(self, builder: GuardBuilderBase):
try:
return self.create_fn(builder, self)
except Exception:
log.error("Error while creating guard:\n%s", str(self).rstrip())
if self.stack:
log.error("Created at:\n%s", "".join(self.stack.format()[-4:]).rstrip())
raise
def is_nn_module(self):
return self.source.is_nn_module()