Change mutation type of MutableMappingVariable to AttributeMutationNew (#159366)

Also add MutableMappingVariable to `call_or_` / `call_ior`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159366
Approved by: https://github.com/zou3519
ghstack dependencies: #159365
This commit is contained in:
Guilherme Leobas
2025-08-15 12:44:28 -03:00
committed by PyTorch MergeBot
parent 0242d40fa5
commit 2542e71f3f
29 changed files with 29 additions and 5 deletions

View File

@ -243,6 +243,10 @@ def set_difference_update(set1, *others):
set1.update(result)
def assert_dict_equal(self_, d1, d2, msg=None):
self_.assertTrue(d1 == d2, msg)
def assert_multi_line_equal(self_, first, second, msg=None):
return self_.assertTrue(first == second, msg)

View File

@ -140,7 +140,7 @@ class CPythonTestCase(TestCase):
assertListEqual = unittest.TestCase.assertListEqual
assertTupleEqual = unittest.TestCase.assertTupleEqual
assertSetEqual = unittest.TestCase.assertSetEqual
assertDictEqual = unittest.TestCase.assertDictEqual
assertDictEqual = polyfills.assert_dict_equal
assertRaises = unittest.TestCase.assertRaises
assertRaisesRegex = unittest.TestCase.assertRaisesRegex
assertWarns = unittest.TestCase.assertWarns

View File

@ -109,6 +109,7 @@ from .tensor import (
UnspecializedPythonVariable,
)
from .user_defined import (
MutableMappingVariable,
UserDefinedDictVariable,
UserDefinedObjectVariable,
UserDefinedSetVariable,
@ -1865,6 +1866,12 @@ class BuiltinVariable(VariableTracker):
hints=["Ensure your call to cast() has exactly 2 arguments."],
)
def call_dir(self, tx: "InstructionTranslator", arg):
if isinstance(arg, variables.UserDefinedClassVariable):
return VariableTracker.build(tx, dir(arg.value))
if isinstance(arg, BuiltinVariable):
return VariableTracker.build(tx, dir(arg.fn))
def call_dict(self, tx: "InstructionTranslator", *args, **kwargs):
return BuiltinVariable.call_custom_dict(tx, dict, *args, **kwargs)
@ -2255,7 +2262,6 @@ class BuiltinVariable(VariableTracker):
"assertRaisesRegex",
"assertNotWarns",
"assertWarnsRegex",
"assertDictEqual",
"assertWarns",
)
):
@ -2742,6 +2748,7 @@ class BuiltinVariable(VariableTracker):
(
ConstDictVariable,
DictKeysVariable,
MutableMappingVariable,
SetVariable,
UserDefinedDictVariable,
UserDefinedSetVariable,
@ -2770,7 +2777,13 @@ class BuiltinVariable(VariableTracker):
# This call looks like `{"one": torch.ones(1)} |= {"two": torch.ones(2)}`.
if isinstance(
a,
(ConstDictVariable, DictKeysVariable, SetVariable, UserDefinedSetVariable),
(
ConstDictVariable,
DictKeysVariable,
MutableMappingVariable,
SetVariable,
UserDefinedSetVariable,
),
):
return a.call_method(tx, "__ior__", [b], {})

View File

@ -91,7 +91,12 @@ from ..utils import (
tuple_methods,
unpatched_nn_module_getattr,
)
from .base import AttributeMutationExisting, ValueMutationNew, VariableTracker
from .base import (
AttributeMutationExisting,
AttributeMutationNew,
ValueMutationNew,
VariableTracker,
)
from .dicts import DefaultDictVariable
from .lists import SizeVariable
@ -2084,7 +2089,9 @@ class MutableMappingVariable(UserDefinedObjectVariable):
def __init__(self, value, **kwargs):
super().__init__(value, **kwargs)
self.generic_dict_vt = variables.ConstDictVariable({})
self.mutation_type = AttributeMutationExisting()
self.mutation_type = (
AttributeMutationExisting() if self.source else AttributeMutationNew()
)
def var_getattr(self, tx: "InstructionTranslator", name: str) -> "VariableTracker":
# A common pattern in the init code of MutableMapping objects is to