mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Revert "[dynamo] Be consistent with storing func source for UserMethodVariable (#159696)"
This reverts commit ee62177c196d716fc3a2d641370bed8a673a45d3. Reverted https://github.com/pytorch/pytorch/pull/159696 on behalf of https://github.com/anijain2305 due to broke internal tests ([comment](https://github.com/pytorch/pytorch/pull/159696#issuecomment-3161196192))
This commit is contained in:
@ -42,7 +42,6 @@ from .variables.base import ValueMutationExisting, VariableTracker
|
||||
from .variables.functions import (
|
||||
ContextlibContextManagerLocalGeneratorObjectVariable,
|
||||
LocalGeneratorObjectVariable,
|
||||
UserMethodVariable,
|
||||
)
|
||||
from .variables.nn_module import NNModuleVariable
|
||||
from .variables.tensor import (
|
||||
@ -251,10 +250,7 @@ class PyCodegen:
|
||||
value.source is not None
|
||||
and allow_cache
|
||||
and not (
|
||||
value.is_realized()
|
||||
and isinstance(
|
||||
value, (LocalGeneratorObjectVariable, UserMethodVariable)
|
||||
)
|
||||
value.is_realized() and isinstance(value, LocalGeneratorObjectVariable)
|
||||
)
|
||||
):
|
||||
# There's a corner case for export: for instance, if the computation
|
||||
|
@ -1122,26 +1122,13 @@ class UserMethodVariable(UserFunctionVariable):
|
||||
return super().inspect_parameter_names()[1:]
|
||||
|
||||
def var_getattr(self, tx: "InstructionTranslator", name: str):
|
||||
if name == "__func__":
|
||||
# self.source points to the source of the function object and not
|
||||
# the method object
|
||||
return VariableTracker.build(tx, self.fn, self.source)
|
||||
source = self.source and AttrSource(self.source, name)
|
||||
if name == "__self__":
|
||||
return self.obj
|
||||
if name == "__func__":
|
||||
return VariableTracker.build(tx, self.fn, source)
|
||||
return super().var_getattr(tx, name)
|
||||
|
||||
def reconstruct(self, codegen):
|
||||
if not self.obj.source or not self.source:
|
||||
raise NotImplementedError
|
||||
|
||||
def get_bound_method():
|
||||
codegen(self.source)
|
||||
codegen.extend_output(codegen.create_load_attrs("__get__"))
|
||||
|
||||
codegen.add_push_null(get_bound_method)
|
||||
codegen(self.obj.source)
|
||||
codegen.extend_output(create_call_function(1, False))
|
||||
|
||||
|
||||
class WrappedUserMethodVariable(UserMethodVariable):
|
||||
def __init__(self, wrapped, context, **kwargs) -> None:
|
||||
|
@ -1380,9 +1380,7 @@ class UserDefinedObjectVariable(UserDefinedVariable):
|
||||
self.value.__class__, name, NO_SUCH_SUBOBJ
|
||||
)
|
||||
is_accessible_from_type_mro = (
|
||||
subobj_from_class is subobj
|
||||
and self.cls_source is not None
|
||||
and self.source is not None
|
||||
subobj_from_class is subobj and self.cls_source is not None
|
||||
)
|
||||
|
||||
if isinstance(subobj, property):
|
||||
@ -1414,11 +1412,6 @@ class UserDefinedObjectVariable(UserDefinedVariable):
|
||||
func = subobj.__get__(self.value)
|
||||
return VariableTracker.build(tx, func, source)
|
||||
elif isinstance(subobj, classmethod):
|
||||
if is_accessible_from_type_mro:
|
||||
# Accessing from __dict__ does not resolve the descriptor, it
|
||||
# returns a classmethod object, so access the __func__
|
||||
# attribute to get to the actual function.
|
||||
source = AttrSource(self.get_source_by_walking_mro(name), "__func__")
|
||||
return variables.UserMethodVariable(
|
||||
subobj.__func__, self.var_getattr(tx, "__class__"), source=source
|
||||
)
|
||||
@ -1468,9 +1461,6 @@ class UserDefinedObjectVariable(UserDefinedVariable):
|
||||
isinstance(subobj, types.MethodType)
|
||||
and isinstance(self.value, torch.nn.Module)
|
||||
):
|
||||
if is_accessible_from_type_mro:
|
||||
source = self.get_source_by_walking_mro(name)
|
||||
|
||||
# Since we get subobj via self._getattr_static, which may not trigger dynamic lookup.
|
||||
# Static lookup can't tell us it's a method or function correctly,
|
||||
# so we trigger dynamic lookup here to get the correct type.
|
||||
|
Reference in New Issue
Block a user