mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 13:44:15 +08:00
Revert "[RELAND][dynamo][nn-modules] Trace through nn.Module dunder methods for UnspecializedNNModule (#126578)"
This reverts commit b2d602306a9eb19e30328cbaee941c874f8148a9. Reverted https://github.com/pytorch/pytorch/pull/126578 on behalf of https://github.com/clee2000 due to failed internal test D58394084. Author has forward fix but includes external changes so reverting is a bit easier to coordinate ([comment](https://github.com/pytorch/pytorch/pull/126578#issuecomment-2161481839))
This commit is contained in:
@ -1084,14 +1084,12 @@ class TestSingleProc(DynamoDistributedSingleProcTestCase):
|
||||
# far from an exhaustive check of all the expected guards, just check a couple of them.
|
||||
FileCheck().check("""local "L['self']" TYPE_MATCH""").check(
|
||||
"""local "L['self']" ID_MATCH"""
|
||||
).check(f"""{expected_guard_source} "L['self'].net" TYPE_MATCH""").check(
|
||||
f"""{expected_guard_source} "L['self'].net" ID_MATCH"""
|
||||
).check(
|
||||
f"""{expected_guard_source} "L['self']._modules['net']" TYPE_MATCH"""
|
||||
f"""{expected_guard_source} "L['self'].net[0]" TYPE_MATCH"""
|
||||
).check(
|
||||
f"""{expected_guard_source} "L['self']._modules['net']" ID_MATCH"""
|
||||
).check(
|
||||
f"""{expected_guard_source} "L['self']._modules['net']._modules['0']" TYPE_MATCH"""
|
||||
).check(
|
||||
f"""{expected_guard_source} "L['self']._modules['net']._modules['1']" ID_MATCH"""
|
||||
f"""{expected_guard_source} "L['self'].net[0]" ID_MATCH"""
|
||||
).run(
|
||||
GUARDS_FILE.getvalue()
|
||||
)
|
||||
|
@ -5187,10 +5187,10 @@ class GraphModule(torch.nn.Module):
|
||||
actual,
|
||||
"""\
|
||||
class GraphModule(torch.nn.Module):
|
||||
def forward(self, L_self_buffers_tensor_constant0_: "f32[3, 3, 3]"):
|
||||
l_self_buffers_tensor_constant0_ = L_self_buffers_tensor_constant0_
|
||||
def forward(self, L_self_tensor_constant0: "f32[3, 3, 3]"):
|
||||
l_self_tensor_constant0 = L_self_tensor_constant0
|
||||
|
||||
alias_default: "f32[3, 3, 3]" = torch.ops.aten.alias.default(l_self_buffers_tensor_constant0_); l_self_buffers_tensor_constant0_ = None
|
||||
alias_default: "f32[3, 3, 3]" = torch.ops.aten.alias.default(l_self_tensor_constant0); l_self_tensor_constant0 = None
|
||||
|
||||
sin_default: "f32[3, 3, 3]" = torch.ops.aten.sin.default(alias_default)
|
||||
|
||||
@ -5209,16 +5209,16 @@ class GraphModule(torch.nn.Module):
|
||||
actual,
|
||||
"""\
|
||||
class GraphModule(torch.nn.Module):
|
||||
def forward(self, L_self_modules_FX_CONST_FOLDED_ATTRS_parameters_0_: "f32[3, 3, 3]", L_self_modules_FX_CONST_FOLDED_ATTRS_parameters_1_: "f32[3, 3, 3]", L_flat_tangents_1_: "f32[3, 3, 3]"):
|
||||
l_self_modules_fx_const_folded_attrs_parameters_0_ = L_self_modules_FX_CONST_FOLDED_ATTRS_parameters_0_
|
||||
l_self_modules_fx_const_folded_attrs_parameters_1_ = L_self_modules_FX_CONST_FOLDED_ATTRS_parameters_1_
|
||||
def forward(self, getattr_L_self_FX_CONST_FOLDED_ATTRS_0_: "f32[3, 3, 3]", getattr_L_self_FX_CONST_FOLDED_ATTRS_1_: "f32[3, 3, 3]", L_flat_tangents_1_: "f32[3, 3, 3]"):
|
||||
getattr_l_self_fx_const_folded_attrs_0_ = getattr_L_self_FX_CONST_FOLDED_ATTRS_0_
|
||||
getattr_l_self_fx_const_folded_attrs_1_ = getattr_L_self_FX_CONST_FOLDED_ATTRS_1_
|
||||
l_flat_tangents_1_ = L_flat_tangents_1_
|
||||
|
||||
_new_zeros_with_same_feature_meta_default: "f32[3, 3, 3]" = torch.ops.aten._new_zeros_with_same_feature_meta.default(l_flat_tangents_1_, l_self_modules_fx_const_folded_attrs_parameters_0_); l_self_modules_fx_const_folded_attrs_parameters_0_ = None
|
||||
_new_zeros_with_same_feature_meta_default: "f32[3, 3, 3]" = torch.ops.aten._new_zeros_with_same_feature_meta.default(l_flat_tangents_1_, getattr_l_self_fx_const_folded_attrs_0_); getattr_l_self_fx_const_folded_attrs_0_ = None
|
||||
|
||||
copy__default: "f32[3, 3, 3]" = torch.ops.aten.copy_.default(_new_zeros_with_same_feature_meta_default, l_flat_tangents_1_); _new_zeros_with_same_feature_meta_default = l_flat_tangents_1_ = None
|
||||
|
||||
mul_tensor: "f32[3, 3, 3]" = torch.ops.aten.mul.Tensor(copy__default, l_self_modules_fx_const_folded_attrs_parameters_1_); copy__default = l_self_modules_fx_const_folded_attrs_parameters_1_ = None
|
||||
mul_tensor: "f32[3, 3, 3]" = torch.ops.aten.mul.Tensor(copy__default, getattr_l_self_fx_const_folded_attrs_1_); copy__default = getattr_l_self_fx_const_folded_attrs_1_ = None
|
||||
return (mul_tensor,)
|
||||
""",
|
||||
)
|
||||
|
@ -2411,7 +2411,6 @@ aten::mm""",
|
||||
num_matched.append(len(pattern.matched_events()))
|
||||
self.assertEqual(num_matched, [i for i, _ in cases])
|
||||
|
||||
@skipIfTorchDynamo("profiler gets ignored if dynamo activated")
|
||||
def test_profiler_pattern_matcher_json_report(self):
|
||||
x = torch.ones((100, 100))
|
||||
model = nn.Sequential(
|
||||
|
@ -1,7 +1,4 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import threading
|
||||
from contextlib import contextmanager
|
||||
|
||||
import torch
|
||||
|
||||
doc = """
|
||||
@ -40,20 +37,3 @@ def new_parameter_placeholder(size, dtype, device, requires_grad):
|
||||
# Allocating a zero tensor would causes assert failures in autograd.
|
||||
result.untyped_storage().resize_(0)
|
||||
return result
|
||||
|
||||
|
||||
_TLS = threading.local()
|
||||
|
||||
|
||||
@contextmanager
|
||||
def do_not_convert_to_tracable_parameter():
|
||||
old_flag = getattr(_TLS, "convert_tracable_parameter", True)
|
||||
_TLS.convert_tracable_parameter = False
|
||||
try:
|
||||
yield False
|
||||
finally:
|
||||
_TLS.convert_tracable_parameter = old_flag
|
||||
|
||||
|
||||
def can_convert_to_tracable_parameter():
|
||||
return getattr(_TLS, "convert_tracable_parameter", True)
|
||||
|
@ -11,9 +11,6 @@ from . import config
|
||||
from .utils import ExactWeakKeyDictionary, is_lazy_module, nn_module_has_global_hooks
|
||||
|
||||
|
||||
unpatched_nn_module_init = torch.nn.Module.__init__
|
||||
|
||||
|
||||
class MutationTracker:
|
||||
db = ExactWeakKeyDictionary()
|
||||
|
||||
|
@ -347,7 +347,13 @@ class SideEffects:
|
||||
elif isinstance(var.mutable_local, AttributeMutationNew):
|
||||
if isinstance(var, variables.AutogradFunctionContextVariable):
|
||||
unimplemented("AutogradFunctionContextVariable escaped")
|
||||
cg.load_import_from(utils.__name__, "object_new")
|
||||
if "__call_nn_module_init" in self.store_attr_mutations.get(
|
||||
var.mutable_local, {}
|
||||
):
|
||||
assert isinstance(var, variables.UnspecializedNNModuleVariable)
|
||||
cg.load_import_from(utils.__name__, "nn_module_new")
|
||||
else:
|
||||
cg.load_import_from(utils.__name__, "object_new")
|
||||
cg(var.mutable_local.cls_source)
|
||||
cg.extend_output(create_call_function(1, True))
|
||||
cg.add_cache(var)
|
||||
@ -474,25 +480,9 @@ class SideEffects:
|
||||
]
|
||||
)
|
||||
elif self.is_attribute_mutation(var):
|
||||
# Applying mutations involves two steps: 1) Push all
|
||||
# reconstructed objects onto the stack. 2) Call STORE_ATTR to
|
||||
# apply the mutations.
|
||||
#
|
||||
# Dynamo must ensure that mutations are applied in the same
|
||||
# order as in the original program. Therefore, two reverse
|
||||
# operations occur below.
|
||||
#
|
||||
# The first reverse operation concerns `suffixes`. We apply
|
||||
# suffixes in reverse order due to the way Python handles the
|
||||
# stack. In Step 1, we push all reconstructed objects onto the
|
||||
# stack, but the item at the top of the stack refers to the last
|
||||
# attribute in the mutation order. If not fixed, this will apply
|
||||
# the mutations of attributes in the reverse order. To account
|
||||
# for this reversal, we iterate through the mutable attributes
|
||||
# in reverse order.
|
||||
for name, value in reversed(
|
||||
self.store_attr_mutations.get(var.mutable_local, {}).items()
|
||||
):
|
||||
for name, value in self.store_attr_mutations.get(
|
||||
var.mutable_local, {}
|
||||
).items():
|
||||
if isinstance(var, variables.NewGlobalVariable):
|
||||
cg.tx.output.update_co_names(name)
|
||||
cg(value)
|
||||
@ -500,6 +490,8 @@ class SideEffects:
|
||||
suffixes.append(
|
||||
[create_instruction("STORE_GLOBAL", argval=name)]
|
||||
)
|
||||
elif name == "__call_nn_module_init":
|
||||
pass # handled in codegen_save_tempvars
|
||||
elif isinstance(value, variables.DeletedVariable):
|
||||
if isinstance(
|
||||
var.mutable_local, AttributeMutationExisting
|
||||
|
@ -416,15 +416,10 @@ def generic_jump(truth_fn: typing.Callable[[object], bool], push: bool):
|
||||
self.push(value)
|
||||
self.jump(inst)
|
||||
elif isinstance(value, UserDefinedObjectVariable):
|
||||
try:
|
||||
x = value.var_getattr(self, "__bool__")
|
||||
except exc.ObservedException:
|
||||
# if __bool__ is missing, trying __len__ to infer a truth value.
|
||||
x = value.var_getattr(self, "__bool__")
|
||||
# if __bool__ is missing, trying __len__ to infer a truth value.
|
||||
if isinstance(x, GetAttrVariable):
|
||||
x = value.var_getattr(self, "__len__")
|
||||
else:
|
||||
if isinstance(x, GetAttrVariable):
|
||||
# if __bool__ is missing, trying __len__ to infer a truth value.
|
||||
x = value.var_getattr(self, "__len__")
|
||||
|
||||
# __bool__ or __len__ is function
|
||||
if isinstance(x, UserMethodVariable):
|
||||
|
@ -2019,12 +2019,12 @@ def object_has_getattribute(value: Any):
|
||||
return False
|
||||
|
||||
|
||||
def get_custom_getattr(value: Any, ignore_nn_module_getattr: bool = False):
|
||||
def get_custom_getattr(value: Any):
|
||||
try:
|
||||
getattr_fn = inspect.getattr_static(type(value), "__getattr__")
|
||||
except AttributeError:
|
||||
getattr_fn = None
|
||||
if ignore_nn_module_getattr and getattr_fn is torch.nn.Module.__getattr__:
|
||||
if getattr_fn is torch.nn.Module.__getattr__:
|
||||
# ignore this case of getattr
|
||||
getattr_fn = None
|
||||
return getattr_fn
|
||||
|
@ -174,11 +174,7 @@ class ConstDictVariable(VariableTracker):
|
||||
def __contains__(self, vt):
|
||||
assert isinstance(vt, VariableTracker)
|
||||
Hashable = ConstDictVariable._HashableTracker
|
||||
return (
|
||||
is_hashable(vt)
|
||||
and Hashable(vt) in self.items
|
||||
and not isinstance(self.items[Hashable(vt)], variables.DeletedVariable)
|
||||
)
|
||||
return is_hashable(vt) and Hashable(vt) in self.items
|
||||
|
||||
def reconstruct(self, codegen):
|
||||
# instructions to load collections.OrderedDict if necessary
|
||||
|
@ -14,10 +14,8 @@ import torch._numpy as tnp
|
||||
import torch.utils._pytree as pytree
|
||||
from .. import config, variables
|
||||
from ..bytecode_transformation import create_call_function, create_instruction
|
||||
from ..create_parameter_op import do_not_convert_to_tracable_parameter
|
||||
from ..exc import unimplemented
|
||||
from ..guards import GuardBuilder, install_guard
|
||||
from ..mutation_guard import unpatched_nn_module_init
|
||||
from ..source import AttrSource, GetItemSource, ODictGetItemSource, TypeSource
|
||||
from ..utils import (
|
||||
check_unspec_or_constant_args,
|
||||
@ -123,6 +121,7 @@ class SuperVariable(VariableTracker):
|
||||
kwargs: "Dict[str, VariableTracker]",
|
||||
) -> "VariableTracker":
|
||||
inner_fn, source = self._resolved_getattr_and_source(self, name)
|
||||
|
||||
if inner_fn is object.__init__:
|
||||
return LambdaVariable(identity)
|
||||
elif inner_fn is torch.nn.Module.__init__:
|
||||
@ -134,10 +133,12 @@ class SuperVariable(VariableTracker):
|
||||
and isinstance(objvar.mutable_local, AttributeMutationNew)
|
||||
and not (args or kwargs)
|
||||
):
|
||||
with do_not_convert_to_tracable_parameter():
|
||||
return variables.UserFunctionVariable(
|
||||
unpatched_nn_module_init, source=source
|
||||
).call_function(tx, [self.objvar] + args, kwargs)
|
||||
tx.output.side_effects.store_attr(
|
||||
objvar,
|
||||
"__call_nn_module_init",
|
||||
variables.ConstantVariable.create(True),
|
||||
)
|
||||
return variables.ConstantVariable.create(None)
|
||||
else:
|
||||
unimplemented("super() nn.Module.__init__")
|
||||
elif isinstance(inner_fn, types.FunctionType):
|
||||
@ -174,19 +175,6 @@ class SuperVariable(VariableTracker):
|
||||
self.objvar, UserDefinedObjectVariable
|
||||
):
|
||||
return self.objvar.method_setattr_standard(tx, *args, **kwargs)
|
||||
elif inner_fn is object.__delattr__:
|
||||
attr = args[0]
|
||||
try:
|
||||
attr = attr.as_python_constant()
|
||||
except NotImplementedError:
|
||||
unimplemented(f"non-const delattr attr: {attr}")
|
||||
if not tx.output.side_effects.is_attribute_mutation(self.objvar):
|
||||
unimplemented(f"delattr({self.objvar}, {attr}, ...)")
|
||||
|
||||
tx.output.side_effects.store_attr(
|
||||
self.objvar, attr, variables.DeletedVariable()
|
||||
)
|
||||
return variables.ConstantVariable(None)
|
||||
|
||||
unimplemented(f"non-function or method super: {inner_fn}")
|
||||
|
||||
|
@ -215,7 +215,7 @@ class NNModuleVariable(VariableTracker):
|
||||
if object_has_getattribute(base):
|
||||
unimplemented("torch.nn.Module with a custom __getattribute__ defined")
|
||||
|
||||
getattr_fn = get_custom_getattr(base, ignore_nn_module_getattr=True)
|
||||
getattr_fn = get_custom_getattr(base)
|
||||
if getattr_fn is None:
|
||||
return None
|
||||
|
||||
@ -665,6 +665,7 @@ class NNModuleVariable(VariableTracker):
|
||||
if isinstance(args[0], SliceVariable):
|
||||
# Build a TupleVariable of NNModules
|
||||
result = []
|
||||
submods = []
|
||||
|
||||
# Turn the slice into the list of integers
|
||||
keys = list(range(len(module)))[args[0].as_python_constant()]
|
||||
@ -678,8 +679,9 @@ class NNModuleVariable(VariableTracker):
|
||||
source=src,
|
||||
)
|
||||
)
|
||||
submods.append(submod)
|
||||
|
||||
new_module = module[args[0].as_python_constant()]
|
||||
new_module = torch.nn.Sequential(*submods)
|
||||
new_module_variable = tx.output.register_attr_or_module(
|
||||
new_module,
|
||||
f"{self}.__getitem__(slice)",
|
||||
@ -693,10 +695,8 @@ class NNModuleVariable(VariableTracker):
|
||||
|
||||
if isinstance(args[0], SymNodeVariable):
|
||||
key = args[0].evaluate_expr(tx.output)
|
||||
elif args[0].is_python_constant():
|
||||
key = args[0].as_python_constant()
|
||||
else:
|
||||
unimplemented(f"getitem on NNModuleVariable with key {args[0]}")
|
||||
key = args[0].as_python_constant()
|
||||
|
||||
submod = module[key]
|
||||
return tx.output.register_attr_or_module(
|
||||
@ -790,7 +790,7 @@ class UnspecializedNNModuleVariable(UserDefinedObjectVariable):
|
||||
@functools.lru_cache(None)
|
||||
def _nn_module_method_ids():
|
||||
# Allow __setattr__ to fall through to base class handler
|
||||
supported = {torch.nn.Module.__setattr__, torch.nn.Module.__init__}
|
||||
supported = {torch.nn.Module.__setattr__}
|
||||
return {
|
||||
id(x.__code__)
|
||||
for x in torch.nn.Module.__dict__.values()
|
||||
@ -798,6 +798,8 @@ class UnspecializedNNModuleVariable(UserDefinedObjectVariable):
|
||||
}
|
||||
|
||||
def unpack_var_sequence(self, tx):
|
||||
from .builder import VariableBuilder
|
||||
|
||||
try:
|
||||
fn = inspect.getattr_static(self.value_type, "__iter__")
|
||||
except AttributeError as e:
|
||||
@ -808,16 +810,11 @@ class UnspecializedNNModuleVariable(UserDefinedObjectVariable):
|
||||
torch.nn.ParameterList.__iter__,
|
||||
torch.nn.Sequential.__iter__,
|
||||
):
|
||||
# The program can mutate the nn module object but the saved `value`
|
||||
# will not reflect the mutations. So, trace through the `__iter__`
|
||||
# function to reflect any tracked mutations.
|
||||
return tx.inline_user_function_return(
|
||||
variables.UserFunctionVariable(fn),
|
||||
[
|
||||
self,
|
||||
],
|
||||
{},
|
||||
).unpack_var_sequence(tx)
|
||||
assert self.source
|
||||
return [
|
||||
VariableBuilder(tx, source=GetItemSource(self.source, idx))(item)
|
||||
for idx, item in enumerate(self.value)
|
||||
]
|
||||
|
||||
return super().unpack_var_sequence(tx)
|
||||
|
||||
@ -946,17 +943,6 @@ class UnspecializedNNModuleVariable(UserDefinedObjectVariable):
|
||||
# Handle submodules
|
||||
self.is_state_mutated = True
|
||||
|
||||
if method is torch.nn.Module.__setattr__ and isinstance(
|
||||
args[1], variables.DeletedVariable
|
||||
):
|
||||
# Trace through __delattr__ to track mutations on the module
|
||||
# members like `_modules``.
|
||||
return tx.inline_user_function_return(
|
||||
variables.UserFunctionVariable(torch.nn.Module.__delattr__),
|
||||
[self, args[0]],
|
||||
kwargs,
|
||||
)
|
||||
|
||||
return super().call_method(tx, name, args, kwargs)
|
||||
|
||||
|
||||
|
@ -18,11 +18,7 @@ from torch._streambase import _StreamBase
|
||||
from ..._guards import TracingContext
|
||||
from .. import config, polyfill, variables
|
||||
from ..codegen import PyCodegen
|
||||
from ..create_parameter_op import (
|
||||
can_convert_to_tracable_parameter,
|
||||
new_parameter_placeholder,
|
||||
tracable_create_parameter,
|
||||
)
|
||||
from ..create_parameter_op import new_parameter_placeholder, tracable_create_parameter
|
||||
from ..device_interface import get_registered_device_interfaces
|
||||
from ..exc import unimplemented
|
||||
from ..guards import GuardBuilder, install_guard
|
||||
@ -875,9 +871,6 @@ Either create the tensor outside the compiled region, or do not set the tensor t
|
||||
if data.source:
|
||||
return cls._nn_param_via_prefix_insert(tx, data, requires_grad)
|
||||
|
||||
if not can_convert_to_tracable_parameter():
|
||||
unimplemented("Workaround for issues with nn_parameter construction")
|
||||
|
||||
try:
|
||||
shape = tuple(data.var_getattr(tx, "shape").as_python_constant())
|
||||
dtype = data.var_getattr(tx, "dtype").as_python_constant()
|
||||
|
@ -34,8 +34,7 @@ import torch.nn
|
||||
from torch._guards import TracingContext
|
||||
|
||||
from .. import variables
|
||||
from ..create_parameter_op import do_not_convert_to_tracable_parameter
|
||||
from ..exc import ObservedException, unimplemented
|
||||
from ..exc import unimplemented
|
||||
from ..guards import GuardBuilder, install_guard
|
||||
from ..source import AttrSource, GetItemSource, ODictGetItemSource, RandomValueSource
|
||||
from ..utils import (
|
||||
@ -58,7 +57,10 @@ from .dicts import DefaultDictVariable
|
||||
|
||||
|
||||
def is_standard_setattr(val):
|
||||
return val in (object.__setattr__,)
|
||||
return val in (
|
||||
object.__setattr__,
|
||||
torch.nn.Module.__setattr__,
|
||||
)
|
||||
|
||||
|
||||
class UserDefinedVariable(VariableTracker):
|
||||
@ -376,7 +378,17 @@ class UserDefinedClassVariable(UserDefinedVariable):
|
||||
else UserDefinedObjectVariable,
|
||||
{},
|
||||
)
|
||||
with do_not_convert_to_tracable_parameter():
|
||||
if (
|
||||
inspect.getattr_static(self.value, "__init__", None)
|
||||
is torch.nn.Module.__init__
|
||||
):
|
||||
tx.output.side_effects.store_attr(
|
||||
var,
|
||||
"__call_nn_module_init",
|
||||
variables.ConstantVariable.create(True),
|
||||
)
|
||||
return var
|
||||
else:
|
||||
var.call_method(tx, "__init__", args, kwargs)
|
||||
return var
|
||||
elif variables.CustomizedDictVariable.is_matching_cls(self.value):
|
||||
@ -626,10 +638,6 @@ class UserDefinedObjectVariable(UserDefinedVariable):
|
||||
else AttrSource(AttrSource(self.source, "__class__"), name)
|
||||
)
|
||||
# TODO(jansel): add a guard to check for monkey patching?
|
||||
from ..mutation_guard import unpatched_nn_module_init
|
||||
|
||||
if method is torch.nn.Module.__init__:
|
||||
method = unpatched_nn_module_init
|
||||
return UserMethodVariable(method, self, source=source).call_function(
|
||||
tx, args, kwargs
|
||||
)
|
||||
@ -791,7 +799,7 @@ class UserDefinedObjectVariable(UserDefinedVariable):
|
||||
|
||||
def _getattr_static(self, name):
|
||||
if (
|
||||
isinstance(self.value, PyTreeSpec)
|
||||
isinstance(self.value, (torch.nn.Module, PyTreeSpec))
|
||||
or "__slots__" in self.value.__class__.__dict__
|
||||
or type(self.value) == threading.local
|
||||
):
|
||||
@ -804,6 +812,7 @@ class UserDefinedObjectVariable(UserDefinedVariable):
|
||||
return cls_var
|
||||
except AttributeError:
|
||||
pass # __slots__
|
||||
# this might call torch.nn.Module.__getattr__
|
||||
subobj = getattr(self.value, name)
|
||||
else:
|
||||
subobj = inspect.getattr_static(self.value, name)
|
||||
@ -1009,35 +1018,14 @@ class UserDefinedObjectVariable(UserDefinedVariable):
|
||||
install_guard(
|
||||
AttrSource(self.source, name).make_guard(GuardBuilder.HASATTR)
|
||||
)
|
||||
if self._check_for_getattribute():
|
||||
unimplemented("hasattr with custom __getattribute__")
|
||||
if self._check_for_getattribute() or self._check_for_getattr():
|
||||
unimplemented("hasattr with custom __getattr__")
|
||||
|
||||
try:
|
||||
self._getattr_static(name)
|
||||
return variables.ConstantVariable.create(True)
|
||||
except AttributeError:
|
||||
# Now check in __getattr__ function
|
||||
getattr_fn = self._check_for_getattr()
|
||||
if isinstance(getattr_fn, types.FunctionType):
|
||||
# Dynamo is going to trace the __getattr__ function with
|
||||
# args=name. Set the source accordingly.
|
||||
new_source = None
|
||||
if self.source:
|
||||
new_source = AttrSource(self.source, "__getattr__")
|
||||
try:
|
||||
result = variables.UserMethodVariable(
|
||||
getattr_fn, self, source=new_source
|
||||
).call_function(tx, [variables.ConstantVariable.create(name)], {})
|
||||
|
||||
return variables.ConstantVariable.create(
|
||||
not isinstance(result, variables.DeletedVariable)
|
||||
)
|
||||
except ObservedException:
|
||||
return variables.ConstantVariable.create(False)
|
||||
elif getattr_fn is None:
|
||||
return variables.ConstantVariable.create(False)
|
||||
else:
|
||||
unimplemented("UserDefined with non-function __getattr__")
|
||||
return variables.ConstantVariable.create(False)
|
||||
|
||||
def odict_getitem(self, tx, key):
|
||||
from .builder import VariableBuilder
|
||||
@ -1104,12 +1092,6 @@ class KeyedJaggedTensorVariable(UserDefinedObjectVariable):
|
||||
return super().var_getattr(tx, name)
|
||||
|
||||
|
||||
class RemovableHandleClass:
|
||||
# Dummy class to pass to python_type of RemovableHandleVariable
|
||||
# Useful for isinstance check on hooks
|
||||
pass
|
||||
|
||||
|
||||
class RemovableHandleVariable(VariableTracker):
|
||||
REMOVED = -1
|
||||
|
||||
@ -1140,6 +1122,3 @@ class RemovableHandleVariable(VariableTracker):
|
||||
return
|
||||
# unreachable due to codegen.add_cache() when the hook is installed
|
||||
super().reconstruct(codegen)
|
||||
|
||||
def python_type(self):
|
||||
return RemovableHandleClass
|
||||
|
Reference in New Issue
Block a user