mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-02 14:34:54 +08:00
[dynamo][dicts] Support hasattr on dicts (#134590)
Fixes - https://github.com/pytorch/pytorch/issues/134577 Pull Request resolved: https://github.com/pytorch/pytorch/pull/134590 Approved by: https://github.com/Skylion007 ghstack dependencies: #134610
This commit is contained in:
committed by
PyTorch MergeBot
parent
880e3d18a4
commit
c566f2465f
@ -1683,6 +1683,22 @@ class FunctionTests(torch._dynamo.test_case.TestCase):
|
||||
tmp = {1: "D", 10: "B", 3: "E", 0: "F"}
|
||||
return x + 1, sorted(tmp), sorted(tmp, reverse=True)
|
||||
|
||||
def test_dict_hasattr(self):
|
||||
def fn(x):
|
||||
if hasattr(x, "to"):
|
||||
return x.to("cpu")
|
||||
if hasattr(x, "items"):
|
||||
return torch.cos(x["a"])
|
||||
return x
|
||||
|
||||
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
|
||||
|
||||
x = dict(a=torch.randn(3))
|
||||
self.assertEqual(fn(x), opt_fn(x))
|
||||
|
||||
x = torch.randn(4)
|
||||
self.assertEqual(fn(x), opt_fn(x))
|
||||
|
||||
@make_test
|
||||
def test_list_clear(a, b):
|
||||
tmp = [a + 1, a + 2]
|
||||
|
||||
@ -355,6 +355,15 @@ class ConstDictVariable(VariableTracker):
|
||||
def unpack_var_sequence(self, tx):
|
||||
return [x.vt for x in self.items.keys()]
|
||||
|
||||
def call_hasattr(self, tx, name):
|
||||
# dict not allow setting arbitrary attributes. To check for hasattr, we can just check the __dict__ of the dict.
|
||||
# OrderedDict though requires side effects tracking because it supports arbitrary setattr.
|
||||
if self.user_cls is dict:
|
||||
if name in self.user_cls.__dict__:
|
||||
return ConstantVariable.create(True)
|
||||
return ConstantVariable.create(False)
|
||||
unimplemented(f"hasattr on {self.user_cls} is not supported")
|
||||
|
||||
|
||||
class DefaultDictVariable(ConstDictVariable):
|
||||
def __init__(self, items, user_cls, default_factory=None, **kwargs) -> None:
|
||||
|
||||
Reference in New Issue
Block a user