mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Change mutation type of MutableMappingVariable to AttributeMutationNew (#159366)
Also add MutableMappingVariable to `call_or_` / `call_ior` Pull Request resolved: https://github.com/pytorch/pytorch/pull/159366 Approved by: https://github.com/zou3519 ghstack dependencies: #159365
This commit is contained in:
committed by
PyTorch MergeBot
parent
0242d40fa5
commit
2542e71f3f
@ -243,6 +243,10 @@ def set_difference_update(set1, *others):
|
||||
set1.update(result)
|
||||
|
||||
|
||||
def assert_dict_equal(self_, d1, d2, msg=None):
|
||||
self_.assertTrue(d1 == d2, msg)
|
||||
|
||||
|
||||
def assert_multi_line_equal(self_, first, second, msg=None):
|
||||
return self_.assertTrue(first == second, msg)
|
||||
|
||||
|
||||
@ -140,7 +140,7 @@ class CPythonTestCase(TestCase):
|
||||
assertListEqual = unittest.TestCase.assertListEqual
|
||||
assertTupleEqual = unittest.TestCase.assertTupleEqual
|
||||
assertSetEqual = unittest.TestCase.assertSetEqual
|
||||
assertDictEqual = unittest.TestCase.assertDictEqual
|
||||
assertDictEqual = polyfills.assert_dict_equal
|
||||
assertRaises = unittest.TestCase.assertRaises
|
||||
assertRaisesRegex = unittest.TestCase.assertRaisesRegex
|
||||
assertWarns = unittest.TestCase.assertWarns
|
||||
|
||||
@ -109,6 +109,7 @@ from .tensor import (
|
||||
UnspecializedPythonVariable,
|
||||
)
|
||||
from .user_defined import (
|
||||
MutableMappingVariable,
|
||||
UserDefinedDictVariable,
|
||||
UserDefinedObjectVariable,
|
||||
UserDefinedSetVariable,
|
||||
@ -1865,6 +1866,12 @@ class BuiltinVariable(VariableTracker):
|
||||
hints=["Ensure your call to cast() has exactly 2 arguments."],
|
||||
)
|
||||
|
||||
def call_dir(self, tx: "InstructionTranslator", arg):
|
||||
if isinstance(arg, variables.UserDefinedClassVariable):
|
||||
return VariableTracker.build(tx, dir(arg.value))
|
||||
if isinstance(arg, BuiltinVariable):
|
||||
return VariableTracker.build(tx, dir(arg.fn))
|
||||
|
||||
def call_dict(self, tx: "InstructionTranslator", *args, **kwargs):
|
||||
return BuiltinVariable.call_custom_dict(tx, dict, *args, **kwargs)
|
||||
|
||||
@ -2255,7 +2262,6 @@ class BuiltinVariable(VariableTracker):
|
||||
"assertRaisesRegex",
|
||||
"assertNotWarns",
|
||||
"assertWarnsRegex",
|
||||
"assertDictEqual",
|
||||
"assertWarns",
|
||||
)
|
||||
):
|
||||
@ -2742,6 +2748,7 @@ class BuiltinVariable(VariableTracker):
|
||||
(
|
||||
ConstDictVariable,
|
||||
DictKeysVariable,
|
||||
MutableMappingVariable,
|
||||
SetVariable,
|
||||
UserDefinedDictVariable,
|
||||
UserDefinedSetVariable,
|
||||
@ -2770,7 +2777,13 @@ class BuiltinVariable(VariableTracker):
|
||||
# This call looks like `{"one": torch.ones(1)} |= {"two": torch.ones(2)}`.
|
||||
if isinstance(
|
||||
a,
|
||||
(ConstDictVariable, DictKeysVariable, SetVariable, UserDefinedSetVariable),
|
||||
(
|
||||
ConstDictVariable,
|
||||
DictKeysVariable,
|
||||
MutableMappingVariable,
|
||||
SetVariable,
|
||||
UserDefinedSetVariable,
|
||||
),
|
||||
):
|
||||
return a.call_method(tx, "__ior__", [b], {})
|
||||
|
||||
|
||||
@ -91,7 +91,12 @@ from ..utils import (
|
||||
tuple_methods,
|
||||
unpatched_nn_module_getattr,
|
||||
)
|
||||
from .base import AttributeMutationExisting, ValueMutationNew, VariableTracker
|
||||
from .base import (
|
||||
AttributeMutationExisting,
|
||||
AttributeMutationNew,
|
||||
ValueMutationNew,
|
||||
VariableTracker,
|
||||
)
|
||||
from .dicts import DefaultDictVariable
|
||||
from .lists import SizeVariable
|
||||
|
||||
@ -2084,7 +2089,9 @@ class MutableMappingVariable(UserDefinedObjectVariable):
|
||||
def __init__(self, value, **kwargs):
|
||||
super().__init__(value, **kwargs)
|
||||
self.generic_dict_vt = variables.ConstDictVariable({})
|
||||
self.mutation_type = AttributeMutationExisting()
|
||||
self.mutation_type = (
|
||||
AttributeMutationExisting() if self.source else AttributeMutationNew()
|
||||
)
|
||||
|
||||
def var_getattr(self, tx: "InstructionTranslator", name: str) -> "VariableTracker":
|
||||
# A common pattern in the init code of MutableMapping objects is to
|
||||
|
||||
Reference in New Issue
Block a user