[dict] Implement dict.__ior__ and fix return type in dict.__or__ (#155072)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/155072
Approved by: https://github.com/anijain2305
ghstack dependencies: #160156
This commit is contained in:
Guilherme Leobas
2025-08-26 14:38:22 -03:00
committed by PyTorch MergeBot
parent 2d44969bbd
commit e3718c4855
5 changed files with 173 additions and 20 deletions

View File

@ -1158,6 +1158,59 @@ class DictGuardTests(LoggingTestCase):
munge_exc(record.getMessage()),
)
@make_logging_test(recompiles=True)
def test_cmp_or(self, records):
@torch.compile(backend="eager", fullgraph=True)
def fn(x, d1, d2):
d = d1 | d2
if d.get(5, False):
return x.sin()
return x.cos()
x = torch.tensor(1.0)
d1 = self.thetype({1: 2, 3: 4})
d2 = self.thetype({1: 2, 5: 6})
y = fn(x, d1, d2)
# sanity check
self.assertEqual(len(records), 0)
self.assertEqual(y, x.sin())
y = fn(x, d1, d1)
self.assertEqual(len(records), 1)
self.assertEqual(y, x.cos())
record = self.getRecord(records, "d2")
self.assertIn(
"""KeyError on d2[5]""",
munge_exc(record.getMessage()),
)
@make_logging_test(recompiles=True)
def test_cmp_ior(self, records):
@torch.compile(backend="eager", fullgraph=True)
def fn(x, d1, d2):
d2 |= d1
if d2.get(3, False):
return x.sin()
return x.cos()
x = torch.tensor(1.0)
d1 = self.thetype({1: 2, 3: 4})
d2 = self.thetype({1: 2, 5: 6})
d3, d4 = d2.copy(), d2.copy()
y = fn(x, d1, d2)
# sanity check
self.assertEqual(len(records), 0)
self.assertEqual(y, x.sin())
y = fn(x, d3, d4)
self.assertEqual(len(records), 1)
self.assertEqual(y, x.cos())
record = self.getRecord(records, "d1")
self.assertIn(
"""KeyError on d1[3]""",
munge_exc(record.getMessage()),
)
class DictMethodsTests(torch._dynamo.test_case.TestCase):
thetype = dict
@ -1251,6 +1304,53 @@ class DictMethodsTests(torch._dynamo.test_case.TestCase):
# Test with non-dict types
self.assertRaises(TypeError, lambda: d1 | 1)
@make_dynamo_test
def test_binop_ior(self):
d1 = self.thetype({"a": 1, "b": 2})
d2 = self.thetype({"b": 3, "c": 4})
# Test the |= operator
d3, d4 = d1.copy(), d2.copy()
d3 |= d2
d4 |= d1
self.assertEqual(d3, {"a": 1, "b": 3, "c": 4})
self.assertEqual(d4, {"a": 1, "b": 2, "c": 4})
# Test with an iterable
d3, d4 = d1.copy(), d2.copy()
# Test the __ior__ method
d3, d4 = d1.copy(), d2.copy()
d3.__ior__(d2)
d4.__ior__(d1)
self.assertEqual(d3, {"a": 1, "b": 3, "c": 4})
self.assertEqual(d4, {"a": 1, "b": 2, "c": 4})
# Test Dict.__or__
d3, d4 = d1.copy(), d2.copy()
self.assertEqual(dict.__ior__(d3, d2), {"a": 1, "b": 3, "c": 4})
self.assertEqual(self.thetype.__ior__(d4, d1), {"a": 1, "b": 2, "c": 4})
# Test return value
d3, d4 = d1.copy(), d2.copy()
self.assertEqual(d3.__ior__(d2), {"a": 1, "b": 3, "c": 4})
self.assertEqual(dict.__ior__(d4, d1), {"a": 1, "b": 2, "c": 4})
# Test with non-dict types
self.assertRaises(TypeError, lambda: dict.__ior__(d1, 1))
@make_dynamo_test
def test_binop_ior_iterable(self):
d1 = self.thetype({"a": 1, "b": 2})
d2 = self.thetype({"b": 3, "c": 4})
d3, d4 = d1.copy(), d2.copy()
def fn(d):
yield from d.items()
self.assertEqual(d3.__ior__(d2.items()), {"a": 1, "b": 3, "c": 4})
self.assertEqual(d4.__ior__(fn(d1)), {"a": 1, "b": 2, "c": 4})
@make_dynamo_test
def test_clear(self):
d = self.thetype({"a": 1, "b": 2})
@ -1451,6 +1551,12 @@ class DictMethodsTests(torch._dynamo.test_case.TestCase):
# Test invalid usage
self.assertRaises(TypeError, d.values, 1)
@make_dynamo_test
def test_type(self):
d = self.thetype({"a": 1, "b": 2})
self.assertIsInstance(d, self.thetype)
self.assertIs(type(d), self.thetype)
class DictSubclassMethodsTests(DictMethodsTests):
thetype = SimpleDict
@ -1469,6 +1575,32 @@ class OrderedDictMethodsTests(DictMethodsTests):
b = self.thetype.fromkeys("bca")
self.assertFalse(a == b)
@make_dynamo_test
def test_binop_or_return_type(self):
d1 = self.thetype({"a": 1, "b": 2})
d2 = self.thetype({"b": 3, "c": 4})
# Test return type
self.assertIs(type(d1 | d2), OrderedDict)
self.assertIs(type(dict(d1) | d2), OrderedDict)
self.assertIs(type(d1 | dict(d2)), OrderedDict)
@make_dynamo_test
def test_binop_ior_return_type(self):
d1 = self.thetype({"a": 1, "b": 2})
d2 = self.thetype({"b": 3, "c": 4})
# Test return type
d3, d4 = d1.copy(), d2.copy()
self.assertIs(type(d3.__ior__(d2)), OrderedDict)
self.assertIs(type(dict.__ior__(d4, d2)), OrderedDict)
self.assertIs(type(self.thetype.__ior__(d4, d2)), OrderedDict)
d3, d4 = d1.copy(), d2.copy()
self.assertIs(type(dict.__ior__(d3, dict(d2))), OrderedDict)
self.assertIs(type(dict.__ior__(dict(d3), d2)), dict)
self.assertIs(type(dict(d4).__ior__(d2)), dict)
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests

View File

@ -2805,6 +2805,8 @@ class BuiltinVariable(VariableTracker):
UserDefinedObjectVariable,
),
):
# TODO(guilhermeleobas): forward the call to b.__ror__(a) if
# a.__ror__(b) returns NotImplemented
return a.call_method(tx, "__or__", [b], {})
# None no-ops this handler and lets the driving function proceed

View File

@ -626,33 +626,48 @@ class ConstDictVariable(VariableTracker):
)
elif name == "__or__":
assert len(args) == 1
# Dicts can only be unioned with other dicts or subclasses of dicts.
# Sets can be unioned with other sets, frozensets or subclasses of sets.
_raise = not (
(
istype(self, ConstDictVariable)
and istype(
args[0], (ConstDictVariable, variables.UserDefinedDictVariable)
)
)
or (
isinstance(self, SetVariable)
and isinstance(
args[0], (SetVariable, variables.UserDefinedSetVariable)
)
)
)
other = args[0]
if _raise:
# Method resolution for binops works as follow (using __or__ as example):
# (1) dict.__or__(dict) => dict
# (2) dict.__or__(subclass): return NotImplemented
# (3) Check if subclass implements __ror__ => forward the call
# to subclass.__ror__(dict)
# Let's not forward the call to __ror__ yet because __ror__ can be
# implemented in C (i.e. OrderedDict subclass) which Dynamo cannot
# trace
# if istype(other, variables.UserDefinedDictVariable):
# if other.call_obj_hasattr(tx, "__ror__").value:
# return other.call_method(tx, "__ror__", [self], kwargs)
# The three dict types Dynamo can handle are dict, OrderedDict and
# defaultdict.
# TODO(guilhermeleobas): this check should be on builtin.py::call_or_
if not istype(
other, (ConstDictVariable, variables.UserDefinedDictVariable)
):
msg = (
f"unsupported operand type(s) for |: '{self.python_type().__name__}'"
f"and '{args[0].python_type().__name__}'"
f"and '{other.python_type().__name__}'"
)
raise_observed_exception(TypeError, tx, args=[msg])
# OrderedDict overloads __ror__
ts = {self.user_cls, other.user_cls}
user_cls = (
collections.OrderedDict
if any(issubclass(t, collections.OrderedDict) for t in ts)
else dict
)
self.install_dict_keys_match_guard()
new_dict_vt = self.clone(
items=self.items.copy(), mutation_type=ValueMutationNew(), source=None
items=self.items.copy(),
mutation_type=ValueMutationNew(),
source=None,
user_cls=user_cls,
)
# NB - Guard on all the keys of the other dict to ensure

View File

@ -1926,7 +1926,7 @@ class UserDefinedDictVariable(UserDefinedObjectVariable):
"dict_vt must be constructed by builder.py when source is present"
)
self._dict_vt = variables.ConstDictVariable(
{}, mutation_type=ValueMutationNew()
{}, type(value), mutation_type=ValueMutationNew()
)
self._dict_methods = dict_methods
@ -1966,6 +1966,10 @@ class UserDefinedDictVariable(UserDefinedObjectVariable):
def is_underlying_vt_modified(self, side_effects):
return side_effects.is_modified(self._dict_vt)
@property
def user_cls(self):
return self._dict_vt.user_cls
@property
def items(self):
return self._dict_vt.items