mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[dynamo][guards] Make class members go through obj.__class__.__dict__ (#159534)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/159534 Approved by: https://github.com/jansel
This commit is contained in:
committed by
PyTorch MergeBot
parent
4516c59f5f
commit
64cbaa876c
@ -880,8 +880,9 @@ num_guards_executed=0)
|
||||
counter += 1
|
||||
|
||||
class Bar:
|
||||
x = 4
|
||||
y = torch.randn(4)
|
||||
def __init__(self):
|
||||
self.x = 4
|
||||
self.y = torch.randn(4)
|
||||
|
||||
bar = Bar()
|
||||
|
||||
|
@ -54,8 +54,9 @@ class RunDiffGuardTests(torch._dynamo.test_case.TestCase):
|
||||
|
||||
def test_post_recompile(self):
|
||||
class Foo:
|
||||
a = 4
|
||||
b = 5
|
||||
def __init__(self):
|
||||
self.a = 4
|
||||
self.b = 5
|
||||
|
||||
foo = Foo()
|
||||
|
||||
|
@ -139,6 +139,8 @@ class DictGuardManager(GuardManager):
|
||||
class GuardAccessor: ...
|
||||
class DictGetItemGuardAccessor(GuardAccessor): ...
|
||||
class GetGenericDictGuardAccessor(GuardAccessor): ...
|
||||
class TypeDictGuardAccessor(GuardAccessor): ...
|
||||
class TypeMROGuardAccessor(GuardAccessor): ...
|
||||
|
||||
def install_object_aliasing_guard(
|
||||
guard_managers: list[GuardManager],
|
||||
|
@ -134,6 +134,8 @@ from .source import (
|
||||
TorchFunctionModeStackSource,
|
||||
TorchSource,
|
||||
TupleIteratorGetItemSource,
|
||||
TypeDictSource,
|
||||
TypeMROSource,
|
||||
TypeSource,
|
||||
UnspecializedBuiltinNNModuleSource,
|
||||
UnspecializedNNModuleSource,
|
||||
@ -864,6 +866,9 @@ class GuardBuilder(GuardBuilderBase):
|
||||
self.guard_nn_modules = config.guard_nn_modules and justknobs_check(
|
||||
"pytorch/compiler:guard_nn_modules"
|
||||
)
|
||||
self.already_guarded_not_present_in_generic_dict: OrderedSet[
|
||||
tuple[str, str]
|
||||
] = OrderedSet()
|
||||
|
||||
def guard_on_dict_keys_and_ignore_order(self, example_value, guard):
|
||||
dict_mgr = self.get_guard_manager(guard)
|
||||
@ -1211,6 +1216,20 @@ class GuardBuilder(GuardBuilderBase):
|
||||
example_value=example_value,
|
||||
guard_manager_enum=guard_manager_enum,
|
||||
)
|
||||
elif istype(source, TypeDictSource):
|
||||
assert base_guard_manager # to make mypy happy
|
||||
out = base_guard_manager.type_dict_manager(
|
||||
source=source_name,
|
||||
example_value=example_value,
|
||||
guard_manager_enum=guard_manager_enum,
|
||||
)
|
||||
elif istype(source, TypeMROSource):
|
||||
assert base_guard_manager # to make mypy happy
|
||||
out = base_guard_manager.type_mro_manager(
|
||||
source=source_name,
|
||||
example_value=example_value,
|
||||
guard_manager_enum=guard_manager_enum,
|
||||
)
|
||||
elif istype(
|
||||
source,
|
||||
(
|
||||
@ -1656,10 +1675,12 @@ class GuardBuilder(GuardBuilderBase):
|
||||
assert attr is not None
|
||||
ref = self.arg_ref(guard)
|
||||
val = self.get(guard.name)
|
||||
assert isinstance(val, torch.nn.Module)
|
||||
|
||||
base_manager = self.get_guard_manager(guard)
|
||||
|
||||
if (ref, attr) in self.already_guarded_not_present_in_generic_dict:
|
||||
return
|
||||
|
||||
mod_dict_source = f"{guard.name}.__dict__"
|
||||
mod_generic_dict_manager = base_manager.get_generic_dict_manager(
|
||||
source=mod_dict_source,
|
||||
@ -1671,6 +1692,7 @@ class GuardBuilder(GuardBuilderBase):
|
||||
mod_generic_dict_manager.add_dict_contains_guard(
|
||||
False, attr, get_verbose_code_parts(code, guard)
|
||||
)
|
||||
self.already_guarded_not_present_in_generic_dict.add((ref, attr))
|
||||
|
||||
def TYPE_MATCH(self, guard: Guard) -> None:
|
||||
# ___check_type_id is same as `id(type(x)) == y`
|
||||
|
@ -266,6 +266,38 @@ class GenericAttrSource(ChainedSource):
|
||||
return f"object.__getattribute__({self.base.name()}, {self.member!r})"
|
||||
|
||||
|
||||
# Represents obj.__dict__ where obj is a type object
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class TypeDictSource(ChainedSource):
|
||||
def reconstruct(self, codegen: "PyCodegen") -> None:
|
||||
codegen(self.base)
|
||||
codegen.extend_output(codegen.create_load_attrs("__dict__"))
|
||||
|
||||
def guard_source(self) -> GuardSource:
|
||||
return self.base.guard_source()
|
||||
|
||||
def name(self) -> str:
|
||||
# type(ob).__dict__ can return a proxy of the dict. But in the C++
|
||||
# guard accessor, we are use type->tp_dict which is a dict. So,
|
||||
# forcefully pass a dict object to ensure that the GuardManager
|
||||
# registers that its working on a dict object.
|
||||
return f"dict({self.base.name()}.__dict__)"
|
||||
|
||||
|
||||
# Represents obj.__mro__ where object is type object
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class TypeMROSource(ChainedSource):
|
||||
def reconstruct(self, codegen: "PyCodegen") -> None:
|
||||
codegen(self.base)
|
||||
codegen.extend_output(codegen.create_load_attrs("__mro__"))
|
||||
|
||||
def guard_source(self) -> GuardSource:
|
||||
return self.base.guard_source()
|
||||
|
||||
def name(self) -> str:
|
||||
return f"{self.base.name()}.__mro__"
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class LocalCellSource(Source):
|
||||
"""
|
||||
|
@ -42,6 +42,7 @@ from ..source import (
|
||||
AttrSource,
|
||||
GenericAttrSource,
|
||||
GetItemSource,
|
||||
TypeMROSource,
|
||||
TypeSource,
|
||||
WeakRefCallSource,
|
||||
)
|
||||
@ -134,9 +135,7 @@ class SuperVariable(VariableTracker):
|
||||
# Equivalent of something like type(L['self']).__mro__[1].attr_name
|
||||
if type_to_use_source:
|
||||
source = AttrSource(
|
||||
GetItemSource(
|
||||
AttrSource(type_to_use_source, "__mro__"), index
|
||||
),
|
||||
GetItemSource(TypeMROSource(type_to_use_source), index),
|
||||
name,
|
||||
)
|
||||
return resolved_getattr, source
|
||||
@ -247,7 +246,7 @@ class SuperVariable(VariableTracker):
|
||||
# different from type(self) with polymorphism.
|
||||
cls_source = None
|
||||
if self.objvar.source:
|
||||
cls_source = AttrSource(self.objvar.source, "__class__")
|
||||
cls_source = TypeSource(self.objvar.source)
|
||||
cls_variable = VariableTracker.build(
|
||||
tx, self.objvar.value_type, cls_source
|
||||
)
|
||||
|
@ -989,7 +989,7 @@ class UnspecializedNNModuleVariable(UserDefinedObjectVariable):
|
||||
fn = self.value_type.forward
|
||||
|
||||
if self.source:
|
||||
source = AttrSource(AttrSource(self.source, "__class__"), name)
|
||||
source = self.get_source_by_walking_mro(name)
|
||||
else:
|
||||
source = None
|
||||
|
||||
@ -1017,7 +1017,7 @@ class UnspecializedNNModuleVariable(UserDefinedObjectVariable):
|
||||
if name in ["_call_impl", "_wrapped_call_impl"]:
|
||||
fn = getattr(self.value_type, name)
|
||||
if self.source:
|
||||
source = AttrSource(AttrSource(self.source, "__class__"), name)
|
||||
source = self.get_source_by_walking_mro(name)
|
||||
else:
|
||||
source = None
|
||||
|
||||
@ -1032,9 +1032,7 @@ class UnspecializedNNModuleVariable(UserDefinedObjectVariable):
|
||||
method = None
|
||||
|
||||
if isinstance(method, staticmethod):
|
||||
source = AttrSource(
|
||||
AttrSource(AttrSource(self.source, "__class__"), name), "__func__"
|
||||
)
|
||||
source = AttrSource(self.get_source_by_walking_mro(name), "__func__")
|
||||
return tx.inline_user_function_return(
|
||||
variables.UserFunctionVariable(method.__func__, source=source),
|
||||
args,
|
||||
|
@ -60,8 +60,11 @@ from ..source import (
|
||||
AttrSource,
|
||||
CallFunctionNoArgsSource,
|
||||
DataclassFieldsSource,
|
||||
DictGetItemSource,
|
||||
GetItemSource,
|
||||
RandomValueSource,
|
||||
TypeDictSource,
|
||||
TypeMROSource,
|
||||
TypeSource,
|
||||
UnspecializedParamBufferSource,
|
||||
)
|
||||
@ -135,6 +138,14 @@ def is_forbidden_context_manager(ctx):
|
||||
return ctx in f_ctxs
|
||||
|
||||
|
||||
def is_cython_function(obj):
|
||||
return (
|
||||
callable(obj)
|
||||
and hasattr(type(obj), "__name__")
|
||||
and type(obj).__name__ == "cython_function_or_method"
|
||||
)
|
||||
|
||||
|
||||
class UserDefinedVariable(VariableTracker):
|
||||
value: object
|
||||
|
||||
@ -998,11 +1009,9 @@ class UserDefinedObjectVariable(UserDefinedVariable):
|
||||
|
||||
# check for methods implemented in C++
|
||||
if isinstance(method, types.FunctionType):
|
||||
source = (
|
||||
None
|
||||
if self.source is None
|
||||
else AttrSource(AttrSource(self.source, "__class__"), name)
|
||||
)
|
||||
source = None
|
||||
if self.source:
|
||||
source = self.get_source_by_walking_mro(name)
|
||||
# TODO(jansel): add a guard to check for monkey patching?
|
||||
from ..mutation_guard import unpatched_nn_module_init
|
||||
|
||||
@ -1224,12 +1233,40 @@ class UserDefinedObjectVariable(UserDefinedVariable):
|
||||
|
||||
for idx, klass in enumerate(type(self.value).__mro__):
|
||||
if name in klass.__dict__:
|
||||
mro_source = AttrSource(self.cls_source, "__mro__")
|
||||
klass_source = GetItemSource(mro_source, idx)
|
||||
dict_source = AttrSource(klass_source, "__dict__")
|
||||
# TODO(anijain2305) - This is a mapping proxy object. Ideally we
|
||||
# should use DictGetItemSource here.
|
||||
return GetItemSource(dict_source, name)
|
||||
if idx != 0:
|
||||
mro_source = TypeMROSource(self.cls_source)
|
||||
klass_source = GetItemSource(mro_source, idx)
|
||||
else:
|
||||
klass_source = self.cls_source
|
||||
dict_source = TypeDictSource(klass_source)
|
||||
out_source = DictGetItemSource(dict_source, name)
|
||||
|
||||
for absent_idx in range(1, idx):
|
||||
# Insert a guard that the name is not present in the mro hierarchy
|
||||
mro_source = TypeMROSource(self.cls_source)
|
||||
klass_source = GetItemSource(mro_source, absent_idx)
|
||||
dict_source = TypeDictSource(klass_source)
|
||||
install_guard(
|
||||
dict_source.make_guard(
|
||||
functools.partial(
|
||||
GuardBuilder.DICT_CONTAINS, key=name, invert=True
|
||||
)
|
||||
)
|
||||
)
|
||||
# Insert a guard that the name is not present in the object __dict__
|
||||
if (
|
||||
self.source
|
||||
and hasattr(self.value, "__dict__")
|
||||
and name not in self.value.__dict__
|
||||
):
|
||||
install_guard(
|
||||
self.source.make_guard(
|
||||
functools.partial(
|
||||
GuardBuilder.NOT_PRESENT_IN_GENERIC_DICT, attr=name
|
||||
)
|
||||
)
|
||||
)
|
||||
return out_source
|
||||
|
||||
unimplemented_v2(
|
||||
gb_type="could not find name in object's mro",
|
||||
@ -1339,10 +1376,17 @@ class UserDefinedObjectVariable(UserDefinedVariable):
|
||||
if subobj is torch.nn.Module.__init__:
|
||||
subobj = unpatched_nn_module_init
|
||||
|
||||
subobj_from_class = inspect.getattr_static(
|
||||
self.value.__class__, name, NO_SUCH_SUBOBJ
|
||||
)
|
||||
is_accessible_from_type_mro = (
|
||||
subobj_from_class is subobj and self.cls_source is not None
|
||||
)
|
||||
|
||||
if isinstance(subobj, property):
|
||||
if self.source:
|
||||
# Read the class attribute to reach the property
|
||||
source = AttrSource(AttrSource(self.source, "__class__"), name)
|
||||
source = self.get_source_by_walking_mro(name)
|
||||
# Get the getter function
|
||||
source = AttrSource(source, "fget")
|
||||
return variables.UserMethodVariable(
|
||||
@ -1360,6 +1404,11 @@ class UserDefinedObjectVariable(UserDefinedVariable):
|
||||
# Safe because `staticmethod.__get__` basically won't trigger user
|
||||
# code and just returns the underlying `__func__`:
|
||||
# https://github.com/python/cpython/blob/3.11/Objects/funcobject.c#L1088-L1100
|
||||
if is_accessible_from_type_mro:
|
||||
# Accessing from __dict__ does not resolve the descriptor, it
|
||||
# returns a staticmethod object, so access the __func__
|
||||
# attribute to get to the actual function.
|
||||
source = AttrSource(self.get_source_by_walking_mro(name), "__func__")
|
||||
func = subobj.__get__(self.value)
|
||||
return VariableTracker.build(tx, func, source)
|
||||
elif isinstance(subobj, classmethod):
|
||||
@ -1485,10 +1534,17 @@ class UserDefinedObjectVariable(UserDefinedVariable):
|
||||
source = self._wrap_source(source)
|
||||
|
||||
if subobj is not NO_SUCH_SUBOBJ:
|
||||
if is_wrapper_or_member_descriptor(subobj):
|
||||
if (
|
||||
is_wrapper_or_member_descriptor(subobj)
|
||||
or torch._C._dynamo.utils.is_instancemethod(subobj)
|
||||
or is_cython_function(subobj)
|
||||
):
|
||||
options = {"source": source}
|
||||
return variables.GetAttrVariable(self, name, **options)
|
||||
if source:
|
||||
if is_accessible_from_type_mro:
|
||||
source = self.get_source_by_walking_mro(name)
|
||||
|
||||
return variables.LazyVariableTracker.create(subobj, source)
|
||||
else:
|
||||
# Check if the subobj is accessible from the class itself. If the class source is known, we can create a
|
||||
|
@ -5511,6 +5511,118 @@ class TypeGuardAccessor : public GuardAccessor {
|
||||
void clone_visitor(TypeGuardAccessor* to) {}
|
||||
};
|
||||
|
||||
/**
|
||||
* Represent x.__dict__ accessor, where x is type object.
|
||||
*/
|
||||
class TypeDictGuardAccessor : public GuardAccessor {
|
||||
public:
|
||||
// name = __type_dict_accessor__, a unique string used as attribute name.
|
||||
TypeDictGuardAccessor(
|
||||
RootGuardManager* root,
|
||||
py::str name,
|
||||
std::string source,
|
||||
py::handle example_value,
|
||||
py::handle guard_manager_enum)
|
||||
: GuardAccessor(
|
||||
root,
|
||||
std::move(name),
|
||||
std::move(source),
|
||||
example_value,
|
||||
guard_manager_enum) {}
|
||||
|
||||
// NB: Intentional duplication between check_nopybind and
|
||||
// check_verbose_nopybind.
|
||||
bool check_nopybind(PyObject* obj, bool matches_dict_tag = false)
|
||||
override { // borrowed ref
|
||||
PyObject* x = ((PyTypeObject*)obj)->tp_dict; // borrowed ref
|
||||
if (x == nullptr) {
|
||||
return false;
|
||||
}
|
||||
return _guard_manager->check_nopybind(x);
|
||||
}
|
||||
|
||||
GuardDebugInfo check_verbose_nopybind(
|
||||
PyObject* obj) override { // borrowed ref
|
||||
PyObject* x = ((PyTypeObject*)obj)->tp_dict; // borrowed ref
|
||||
if (x == nullptr) {
|
||||
return GuardDebugInfo(false, "null type dict on " + repr(), 0);
|
||||
}
|
||||
return _guard_manager->check_verbose_nopybind(x);
|
||||
}
|
||||
|
||||
std::string repr() const override {
|
||||
return "TypeDictGuardAccessor";
|
||||
}
|
||||
|
||||
public: // cloning functions
|
||||
TypeDictGuardAccessor(
|
||||
GuardManager* guard_manager,
|
||||
TypeDictGuardAccessor* from)
|
||||
: GuardAccessor(guard_manager, from) {
|
||||
from->clone_visitor(this);
|
||||
}
|
||||
|
||||
GuardAccessor* clone(
|
||||
RootGuardManager* cloned_root,
|
||||
const py::function& clone_filter_fn) override {
|
||||
return clone_common<TypeDictGuardAccessor>(cloned_root, clone_filter_fn);
|
||||
}
|
||||
|
||||
void clone_visitor(TypeDictGuardAccessor* to) {}
|
||||
};
|
||||
|
||||
/**
|
||||
* Represent x.__mro__ accessor, where x is type object.
|
||||
*/
|
||||
class TypeMROGuardAccessor : public GuardAccessor {
|
||||
public:
|
||||
// name = __type_mro_accessor__, a unique string used as attribute name.
|
||||
TypeMROGuardAccessor(
|
||||
RootGuardManager* root,
|
||||
py::str name,
|
||||
std::string source,
|
||||
py::handle example_value,
|
||||
py::handle guard_manager_enum)
|
||||
: GuardAccessor(
|
||||
root,
|
||||
std::move(name),
|
||||
std::move(source),
|
||||
example_value,
|
||||
guard_manager_enum) {}
|
||||
|
||||
// NB: Intentional duplication between check_nopybind and
|
||||
// check_verbose_nopybind.
|
||||
bool check_nopybind(PyObject* obj, bool matches_dict_tag = false)
|
||||
override { // borrowed ref
|
||||
PyObject* x = ((PyTypeObject*)obj)->tp_mro; // borrowed ref
|
||||
return _guard_manager->check_nopybind(x);
|
||||
}
|
||||
|
||||
GuardDebugInfo check_verbose_nopybind(
|
||||
PyObject* obj) override { // borrowed ref
|
||||
PyObject* x = ((PyTypeObject*)obj)->tp_mro; // borrowed ref
|
||||
return _guard_manager->check_verbose_nopybind(x);
|
||||
}
|
||||
|
||||
std::string repr() const override {
|
||||
return "TypeMROGuardAccessor";
|
||||
}
|
||||
|
||||
public: // cloning functions
|
||||
TypeMROGuardAccessor(GuardManager* guard_manager, TypeMROGuardAccessor* from)
|
||||
: GuardAccessor(guard_manager, from) {
|
||||
from->clone_visitor(this);
|
||||
}
|
||||
|
||||
GuardAccessor* clone(
|
||||
RootGuardManager* cloned_root,
|
||||
const py::function& clone_filter_fn) override {
|
||||
return clone_common<TypeMROGuardAccessor>(cloned_root, clone_filter_fn);
|
||||
}
|
||||
|
||||
void clone_visitor(TypeMROGuardAccessor* to) {}
|
||||
};
|
||||
|
||||
/**
|
||||
* Getitem tuple_iterator accessor.
|
||||
*/
|
||||
@ -6585,6 +6697,16 @@ PyObject* torch_c_dynamo_guards_init() {
|
||||
GuardAccessor,
|
||||
std::unique_ptr<TypeGuardAccessor>>(py_m, "TypeGuardAccessor");
|
||||
// NOLINTNEXTLINE(bugprone-unused-raii)
|
||||
py::class_<
|
||||
TypeDictGuardAccessor,
|
||||
GuardAccessor,
|
||||
std::unique_ptr<TypeDictGuardAccessor>>(py_m, "TypeDictGuardAccessor");
|
||||
// NOLINTNEXTLINE(bugprone-unused-raii)
|
||||
py::class_<
|
||||
TypeMROGuardAccessor,
|
||||
GuardAccessor,
|
||||
std::unique_ptr<TypeMROGuardAccessor>>(py_m, "TypeMROGuardAccessor");
|
||||
// NOLINTNEXTLINE(bugprone-unused-raii)
|
||||
py::class_<
|
||||
WeakRefCallGuardAccessor,
|
||||
GuardAccessor,
|
||||
@ -7075,6 +7197,46 @@ PyObject* torch_c_dynamo_guards_init() {
|
||||
py::return_value_policy::reference)
|
||||
// return by reference because GuardManager has the ownership of accessors
|
||||
// and guard managers
|
||||
.def(
|
||||
"type_dict_manager",
|
||||
[](GuardManager& self,
|
||||
std::string source,
|
||||
py::handle example_value,
|
||||
py::handle guard_manager_enum) -> GuardManager* {
|
||||
// A unique key is used to save as the accessor key.
|
||||
py::str unique_key("__type_dict_accessor__");
|
||||
return self.get_child_manager<TypeDictGuardAccessor>(
|
||||
std::move(unique_key),
|
||||
std::move(source),
|
||||
example_value,
|
||||
guard_manager_enum);
|
||||
},
|
||||
py::arg("source"),
|
||||
py::arg("example_value"),
|
||||
py::arg("guard_manager_enum"),
|
||||
py::return_value_policy::reference)
|
||||
// return by reference because GuardManager has the ownership of accessors
|
||||
// and guard managers
|
||||
.def(
|
||||
"type_mro_manager",
|
||||
[](GuardManager& self,
|
||||
std::string source,
|
||||
py::handle example_value,
|
||||
py::handle guard_manager_enum) -> GuardManager* {
|
||||
// A unique key is used to save as the accessor key.
|
||||
py::str unique_key("__type_mro_accessor__");
|
||||
return self.get_child_manager<TypeMROGuardAccessor>(
|
||||
std::move(unique_key),
|
||||
std::move(source),
|
||||
example_value,
|
||||
guard_manager_enum);
|
||||
},
|
||||
py::arg("source"),
|
||||
py::arg("example_value"),
|
||||
py::arg("guard_manager_enum"),
|
||||
py::return_value_policy::reference)
|
||||
// return by reference because GuardManager has the ownership of accessors
|
||||
// and guard managers
|
||||
.def(
|
||||
"weakref_call_manager",
|
||||
[](GuardManager& self,
|
||||
|
Reference in New Issue
Block a user