mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[dict] Allow Dynamo to trace through explicit dict dunder method call (#154794)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/154794 Approved by: https://github.com/mlazos ghstack dependencies: #154003, #154793
This commit is contained in:
committed by
PyTorch MergeBot
parent
57d64298a0
commit
ba8d19ec02
@ -1360,6 +1360,18 @@ class DictMethodsTests(torch._dynamo.test_case.TestCase):
|
||||
self.assertRaises(TypeError, d.values, 1)
|
||||
|
||||
|
||||
class DictSubclassMethodsTests(DictMethodsTests):
|
||||
thetype = SimpleDict
|
||||
|
||||
@unittest.expectedFailure
|
||||
def test_cmp_eq(self):
|
||||
return super().test_cmp_eq()
|
||||
|
||||
@unittest.expectedFailure
|
||||
def test_cmp_ne(self):
|
||||
return super().test_cmp_ne()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._dynamo.test_case import run_tests
|
||||
|
||||
|
@ -108,6 +108,7 @@ from .tensor import (
|
||||
UnspecializedPythonVariable,
|
||||
)
|
||||
from .user_defined import (
|
||||
UserDefinedDictVariable,
|
||||
UserDefinedObjectVariable,
|
||||
UserDefinedSetVariable,
|
||||
UserDefinedVariable,
|
||||
@ -2588,7 +2589,13 @@ class BuiltinVariable(VariableTracker):
|
||||
# This call looks like `{"one": torch.ones(1)} | {"two": torch.ones(2)}`.
|
||||
if isinstance(
|
||||
a,
|
||||
(ConstDictVariable, DictKeysVariable, SetVariable, UserDefinedSetVariable),
|
||||
(
|
||||
ConstDictVariable,
|
||||
DictKeysVariable,
|
||||
SetVariable,
|
||||
UserDefinedDictVariable,
|
||||
UserDefinedSetVariable,
|
||||
),
|
||||
):
|
||||
return a.call_method(tx, "__or__", [b], {})
|
||||
|
||||
|
@ -611,7 +611,12 @@ class ConstDictVariable(VariableTracker):
|
||||
# Dicts can only be unioned with other dicts or subclasses of dicts.
|
||||
# Sets can be unioned with other sets, frozensets or subclasses of sets.
|
||||
_raise = not (
|
||||
(istype(self, ConstDictVariable) and istype(args[0], ConstDictVariable))
|
||||
(
|
||||
istype(self, ConstDictVariable)
|
||||
and istype(
|
||||
args[0], (ConstDictVariable, variables.UserDefinedDictVariable)
|
||||
)
|
||||
)
|
||||
or (
|
||||
isinstance(self, SetVariable)
|
||||
and isinstance(
|
||||
|
@ -403,6 +403,9 @@ class UserDefinedClassVariable(UserDefinedVariable):
|
||||
return variables.ConstantVariable(self.value == args[0].value)
|
||||
elif name == "__ne__" and len(args) == 1 and hasattr(args[0], "value"):
|
||||
return variables.ConstantVariable(self.value != args[0].value)
|
||||
elif issubclass(self.value, dict) and name != "__new__":
|
||||
# __new__ is handled below
|
||||
return variables.BuiltinVariable(dict).call_method(tx, name, args, kwargs)
|
||||
elif issubclass(self.value, (set, frozenset)) and name != "__new__":
|
||||
# __new__ is handled below
|
||||
return variables.BuiltinVariable(set).call_method(tx, name, args, kwargs)
|
||||
@ -1693,6 +1696,16 @@ class UserDefinedDictVariable(UserDefinedObjectVariable):
|
||||
def is_underlying_vt_modified(self, side_effects):
|
||||
return side_effects.is_modified(self._dict_vt)
|
||||
|
||||
@property
|
||||
def items(self):
|
||||
return self._dict_vt.items
|
||||
|
||||
def install_dict_keys_match_guard(self):
|
||||
return self._dict_vt.install_dict_keys_match_guard()
|
||||
|
||||
def install_dict_contains_guard(self):
|
||||
return self._dict_vt.install_dict_contains_guard()
|
||||
|
||||
|
||||
class UserDefinedSetVariable(UserDefinedObjectVariable):
|
||||
"""
|
||||
|
Reference in New Issue
Block a user