Revert "[RELAND][dynamo][nn-modules] Trace through nn.Module dunder methods for UnspecializedNNModule (#126578)"

This reverts commit b2d602306a9eb19e30328cbaee941c874f8148a9.

Reverted https://github.com/pytorch/pytorch/pull/126578 on behalf of https://github.com/clee2000 due to failed internal test D58394084.  Author has forward fix but includes external changes so reverting is a bit easier to coordinate ([comment](https://github.com/pytorch/pytorch/pull/126578#issuecomment-2161481839))
This commit is contained in:
PyTorch MergeBot
2024-06-11 19:41:41 +00:00
parent 45dccfddcd
commit adb699189b
50 changed files with 72 additions and 169 deletions

View File

@ -1084,14 +1084,12 @@ class TestSingleProc(DynamoDistributedSingleProcTestCase):
# far from an exhaustive check of all the expected guards, just check a couple of them.
FileCheck().check("""local "L['self']" TYPE_MATCH""").check(
"""local "L['self']" ID_MATCH"""
).check(f"""{expected_guard_source} "L['self'].net" TYPE_MATCH""").check(
f"""{expected_guard_source} "L['self'].net" ID_MATCH"""
).check(
f"""{expected_guard_source} "L['self']._modules['net']" TYPE_MATCH"""
f"""{expected_guard_source} "L['self'].net[0]" TYPE_MATCH"""
).check(
f"""{expected_guard_source} "L['self']._modules['net']" ID_MATCH"""
).check(
f"""{expected_guard_source} "L['self']._modules['net']._modules['0']" TYPE_MATCH"""
).check(
f"""{expected_guard_source} "L['self']._modules['net']._modules['1']" ID_MATCH"""
f"""{expected_guard_source} "L['self'].net[0]" ID_MATCH"""
).run(
GUARDS_FILE.getvalue()
)

View File

@ -5187,10 +5187,10 @@ class GraphModule(torch.nn.Module):
actual,
"""\
class GraphModule(torch.nn.Module):
def forward(self, L_self_buffers_tensor_constant0_: "f32[3, 3, 3]"):
l_self_buffers_tensor_constant0_ = L_self_buffers_tensor_constant0_
def forward(self, L_self_tensor_constant0: "f32[3, 3, 3]"):
l_self_tensor_constant0 = L_self_tensor_constant0
alias_default: "f32[3, 3, 3]" = torch.ops.aten.alias.default(l_self_buffers_tensor_constant0_); l_self_buffers_tensor_constant0_ = None
alias_default: "f32[3, 3, 3]" = torch.ops.aten.alias.default(l_self_tensor_constant0); l_self_tensor_constant0 = None
sin_default: "f32[3, 3, 3]" = torch.ops.aten.sin.default(alias_default)
@ -5209,16 +5209,16 @@ class GraphModule(torch.nn.Module):
actual,
"""\
class GraphModule(torch.nn.Module):
def forward(self, L_self_modules_FX_CONST_FOLDED_ATTRS_parameters_0_: "f32[3, 3, 3]", L_self_modules_FX_CONST_FOLDED_ATTRS_parameters_1_: "f32[3, 3, 3]", L_flat_tangents_1_: "f32[3, 3, 3]"):
l_self_modules_fx_const_folded_attrs_parameters_0_ = L_self_modules_FX_CONST_FOLDED_ATTRS_parameters_0_
l_self_modules_fx_const_folded_attrs_parameters_1_ = L_self_modules_FX_CONST_FOLDED_ATTRS_parameters_1_
def forward(self, getattr_L_self_FX_CONST_FOLDED_ATTRS_0_: "f32[3, 3, 3]", getattr_L_self_FX_CONST_FOLDED_ATTRS_1_: "f32[3, 3, 3]", L_flat_tangents_1_: "f32[3, 3, 3]"):
getattr_l_self_fx_const_folded_attrs_0_ = getattr_L_self_FX_CONST_FOLDED_ATTRS_0_
getattr_l_self_fx_const_folded_attrs_1_ = getattr_L_self_FX_CONST_FOLDED_ATTRS_1_
l_flat_tangents_1_ = L_flat_tangents_1_
_new_zeros_with_same_feature_meta_default: "f32[3, 3, 3]" = torch.ops.aten._new_zeros_with_same_feature_meta.default(l_flat_tangents_1_, l_self_modules_fx_const_folded_attrs_parameters_0_); l_self_modules_fx_const_folded_attrs_parameters_0_ = None
_new_zeros_with_same_feature_meta_default: "f32[3, 3, 3]" = torch.ops.aten._new_zeros_with_same_feature_meta.default(l_flat_tangents_1_, getattr_l_self_fx_const_folded_attrs_0_); getattr_l_self_fx_const_folded_attrs_0_ = None
copy__default: "f32[3, 3, 3]" = torch.ops.aten.copy_.default(_new_zeros_with_same_feature_meta_default, l_flat_tangents_1_); _new_zeros_with_same_feature_meta_default = l_flat_tangents_1_ = None
mul_tensor: "f32[3, 3, 3]" = torch.ops.aten.mul.Tensor(copy__default, l_self_modules_fx_const_folded_attrs_parameters_1_); copy__default = l_self_modules_fx_const_folded_attrs_parameters_1_ = None
mul_tensor: "f32[3, 3, 3]" = torch.ops.aten.mul.Tensor(copy__default, getattr_l_self_fx_const_folded_attrs_1_); copy__default = getattr_l_self_fx_const_folded_attrs_1_ = None
return (mul_tensor,)
""",
)

View File

@ -2411,7 +2411,6 @@ aten::mm""",
num_matched.append(len(pattern.matched_events()))
self.assertEqual(num_matched, [i for i, _ in cases])
@skipIfTorchDynamo("profiler gets ignored if dynamo activated")
def test_profiler_pattern_matcher_json_report(self):
x = torch.ones((100, 100))
model = nn.Sequential(

View File

@ -1,7 +1,4 @@
# mypy: allow-untyped-defs
import threading
from contextlib import contextmanager
import torch
doc = """
@ -40,20 +37,3 @@ def new_parameter_placeholder(size, dtype, device, requires_grad):
# Allocating a zero tensor would causes assert failures in autograd.
result.untyped_storage().resize_(0)
return result
_TLS = threading.local()
@contextmanager
def do_not_convert_to_tracable_parameter():
old_flag = getattr(_TLS, "convert_tracable_parameter", True)
_TLS.convert_tracable_parameter = False
try:
yield False
finally:
_TLS.convert_tracable_parameter = old_flag
def can_convert_to_tracable_parameter():
return getattr(_TLS, "convert_tracable_parameter", True)

View File

@ -11,9 +11,6 @@ from . import config
from .utils import ExactWeakKeyDictionary, is_lazy_module, nn_module_has_global_hooks
unpatched_nn_module_init = torch.nn.Module.__init__
class MutationTracker:
db = ExactWeakKeyDictionary()

View File

@ -347,7 +347,13 @@ class SideEffects:
elif isinstance(var.mutable_local, AttributeMutationNew):
if isinstance(var, variables.AutogradFunctionContextVariable):
unimplemented("AutogradFunctionContextVariable escaped")
cg.load_import_from(utils.__name__, "object_new")
if "__call_nn_module_init" in self.store_attr_mutations.get(
var.mutable_local, {}
):
assert isinstance(var, variables.UnspecializedNNModuleVariable)
cg.load_import_from(utils.__name__, "nn_module_new")
else:
cg.load_import_from(utils.__name__, "object_new")
cg(var.mutable_local.cls_source)
cg.extend_output(create_call_function(1, True))
cg.add_cache(var)
@ -474,25 +480,9 @@ class SideEffects:
]
)
elif self.is_attribute_mutation(var):
# Applying mutations involves two steps: 1) Push all
# reconstructed objects onto the stack. 2) Call STORE_ATTR to
# apply the mutations.
#
# Dynamo must ensure that mutations are applied in the same
# order as in the original program. Therefore, two reverse
# operations occur below.
#
# The first reverse operation concerns `suffixes`. We apply
# suffixes in reverse order due to the way Python handles the
# stack. In Step 1, we push all reconstructed objects onto the
# stack, but the item at the top of the stack refers to the last
# attribute in the mutation order. If not fixed, this will apply
# the mutations of attributes in the reverse order. To account
# for this reversal, we iterate through the mutable attributes
# in reverse order.
for name, value in reversed(
self.store_attr_mutations.get(var.mutable_local, {}).items()
):
for name, value in self.store_attr_mutations.get(
var.mutable_local, {}
).items():
if isinstance(var, variables.NewGlobalVariable):
cg.tx.output.update_co_names(name)
cg(value)
@ -500,6 +490,8 @@ class SideEffects:
suffixes.append(
[create_instruction("STORE_GLOBAL", argval=name)]
)
elif name == "__call_nn_module_init":
pass # handled in codegen_save_tempvars
elif isinstance(value, variables.DeletedVariable):
if isinstance(
var.mutable_local, AttributeMutationExisting

View File

@ -416,15 +416,10 @@ def generic_jump(truth_fn: typing.Callable[[object], bool], push: bool):
self.push(value)
self.jump(inst)
elif isinstance(value, UserDefinedObjectVariable):
try:
x = value.var_getattr(self, "__bool__")
except exc.ObservedException:
# if __bool__ is missing, trying __len__ to infer a truth value.
x = value.var_getattr(self, "__bool__")
# if __bool__ is missing, trying __len__ to infer a truth value.
if isinstance(x, GetAttrVariable):
x = value.var_getattr(self, "__len__")
else:
if isinstance(x, GetAttrVariable):
# if __bool__ is missing, trying __len__ to infer a truth value.
x = value.var_getattr(self, "__len__")
# __bool__ or __len__ is function
if isinstance(x, UserMethodVariable):

View File

@ -2019,12 +2019,12 @@ def object_has_getattribute(value: Any):
return False
def get_custom_getattr(value: Any, ignore_nn_module_getattr: bool = False):
def get_custom_getattr(value: Any):
try:
getattr_fn = inspect.getattr_static(type(value), "__getattr__")
except AttributeError:
getattr_fn = None
if ignore_nn_module_getattr and getattr_fn is torch.nn.Module.__getattr__:
if getattr_fn is torch.nn.Module.__getattr__:
# ignore this case of getattr
getattr_fn = None
return getattr_fn

View File

@ -174,11 +174,7 @@ class ConstDictVariable(VariableTracker):
def __contains__(self, vt):
assert isinstance(vt, VariableTracker)
Hashable = ConstDictVariable._HashableTracker
return (
is_hashable(vt)
and Hashable(vt) in self.items
and not isinstance(self.items[Hashable(vt)], variables.DeletedVariable)
)
return is_hashable(vt) and Hashable(vt) in self.items
def reconstruct(self, codegen):
# instructions to load collections.OrderedDict if necessary

View File

@ -14,10 +14,8 @@ import torch._numpy as tnp
import torch.utils._pytree as pytree
from .. import config, variables
from ..bytecode_transformation import create_call_function, create_instruction
from ..create_parameter_op import do_not_convert_to_tracable_parameter
from ..exc import unimplemented
from ..guards import GuardBuilder, install_guard
from ..mutation_guard import unpatched_nn_module_init
from ..source import AttrSource, GetItemSource, ODictGetItemSource, TypeSource
from ..utils import (
check_unspec_or_constant_args,
@ -123,6 +121,7 @@ class SuperVariable(VariableTracker):
kwargs: "Dict[str, VariableTracker]",
) -> "VariableTracker":
inner_fn, source = self._resolved_getattr_and_source(self, name)
if inner_fn is object.__init__:
return LambdaVariable(identity)
elif inner_fn is torch.nn.Module.__init__:
@ -134,10 +133,12 @@ class SuperVariable(VariableTracker):
and isinstance(objvar.mutable_local, AttributeMutationNew)
and not (args or kwargs)
):
with do_not_convert_to_tracable_parameter():
return variables.UserFunctionVariable(
unpatched_nn_module_init, source=source
).call_function(tx, [self.objvar] + args, kwargs)
tx.output.side_effects.store_attr(
objvar,
"__call_nn_module_init",
variables.ConstantVariable.create(True),
)
return variables.ConstantVariable.create(None)
else:
unimplemented("super() nn.Module.__init__")
elif isinstance(inner_fn, types.FunctionType):
@ -174,19 +175,6 @@ class SuperVariable(VariableTracker):
self.objvar, UserDefinedObjectVariable
):
return self.objvar.method_setattr_standard(tx, *args, **kwargs)
elif inner_fn is object.__delattr__:
attr = args[0]
try:
attr = attr.as_python_constant()
except NotImplementedError:
unimplemented(f"non-const delattr attr: {attr}")
if not tx.output.side_effects.is_attribute_mutation(self.objvar):
unimplemented(f"delattr({self.objvar}, {attr}, ...)")
tx.output.side_effects.store_attr(
self.objvar, attr, variables.DeletedVariable()
)
return variables.ConstantVariable(None)
unimplemented(f"non-function or method super: {inner_fn}")

View File

@ -215,7 +215,7 @@ class NNModuleVariable(VariableTracker):
if object_has_getattribute(base):
unimplemented("torch.nn.Module with a custom __getattribute__ defined")
getattr_fn = get_custom_getattr(base, ignore_nn_module_getattr=True)
getattr_fn = get_custom_getattr(base)
if getattr_fn is None:
return None
@ -665,6 +665,7 @@ class NNModuleVariable(VariableTracker):
if isinstance(args[0], SliceVariable):
# Build a TupleVariable of NNModules
result = []
submods = []
# Turn the slice into the list of integers
keys = list(range(len(module)))[args[0].as_python_constant()]
@ -678,8 +679,9 @@ class NNModuleVariable(VariableTracker):
source=src,
)
)
submods.append(submod)
new_module = module[args[0].as_python_constant()]
new_module = torch.nn.Sequential(*submods)
new_module_variable = tx.output.register_attr_or_module(
new_module,
f"{self}.__getitem__(slice)",
@ -693,10 +695,8 @@ class NNModuleVariable(VariableTracker):
if isinstance(args[0], SymNodeVariable):
key = args[0].evaluate_expr(tx.output)
elif args[0].is_python_constant():
key = args[0].as_python_constant()
else:
unimplemented(f"getitem on NNModuleVariable with key {args[0]}")
key = args[0].as_python_constant()
submod = module[key]
return tx.output.register_attr_or_module(
@ -790,7 +790,7 @@ class UnspecializedNNModuleVariable(UserDefinedObjectVariable):
@functools.lru_cache(None)
def _nn_module_method_ids():
# Allow __setattr__ to fall through to base class handler
supported = {torch.nn.Module.__setattr__, torch.nn.Module.__init__}
supported = {torch.nn.Module.__setattr__}
return {
id(x.__code__)
for x in torch.nn.Module.__dict__.values()
@ -798,6 +798,8 @@ class UnspecializedNNModuleVariable(UserDefinedObjectVariable):
}
def unpack_var_sequence(self, tx):
from .builder import VariableBuilder
try:
fn = inspect.getattr_static(self.value_type, "__iter__")
except AttributeError as e:
@ -808,16 +810,11 @@ class UnspecializedNNModuleVariable(UserDefinedObjectVariable):
torch.nn.ParameterList.__iter__,
torch.nn.Sequential.__iter__,
):
# The program can mutate the nn module object but the saved `value`
# will not reflect the mutations. So, trace through the `__iter__`
# function to reflect any tracked mutations.
return tx.inline_user_function_return(
variables.UserFunctionVariable(fn),
[
self,
],
{},
).unpack_var_sequence(tx)
assert self.source
return [
VariableBuilder(tx, source=GetItemSource(self.source, idx))(item)
for idx, item in enumerate(self.value)
]
return super().unpack_var_sequence(tx)
@ -946,17 +943,6 @@ class UnspecializedNNModuleVariable(UserDefinedObjectVariable):
# Handle submodules
self.is_state_mutated = True
if method is torch.nn.Module.__setattr__ and isinstance(
args[1], variables.DeletedVariable
):
# Trace through __delattr__ to track mutations on the module
# members like `_modules``.
return tx.inline_user_function_return(
variables.UserFunctionVariable(torch.nn.Module.__delattr__),
[self, args[0]],
kwargs,
)
return super().call_method(tx, name, args, kwargs)

View File

@ -18,11 +18,7 @@ from torch._streambase import _StreamBase
from ..._guards import TracingContext
from .. import config, polyfill, variables
from ..codegen import PyCodegen
from ..create_parameter_op import (
can_convert_to_tracable_parameter,
new_parameter_placeholder,
tracable_create_parameter,
)
from ..create_parameter_op import new_parameter_placeholder, tracable_create_parameter
from ..device_interface import get_registered_device_interfaces
from ..exc import unimplemented
from ..guards import GuardBuilder, install_guard
@ -875,9 +871,6 @@ Either create the tensor outside the compiled region, or do not set the tensor t
if data.source:
return cls._nn_param_via_prefix_insert(tx, data, requires_grad)
if not can_convert_to_tracable_parameter():
unimplemented("Workaround for issues with nn_parameter construction")
try:
shape = tuple(data.var_getattr(tx, "shape").as_python_constant())
dtype = data.var_getattr(tx, "dtype").as_python_constant()

View File

@ -34,8 +34,7 @@ import torch.nn
from torch._guards import TracingContext
from .. import variables
from ..create_parameter_op import do_not_convert_to_tracable_parameter
from ..exc import ObservedException, unimplemented
from ..exc import unimplemented
from ..guards import GuardBuilder, install_guard
from ..source import AttrSource, GetItemSource, ODictGetItemSource, RandomValueSource
from ..utils import (
@ -58,7 +57,10 @@ from .dicts import DefaultDictVariable
def is_standard_setattr(val):
return val in (object.__setattr__,)
return val in (
object.__setattr__,
torch.nn.Module.__setattr__,
)
class UserDefinedVariable(VariableTracker):
@ -376,7 +378,17 @@ class UserDefinedClassVariable(UserDefinedVariable):
else UserDefinedObjectVariable,
{},
)
with do_not_convert_to_tracable_parameter():
if (
inspect.getattr_static(self.value, "__init__", None)
is torch.nn.Module.__init__
):
tx.output.side_effects.store_attr(
var,
"__call_nn_module_init",
variables.ConstantVariable.create(True),
)
return var
else:
var.call_method(tx, "__init__", args, kwargs)
return var
elif variables.CustomizedDictVariable.is_matching_cls(self.value):
@ -626,10 +638,6 @@ class UserDefinedObjectVariable(UserDefinedVariable):
else AttrSource(AttrSource(self.source, "__class__"), name)
)
# TODO(jansel): add a guard to check for monkey patching?
from ..mutation_guard import unpatched_nn_module_init
if method is torch.nn.Module.__init__:
method = unpatched_nn_module_init
return UserMethodVariable(method, self, source=source).call_function(
tx, args, kwargs
)
@ -791,7 +799,7 @@ class UserDefinedObjectVariable(UserDefinedVariable):
def _getattr_static(self, name):
if (
isinstance(self.value, PyTreeSpec)
isinstance(self.value, (torch.nn.Module, PyTreeSpec))
or "__slots__" in self.value.__class__.__dict__
or type(self.value) == threading.local
):
@ -804,6 +812,7 @@ class UserDefinedObjectVariable(UserDefinedVariable):
return cls_var
except AttributeError:
pass # __slots__
# this might call torch.nn.Module.__getattr__
subobj = getattr(self.value, name)
else:
subobj = inspect.getattr_static(self.value, name)
@ -1009,35 +1018,14 @@ class UserDefinedObjectVariable(UserDefinedVariable):
install_guard(
AttrSource(self.source, name).make_guard(GuardBuilder.HASATTR)
)
if self._check_for_getattribute():
unimplemented("hasattr with custom __getattribute__")
if self._check_for_getattribute() or self._check_for_getattr():
unimplemented("hasattr with custom __getattr__")
try:
self._getattr_static(name)
return variables.ConstantVariable.create(True)
except AttributeError:
# Now check in __getattr__ function
getattr_fn = self._check_for_getattr()
if isinstance(getattr_fn, types.FunctionType):
# Dynamo is going to trace the __getattr__ function with
# args=name. Set the source accordingly.
new_source = None
if self.source:
new_source = AttrSource(self.source, "__getattr__")
try:
result = variables.UserMethodVariable(
getattr_fn, self, source=new_source
).call_function(tx, [variables.ConstantVariable.create(name)], {})
return variables.ConstantVariable.create(
not isinstance(result, variables.DeletedVariable)
)
except ObservedException:
return variables.ConstantVariable.create(False)
elif getattr_fn is None:
return variables.ConstantVariable.create(False)
else:
unimplemented("UserDefined with non-function __getattr__")
return variables.ConstantVariable.create(False)
def odict_getitem(self, tx, key):
from .builder import VariableBuilder
@ -1104,12 +1092,6 @@ class KeyedJaggedTensorVariable(UserDefinedObjectVariable):
return super().var_getattr(tx, name)
class RemovableHandleClass:
# Dummy class to pass to python_type of RemovableHandleVariable
# Useful for isinstance check on hooks
pass
class RemovableHandleVariable(VariableTracker):
REMOVED = -1
@ -1140,6 +1122,3 @@ class RemovableHandleVariable(VariableTracker):
return
# unreachable due to codegen.add_cache() when the hook is installed
super().reconstruct(codegen)
def python_type(self):
return RemovableHandleClass