mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Change mutation type of MutableMappingVariable
to AttributeMutationNew
(#159366)
Also add MutableMappingVariable to `call_or_` / `call_ior` Pull Request resolved: https://github.com/pytorch/pytorch/pull/159366 Approved by: https://github.com/zou3519 ghstack dependencies: #159365
This commit is contained in:
committed by
PyTorch MergeBot
parent
0242d40fa5
commit
2542e71f3f
@ -109,6 +109,7 @@ from .tensor import (
|
||||
UnspecializedPythonVariable,
|
||||
)
|
||||
from .user_defined import (
|
||||
MutableMappingVariable,
|
||||
UserDefinedDictVariable,
|
||||
UserDefinedObjectVariable,
|
||||
UserDefinedSetVariable,
|
||||
@ -1865,6 +1866,12 @@ class BuiltinVariable(VariableTracker):
|
||||
hints=["Ensure your call to cast() has exactly 2 arguments."],
|
||||
)
|
||||
|
||||
def call_dir(self, tx: "InstructionTranslator", arg):
|
||||
if isinstance(arg, variables.UserDefinedClassVariable):
|
||||
return VariableTracker.build(tx, dir(arg.value))
|
||||
if isinstance(arg, BuiltinVariable):
|
||||
return VariableTracker.build(tx, dir(arg.fn))
|
||||
|
||||
def call_dict(self, tx: "InstructionTranslator", *args, **kwargs):
|
||||
return BuiltinVariable.call_custom_dict(tx, dict, *args, **kwargs)
|
||||
|
||||
@ -2255,7 +2262,6 @@ class BuiltinVariable(VariableTracker):
|
||||
"assertRaisesRegex",
|
||||
"assertNotWarns",
|
||||
"assertWarnsRegex",
|
||||
"assertDictEqual",
|
||||
"assertWarns",
|
||||
)
|
||||
):
|
||||
@ -2742,6 +2748,7 @@ class BuiltinVariable(VariableTracker):
|
||||
(
|
||||
ConstDictVariable,
|
||||
DictKeysVariable,
|
||||
MutableMappingVariable,
|
||||
SetVariable,
|
||||
UserDefinedDictVariable,
|
||||
UserDefinedSetVariable,
|
||||
@ -2770,7 +2777,13 @@ class BuiltinVariable(VariableTracker):
|
||||
# This call looks like `{"one": torch.ones(1)} |= {"two": torch.ones(2)}`.
|
||||
if isinstance(
|
||||
a,
|
||||
(ConstDictVariable, DictKeysVariable, SetVariable, UserDefinedSetVariable),
|
||||
(
|
||||
ConstDictVariable,
|
||||
DictKeysVariable,
|
||||
MutableMappingVariable,
|
||||
SetVariable,
|
||||
UserDefinedSetVariable,
|
||||
),
|
||||
):
|
||||
return a.call_method(tx, "__ior__", [b], {})
|
||||
|
||||
|
Reference in New Issue
Block a user