[dynamo][dicts] Support hasattr on dicts (#134590)

Fixes - https://github.com/pytorch/pytorch/issues/134577

Pull Request resolved: https://github.com/pytorch/pytorch/pull/134590
Approved by: https://github.com/Skylion007
ghstack dependencies: #134610
This commit is contained in:
Animesh Jain
2024-08-27 21:06:31 -07:00
committed by PyTorch MergeBot
parent 880e3d18a4
commit c566f2465f
2 changed files with 25 additions and 0 deletions

View File

@ -1683,6 +1683,22 @@ class FunctionTests(torch._dynamo.test_case.TestCase):
tmp = {1: "D", 10: "B", 3: "E", 0: "F"}
return x + 1, sorted(tmp), sorted(tmp, reverse=True)
def test_dict_hasattr(self):
def fn(x):
if hasattr(x, "to"):
return x.to("cpu")
if hasattr(x, "items"):
return torch.cos(x["a"])
return x
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
x = dict(a=torch.randn(3))
self.assertEqual(fn(x), opt_fn(x))
x = torch.randn(4)
self.assertEqual(fn(x), opt_fn(x))
@make_test
def test_list_clear(a, b):
tmp = [a + 1, a + 2]

View File

@ -355,6 +355,15 @@ class ConstDictVariable(VariableTracker):
def unpack_var_sequence(self, tx):
return [x.vt for x in self.items.keys()]
def call_hasattr(self, tx, name):
# dict not allow setting arbitrary attributes. To check for hasattr, we can just check the __dict__ of the dict.
# OrderedDict though requires side effects tracking because it supports arbitrary setattr.
if self.user_cls is dict:
if name in self.user_cls.__dict__:
return ConstantVariable.create(True)
return ConstantVariable.create(False)
unimplemented(f"hasattr on {self.user_cls} is not supported")
class DefaultDictVariable(ConstDictVariable):
def __init__(self, items, user_cls, default_factory=None, **kwargs) -> None: