mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[Dynamo] Fix NamedTuple hasattr bug (#124531)
Fixes #124402 Pull Request resolved: https://github.com/pytorch/pytorch/pull/124531 Approved by: https://github.com/jansel
This commit is contained in:
committed by
PyTorch MergeBot
parent
a6a3f2e06b
commit
0d90d4d613
@ -1284,6 +1284,22 @@ class FunctionTests(torch._dynamo.test_case.TestCase):
|
||||
mytuple = FunctionTests.MyNamedTuple(a, b)
|
||||
return mytuple.add(), mytuple.static_method(), mytuple.class_method()
|
||||
|
||||
@make_test
|
||||
def test_namedtuple_hasattr(a, b):
|
||||
mytuple = FunctionTests.MyNamedTuple(a, b)
|
||||
|
||||
def isinstance_namedtuple(obj) -> bool:
|
||||
return (
|
||||
isinstance(obj, tuple)
|
||||
and hasattr(obj, "_asdict")
|
||||
and hasattr(obj, "_fields")
|
||||
)
|
||||
|
||||
if isinstance_namedtuple(mytuple):
|
||||
return a + b
|
||||
else:
|
||||
return a - b
|
||||
|
||||
@make_test
|
||||
def test_is_quantized(a, b):
|
||||
if not a.is_quantized:
|
||||
|
@ -586,8 +586,7 @@ class NamedTupleVariable(TupleVariable):
|
||||
return self.items[fields.index(name)]
|
||||
|
||||
def call_hasattr(self, tx, name: str) -> "VariableTracker":
|
||||
fields = namedtuple_fields(self.tuple_cls)
|
||||
return variables.ConstantVariable.create(name in fields)
|
||||
return variables.ConstantVariable.create(hasattr(self.tuple_cls, name))
|
||||
|
||||
|
||||
class SliceVariable(BaseListVariable):
|
||||
|
Reference in New Issue
Block a user