Compare commits

...

1 Commits

Author SHA1 Message Date
022649a10e [Dynamo] Automatically in-graph traceable tensor subclass ctors
ghstack-source-id: 54e4f586180aea0d38d0e6e32e553ace3e9a3469
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135151
2024-09-05 16:42:55 -07:00
4 changed files with 109 additions and 2 deletions

View File

@ -30,6 +30,7 @@ from torch.testing._internal.common_utils import (
)
from torch.testing._internal.inductor_utils import HAS_CUDA
from torch.testing._internal.two_tensor import TwoTensor
from torch.utils._python_dispatch import return_and_correct_aliasing
def traceable_subclass(c):
@ -1427,6 +1428,99 @@ s1 > 3""",
lambda: torch.compile(lambda x: x * x)(x),
)
def test_subclass_constructor_proxying(self):
import dataclasses
from collections import namedtuple
from typing import Any
@dataclasses.dataclass(frozen=True)
class SubclassTensorArgs:
original_shape: torch.Size
device: torch.device
inner_meta: Any
SubclassTensorArgs2 = namedtuple(
"SubclassTensorArgs2",
[
"original_shape",
"device",
"inner_meta",
],
)
class SubclassTensor(torch.Tensor):
@staticmethod
def __new__(cls, a, meta):
shape = a.shape
kwargs = {}
kwargs["strides"] = a.stride()
kwargs["storage_offset"] = a.storage_offset()
kwargs["device"] = a.device
kwargs["layout"] = a.layout
kwargs["requires_grad"] = a.requires_grad
kwargs["dtype"] = a.dtype
out = torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs)
return out
def __init__(self, a, meta):
self.a = a
self.meta = meta
def __repr__(self):
a_repr = repr(self.a)
return f"SubclassTensor({a_repr})"
def __tensor_flatten__(self):
return ["a"], self.meta
@staticmethod
def __tensor_unflatten__(inner_tensors, meta, _, __):
a = inner_tensors["a"]
return SubclassTensor(a, meta)
@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs):
if kwargs is None:
kwargs = {}
args_a = pytree.tree_map(
lambda x: x.a if isinstance(x, SubclassTensor) else x, args
)
kwargs_a = pytree.tree_map(
lambda x: x.a if isinstance(x, SubclassTensor) else x, kwargs
)
out_a = func(*args_a, **kwargs_a)
out = pytree.tree_map(
lambda x: SubclassTensor(
x, SubclassTensorArgs2(x.shape, x.device, None)
)
if isinstance(x, torch.Tensor)
else x,
out_a,
)
return return_and_correct_aliasing(func, args, kwargs, out)
@torch.compile(fullgraph=True)
def f1(x):
meta = SubclassTensorArgs(
x.shape, x.device, SubclassTensorArgs(x.shape, x.device, None)
)
out = SubclassTensor(x, meta)
return out * out
x = torch.randn(3, 3)
f1(x)
@torch.compile(fullgraph=True)
def f1(x):
meta = SubclassTensorArgs2(
x.shape, x.device, SubclassTensorArgs2(x.shape, x.device, None)
)
out = SubclassTensor(x, meta)
return out * out
x = torch.randn(3, 3)
f1(x)
def test_torch_function_subclass_survives_into_aot_autograd(self):
# If you have a tensor subclass that relies on dispatch into the same op
# without unwrapping and calling torch._C.DisableTorchFunctionSubclass(),

View File

@ -15,6 +15,7 @@ import torch.onnx.operators
from torch._guards import TracingContext
from torch._logging import warning_once
from torch._streambase import _StreamBase
from torch.utils._python_dispatch import is_traceable_wrapper_subclass_type
from .. import config, polyfills, variables
from ..codegen import PyCodegen
@ -1028,6 +1029,9 @@ Either create the tensor outside the compiled region, or do not set the tensor t
if data.source:
return cls._nn_param_via_prefix_insert(tx, data, requires_grad)
if is_traceable_wrapper_subclass_type(data.class_type):
unimplemented("Parameter constructor with tensor subclass NYI")
if not can_convert_to_tracable_parameter():
unimplemented("Workaround for issues with nn_parameter construction")

View File

@ -16,6 +16,7 @@ from typing import Dict, Generic, List, TYPE_CHECKING
import torch._dynamo.config
import torch.nn
from torch._guards import TracingContext
from torch.utils._python_dispatch import is_traceable_wrapper_subclass_type
from .. import polyfills, variables
from ..bytecode_transformation import create_call_function
@ -512,7 +513,10 @@ class UserDefinedClassVariable(UserDefinedVariable):
user_cls_source=self.source,
mutable_local=MutableLocal(),
)
elif self.value in self._in_graph_classes():
elif (
self.value in self._in_graph_classes()
or is_traceable_wrapper_subclass_type(self.value)
):
# torch.LongTensor cannot accept a list of FakeTensors.
# So we stack the list of FakeTensors instead.
if (

View File

@ -3,7 +3,7 @@ import contextlib
import warnings
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Set, Union, Protocol, Tuple, Sequence, overload, Deque
from typing import Any, Dict, List, Optional, Set, Union, Protocol, Tuple, Sequence, overload, Deque, Type
from typing_extensions import TypeGuard
from collections import deque
@ -391,6 +391,11 @@ def is_traceable_wrapper_subclass(t: object) -> TypeGuard[TensorWithFlatten]:
and hasattr(t, "__tensor_unflatten__")
)
def is_traceable_wrapper_subclass_type(t: Type) -> TypeGuard[Type[TensorWithFlatten]]:
"""Same as above, but takes a type argument instead of an instance."""
return (issubclass(t, torch.Tensor) and t != torch.Tensor
and hasattr(t, "__tensor_flatten__") and hasattr(t, "__tensor_unflatten__"))
def transform_subclass(t, callback, outer_size=None, outer_stride=None):
"""