[Dynamo] Fix NamedTuple hasattr bug (#124531)

Fixes #124402

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124531
Approved by: https://github.com/jansel
This commit is contained in:
Yanbo Liang
2024-04-21 04:36:22 +00:00
committed by PyTorch MergeBot
parent a6a3f2e06b
commit 0d90d4d613
2 changed files with 17 additions and 2 deletions

View File

@ -1284,6 +1284,22 @@ class FunctionTests(torch._dynamo.test_case.TestCase):
mytuple = FunctionTests.MyNamedTuple(a, b) mytuple = FunctionTests.MyNamedTuple(a, b)
return mytuple.add(), mytuple.static_method(), mytuple.class_method() return mytuple.add(), mytuple.static_method(), mytuple.class_method()
@make_test
def test_namedtuple_hasattr(a, b):
mytuple = FunctionTests.MyNamedTuple(a, b)
def isinstance_namedtuple(obj) -> bool:
return (
isinstance(obj, tuple)
and hasattr(obj, "_asdict")
and hasattr(obj, "_fields")
)
if isinstance_namedtuple(mytuple):
return a + b
else:
return a - b
@make_test @make_test
def test_is_quantized(a, b): def test_is_quantized(a, b):
if not a.is_quantized: if not a.is_quantized:

View File

@ -586,8 +586,7 @@ class NamedTupleVariable(TupleVariable):
return self.items[fields.index(name)] return self.items[fields.index(name)]
def call_hasattr(self, tx, name: str) -> "VariableTracker": def call_hasattr(self, tx, name: str) -> "VariableTracker":
fields = namedtuple_fields(self.tuple_cls) return variables.ConstantVariable.create(hasattr(self.tuple_cls, name))
return variables.ConstantVariable.create(name in fields)
class SliceVariable(BaseListVariable): class SliceVariable(BaseListVariable):