[dynamo] Support custom __setattr__ on UserDefinedObjectVariable (#123318)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/123318
Approved by: https://github.com/anijain2305
This commit is contained in:
Jason Ansel
2024-04-07 11:07:37 -07:00
committed by PyTorch MergeBot
parent 89724843bb
commit 212e460dce
19 changed files with 240 additions and 41 deletions

View File

@ -86,7 +86,7 @@ detectron2_fasterrcnn_r_50_fpn,eager_fail_to_run,0
detectron2_fcos_r_50_fpn,pass,35
detectron2_fcos_r_50_fpn,pass,30

1 name accuracy graph_breaks
86 timm_efficientnet pass 0
87 timm_regnet pass 0
88 timm_resnest pass 0
89 timm_vision_transformer pass 0
90 timm_vision_transformer_large pass_due_to_skip 0
91 timm_vovnet pass 0
92 torch_multimodal_clip pass 0

View File

@ -54,47 +54,47 @@ densenet121,pass,0
detectron2_fasterrcnn_r_101_c4,pass,51
detectron2_fasterrcnn_r_101_c4,pass,54
detectron2_fasterrcnn_r_101_dc5,pass,51
detectron2_fasterrcnn_r_101_dc5,pass,54
detectron2_fasterrcnn_r_101_fpn,pass,55
detectron2_fasterrcnn_r_101_fpn,pass,58
detectron2_fasterrcnn_r_50_c4,pass,51
detectron2_fasterrcnn_r_50_c4,pass,54
detectron2_fasterrcnn_r_50_dc5,pass,51
detectron2_fasterrcnn_r_50_dc5,pass,54
detectron2_fasterrcnn_r_50_fpn,pass,55
detectron2_fasterrcnn_r_50_fpn,pass,58
detectron2_fcos_r_50_fpn,pass,38
detectron2_fcos_r_50_fpn,pass,33
detectron2_maskrcnn_r_101_c4,fail_accuracy,66
detectron2_maskrcnn_r_101_c4,fail_accuracy,75
detectron2_maskrcnn_r_101_fpn,pass,73
detectron2_maskrcnn_r_101_fpn,pass,82
detectron2_maskrcnn_r_50_c4,pass,66
detectron2_maskrcnn_r_50_c4,pass,75
detectron2_maskrcnn_r_50_fpn,pass,73
detectron2_maskrcnn_r_50_fpn,pass,82

1 name accuracy graph_breaks
54 nvidia_deeprecommender pass 0
55 opacus_cifar10 pass 0
56 phlippe_densenet pass 0
57 phlippe_resnet pass 0
58 pyhpc_equation_of_state pass 0
59 pyhpc_isoneutral_mixing pass 0
60 pyhpc_turbulent_kinetic_energy pass 0
61 pytorch_CycleGAN_and_pix2pix pass 0
62 pytorch_stargan pass 0
63 pytorch_unet pass 0
64 resnet152 pass 0
65 resnet18 pass 0
66 resnet50 pass 0
67 resnet50_quantized_qat pass 2
68 resnext50_32x4d pass 0
69 shufflenet_v2_x1_0 pass 0
70 soft_actor_critic pass 0
71 speech_transformer pass 10
72 squeezenet1_1 pass 0
73 stable_diffusion_unet pass_due_to_skip 0
74 timm_efficientdet model_fail_to_load 0
75 timm_efficientnet pass 0
76 timm_nfnet pass 0
77 timm_regnet pass 0
78 timm_resnest pass 0
79 timm_vision_transformer pass 0
80 timm_vision_transformer_large pass_due_to_skip 0
81 timm_vovnet pass 0
82 torch_multimodal_clip pass 0
83 tts_angular pass 2
84 vgg16 pass 0
85 vision_maskrcnn pass 28
86 yolov3 pass 2
87
88
89
90
91
92
93
94
95
96
97
98
99
100

View File

@ -86,7 +86,7 @@ detectron2_fasterrcnn_r_50_fpn,eager_fail_to_run,0
detectron2_fcos_r_50_fpn,pass,35
detectron2_fcos_r_50_fpn,pass,30

1 name accuracy graph_breaks
86 timm_regnet pass 0
87 timm_resnest pass 0
88 timm_vision_transformer pass 0
89 timm_vision_transformer_large pass_due_to_skip 0
90 timm_vovnet pass 0
91 torch_multimodal_clip pass 0
92 tts_angular pass 2

View File

@ -54,7 +54,7 @@ densenet121,pass,0
detectron2_fcos_r_50_fpn,pass,38
detectron2_fcos_r_50_fpn,pass,33

1 name accuracy graph_breaks
54 resnet152 pass 0
55 resnet18 pass 0
56 resnet50 pass 0
57 resnet50_quantized_qat pass 2
58 resnext50_32x4d pass 0
59 shufflenet_v2_x1_0 pass 0
60 soft_actor_critic pass 0

View File

@ -86,7 +86,7 @@ detectron2_fasterrcnn_r_50_fpn,eager_fail_to_run,0
detectron2_fcos_r_50_fpn,pass,36
detectron2_fcos_r_50_fpn,pass,31

1 name accuracy graph_breaks
86 timm_regnet pass 0
87 timm_resnest pass 0
88 timm_vision_transformer pass 0
89 timm_vision_transformer_large pass_due_to_skip 0
90 timm_vovnet pass 0
91 torch_multimodal_clip pass 0
92 tts_angular pass 2

View File

@ -86,7 +86,7 @@ detectron2_fasterrcnn_r_50_fpn,eager_fail_to_run,0
detectron2_fcos_r_50_fpn,pass,35
detectron2_fcos_r_50_fpn,pass,30

1 name accuracy graph_breaks
86 timm_efficientnet pass 0
87 timm_regnet pass 0
88 timm_resnest pass 0
89 timm_vision_transformer pass 0
90 timm_vision_transformer_large pass_due_to_skip 0
91 timm_vovnet pass 0
92 torch_multimodal_clip pass 0

View File

@ -86,7 +86,7 @@ detectron2_fasterrcnn_r_50_fpn,eager_fail_to_run,0
detectron2_fcos_r_50_fpn,pass,36
detectron2_fcos_r_50_fpn,pass,31

1 name accuracy graph_breaks
86 timm_efficientnet pass 0
87 timm_regnet pass 0
88 timm_resnest pass 0
89 timm_vision_transformer pass 0
90 timm_vision_transformer_large pass_due_to_skip 0
91 timm_vovnet pass 0
92 torch_multimodal_clip pass 0

View File

@ -143,6 +143,18 @@ def closure_adder(val):
return inner
class UserDefineSetAttr:
setup = False
def __setattr__(self, key, value):
assert torch.compiler.is_dynamo_compiling() or UserDefineSetAttr.setup
super().__setattr__(f"pfx_{key}", value)
def __getattr__(self, key):
assert torch.compiler.is_dynamo_compiling() or UserDefineSetAttr.setup
return self.__dict__[f"pfx_{key}"]
class MiscTests(torch._dynamo.test_case.TestCase):
def test_get_cache_entry(self):
def f(x):
@ -488,6 +500,34 @@ class MiscTests(torch._dynamo.test_case.TestCase):
cleanup_op("mylib::foo")
del lib
def test_user_defined_setattr1(self):
@torch.compile(backend="eager", fullgraph=True)
def fn(obj):
obj.y = obj.x + 1
obj = UserDefineSetAttr()
with patch.object(UserDefineSetAttr, "setup", True):
obj.x = torch.randn(8)
fn(obj)
with patch.object(UserDefineSetAttr, "setup", True):
self.assertEqual(obj.y, obj.x + 1)
self.assertEqual(obj.__dict__.keys(), {"pfx_x", "pfx_y"})
def test_user_defined_setattr2(self):
@torch.compile(backend="eager", fullgraph=True)
def fn(x):
obj = UserDefineSetAttr()
obj.x = x
obj.y = obj.x + 1
return obj
x = torch.randn(8)
obj = fn(x)
with patch.object(UserDefineSetAttr, "setup", True):
self.assertIs(obj.x, x)
self.assertEqual(obj.y, x + 1)
self.assertEqual(obj.__dict__.keys(), {"pfx_x", "pfx_y"})
def test_closure_recompiles(self):
cnt = CompileCounter()

View File

@ -17,7 +17,7 @@ from collections import namedtuple
from copy import deepcopy
from enum import Enum
from functools import wraps
from typing import Any, Iterator, List
from typing import Any, Dict, Iterator, List, Tuple
from unittest import mock
import numpy as np
@ -4453,6 +4453,95 @@ def forward(self, s0 : torch.SymInt, s1 : torch.SymInt, L_x_ : torch.Tensor):
with self.assertRaisesRegex(AssertionError, ""):
f_fail(torch.ones(6, 4))
def test_detectron2_instances_cat(self):
class Instances:
def __init__(self, image_size: Tuple[int, int], **kwargs: Any):
self._image_size = image_size
self._fields: Dict[str, Any] = {}
for k, v in kwargs.items():
self.set(k, v)
@property
def image_size(self) -> Tuple[int, int]:
return self._image_size
def __setattr__(self, name: str, val: Any) -> None:
if name.startswith("_"):
super().__setattr__(name, val)
else:
self.set(name, val)
def __getattr__(self, name: str) -> Any:
if name == "_fields" or name not in self._fields:
raise AttributeError(
f"Cannot find field '{name}' in the given Instances!"
)
return self._fields[name]
def __len__(self) -> int:
for v in self._fields.values():
# use __len__ because len() has to be int and is not friendly to tracing
return v.__len__()
raise NotImplementedError("Empty Instances does not support __len__!")
def set(self, name: str, value: Any) -> None:
# TODO(jansel): support catch_warnings
# with warnings.catch_warnings(record=True):
data_len = len(value)
if len(self._fields):
assert (
len(self) == data_len
), f"Adding a field of length {data_len} to a Instances of length {len(self)}"
self._fields[name] = value
def get(self, name: str) -> Any:
return self._fields[name]
@staticmethod
def cat(instance_lists: List["Instances"]) -> "Instances":
# TODO(jansel): support all isinstance generator
# assert all(isinstance(i, Instances) for i in instance_lists)
assert len(instance_lists) > 0
if len(instance_lists) == 1:
return instance_lists[0]
image_size = instance_lists[0].image_size
if not isinstance(
image_size, torch.Tensor
): # could be a tensor in tracing
for i in instance_lists[1:]:
assert i.image_size == image_size
ret = Instances(image_size)
for k in instance_lists[0]._fields.keys():
values = [i.get(k) for i in instance_lists]
v0 = values[0]
if isinstance(v0, torch.Tensor):
values = torch.cat(values, dim=0)
elif isinstance(v0, list):
values = list(itertools.chain(*values))
elif hasattr(type(v0), "cat"):
values = type(v0).cat(values)
else:
raise ValueError(
f"Unsupported type {type(v0)} for concatenation"
)
ret.set(k, values)
return ret
instances = [
Instances((16, 16), a=[torch.randn(16, 16)], b=[torch.randn(16, 16)])
for _ in range(3)
]
@torch.compile(backend="eager", fullgraph=True)
def fn(instances):
return instances[0].cat(instances)
actual = fn(instances)
expected = instances[0].cat(instances)
self.assertEqual(type(actual), type(expected))
self.assertEqual(actual.__dict__, expected.__dict__)
def test_super_in_staticmethod(self):
class A:
@staticmethod

View File

@ -10,6 +10,7 @@ from . import utils
from .bytecode_transformation import (
create_call_function,
create_call_method,
create_dup_top,
create_instruction,
create_load_attr,
@ -265,6 +266,12 @@ class PyCodegen:
self.tx.output.update_co_names(name)
return create_load_method(name)
def load_method(self, name):
self.append_output(self.create_load_method(name))
def call_method(self, nargs):
self.extend_output(create_call_method(nargs))
def create_load_attr(self, name) -> Instruction:
if name not in self.code_options["co_names"]:
self.code_options["co_names"] += (name,)
@ -322,6 +329,9 @@ class PyCodegen:
create_instruction("POP_TOP"),
]
def pop_top(self):
self.append_output(create_instruction("POP_TOP"))
def call_function(self, nargs: int, push_null: bool):
self.extend_output(create_call_function(nargs, push_null=push_null))

View File

@ -179,9 +179,9 @@ class SideEffects:
@staticmethod
def cls_supports_mutation_side_effects(cls):
return inspect.getattr_static(cls, "__setattr__", None) in (
object.__setattr__,
torch.nn.Module.__setattr__,
return (
inspect.getattr_static(cls, "__getattribute__", None)
is object.__getattribute__
)
def is_attribute_mutation(self, item):
@ -192,6 +192,11 @@ class SideEffects:
self.store_attr_mutations.get(item.mutable_local)
)
def has_pending_mutation_of_attr(self, item, name):
return self.is_attribute_mutation(
item
) and name in self.store_attr_mutations.get(item.mutable_local, ())
def is_modified(self, item):
if isinstance(item.mutable_local, AttributeMutationNew):
return True
@ -494,6 +499,19 @@ class SideEffects:
suffixes.append(
[create_instruction("DELETE_ATTR", argval=name)]
)
elif (
isinstance(var, variables.UserDefinedObjectVariable)
and var.needs_slow_setattr()
):
# __setattr__ is defined on this object, so call object.__setattr__ directly
cg.load_import_from("builtins", "object")
cg.load_method("__setattr__")
cg(var.mutable_local.source) # type: ignore[attr-defined]
cg(variables.ConstantVariable(name))
cg(value)
suffixes.append(
[*create_call_method(3), create_instruction("POP_TOP")]
)
else:
cg.tx.output.update_co_names(name)
cg(value)
@ -503,8 +521,8 @@ class SideEffects:
for _ in range(var.index):
cg.load_import_from(utils.__name__, "iter_next")
cg(var.mutable_local.source) # type: ignore[attr-defined]
cg.extend_output(create_call_function(1, True))
cg.append_output(create_instruction("POP_TOP"))
cg.call_function(1, True)
cg.pop_top()
else:
raise AssertionError(type(var))

View File

@ -2419,9 +2419,11 @@ class InliningInstructionTranslator(InstructionTranslatorBase):
code: types.CodeType = func.get_code()
if code.co_name in ("__setitem__", "__setattr__") and not (
args is not None
and len(args) > 0
and isinstance(args[0], variables.CustomizedDictVariable)
args
and isinstance(
args[0],
(variables.CustomizedDictVariable, variables.UserDefinedObjectVariable),
)
):
unimplemented(f"inline {code.co_name}")

View File

@ -26,7 +26,7 @@ from .dicts import (
DefaultDictVariable,
SetVariable,
)
from .distributed import BackwardHookVariable, DistributedVariable
from .distributed import BackwardHookVariable, DistributedVariable, PlacementVariable
from .functions import (
FunctoolsPartialVariable,
NestedUserFunctionVariable,
@ -131,6 +131,7 @@ __all__ = [
"NumpyNdarrayVariable",
"NumpyVariable",
"OptimizerVariable",
"PlacementVariable",
"PythonModuleVariable",
"RangeVariable",
"RemovableHandleVariable",

View File

@ -1550,14 +1550,13 @@ class BuiltinVariable(VariableTracker):
def call_setattr(
self, tx, obj: VariableTracker, name_var: VariableTracker, val: VariableTracker
):
from .distributed import PlacementVariable
if isinstance(
obj,
(
variables.DataClassVariable,
variables.CustomizedDictVariable,
PlacementVariable,
variables.PlacementVariable,
variables.UserDefinedObjectVariable,
),
):
return obj.call_method(tx, "__setattr__", [name_var, val], {})
@ -1602,7 +1601,7 @@ class BuiltinVariable(VariableTracker):
# Step 3 - drop the version counter - this is a step required to get
# .data setting to play correctly with the autograd engine.
# Esentially, dynamo is trying to faithful preserve the (absurd)
# Essentially, dynamo is trying to faithfully preserve the (absurd)
# behavior of .data= from eager mode
def _lower_version_count_by_1(x):
version = x._version

View File

@ -87,9 +87,7 @@ class BaseUserFunctionVariable(VariableTracker):
def call_function(
self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
) -> "VariableTracker":
return tx.inline_user_function_return(
self, list(self.self_args()) + list(args), kwargs
)
return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)
def call_hasattr(self, tx, name: str) -> VariableTracker:
result = False

View File

@ -25,7 +25,7 @@ from ..utils import (
)
from .base import VariableTracker
from .functions import NestedUserFunctionVariable, UserFunctionVariable
from .user_defined import UserDefinedObjectVariable
from .user_defined import is_standard_setattr, UserDefinedObjectVariable
class SuperVariable(VariableTracker):
@ -170,8 +170,12 @@ class SuperVariable(VariableTracker):
return super(variables.CustomizedDictVariable, self.objvar).call_method(
tx, "__setitem__", args, kwargs
)
else:
unimplemented(f"non-function or method super: {inner_fn}")
elif is_standard_setattr(inner_fn) and isinstance(
self.objvar, UserDefinedObjectVariable
):
return self.objvar.method_setattr_standard(tx, *args, **kwargs)
unimplemented(f"non-function or method super: {inner_fn}")
class UnknownVariable(VariableTracker):
@ -503,6 +507,8 @@ class AutogradFunctionContextVariable(UserDefinedObjectVariable):
args: "List[VariableTracker]",
kwargs: "Dict[str, VariableTracker]",
) -> "VariableTracker":
if name == "__setattr__":
return super().call_method(tx, name, args, kwargs)
if name != "save_for_backward":
unimplemented(f"autograd.Function context method: {name}")
if self.saved_tensors is None:
@ -604,8 +610,12 @@ class GetAttrVariable(VariableTracker):
# redirect to var_getattr on the original obj
if isinstance(obj, variables.UserDefinedObjectVariable):
obj._check_for_getattribute()
if key in obj.value.__dict__:
if (
key in obj.value.__dict__
or tx.output.side_effects.has_pending_mutation_of_attr(obj, key)
):
return obj.var_getattr(tx, key)
return super().call_method(tx, name, args, kwargs)

View File

@ -679,10 +679,12 @@ class UnspecializedNNModuleVariable(UserDefinedObjectVariable):
@staticmethod
@functools.lru_cache(None)
def _nn_module_method_ids():
# Allow __setattr__ to fall through to base class handler
supported = {torch.nn.Module.__setattr__}
return {
id(x.__code__)
for x in torch.nn.Module.__dict__.values()
if hasattr(x, "__code__")
if hasattr(x, "__code__") and x not in supported
}
def unpack_var_sequence(self, tx):

View File

@ -52,6 +52,13 @@ from .ctx_manager import GenericContextWrappingVariable, NullContextVariable
from .dicts import DefaultDictVariable
def is_standard_setattr(val):
return val in (
object.__setattr__,
torch.nn.Module.__setattr__,
)
class UserDefinedVariable(VariableTracker):
pass
@ -443,6 +450,10 @@ class UserDefinedClassVariable(UserDefinedVariable):
return super().const_getattr(tx, name)
class NO_SUCH_SUBOBJ:
pass
class UserDefinedObjectVariable(UserDefinedVariable):
"""
Mostly objects of defined type. Catch-all for something where we only know the type.
@ -540,6 +551,9 @@ class UserDefinedObjectVariable(UserDefinedVariable):
if method is object.__init__:
return ConstantVariable.create(None)
if is_standard_setattr(method):
return self.method_setattr_standard(tx, *args, **kwargs)
# [NOTE] OrderedDict, dict subtypes must always have source
# We cannot instantiate such subtypes in-graph due to builtin __new__
if method is collections.OrderedDict.keys:
@ -603,6 +617,22 @@ class UserDefinedObjectVariable(UserDefinedVariable):
return super().call_method(tx, name, args, kwargs)
def method_setattr_standard(self, tx, name, value):
try:
name = name.as_python_constant()
except NotImplementedError:
unimplemented(f"non-const setattr name: {name}")
if not tx.output.side_effects.is_attribute_mutation(self):
unimplemented(f"setattr({self}, {name}, ...)")
tx.output.side_effects.store_attr(self, name, value)
return variables.ConstantVariable(None)
def needs_slow_setattr(self):
return not is_standard_setattr(
inspect.getattr_static(self.value, "__setattr__", None)
)
def unpack_var_sequence(self, tx):
if (
self.source
@ -745,8 +775,8 @@ class UserDefinedObjectVariable(UserDefinedVariable):
self._check_for_getattribute()
getattr_fn = self._check_for_getattr()
class NO_SUCH_SUBOBJ:
pass
if tx.output.side_effects.has_pending_mutation_of_attr(self, name):
return tx.output.side_effects.load_attr(self, name)
try:
subobj = self._getattr_static(name)