mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[dict] Implement dict.__ior__ and fix return type in dict.__or__ (#155072)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/155072 Approved by: https://github.com/anijain2305 ghstack dependencies: #160156
This commit is contained in:
committed by
PyTorch MergeBot
parent
2d44969bbd
commit
e3718c4855
@ -1158,6 +1158,59 @@ class DictGuardTests(LoggingTestCase):
|
||||
munge_exc(record.getMessage()),
|
||||
)
|
||||
|
||||
@make_logging_test(recompiles=True)
|
||||
def test_cmp_or(self, records):
|
||||
@torch.compile(backend="eager", fullgraph=True)
|
||||
def fn(x, d1, d2):
|
||||
d = d1 | d2
|
||||
if d.get(5, False):
|
||||
return x.sin()
|
||||
return x.cos()
|
||||
|
||||
x = torch.tensor(1.0)
|
||||
d1 = self.thetype({1: 2, 3: 4})
|
||||
d2 = self.thetype({1: 2, 5: 6})
|
||||
y = fn(x, d1, d2)
|
||||
# sanity check
|
||||
self.assertEqual(len(records), 0)
|
||||
self.assertEqual(y, x.sin())
|
||||
|
||||
y = fn(x, d1, d1)
|
||||
self.assertEqual(len(records), 1)
|
||||
self.assertEqual(y, x.cos())
|
||||
record = self.getRecord(records, "d2")
|
||||
self.assertIn(
|
||||
"""KeyError on d2[5]""",
|
||||
munge_exc(record.getMessage()),
|
||||
)
|
||||
|
||||
@make_logging_test(recompiles=True)
|
||||
def test_cmp_ior(self, records):
|
||||
@torch.compile(backend="eager", fullgraph=True)
|
||||
def fn(x, d1, d2):
|
||||
d2 |= d1
|
||||
if d2.get(3, False):
|
||||
return x.sin()
|
||||
return x.cos()
|
||||
|
||||
x = torch.tensor(1.0)
|
||||
d1 = self.thetype({1: 2, 3: 4})
|
||||
d2 = self.thetype({1: 2, 5: 6})
|
||||
d3, d4 = d2.copy(), d2.copy()
|
||||
y = fn(x, d1, d2)
|
||||
# sanity check
|
||||
self.assertEqual(len(records), 0)
|
||||
self.assertEqual(y, x.sin())
|
||||
|
||||
y = fn(x, d3, d4)
|
||||
self.assertEqual(len(records), 1)
|
||||
self.assertEqual(y, x.cos())
|
||||
record = self.getRecord(records, "d1")
|
||||
self.assertIn(
|
||||
"""KeyError on d1[3]""",
|
||||
munge_exc(record.getMessage()),
|
||||
)
|
||||
|
||||
|
||||
class DictMethodsTests(torch._dynamo.test_case.TestCase):
|
||||
thetype = dict
|
||||
@ -1251,6 +1304,53 @@ class DictMethodsTests(torch._dynamo.test_case.TestCase):
|
||||
# Test with non-dict types
|
||||
self.assertRaises(TypeError, lambda: d1 | 1)
|
||||
|
||||
@make_dynamo_test
|
||||
def test_binop_ior(self):
|
||||
d1 = self.thetype({"a": 1, "b": 2})
|
||||
d2 = self.thetype({"b": 3, "c": 4})
|
||||
|
||||
# Test the |= operator
|
||||
d3, d4 = d1.copy(), d2.copy()
|
||||
d3 |= d2
|
||||
d4 |= d1
|
||||
self.assertEqual(d3, {"a": 1, "b": 3, "c": 4})
|
||||
self.assertEqual(d4, {"a": 1, "b": 2, "c": 4})
|
||||
|
||||
# Test with an iterable
|
||||
d3, d4 = d1.copy(), d2.copy()
|
||||
|
||||
# Test the __ior__ method
|
||||
d3, d4 = d1.copy(), d2.copy()
|
||||
d3.__ior__(d2)
|
||||
d4.__ior__(d1)
|
||||
self.assertEqual(d3, {"a": 1, "b": 3, "c": 4})
|
||||
self.assertEqual(d4, {"a": 1, "b": 2, "c": 4})
|
||||
|
||||
# Test Dict.__or__
|
||||
d3, d4 = d1.copy(), d2.copy()
|
||||
self.assertEqual(dict.__ior__(d3, d2), {"a": 1, "b": 3, "c": 4})
|
||||
self.assertEqual(self.thetype.__ior__(d4, d1), {"a": 1, "b": 2, "c": 4})
|
||||
|
||||
# Test return value
|
||||
d3, d4 = d1.copy(), d2.copy()
|
||||
self.assertEqual(d3.__ior__(d2), {"a": 1, "b": 3, "c": 4})
|
||||
self.assertEqual(dict.__ior__(d4, d1), {"a": 1, "b": 2, "c": 4})
|
||||
|
||||
# Test with non-dict types
|
||||
self.assertRaises(TypeError, lambda: dict.__ior__(d1, 1))
|
||||
|
||||
@make_dynamo_test
|
||||
def test_binop_ior_iterable(self):
|
||||
d1 = self.thetype({"a": 1, "b": 2})
|
||||
d2 = self.thetype({"b": 3, "c": 4})
|
||||
d3, d4 = d1.copy(), d2.copy()
|
||||
|
||||
def fn(d):
|
||||
yield from d.items()
|
||||
|
||||
self.assertEqual(d3.__ior__(d2.items()), {"a": 1, "b": 3, "c": 4})
|
||||
self.assertEqual(d4.__ior__(fn(d1)), {"a": 1, "b": 2, "c": 4})
|
||||
|
||||
@make_dynamo_test
|
||||
def test_clear(self):
|
||||
d = self.thetype({"a": 1, "b": 2})
|
||||
@ -1451,6 +1551,12 @@ class DictMethodsTests(torch._dynamo.test_case.TestCase):
|
||||
# Test invalid usage
|
||||
self.assertRaises(TypeError, d.values, 1)
|
||||
|
||||
@make_dynamo_test
|
||||
def test_type(self):
|
||||
d = self.thetype({"a": 1, "b": 2})
|
||||
self.assertIsInstance(d, self.thetype)
|
||||
self.assertIs(type(d), self.thetype)
|
||||
|
||||
|
||||
class DictSubclassMethodsTests(DictMethodsTests):
|
||||
thetype = SimpleDict
|
||||
@ -1469,6 +1575,32 @@ class OrderedDictMethodsTests(DictMethodsTests):
|
||||
b = self.thetype.fromkeys("bca")
|
||||
self.assertFalse(a == b)
|
||||
|
||||
@make_dynamo_test
|
||||
def test_binop_or_return_type(self):
|
||||
d1 = self.thetype({"a": 1, "b": 2})
|
||||
d2 = self.thetype({"b": 3, "c": 4})
|
||||
|
||||
# Test return type
|
||||
self.assertIs(type(d1 | d2), OrderedDict)
|
||||
self.assertIs(type(dict(d1) | d2), OrderedDict)
|
||||
self.assertIs(type(d1 | dict(d2)), OrderedDict)
|
||||
|
||||
@make_dynamo_test
|
||||
def test_binop_ior_return_type(self):
|
||||
d1 = self.thetype({"a": 1, "b": 2})
|
||||
d2 = self.thetype({"b": 3, "c": 4})
|
||||
|
||||
# Test return type
|
||||
d3, d4 = d1.copy(), d2.copy()
|
||||
self.assertIs(type(d3.__ior__(d2)), OrderedDict)
|
||||
self.assertIs(type(dict.__ior__(d4, d2)), OrderedDict)
|
||||
self.assertIs(type(self.thetype.__ior__(d4, d2)), OrderedDict)
|
||||
|
||||
d3, d4 = d1.copy(), d2.copy()
|
||||
self.assertIs(type(dict.__ior__(d3, dict(d2))), OrderedDict)
|
||||
self.assertIs(type(dict.__ior__(dict(d3), d2)), dict)
|
||||
self.assertIs(type(dict(d4).__ior__(d2)), dict)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._dynamo.test_case import run_tests
|
||||
|
@ -2805,6 +2805,8 @@ class BuiltinVariable(VariableTracker):
|
||||
UserDefinedObjectVariable,
|
||||
),
|
||||
):
|
||||
# TODO(guilhermeleobas): forward the call to b.__ror__(a) if
|
||||
# a.__ror__(b) returns NotImplemented
|
||||
return a.call_method(tx, "__or__", [b], {})
|
||||
|
||||
# None no-ops this handler and lets the driving function proceed
|
||||
|
@ -626,33 +626,48 @@ class ConstDictVariable(VariableTracker):
|
||||
)
|
||||
elif name == "__or__":
|
||||
assert len(args) == 1
|
||||
# Dicts can only be unioned with other dicts or subclasses of dicts.
|
||||
# Sets can be unioned with other sets, frozensets or subclasses of sets.
|
||||
_raise = not (
|
||||
(
|
||||
istype(self, ConstDictVariable)
|
||||
and istype(
|
||||
args[0], (ConstDictVariable, variables.UserDefinedDictVariable)
|
||||
)
|
||||
)
|
||||
or (
|
||||
isinstance(self, SetVariable)
|
||||
and isinstance(
|
||||
args[0], (SetVariable, variables.UserDefinedSetVariable)
|
||||
)
|
||||
)
|
||||
)
|
||||
other = args[0]
|
||||
|
||||
if _raise:
|
||||
# Method resolution for binops works as follow (using __or__ as example):
|
||||
# (1) dict.__or__(dict) => dict
|
||||
# (2) dict.__or__(subclass): return NotImplemented
|
||||
# (3) Check if subclass implements __ror__ => forward the call
|
||||
# to subclass.__ror__(dict)
|
||||
|
||||
# Let's not forward the call to __ror__ yet because __ror__ can be
|
||||
# implemented in C (i.e. OrderedDict subclass) which Dynamo cannot
|
||||
# trace
|
||||
# if istype(other, variables.UserDefinedDictVariable):
|
||||
# if other.call_obj_hasattr(tx, "__ror__").value:
|
||||
# return other.call_method(tx, "__ror__", [self], kwargs)
|
||||
|
||||
# The three dict types Dynamo can handle are dict, OrderedDict and
|
||||
# defaultdict.
|
||||
|
||||
# TODO(guilhermeleobas): this check should be on builtin.py::call_or_
|
||||
if not istype(
|
||||
other, (ConstDictVariable, variables.UserDefinedDictVariable)
|
||||
):
|
||||
msg = (
|
||||
f"unsupported operand type(s) for |: '{self.python_type().__name__}'"
|
||||
f"and '{args[0].python_type().__name__}'"
|
||||
f"and '{other.python_type().__name__}'"
|
||||
)
|
||||
raise_observed_exception(TypeError, tx, args=[msg])
|
||||
|
||||
# OrderedDict overloads __ror__
|
||||
ts = {self.user_cls, other.user_cls}
|
||||
user_cls = (
|
||||
collections.OrderedDict
|
||||
if any(issubclass(t, collections.OrderedDict) for t in ts)
|
||||
else dict
|
||||
)
|
||||
|
||||
self.install_dict_keys_match_guard()
|
||||
new_dict_vt = self.clone(
|
||||
items=self.items.copy(), mutation_type=ValueMutationNew(), source=None
|
||||
items=self.items.copy(),
|
||||
mutation_type=ValueMutationNew(),
|
||||
source=None,
|
||||
user_cls=user_cls,
|
||||
)
|
||||
|
||||
# NB - Guard on all the keys of the other dict to ensure
|
||||
|
@ -1926,7 +1926,7 @@ class UserDefinedDictVariable(UserDefinedObjectVariable):
|
||||
"dict_vt must be constructed by builder.py when source is present"
|
||||
)
|
||||
self._dict_vt = variables.ConstDictVariable(
|
||||
{}, mutation_type=ValueMutationNew()
|
||||
{}, type(value), mutation_type=ValueMutationNew()
|
||||
)
|
||||
self._dict_methods = dict_methods
|
||||
|
||||
@ -1966,6 +1966,10 @@ class UserDefinedDictVariable(UserDefinedObjectVariable):
|
||||
def is_underlying_vt_modified(self, side_effects):
|
||||
return side_effects.is_modified(self._dict_vt)
|
||||
|
||||
@property
|
||||
def user_cls(self):
|
||||
return self._dict_vt.user_cls
|
||||
|
||||
@property
|
||||
def items(self):
|
||||
return self._dict_vt.items
|
||||
|
Reference in New Issue
Block a user