Compare commits

...

3 Commits

8 changed files with 64 additions and 66 deletions

View File

@ -5378,6 +5378,26 @@ def forward(self, s0 : torch.SymInt, s1 : torch.SymInt, L_x_ : torch.Tensor):
inps = gen_inps(3, 5)
self.assertEqual(g(*inps), opt_g(*inps))
def test_guard_with_tuple_mutation(self):
class Foo:
def __init__(self):
self.x = 10
foo = Foo()
d = {
"a": 2,
"b": (foo,),
}
def fn(x, d):
return x * d["a"] * d["b"][0].x
opt_fn = torch.compile(fn, backend="eager")
inp = torch.randn(3, 3)
self.assertEqual(fn(inp, d), opt_fn(inp, d))
d["b"][0].x = 12
self.assertEqual(fn(inp, d), opt_fn(inp, d))
instantiate_parametrized_tests(ReproTests)

View File

@ -4392,7 +4392,7 @@ graph():
inp = (torch.tensor(6), torch.randn(13))
self.assertTrue(torch.allclose(ep.module()(*inp), M()(*inp)))
@unittest.skip("Test is only supposed to work with non-strict mode")
@testing.expectedFailureTrainingIRToRunDecomp
def test_issue_113041(self):
class TestModule(torch.nn.Module):
def __init__(self):

View File

@ -85,7 +85,6 @@ from .source import (
GlobalStateSource,
GlobalWeakRefSource,
GradSource,
is_unspecialized_builtin_nnmodule_attr,
LocalSource,
NNModuleSource,
NumpyTensorSource,
@ -96,7 +95,6 @@ from .source import (
SubclassAttrListSource,
TupleIteratorGetItemSource,
TypeSource,
UnspecializedBuiltinNNModuleSource,
UnspecializedNNModuleSource,
WeakRefCallSource,
)
@ -867,7 +865,6 @@ class GuardBuilder(GuardBuilderBase):
NNModuleSource,
UnspecializedNNModuleSource,
FSDPNNModuleSource,
UnspecializedBuiltinNNModuleSource,
),
):
assert base_guard_manager # to make mypy happy
@ -1684,6 +1681,13 @@ class GuardBuilder(GuardBuilderBase):
else:
self._produce_guard_code(guard, code)
def EMPTY_NN_MODULE_HOOKS_DICT(self, guard):
"""Special guard to skip guards on empty hooks. This is controlled by skip_nnmodule_hook_guards"""
if config.skip_nnmodule_hook_guards:
# This is unsafe if you add/remove a hook on nn module variable
return
self.SEQUENCE_LENGTH(guard)
def OBJECT_MUTATION(self, guard: Guard):
mutation_guard.watch(self.get(guard.name), self.check_fn_manager)
@ -2112,22 +2116,6 @@ class DeletedGuardFn:
pass
def is_nn_module_hook(source: Source) -> bool:
# Note that we only skip guards on builtin nn modules like Conv2D etc. But still this is a soundness issue if one
# adds/removes a hook after the model is compiled.
return (
is_unspecialized_builtin_nnmodule_attr(source)
and isinstance(source, AttrSource)
and source.member
in (
"_backward_hooks",
"_backward_pre_hooks",
"_forward_hooks",
"_forward_pre_hooks",
)
)
# NB: Naively, you'd expect this to only be a function that produces
# the callable that constitutes the guard. However, there is some
# delicate handling for invalidating this check function when the
@ -2188,11 +2176,6 @@ class CheckFunctionManager:
):
continue
# This is unsafe if you add/remove a hook on unspecialized nn module variable
if config.skip_nnmodule_hook_guards and is_nn_module_hook(
guard.originating_source
):
continue
guard.create(builder)
self.check_fn = self.compile_check_fn(builder, guards, guard_fail_fn)

View File

@ -535,11 +535,6 @@ class UnspecializedNNModuleSource(NNModuleSource):
return _GUARD_SOURCE_NOT_NN_MODULE[self.base.guard_source()]
@dataclasses.dataclass(frozen=True)
class UnspecializedBuiltinNNModuleSource(UnspecializedNNModuleSource):
pass
@dataclasses.dataclass(frozen=True)
class FSDPNNModuleSource(NNModuleSource):
def guard_source(self):
@ -676,9 +671,3 @@ def is_from_defaults(source: Source):
def is_cell_contents(source: Source):
return isinstance(source, AttrSource) and source.member == "cell_contents"
def is_unspecialized_builtin_nnmodule_attr(source: Source):
return isinstance(source, AttrSource) and isinstance(
source.base, UnspecializedBuiltinNNModuleSource
)

View File

@ -80,11 +80,7 @@ from .misc import (
TypingVariable,
UnknownVariable,
)
from .nn_module import (
NNModuleVariable,
UnspecializedBuiltinNNModuleVariable,
UnspecializedNNModuleVariable,
)
from .nn_module import NNModuleVariable, UnspecializedNNModuleVariable
from .optimizer import OptimizerVariable
from .sdpa import SDPAParamsVariable
@ -160,7 +156,6 @@ __all__ = [
"TupleVariable",
"UnknownVariable",
"UnspecializedNNModuleVariable",
"UnspecializedBuiltinNNModuleVariable",
"UnspecializedPythonVariable",
"UntypedStorageVariable",
"UserDefinedClassVariable",

View File

@ -67,7 +67,6 @@ from ..source import (
is_constant_source,
is_from_defaults,
is_from_optimizer_source,
is_unspecialized_builtin_nnmodule_attr,
LocalSource,
NumpyTensorSource,
OptimizerSource,
@ -1087,16 +1086,6 @@ class VariableBuilder:
self.install_guards(GuardBuilder.CONSTANT_MATCH)
return ConstantVariable.create(value=value)
if (
self.source
and is_unspecialized_builtin_nnmodule_attr(self.source)
and type(value) is tuple
and all(ConstantVariable.is_literal(x) for x in value)
):
# Heuristic to speedup up guards coming from conv2d attrs like dilation and padding.
self.install_guards(GuardBuilder.CONSTANT_MATCH)
return TupleVariable([ConstantVariable.create(x) for x in value])
# One can index a tensor with a list/tuple. Therefore, we need to
# have a stricter match.
self.install_guards(GuardBuilder.SEQUENCE_LENGTH)

View File

@ -23,7 +23,6 @@ from ..source import (
FSDPNNModuleSource,
GetItemSource,
NNModuleSource,
UnspecializedBuiltinNNModuleSource,
UnspecializedNNModuleSource,
)
from ..utils import (
@ -1038,6 +1037,29 @@ class UnspecializedNNModuleVariable(UserDefinedObjectVariable):
return dict_vt.maybe_getitem_const(name_vt)
return None
def var_getattr(self, tx, name):
# Allow skipping of empty hook dict guards on inbuilt nn modules
if name in (
"_backward_hooks",
"_backward_pre_hooks",
"_forward_hooks",
"_forward_pre_hooks",
):
if not tx.output.side_effects.has_pending_mutation_of_attr(
self, name
) and self.value.__module__.startswith(("torch.nn.", "torch.ao.")):
hooks_dict = getattr(self.value, name)
if isinstance(hooks_dict, dict) and len(hooks_dict) == 0:
if self.source:
hooks_source = AttrSource(self.source, name)
install_guard(
hooks_source.make_guard(
GuardBuilder.EMPTY_NN_MODULE_HOOKS_DICT
)
)
return variables.ConstDictVariable({})
return super().var_getattr(tx, name)
def manually_trace_nn_module_getattr(self, tx, name):
"""
Dynamo tracing of nn.Module __getattr__ can be expensive if the model
@ -1093,12 +1115,3 @@ class FSDPManagedNNModuleVariable(UnspecializedNNModuleVariable):
value = FSDPManagedNNModuleVariable._wrap_source(value)
return super().__setattr__(name, value)
class UnspecializedBuiltinNNModuleVariable(UnspecializedNNModuleVariable):
# A subclass of UnspecializedNNModuleVariable to differentiate between user-defined and builtin nn modules.
def __setattr__(self, name: str, value: Any) -> None:
if name == "source":
value = UnspecializedBuiltinNNModuleSource(value)
return super().__setattr__(name, value)

View File

@ -836,9 +836,18 @@ std::string get_exception_message() {
}
bool is_immutable_object(py::handle example_value) {
return PyTuple_CheckExact(example_value.ptr()) ||
PyLong_Check(example_value.ptr()) || PyFloat_Check(example_value.ptr()) ||
PyBool_Check(example_value.ptr()) ||
if (PyTuple_Check(example_value.ptr())) {
// Check that each element is immutable
for (Py_ssize_t i = 0; i < PyTuple_Size(example_value.ptr()); ++i) {
if (!is_immutable_object(
py::handle(PyTuple_GetItem(example_value.ptr(), i)))) {
return false;
}
}
return true;
}
return PyLong_Check(example_value.ptr()) ||
PyFloat_Check(example_value.ptr()) || PyBool_Check(example_value.ptr()) ||
PyUnicode_Check(example_value.ptr()) ||
THPVariable_Check(example_value.ptr());
}