Compare commits

...

1 Commits

Author SHA1 Message Date
7fd3a2cf43 [Dynamo] Support for proxying frozen dataclasses
ghstack-source-id: fb6556cd2f9424fe223147471fe95126441954d9
Pull Request resolved: https://github.com/pytorch/pytorch/pull/134846
2024-09-01 13:30:12 -07:00
7 changed files with 225 additions and 12 deletions

View File

@ -44,6 +44,7 @@ from torch._dynamo.testing import (
CompileCounter,
CompileCounterWithBackend,
expectedFailureDynamic,
requiresPy310,
same,
skipIfNotPy311,
unsupported,
@ -10142,6 +10143,113 @@ def ___make_guard_fn():
}
self.assertEqual(expected_fqn, gm.meta["dynamo_flat_name_to_original_fqn"])
def test_proxy_frozen_dataclass(self):
@dataclasses.dataclass(frozen=True)
class TestDataClass:
x: torch.Tensor
y: torch.Tensor
@allow_in_graph
def inner_fn(dc):
return dc.x + dc.y
def fn(x, y):
dc = TestDataClass(x, y)
return inner_fn(dc)
fn_opt = torch.compile(fullgraph=True)(fn)
inps = (torch.ones(2, 2), torch.ones(2, 2))
actual = fn_opt(*inps)
expected = fn(*inps)
self.assertEqual(actual, expected)
def test_reconstruct_frozen_dataclass(self):
@dataclasses.dataclass(frozen=True)
class TestDataClass:
x: torch.Tensor
y: torch.Tensor
def fn(x, y):
dc = TestDataClass(x, y)
torch._dynamo.graph_break()
return dc.x + dc.y
fn_opt = torch.compile()(fn)
inps = (torch.ones(2, 2), torch.ones(2, 2))
actual = fn_opt(*inps)
expected = fn(*inps)
def test_frozen_dataclass_default_value(self):
@dataclasses.dataclass(frozen=True)
class TestDataClass:
x: torch.Tensor
y: torch.Tensor
z: int = dataclasses.field(default=5)
a: int = 6
@allow_in_graph
def inner_fn(dc):
return dc.x + dc.y + dc.z + dc.a
def fn(x, y):
dc = TestDataClass(x, y)
return inner_fn(dc)
fn_opt = torch.compile(fullgraph=True)(fn)
inps = (torch.ones(2, 2), torch.ones(2, 2))
actual = fn_opt(*inps)
expected = fn(*inps)
self.assertEqual(actual, expected)
def test_frozen_dataclass_default_factory(self):
@dataclasses.dataclass(frozen=True)
class TestDataClass:
x: torch.Tensor
y: torch.Tensor
z: int = dataclasses.field(default_factory=list)
a: int = dataclasses.field(default_factory=lambda: [5])
@allow_in_graph
def inner_fn(dc):
return dc.x + dc.y + dc.a[0]
def fn(x, y):
dc = TestDataClass(x, y)
return inner_fn(dc)
fn_opt = torch.compile(fullgraph=True)(fn)
inps = (torch.ones(2, 2), torch.ones(2, 2))
actual = fn_opt(*inps)
expected = fn(*inps)
self.assertEqual(actual, expected)
@requiresPy310
def test_frozen_dataclass_kw_only(self):
@dataclasses.dataclass(frozen=True)
class TestDataClass:
x: torch.Tensor
y: torch.Tensor
z: int = dataclasses.field(kw_only=True)
a: int = dataclasses.field(kw_only=True)
@allow_in_graph
def inner_fn(dc):
return dc.x + dc.y + dc.a + dc.z
def fn(x, y):
dc = TestDataClass(x, y, z=5, a=2)
return inner_fn(dc)
fn_opt = torch.compile(fullgraph=True)(fn)
inps = (torch.ones(2, 2), torch.ones(2, 2))
actual = fn_opt(*inps)
expected = fn(*inps)
self.assertEqual(actual, expected)
def test_shape_env_no_recording(self):
main = ShapeEnv(should_record_events=False)

View File

@ -17,13 +17,14 @@ from .bytecode_transformation import (
from .codegen import PyCodegen
from .exc import unimplemented
from .source import GlobalSource, LocalSource, Source
from .utils import nn_module_new, object_new
from .utils import is_frozen_dataclass, nn_module_new, object_new
from .variables.base import (
is_side_effect_safe,
MutableLocalBase,
MutableLocalSource,
VariableTracker,
)
from .variables.user_defined import FrozenDataClassVariable
class MutableSideEffects(MutableLocalBase):
@ -285,6 +286,8 @@ class SideEffects:
variable_cls = variables.UnspecializedNNModuleVariable
elif issubclass(user_cls, MutableMapping):
variable_cls = variables.MutableMappingVariable
elif is_frozen_dataclass(user_cls):
variable_cls = FrozenDataClassVariable
else:
variable_cls = variables.UserDefinedObjectVariable

View File

@ -372,6 +372,13 @@ def skipIfPy312(fn):
return fn
def requiresPy310(fn):
if sys.version_info >= (3, 10):
return fn
else:
unittest.skip(fn)
# Controls tests generated in test/inductor/test_torchinductor_dynamic_shapes.py
# and test/dynamo/test_dynamic_shapes.py
def expectedFailureDynamic(fn):

View File

@ -30,6 +30,7 @@ import uuid
import warnings
import weakref
from contextlib import contextmanager
from dataclasses import is_dataclass
from functools import lru_cache
from types import MethodWrapperType
from typing import (
@ -2313,9 +2314,13 @@ def import_submodule(mod: types.ModuleType):
def object_has_getattribute(value: Any):
return class_has_getattribute(type(value))
def class_has_getattribute(cls: type):
try:
if isinstance(
inspect.getattr_static(type(value), "__getattribute__"),
inspect.getattr_static(cls, "__getattribute__"),
types.FunctionType,
):
return True
@ -2961,6 +2966,16 @@ def to_fake_tensor(t, fake_mode):
)
# NB: this works for both classes and instances
def is_frozen_dataclass(value):
return (
not object_has_getattribute(value)
and not class_has_getattribute(value)
and is_dataclass(value)
and value.__dataclass_params__.frozen
)
def get_first_attr(obj, *attrs):
"""
Return the first available attribute or throw an exception if none is present.

View File

@ -82,6 +82,7 @@ from ..utils import (
get_fake_value,
get_locals_to_steal,
get_static_address_type,
is_frozen_dataclass,
is_function_or_wrapper,
is_lru_cache_wrapped_function,
is_namedtuple,
@ -193,6 +194,7 @@ from .torch_function import (
TorchFunctionModeVariable,
)
from .user_defined import (
FrozenDataClassVariable,
KeyedJaggedTensorVariable,
MutableMappingVariable,
SourcelessGraphModuleVariable,
@ -1132,6 +1134,10 @@ class VariableBuilder:
elif issubclass(type(value), MutableMapping):
self.install_guards(GuardBuilder.TYPE_MATCH)
return MutableMappingVariable(value, source=self.source)
elif is_frozen_dataclass(value):
self.install_guards(GuardBuilder.TYPE_MATCH)
result = FrozenDataClassVariable.create(self.tx, value, source=self.source)
return self.tx.output.side_effects.track_object_existing(value, result)
else:
return self.wrap_user_defined(value)

View File

@ -671,16 +671,6 @@ def _call_hasattr_customobj(
)
class DataClassVariable(ConstDictVariable):
"""
This class doesn't appear to be used anywhere.
It used to be used to deal with transformers.file_utils.ModelOutput
from huggingface.
Keeping since we wish to support dataclasses in general in the future
"""
class CustomizedDictVariable(ConstDictVariable):
@staticmethod
def is_matching_cls_hf(cls):

View File

@ -2,6 +2,7 @@
import collections
import contextlib
import dataclasses
import enum
import functools
import inspect
@ -39,6 +40,7 @@ from ..utils import (
check_constant_args,
get_custom_getattr,
has_torch_function,
is_frozen_dataclass,
is_namedtuple_cls,
is_utils_checkpoint,
is_wrapper_or_member_descriptor,
@ -452,6 +454,40 @@ class UserDefinedClassVariable(UserDefinedVariable):
assert all(x is not None for x in items)
return variables.NamedTupleVariable(items, self.value)
elif is_frozen_dataclass(self.value) and self.is_standard_new():
from .builder import SourcelessBuilder
fields = dataclasses.fields(self.value)
items = list(args)
items.extend([None] * (len(fields) - len(items)))
default_kwargs = {}
for field, var_tracker in zip(fields, items):
if var_tracker is None:
if field.name in kwargs:
var_tracker = kwargs[field.name]
else:
if not field.init:
continue
if field.default is not dataclasses.MISSING:
var_tracker = SourcelessBuilder.create(tx, field.default)
elif field.default_factory is not dataclasses.MISSING:
factory_fn = SourcelessBuilder.create(
tx, field.default_factory
)
var_tracker = factory_fn.call_function(tx, [], {})
else:
# if we are subclass, the constructor could possibly
# be missing args
continue
default_kwargs[field.name] = var_tracker
kwargs.update(default_kwargs)
var = tx.output.side_effects.track_object_new_from_user_defined_class(self)
var.call_method(tx, "__init__", args, kwargs)
return var
elif (
self.is_standard_new()
and SideEffects.cls_supports_mutation_side_effects(self.value)
@ -1175,6 +1211,54 @@ class UserDefinedObjectVariable(UserDefinedVariable):
)(collections.OrderedDict.__getitem__(self.value, key.as_python_constant()))
class FrozenDataClassVariable(UserDefinedObjectVariable):
@staticmethod
def create(tx, value, source):
from dataclasses import fields
assert is_frozen_dataclass(value)
from .builder import VariableBuilder
field_map = {}
for field in fields(value):
if hasattr(value, field.name):
field_map[field.name] = VariableBuilder(
tx, AttrSource(source, field.name)
)(getattr(value, field.name))
return FrozenDataClassVariable(value, fields=field_map, source=source)
def __init__(self, value, fields=None, **kwargs) -> None:
super().__init__(value, **kwargs)
if fields is None:
fields = {}
self.fields = fields
def as_proxy(self):
from dataclasses import fields
args = []
kwargs = {}
for field in fields(self.value):
proxy = self.fields[field.name].as_proxy()
if hasattr(field, "kw_only") and field.kw_only:
kwargs[field.name] = proxy
else:
args.append(proxy)
return self.python_type()(*args, **kwargs)
# NB: This is called during __init__ for a frozen dataclass
# use this to accumulate the most up-to-date field values
def method_setattr_standard(self, tx: "InstructionTranslator", name, value):
self.fields[name.as_python_constant()] = value
return super().method_setattr_standard(tx, name, value)
def __repr__(self) -> str:
return f"{self.__class__.__name__}({self.value_type.__name__})"
class SourcelessGraphModuleVariable(UserDefinedObjectVariable):
def __init__(
self,