diff --git a/test/dynamo/test_functions.py b/test/dynamo/test_functions.py index 48acf9b16fcc..bcd299a9e8e1 100644 --- a/test/dynamo/test_functions.py +++ b/test/dynamo/test_functions.py @@ -1284,6 +1284,22 @@ class FunctionTests(torch._dynamo.test_case.TestCase): mytuple = FunctionTests.MyNamedTuple(a, b) return mytuple.add(), mytuple.static_method(), mytuple.class_method() + @make_test + def test_namedtuple_hasattr(a, b): + mytuple = FunctionTests.MyNamedTuple(a, b) + + def isinstance_namedtuple(obj) -> bool: + return ( + isinstance(obj, tuple) + and hasattr(obj, "_asdict") + and hasattr(obj, "_fields") + ) + + if isinstance_namedtuple(mytuple): + return a + b + else: + return a - b + @make_test def test_is_quantized(a, b): if not a.is_quantized: diff --git a/torch/_dynamo/variables/lists.py b/torch/_dynamo/variables/lists.py index 1f9b83a8a957..d51b4daff347 100644 --- a/torch/_dynamo/variables/lists.py +++ b/torch/_dynamo/variables/lists.py @@ -586,8 +586,7 @@ class NamedTupleVariable(TupleVariable): return self.items[fields.index(name)] def call_hasattr(self, tx, name: str) -> "VariableTracker": - fields = namedtuple_fields(self.tuple_cls) - return variables.ConstantVariable.create(name in fields) + return variables.ConstantVariable.create(hasattr(self.tuple_cls, name)) class SliceVariable(BaseListVariable):