Revert "[dynamo][guards] Make class members go through obj.__class__.__dict__ (#159534)"

This reverts commit 1616777cd2a3170ff76afa3e7860b0969420c445.

Reverted https://github.com/pytorch/pytorch/pull/159534 on behalf of https://github.com/malfet due to Broke some inductor test and lint among other things, see 9c18901bfd/1 ([comment](https://github.com/pytorch/pytorch/pull/159534#issuecomment-3146983186))
This commit is contained in:
PyTorch MergeBot
2025-08-03 04:58:32 +00:00
parent 6e8d705a22
commit 805a102beb
9 changed files with 27 additions and 290 deletions

View File

@ -880,9 +880,8 @@ num_guards_executed=0)
counter += 1
class Bar:
def __init__(self):
self.x = 4
self.y = torch.randn(4)
x = 4
y = torch.randn(4)
bar = Bar()

View File

@ -54,9 +54,8 @@ class RunDiffGuardTests(torch._dynamo.test_case.TestCase):
def test_post_recompile(self):
class Foo:
def __init__(self):
self.a = 4
self.b = 5
a = 4
b = 5
foo = Foo()

View File

@ -139,8 +139,6 @@ class DictGuardManager(GuardManager):
class GuardAccessor: ...
class DictGetItemGuardAccessor(GuardAccessor): ...
class GetGenericDictGuardAccessor(GuardAccessor): ...
class TypeDictGuardAccessor(GuardAccessor): ...
class TypeMROGuardAccessor(GuardAccessor): ...
def install_object_aliasing_guard(
guard_managers: list[GuardManager],

View File

@ -132,8 +132,6 @@ from .source import (
TorchFunctionModeStackSource,
TorchSource,
TupleIteratorGetItemSource,
TypeDictSource,
TypeMROSource,
TypeSource,
UnspecializedBuiltinNNModuleSource,
UnspecializedNNModuleSource,
@ -864,9 +862,6 @@ class GuardBuilder(GuardBuilderBase):
self.guard_nn_modules = config.guard_nn_modules and justknobs_check(
"pytorch/compiler:guard_nn_modules"
)
self.already_guarded_not_present_in_generic_dict: OrderedSet[
tuple[str, str]
] = OrderedSet()
def guard_on_dict_keys_and_ignore_order(self, example_value, guard):
dict_mgr = self.get_guard_manager(guard)
@ -1214,20 +1209,6 @@ class GuardBuilder(GuardBuilderBase):
example_value=example_value,
guard_manager_enum=guard_manager_enum,
)
elif istype(source, TypeDictSource):
assert base_guard_manager # to make mypy happy
out = base_guard_manager.type_dict_manager(
source=source_name,
example_value=example_value,
guard_manager_enum=guard_manager_enum,
)
elif istype(source, TypeMROSource):
assert base_guard_manager # to make mypy happy
out = base_guard_manager.type_mro_manager(
source=source_name,
example_value=example_value,
guard_manager_enum=guard_manager_enum,
)
elif istype(
source,
(
@ -1653,12 +1634,10 @@ class GuardBuilder(GuardBuilderBase):
assert attr is not None
ref = self.arg_ref(guard)
val = self.get(guard.name)
assert isinstance(val, torch.nn.Module)
base_manager = self.get_guard_manager(guard)
if (ref, attr) in self.already_guarded_not_present_in_generic_dict:
return
mod_dict_source = f"{guard.name}.__dict__"
mod_generic_dict_manager = base_manager.get_generic_dict_manager(
source=mod_dict_source,
@ -1670,7 +1649,6 @@ class GuardBuilder(GuardBuilderBase):
mod_generic_dict_manager.add_dict_contains_guard(
False, attr, get_verbose_code_parts(code, guard)
)
self.already_guarded_not_present_in_generic_dict.add((ref, attr))
def TYPE_MATCH(self, guard: Guard) -> None:
# ___check_type_id is same as `id(type(x)) == y`

View File

@ -266,38 +266,6 @@ class GenericAttrSource(ChainedSource):
return f"object.__getattribute__({self.base.name()}, {self.member!r})"
# Represents obj.__dict__ where obj is a type object
@dataclasses.dataclass(frozen=True)
class TypeDictSource(ChainedSource):
def reconstruct(self, codegen: "PyCodegen"):
codegen(self.base)
codegen.extend_output(codegen.create_load_attrs("__dict__"))
def guard_source(self):
return self.base.guard_source()
def name(self):
# type(ob).__dict__ can return a proxy of the dict. But in the C++
# guard accessor, we are use type->tp_dict which is a dict. So,
# forcefully pass a dict object to ensure that the GuardManager
# registers that its working on a dict object.
return f"dict({self.base.name()}.__dict__)"
# Represents obj.__mro__ where object is type object
@dataclasses.dataclass(frozen=True)
class TypeMROSource(ChainedSource):
def reconstruct(self, codegen: "PyCodegen"):
codegen(self.base)
codegen.extend_output(codegen.create_load_attrs("__mro__"))
def guard_source(self):
return self.base.guard_source()
def name(self):
return f"{self.base.name()}.__mro__"
@dataclasses.dataclass(frozen=True)
class LocalCellSource(Source):
"""

View File

@ -42,7 +42,6 @@ from ..source import (
AttrSource,
GenericAttrSource,
GetItemSource,
TypeMROSource,
TypeSource,
WeakRefCallSource,
)
@ -135,7 +134,9 @@ class SuperVariable(VariableTracker):
# Equivalent of something like type(L['self']).__mro__[1].attr_name
if type_to_use_source:
source = AttrSource(
GetItemSource(TypeMROSource(type_to_use_source), index),
GetItemSource(
AttrSource(type_to_use_source, "__mro__"), index
),
name,
)
return resolved_getattr, source
@ -246,7 +247,7 @@ class SuperVariable(VariableTracker):
# different from type(self) with polymorphism.
cls_source = None
if self.objvar.source:
cls_source = TypeSource(self.objvar.source)
cls_source = AttrSource(self.objvar.source, "__class__")
cls_variable = VariableTracker.build(
tx, self.objvar.value_type, cls_source
)

View File

@ -989,7 +989,7 @@ class UnspecializedNNModuleVariable(UserDefinedObjectVariable):
fn = self.value_type.forward
if self.source:
source = self.get_source_by_walking_mro(name)
source = AttrSource(AttrSource(self.source, "__class__"), name)
else:
source = None
@ -1017,7 +1017,7 @@ class UnspecializedNNModuleVariable(UserDefinedObjectVariable):
if name in ["_call_impl", "_wrapped_call_impl"]:
fn = getattr(self.value_type, name)
if self.source:
source = self.get_source_by_walking_mro(name)
source = AttrSource(AttrSource(self.source, "__class__"), name)
else:
source = None
@ -1032,7 +1032,9 @@ class UnspecializedNNModuleVariable(UserDefinedObjectVariable):
method = None
if isinstance(method, staticmethod):
source = AttrSource(self.get_source_by_walking_mro(name), "__func__")
source = AttrSource(
AttrSource(AttrSource(self.source, "__class__"), name), "__func__"
)
return tx.inline_user_function_return(
variables.UserFunctionVariable(method.__func__, source=source),
args,

View File

@ -60,11 +60,8 @@ from ..source import (
AttrSource,
CallFunctionNoArgsSource,
DataclassFieldsSource,
DictGetItemSource,
GetItemSource,
RandomValueSource,
TypeDictSource,
TypeMROSource,
TypeSource,
UnspecializedParamBufferSource,
)
@ -1001,9 +998,11 @@ class UserDefinedObjectVariable(UserDefinedVariable):
# check for methods implemented in C++
if isinstance(method, types.FunctionType):
source = None
if self.source:
source = self.get_source_by_walking_mro(name)
source = (
None
if self.source is None
else AttrSource(AttrSource(self.source, "__class__"), name)
)
# TODO(jansel): add a guard to check for monkey patching?
from ..mutation_guard import unpatched_nn_module_init
@ -1225,40 +1224,12 @@ class UserDefinedObjectVariable(UserDefinedVariable):
for idx, klass in enumerate(type(self.value).__mro__):
if name in klass.__dict__:
if idx != 0:
mro_source = TypeMROSource(self.cls_source)
klass_source = GetItemSource(mro_source, idx)
else:
klass_source = self.cls_source
dict_source = TypeDictSource(klass_source)
out_source = DictGetItemSource(dict_source, name)
for absent_idx in range(1, idx):
# Insert a guard that the name is not present in the mro hierarchy
mro_source = TypeMROSource(self.cls_source)
klass_source = GetItemSource(mro_source, absent_idx)
dict_source = TypeDictSource(klass_source)
install_guard(
dict_source.make_guard(
functools.partial(
GuardBuilder.DICT_CONTAINS, key=name, invert=True
)
)
)
# Insert a guard that the name is not present in the object __dict__
if (
self.source
and hasattr(self.value, "__dict__")
and name not in self.value.__dict__
):
install_guard(
self.source.make_guard(
functools.partial(
GuardBuilder.NOT_PRESENT_IN_GENERIC_DICT, attr=name
)
)
)
return out_source
mro_source = AttrSource(self.cls_source, "__mro__")
klass_source = GetItemSource(mro_source, idx)
dict_source = AttrSource(klass_source, "__dict__")
# TODO(anijain2305) - This is a mapping proxy object. Ideally we
# should use DictGetItemSource here.
return GetItemSource(dict_source, name)
unimplemented_v2(
gb_type="could not find name in object's mro",
@ -1368,17 +1339,10 @@ class UserDefinedObjectVariable(UserDefinedVariable):
if subobj is torch.nn.Module.__init__:
subobj = unpatched_nn_module_init
subobj_from_class = inspect.getattr_static(
self.value.__class__, name, NO_SUCH_SUBOBJ
)
is_accessible_from_type_mro = (
subobj_from_class is subobj and self.cls_source is not None
)
if isinstance(subobj, property):
if self.source:
# Read the class attribute to reach the property
source = self.get_source_by_walking_mro(name)
source = AttrSource(AttrSource(self.source, "__class__"), name)
# Get the getter function
source = AttrSource(source, "fget")
return variables.UserMethodVariable(
@ -1396,11 +1360,6 @@ class UserDefinedObjectVariable(UserDefinedVariable):
# Safe because `staticmethod.__get__` basically won't trigger user
# code and just returns the underlying `__func__`:
# https://github.com/python/cpython/blob/3.11/Objects/funcobject.c#L1088-L1100
if is_accessible_from_type_mro:
# Accessing from __dict__ does not resolve the descriptor, it
# returns a staticmethod object, so access the __func__
# attribute to get to the actual function.
source = AttrSource(self.get_source_by_walking_mro(name), "__func__")
func = subobj.__get__(self.value)
return VariableTracker.build(tx, func, source)
elif isinstance(subobj, classmethod):
@ -1526,15 +1485,10 @@ class UserDefinedObjectVariable(UserDefinedVariable):
source = self._wrap_source(source)
if subobj is not NO_SUCH_SUBOBJ:
if is_wrapper_or_member_descriptor(
subobj
) or torch._C._dynamo.utils.is_instancemethod(subobj):
if is_wrapper_or_member_descriptor(subobj):
options = {"source": source}
return variables.GetAttrVariable(self, name, **options)
if source:
if is_accessible_from_type_mro:
source = self.get_source_by_walking_mro(name)
return variables.LazyVariableTracker.create(subobj, source)
else:
# Check if the subobj is accessible from the class itself. If the class source is known, we can create a

View File

@ -5511,118 +5511,6 @@ class TypeGuardAccessor : public GuardAccessor {
void clone_visitor(TypeGuardAccessor* to) {}
};
/**
* Represent x.__dict__ accessor, where x is type object.
*/
class TypeDictGuardAccessor : public GuardAccessor {
public:
// name = __type_dict_accessor__, a unique string used as attribute name.
TypeDictGuardAccessor(
RootGuardManager* root,
py::str name,
std::string source,
py::handle example_value,
py::handle guard_manager_enum)
: GuardAccessor(
root,
std::move(name),
std::move(source),
example_value,
guard_manager_enum) {}
// NB: Intentional duplication between check_nopybind and
// check_verbose_nopybind.
bool check_nopybind(PyObject* obj, bool matches_dict_tag = false)
override { // borrowed ref
PyObject* x = ((PyTypeObject*)obj)->tp_dict; // borrowed ref
if (x == nullptr) {
return false;
}
return _guard_manager->check_nopybind(x);
}
GuardDebugInfo check_verbose_nopybind(
PyObject* obj) override { // borrowed ref
PyObject* x = ((PyTypeObject*)obj)->tp_dict; // borrowed ref
if (x == nullptr) {
return GuardDebugInfo(false, "null type dict on " + repr(), 0);
}
return _guard_manager->check_verbose_nopybind(x);
}
std::string repr() const override {
return "TypeDictGuardAccessor";
}
public: // cloning functions
TypeDictGuardAccessor(
GuardManager* guard_manager,
TypeDictGuardAccessor* from)
: GuardAccessor(guard_manager, from) {
from->clone_visitor(this);
}
GuardAccessor* clone(
RootGuardManager* cloned_root,
const py::function& clone_filter_fn) override {
return clone_common<TypeDictGuardAccessor>(cloned_root, clone_filter_fn);
}
void clone_visitor(TypeDictGuardAccessor* to) {}
};
/**
* Represent x.__mro__ accessor, where x is type object.
*/
class TypeMROGuardAccessor : public GuardAccessor {
public:
// name = __type_mro_accessor__, a unique string used as attribute name.
TypeMROGuardAccessor(
RootGuardManager* root,
py::str name,
std::string source,
py::handle example_value,
py::handle guard_manager_enum)
: GuardAccessor(
root,
std::move(name),
std::move(source),
example_value,
guard_manager_enum) {}
// NB: Intentional duplication between check_nopybind and
// check_verbose_nopybind.
bool check_nopybind(PyObject* obj, bool matches_dict_tag = false)
override { // borrowed ref
PyObject* x = ((PyTypeObject*)obj)->tp_mro; // borrowed ref
return _guard_manager->check_nopybind(x);
}
GuardDebugInfo check_verbose_nopybind(
PyObject* obj) override { // borrowed ref
PyObject* x = ((PyTypeObject*)obj)->tp_mro; // borrowed ref
return _guard_manager->check_verbose_nopybind(x);
}
std::string repr() const override {
return "TypeMROGuardAccessor";
}
public: // cloning functions
TypeMROGuardAccessor(GuardManager* guard_manager, TypeMROGuardAccessor* from)
: GuardAccessor(guard_manager, from) {
from->clone_visitor(this);
}
GuardAccessor* clone(
RootGuardManager* cloned_root,
const py::function& clone_filter_fn) override {
return clone_common<TypeMROGuardAccessor>(cloned_root, clone_filter_fn);
}
void clone_visitor(TypeMROGuardAccessor* to) {}
};
/**
* Getitem tuple_iterator accessor.
*/
@ -6545,16 +6433,6 @@ PyObject* torch_c_dynamo_guards_init() {
GuardAccessor,
std::unique_ptr<TypeGuardAccessor>>(py_m, "TypeGuardAccessor");
// NOLINTNEXTLINE(bugprone-unused-raii)
py::class_<
TypeDictGuardAccessor,
GuardAccessor,
std::unique_ptr<TypeDictGuardAccessor>>(py_m, "TypeDictGuardAccessor");
// NOLINTNEXTLINE(bugprone-unused-raii)
py::class_<
TypeMROGuardAccessor,
GuardAccessor,
std::unique_ptr<TypeMROGuardAccessor>>(py_m, "TypeMROGuardAccessor");
// NOLINTNEXTLINE(bugprone-unused-raii)
py::class_<
WeakRefCallGuardAccessor,
GuardAccessor,
@ -7035,46 +6913,6 @@ PyObject* torch_c_dynamo_guards_init() {
py::return_value_policy::reference)
// return by reference because GuardManager has the ownership of accessors
// and guard managers
.def(
"type_dict_manager",
[](GuardManager& self,
std::string source,
py::handle example_value,
py::handle guard_manager_enum) -> GuardManager* {
// A unique key is used to save as the accessor key.
py::str unique_key("__type_dict_accessor__");
return self.get_child_manager<TypeDictGuardAccessor>(
std::move(unique_key),
std::move(source),
example_value,
guard_manager_enum);
},
py::arg("source"),
py::arg("example_value"),
py::arg("guard_manager_enum"),
py::return_value_policy::reference)
// return by reference because GuardManager has the ownership of accessors
// and guard managers
.def(
"type_mro_manager",
[](GuardManager& self,
std::string source,
py::handle example_value,
py::handle guard_manager_enum) -> GuardManager* {
// A unique key is used to save as the accessor key.
py::str unique_key("__type_mro_accessor__");
return self.get_child_manager<TypeMROGuardAccessor>(
std::move(unique_key),
std::move(source),
example_value,
guard_manager_enum);
},
py::arg("source"),
py::arg("example_value"),
py::arg("guard_manager_enum"),
py::return_value_policy::reference)
// return by reference because GuardManager has the ownership of accessors
// and guard managers
.def(
"weakref_call_manager",
[](GuardManager& self,