[dict] Allow Dynamo to trace through explicit dict dunder method call (#154794)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/154794
Approved by: https://github.com/mlazos
ghstack dependencies: #154003, #154793
This commit is contained in:
Guilherme Leobas
2025-07-08 15:27:55 -03:00
committed by PyTorch MergeBot
parent 57d64298a0
commit ba8d19ec02
4 changed files with 39 additions and 2 deletions

View File

@ -1360,6 +1360,18 @@ class DictMethodsTests(torch._dynamo.test_case.TestCase):
self.assertRaises(TypeError, d.values, 1)
class DictSubclassMethodsTests(DictMethodsTests):
thetype = SimpleDict
@unittest.expectedFailure
def test_cmp_eq(self):
return super().test_cmp_eq()
@unittest.expectedFailure
def test_cmp_ne(self):
return super().test_cmp_ne()
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests

View File

@ -108,6 +108,7 @@ from .tensor import (
UnspecializedPythonVariable,
)
from .user_defined import (
UserDefinedDictVariable,
UserDefinedObjectVariable,
UserDefinedSetVariable,
UserDefinedVariable,
@ -2588,7 +2589,13 @@ class BuiltinVariable(VariableTracker):
# This call looks like `{"one": torch.ones(1)} | {"two": torch.ones(2)}`.
if isinstance(
a,
(ConstDictVariable, DictKeysVariable, SetVariable, UserDefinedSetVariable),
(
ConstDictVariable,
DictKeysVariable,
SetVariable,
UserDefinedDictVariable,
UserDefinedSetVariable,
),
):
return a.call_method(tx, "__or__", [b], {})

View File

@ -611,7 +611,12 @@ class ConstDictVariable(VariableTracker):
# Dicts can only be unioned with other dicts or subclasses of dicts.
# Sets can be unioned with other sets, frozensets or subclasses of sets.
_raise = not (
(istype(self, ConstDictVariable) and istype(args[0], ConstDictVariable))
(
istype(self, ConstDictVariable)
and istype(
args[0], (ConstDictVariable, variables.UserDefinedDictVariable)
)
)
or (
isinstance(self, SetVariable)
and isinstance(

View File

@ -403,6 +403,9 @@ class UserDefinedClassVariable(UserDefinedVariable):
return variables.ConstantVariable(self.value == args[0].value)
elif name == "__ne__" and len(args) == 1 and hasattr(args[0], "value"):
return variables.ConstantVariable(self.value != args[0].value)
elif issubclass(self.value, dict) and name != "__new__":
# __new__ is handled below
return variables.BuiltinVariable(dict).call_method(tx, name, args, kwargs)
elif issubclass(self.value, (set, frozenset)) and name != "__new__":
# __new__ is handled below
return variables.BuiltinVariable(set).call_method(tx, name, args, kwargs)
@ -1693,6 +1696,16 @@ class UserDefinedDictVariable(UserDefinedObjectVariable):
def is_underlying_vt_modified(self, side_effects):
return side_effects.is_modified(self._dict_vt)
@property
def items(self):
return self._dict_vt.items
def install_dict_keys_match_guard(self):
return self._dict_vt.install_dict_keys_match_guard()
def install_dict_contains_guard(self):
return self._dict_vt.install_dict_contains_guard()
class UserDefinedSetVariable(UserDefinedObjectVariable):
"""