mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[dynamo] Support custom __setattr__ on UserDefinedObjectVariable (#123318)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/123318 Approved by: https://github.com/anijain2305
This commit is contained in:
committed by
PyTorch MergeBot
parent
89724843bb
commit
212e460dce
@ -86,7 +86,7 @@ detectron2_fasterrcnn_r_50_fpn,eager_fail_to_run,0
|
||||
|
||||
|
||||
|
||||
detectron2_fcos_r_50_fpn,pass,35
|
||||
detectron2_fcos_r_50_fpn,pass,30
|
||||
|
||||
|
||||
|
||||
|
|
@ -54,47 +54,47 @@ densenet121,pass,0
|
||||
|
||||
|
||||
|
||||
detectron2_fasterrcnn_r_101_c4,pass,51
|
||||
detectron2_fasterrcnn_r_101_c4,pass,54
|
||||
|
||||
|
||||
|
||||
detectron2_fasterrcnn_r_101_dc5,pass,51
|
||||
detectron2_fasterrcnn_r_101_dc5,pass,54
|
||||
|
||||
|
||||
|
||||
detectron2_fasterrcnn_r_101_fpn,pass,55
|
||||
detectron2_fasterrcnn_r_101_fpn,pass,58
|
||||
|
||||
|
||||
|
||||
detectron2_fasterrcnn_r_50_c4,pass,51
|
||||
detectron2_fasterrcnn_r_50_c4,pass,54
|
||||
|
||||
|
||||
|
||||
detectron2_fasterrcnn_r_50_dc5,pass,51
|
||||
detectron2_fasterrcnn_r_50_dc5,pass,54
|
||||
|
||||
|
||||
|
||||
detectron2_fasterrcnn_r_50_fpn,pass,55
|
||||
detectron2_fasterrcnn_r_50_fpn,pass,58
|
||||
|
||||
|
||||
|
||||
detectron2_fcos_r_50_fpn,pass,38
|
||||
detectron2_fcos_r_50_fpn,pass,33
|
||||
|
||||
|
||||
|
||||
detectron2_maskrcnn_r_101_c4,fail_accuracy,66
|
||||
detectron2_maskrcnn_r_101_c4,fail_accuracy,75
|
||||
|
||||
|
||||
|
||||
detectron2_maskrcnn_r_101_fpn,pass,73
|
||||
detectron2_maskrcnn_r_101_fpn,pass,82
|
||||
|
||||
|
||||
|
||||
detectron2_maskrcnn_r_50_c4,pass,66
|
||||
detectron2_maskrcnn_r_50_c4,pass,75
|
||||
|
||||
|
||||
|
||||
detectron2_maskrcnn_r_50_fpn,pass,73
|
||||
detectron2_maskrcnn_r_50_fpn,pass,82
|
||||
|
||||
|
||||
|
||||
|
|
@ -86,7 +86,7 @@ detectron2_fasterrcnn_r_50_fpn,eager_fail_to_run,0
|
||||
|
||||
|
||||
|
||||
detectron2_fcos_r_50_fpn,pass,35
|
||||
detectron2_fcos_r_50_fpn,pass,30
|
||||
|
||||
|
||||
|
||||
|
|
@ -54,7 +54,7 @@ densenet121,pass,0
|
||||
|
||||
|
||||
|
||||
detectron2_fcos_r_50_fpn,pass,38
|
||||
detectron2_fcos_r_50_fpn,pass,33
|
||||
|
||||
|
||||
|
||||
|
|
@ -86,7 +86,7 @@ detectron2_fasterrcnn_r_50_fpn,eager_fail_to_run,0
|
||||
|
||||
|
||||
|
||||
detectron2_fcos_r_50_fpn,pass,36
|
||||
detectron2_fcos_r_50_fpn,pass,31
|
||||
|
||||
|
||||
|
||||
|
|
@ -86,7 +86,7 @@ detectron2_fasterrcnn_r_50_fpn,eager_fail_to_run,0
|
||||
|
||||
|
||||
|
||||
detectron2_fcos_r_50_fpn,pass,35
|
||||
detectron2_fcos_r_50_fpn,pass,30
|
||||
|
||||
|
||||
|
||||
|
|
@ -86,7 +86,7 @@ detectron2_fasterrcnn_r_50_fpn,eager_fail_to_run,0
|
||||
|
||||
|
||||
|
||||
detectron2_fcos_r_50_fpn,pass,36
|
||||
detectron2_fcos_r_50_fpn,pass,31
|
||||
|
||||
|
||||
|
||||
|
|
@ -143,6 +143,18 @@ def closure_adder(val):
|
||||
return inner
|
||||
|
||||
|
||||
class UserDefineSetAttr:
|
||||
setup = False
|
||||
|
||||
def __setattr__(self, key, value):
|
||||
assert torch.compiler.is_dynamo_compiling() or UserDefineSetAttr.setup
|
||||
super().__setattr__(f"pfx_{key}", value)
|
||||
|
||||
def __getattr__(self, key):
|
||||
assert torch.compiler.is_dynamo_compiling() or UserDefineSetAttr.setup
|
||||
return self.__dict__[f"pfx_{key}"]
|
||||
|
||||
|
||||
class MiscTests(torch._dynamo.test_case.TestCase):
|
||||
def test_get_cache_entry(self):
|
||||
def f(x):
|
||||
@ -488,6 +500,34 @@ class MiscTests(torch._dynamo.test_case.TestCase):
|
||||
cleanup_op("mylib::foo")
|
||||
del lib
|
||||
|
||||
def test_user_defined_setattr1(self):
|
||||
@torch.compile(backend="eager", fullgraph=True)
|
||||
def fn(obj):
|
||||
obj.y = obj.x + 1
|
||||
|
||||
obj = UserDefineSetAttr()
|
||||
with patch.object(UserDefineSetAttr, "setup", True):
|
||||
obj.x = torch.randn(8)
|
||||
fn(obj)
|
||||
with patch.object(UserDefineSetAttr, "setup", True):
|
||||
self.assertEqual(obj.y, obj.x + 1)
|
||||
self.assertEqual(obj.__dict__.keys(), {"pfx_x", "pfx_y"})
|
||||
|
||||
def test_user_defined_setattr2(self):
|
||||
@torch.compile(backend="eager", fullgraph=True)
|
||||
def fn(x):
|
||||
obj = UserDefineSetAttr()
|
||||
obj.x = x
|
||||
obj.y = obj.x + 1
|
||||
return obj
|
||||
|
||||
x = torch.randn(8)
|
||||
obj = fn(x)
|
||||
with patch.object(UserDefineSetAttr, "setup", True):
|
||||
self.assertIs(obj.x, x)
|
||||
self.assertEqual(obj.y, x + 1)
|
||||
self.assertEqual(obj.__dict__.keys(), {"pfx_x", "pfx_y"})
|
||||
|
||||
def test_closure_recompiles(self):
|
||||
cnt = CompileCounter()
|
||||
|
||||
|
@ -17,7 +17,7 @@ from collections import namedtuple
|
||||
from copy import deepcopy
|
||||
from enum import Enum
|
||||
from functools import wraps
|
||||
from typing import Any, Iterator, List
|
||||
from typing import Any, Dict, Iterator, List, Tuple
|
||||
from unittest import mock
|
||||
|
||||
import numpy as np
|
||||
@ -4453,6 +4453,95 @@ def forward(self, s0 : torch.SymInt, s1 : torch.SymInt, L_x_ : torch.Tensor):
|
||||
with self.assertRaisesRegex(AssertionError, ""):
|
||||
f_fail(torch.ones(6, 4))
|
||||
|
||||
def test_detectron2_instances_cat(self):
|
||||
class Instances:
|
||||
def __init__(self, image_size: Tuple[int, int], **kwargs: Any):
|
||||
self._image_size = image_size
|
||||
self._fields: Dict[str, Any] = {}
|
||||
for k, v in kwargs.items():
|
||||
self.set(k, v)
|
||||
|
||||
@property
|
||||
def image_size(self) -> Tuple[int, int]:
|
||||
return self._image_size
|
||||
|
||||
def __setattr__(self, name: str, val: Any) -> None:
|
||||
if name.startswith("_"):
|
||||
super().__setattr__(name, val)
|
||||
else:
|
||||
self.set(name, val)
|
||||
|
||||
def __getattr__(self, name: str) -> Any:
|
||||
if name == "_fields" or name not in self._fields:
|
||||
raise AttributeError(
|
||||
f"Cannot find field '{name}' in the given Instances!"
|
||||
)
|
||||
return self._fields[name]
|
||||
|
||||
def __len__(self) -> int:
|
||||
for v in self._fields.values():
|
||||
# use __len__ because len() has to be int and is not friendly to tracing
|
||||
return v.__len__()
|
||||
raise NotImplementedError("Empty Instances does not support __len__!")
|
||||
|
||||
def set(self, name: str, value: Any) -> None:
|
||||
# TODO(jansel): support catch_warnings
|
||||
# with warnings.catch_warnings(record=True):
|
||||
data_len = len(value)
|
||||
if len(self._fields):
|
||||
assert (
|
||||
len(self) == data_len
|
||||
), f"Adding a field of length {data_len} to a Instances of length {len(self)}"
|
||||
self._fields[name] = value
|
||||
|
||||
def get(self, name: str) -> Any:
|
||||
return self._fields[name]
|
||||
|
||||
@staticmethod
|
||||
def cat(instance_lists: List["Instances"]) -> "Instances":
|
||||
# TODO(jansel): support all isinstance generator
|
||||
# assert all(isinstance(i, Instances) for i in instance_lists)
|
||||
assert len(instance_lists) > 0
|
||||
if len(instance_lists) == 1:
|
||||
return instance_lists[0]
|
||||
|
||||
image_size = instance_lists[0].image_size
|
||||
if not isinstance(
|
||||
image_size, torch.Tensor
|
||||
): # could be a tensor in tracing
|
||||
for i in instance_lists[1:]:
|
||||
assert i.image_size == image_size
|
||||
ret = Instances(image_size)
|
||||
for k in instance_lists[0]._fields.keys():
|
||||
values = [i.get(k) for i in instance_lists]
|
||||
v0 = values[0]
|
||||
if isinstance(v0, torch.Tensor):
|
||||
values = torch.cat(values, dim=0)
|
||||
elif isinstance(v0, list):
|
||||
values = list(itertools.chain(*values))
|
||||
elif hasattr(type(v0), "cat"):
|
||||
values = type(v0).cat(values)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported type {type(v0)} for concatenation"
|
||||
)
|
||||
ret.set(k, values)
|
||||
return ret
|
||||
|
||||
instances = [
|
||||
Instances((16, 16), a=[torch.randn(16, 16)], b=[torch.randn(16, 16)])
|
||||
for _ in range(3)
|
||||
]
|
||||
|
||||
@torch.compile(backend="eager", fullgraph=True)
|
||||
def fn(instances):
|
||||
return instances[0].cat(instances)
|
||||
|
||||
actual = fn(instances)
|
||||
expected = instances[0].cat(instances)
|
||||
self.assertEqual(type(actual), type(expected))
|
||||
self.assertEqual(actual.__dict__, expected.__dict__)
|
||||
|
||||
def test_super_in_staticmethod(self):
|
||||
class A:
|
||||
@staticmethod
|
||||
|
@ -10,6 +10,7 @@ from . import utils
|
||||
|
||||
from .bytecode_transformation import (
|
||||
create_call_function,
|
||||
create_call_method,
|
||||
create_dup_top,
|
||||
create_instruction,
|
||||
create_load_attr,
|
||||
@ -265,6 +266,12 @@ class PyCodegen:
|
||||
self.tx.output.update_co_names(name)
|
||||
return create_load_method(name)
|
||||
|
||||
def load_method(self, name):
|
||||
self.append_output(self.create_load_method(name))
|
||||
|
||||
def call_method(self, nargs):
|
||||
self.extend_output(create_call_method(nargs))
|
||||
|
||||
def create_load_attr(self, name) -> Instruction:
|
||||
if name not in self.code_options["co_names"]:
|
||||
self.code_options["co_names"] += (name,)
|
||||
@ -322,6 +329,9 @@ class PyCodegen:
|
||||
create_instruction("POP_TOP"),
|
||||
]
|
||||
|
||||
def pop_top(self):
|
||||
self.append_output(create_instruction("POP_TOP"))
|
||||
|
||||
def call_function(self, nargs: int, push_null: bool):
|
||||
self.extend_output(create_call_function(nargs, push_null=push_null))
|
||||
|
||||
|
@ -179,9 +179,9 @@ class SideEffects:
|
||||
|
||||
@staticmethod
|
||||
def cls_supports_mutation_side_effects(cls):
|
||||
return inspect.getattr_static(cls, "__setattr__", None) in (
|
||||
object.__setattr__,
|
||||
torch.nn.Module.__setattr__,
|
||||
return (
|
||||
inspect.getattr_static(cls, "__getattribute__", None)
|
||||
is object.__getattribute__
|
||||
)
|
||||
|
||||
def is_attribute_mutation(self, item):
|
||||
@ -192,6 +192,11 @@ class SideEffects:
|
||||
self.store_attr_mutations.get(item.mutable_local)
|
||||
)
|
||||
|
||||
def has_pending_mutation_of_attr(self, item, name):
|
||||
return self.is_attribute_mutation(
|
||||
item
|
||||
) and name in self.store_attr_mutations.get(item.mutable_local, ())
|
||||
|
||||
def is_modified(self, item):
|
||||
if isinstance(item.mutable_local, AttributeMutationNew):
|
||||
return True
|
||||
@ -494,6 +499,19 @@ class SideEffects:
|
||||
suffixes.append(
|
||||
[create_instruction("DELETE_ATTR", argval=name)]
|
||||
)
|
||||
elif (
|
||||
isinstance(var, variables.UserDefinedObjectVariable)
|
||||
and var.needs_slow_setattr()
|
||||
):
|
||||
# __setattr__ is defined on this object, so call object.__setattr__ directly
|
||||
cg.load_import_from("builtins", "object")
|
||||
cg.load_method("__setattr__")
|
||||
cg(var.mutable_local.source) # type: ignore[attr-defined]
|
||||
cg(variables.ConstantVariable(name))
|
||||
cg(value)
|
||||
suffixes.append(
|
||||
[*create_call_method(3), create_instruction("POP_TOP")]
|
||||
)
|
||||
else:
|
||||
cg.tx.output.update_co_names(name)
|
||||
cg(value)
|
||||
@ -503,8 +521,8 @@ class SideEffects:
|
||||
for _ in range(var.index):
|
||||
cg.load_import_from(utils.__name__, "iter_next")
|
||||
cg(var.mutable_local.source) # type: ignore[attr-defined]
|
||||
cg.extend_output(create_call_function(1, True))
|
||||
cg.append_output(create_instruction("POP_TOP"))
|
||||
cg.call_function(1, True)
|
||||
cg.pop_top()
|
||||
else:
|
||||
raise AssertionError(type(var))
|
||||
|
||||
|
@ -2419,9 +2419,11 @@ class InliningInstructionTranslator(InstructionTranslatorBase):
|
||||
|
||||
code: types.CodeType = func.get_code()
|
||||
if code.co_name in ("__setitem__", "__setattr__") and not (
|
||||
args is not None
|
||||
and len(args) > 0
|
||||
and isinstance(args[0], variables.CustomizedDictVariable)
|
||||
args
|
||||
and isinstance(
|
||||
args[0],
|
||||
(variables.CustomizedDictVariable, variables.UserDefinedObjectVariable),
|
||||
)
|
||||
):
|
||||
unimplemented(f"inline {code.co_name}")
|
||||
|
||||
|
@ -26,7 +26,7 @@ from .dicts import (
|
||||
DefaultDictVariable,
|
||||
SetVariable,
|
||||
)
|
||||
from .distributed import BackwardHookVariable, DistributedVariable
|
||||
from .distributed import BackwardHookVariable, DistributedVariable, PlacementVariable
|
||||
from .functions import (
|
||||
FunctoolsPartialVariable,
|
||||
NestedUserFunctionVariable,
|
||||
@ -131,6 +131,7 @@ __all__ = [
|
||||
"NumpyNdarrayVariable",
|
||||
"NumpyVariable",
|
||||
"OptimizerVariable",
|
||||
"PlacementVariable",
|
||||
"PythonModuleVariable",
|
||||
"RangeVariable",
|
||||
"RemovableHandleVariable",
|
||||
|
@ -1550,14 +1550,13 @@ class BuiltinVariable(VariableTracker):
|
||||
def call_setattr(
|
||||
self, tx, obj: VariableTracker, name_var: VariableTracker, val: VariableTracker
|
||||
):
|
||||
from .distributed import PlacementVariable
|
||||
|
||||
if isinstance(
|
||||
obj,
|
||||
(
|
||||
variables.DataClassVariable,
|
||||
variables.CustomizedDictVariable,
|
||||
PlacementVariable,
|
||||
variables.PlacementVariable,
|
||||
variables.UserDefinedObjectVariable,
|
||||
),
|
||||
):
|
||||
return obj.call_method(tx, "__setattr__", [name_var, val], {})
|
||||
@ -1602,7 +1601,7 @@ class BuiltinVariable(VariableTracker):
|
||||
|
||||
# Step 3 - drop the version counter - this is a step required to get
|
||||
# .data setting to play correctly with the autograd engine.
|
||||
# Esentially, dynamo is trying to faithful preserve the (absurd)
|
||||
# Essentially, dynamo is trying to faithfully preserve the (absurd)
|
||||
# behavior of .data= from eager mode
|
||||
def _lower_version_count_by_1(x):
|
||||
version = x._version
|
||||
|
@ -87,9 +87,7 @@ class BaseUserFunctionVariable(VariableTracker):
|
||||
def call_function(
|
||||
self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
|
||||
) -> "VariableTracker":
|
||||
return tx.inline_user_function_return(
|
||||
self, list(self.self_args()) + list(args), kwargs
|
||||
)
|
||||
return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)
|
||||
|
||||
def call_hasattr(self, tx, name: str) -> VariableTracker:
|
||||
result = False
|
||||
|
@ -25,7 +25,7 @@ from ..utils import (
|
||||
)
|
||||
from .base import VariableTracker
|
||||
from .functions import NestedUserFunctionVariable, UserFunctionVariable
|
||||
from .user_defined import UserDefinedObjectVariable
|
||||
from .user_defined import is_standard_setattr, UserDefinedObjectVariable
|
||||
|
||||
|
||||
class SuperVariable(VariableTracker):
|
||||
@ -170,8 +170,12 @@ class SuperVariable(VariableTracker):
|
||||
return super(variables.CustomizedDictVariable, self.objvar).call_method(
|
||||
tx, "__setitem__", args, kwargs
|
||||
)
|
||||
else:
|
||||
unimplemented(f"non-function or method super: {inner_fn}")
|
||||
elif is_standard_setattr(inner_fn) and isinstance(
|
||||
self.objvar, UserDefinedObjectVariable
|
||||
):
|
||||
return self.objvar.method_setattr_standard(tx, *args, **kwargs)
|
||||
|
||||
unimplemented(f"non-function or method super: {inner_fn}")
|
||||
|
||||
|
||||
class UnknownVariable(VariableTracker):
|
||||
@ -503,6 +507,8 @@ class AutogradFunctionContextVariable(UserDefinedObjectVariable):
|
||||
args: "List[VariableTracker]",
|
||||
kwargs: "Dict[str, VariableTracker]",
|
||||
) -> "VariableTracker":
|
||||
if name == "__setattr__":
|
||||
return super().call_method(tx, name, args, kwargs)
|
||||
if name != "save_for_backward":
|
||||
unimplemented(f"autograd.Function context method: {name}")
|
||||
if self.saved_tensors is None:
|
||||
@ -604,8 +610,12 @@ class GetAttrVariable(VariableTracker):
|
||||
# redirect to var_getattr on the original obj
|
||||
if isinstance(obj, variables.UserDefinedObjectVariable):
|
||||
obj._check_for_getattribute()
|
||||
if key in obj.value.__dict__:
|
||||
if (
|
||||
key in obj.value.__dict__
|
||||
or tx.output.side_effects.has_pending_mutation_of_attr(obj, key)
|
||||
):
|
||||
return obj.var_getattr(tx, key)
|
||||
|
||||
return super().call_method(tx, name, args, kwargs)
|
||||
|
||||
|
||||
|
@ -679,10 +679,12 @@ class UnspecializedNNModuleVariable(UserDefinedObjectVariable):
|
||||
@staticmethod
|
||||
@functools.lru_cache(None)
|
||||
def _nn_module_method_ids():
|
||||
# Allow __setattr__ to fall through to base class handler
|
||||
supported = {torch.nn.Module.__setattr__}
|
||||
return {
|
||||
id(x.__code__)
|
||||
for x in torch.nn.Module.__dict__.values()
|
||||
if hasattr(x, "__code__")
|
||||
if hasattr(x, "__code__") and x not in supported
|
||||
}
|
||||
|
||||
def unpack_var_sequence(self, tx):
|
||||
|
@ -52,6 +52,13 @@ from .ctx_manager import GenericContextWrappingVariable, NullContextVariable
|
||||
from .dicts import DefaultDictVariable
|
||||
|
||||
|
||||
def is_standard_setattr(val):
|
||||
return val in (
|
||||
object.__setattr__,
|
||||
torch.nn.Module.__setattr__,
|
||||
)
|
||||
|
||||
|
||||
class UserDefinedVariable(VariableTracker):
|
||||
pass
|
||||
|
||||
@ -443,6 +450,10 @@ class UserDefinedClassVariable(UserDefinedVariable):
|
||||
return super().const_getattr(tx, name)
|
||||
|
||||
|
||||
class NO_SUCH_SUBOBJ:
|
||||
pass
|
||||
|
||||
|
||||
class UserDefinedObjectVariable(UserDefinedVariable):
|
||||
"""
|
||||
Mostly objects of defined type. Catch-all for something where we only know the type.
|
||||
@ -540,6 +551,9 @@ class UserDefinedObjectVariable(UserDefinedVariable):
|
||||
if method is object.__init__:
|
||||
return ConstantVariable.create(None)
|
||||
|
||||
if is_standard_setattr(method):
|
||||
return self.method_setattr_standard(tx, *args, **kwargs)
|
||||
|
||||
# [NOTE] OrderedDict, dict subtypes must always have source
|
||||
# We cannot instantiate such subtypes in-graph due to builtin __new__
|
||||
if method is collections.OrderedDict.keys:
|
||||
@ -603,6 +617,22 @@ class UserDefinedObjectVariable(UserDefinedVariable):
|
||||
|
||||
return super().call_method(tx, name, args, kwargs)
|
||||
|
||||
def method_setattr_standard(self, tx, name, value):
|
||||
try:
|
||||
name = name.as_python_constant()
|
||||
except NotImplementedError:
|
||||
unimplemented(f"non-const setattr name: {name}")
|
||||
if not tx.output.side_effects.is_attribute_mutation(self):
|
||||
unimplemented(f"setattr({self}, {name}, ...)")
|
||||
|
||||
tx.output.side_effects.store_attr(self, name, value)
|
||||
return variables.ConstantVariable(None)
|
||||
|
||||
def needs_slow_setattr(self):
|
||||
return not is_standard_setattr(
|
||||
inspect.getattr_static(self.value, "__setattr__", None)
|
||||
)
|
||||
|
||||
def unpack_var_sequence(self, tx):
|
||||
if (
|
||||
self.source
|
||||
@ -745,8 +775,8 @@ class UserDefinedObjectVariable(UserDefinedVariable):
|
||||
self._check_for_getattribute()
|
||||
getattr_fn = self._check_for_getattr()
|
||||
|
||||
class NO_SUCH_SUBOBJ:
|
||||
pass
|
||||
if tx.output.side_effects.has_pending_mutation_of_attr(self, name):
|
||||
return tx.output.side_effects.load_attr(self, name)
|
||||
|
||||
try:
|
||||
subobj = self._getattr_static(name)
|
||||
|
Reference in New Issue
Block a user