mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Implement serializable getattr support for tensor subclasses (#145772)
builtins.getattr is not serializable, so we replace it with a custom op that has more refined schema. Differential Revision: [D68899421](https://our.internmc.facebook.com/intern/diff/D68899421) Pull Request resolved: https://github.com/pytorch/pytorch/pull/145772 Approved by: https://github.com/bdhirsh
This commit is contained in:
committed by
PyTorch MergeBot
parent
d5d3bdb55a
commit
ebd992724f
@ -1141,6 +1141,7 @@ API Reference
|
||||
.. autoclass:: torch.export.graph_signature.CustomObjArgument
|
||||
|
||||
.. py:module:: torch.export.dynamic_shapes
|
||||
.. py:module:: torch.export.custom_ops
|
||||
|
||||
.. automodule:: torch.export.unflatten
|
||||
:members:
|
||||
|
@ -1930,8 +1930,6 @@ graph():
|
||||
|
||||
@testing.expectedFailureLegacyExportNonStrict # Old export doesn't work with subclasses
|
||||
@testing.expectedFailureLegacyExportStrict # Old export doesn't work with subclasses
|
||||
@testing.expectedFailureSerDerNonStrict # builtins.getattr is not supported T211130564
|
||||
@testing.expectedFailureSerDer # builtins.getattr is not supported T211130564
|
||||
def test_subclass_nested_attr_access(self):
|
||||
class Foo(torch.nn.Module):
|
||||
def __init__(self):
|
||||
@ -1958,6 +1956,7 @@ graph():
|
||||
ref_x = torch.randn(3, 4)
|
||||
ref_out = m(ref_x)
|
||||
ep_training = torch.export.export_for_training(m, (ref_x,), strict=False)
|
||||
self.assertTrue(torch.allclose(ep_training.module()(ref_x), ref_out))
|
||||
self.assertExpectedInline(
|
||||
str(ep_training.graph).strip(),
|
||||
"""\
|
||||
@ -1970,9 +1969,9 @@ graph():
|
||||
%add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul, %p_p2), kwargs = {})
|
||||
%add_1 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%add, %b_b1), kwargs = {})
|
||||
%sum_1 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%add_1,), kwargs = {})
|
||||
%getattr_65 : [num_users=1] = call_function[target=builtins.getattr](args = (%sum_1, a), kwargs = {})
|
||||
%getattr_70 : [num_users=1] = call_function[target=builtins.getattr](args = (%getattr_65, b), kwargs = {})
|
||||
%add_2 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%x, %getattr_70), kwargs = {})
|
||||
%access_subclass_inner_tensor_default_64 : [num_users=1] = call_function[target=torch.ops.export.access_subclass_inner_tensor.default](args = (%sum_1, a), kwargs = {})
|
||||
%access_subclass_inner_tensor_default_69 : [num_users=1] = call_function[target=torch.ops.export.access_subclass_inner_tensor.default](args = (%access_subclass_inner_tensor_default_64, b), kwargs = {})
|
||||
%add_2 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%x, %access_subclass_inner_tensor_default_69), kwargs = {})
|
||||
return (add_2,)""",
|
||||
)
|
||||
ep = export(m, (ref_x,))
|
||||
@ -1980,8 +1979,6 @@ graph():
|
||||
|
||||
@testing.expectedFailureLegacyExportNonStrict # Old export doesn't work with subclasses
|
||||
@testing.expectedFailureLegacyExportStrict # Old export doesn't work with subclasses
|
||||
@testing.expectedFailureSerDerNonStrict # builtins.getattr is not supported T211130564
|
||||
@testing.expectedFailureSerDer # builtins.getattr is not supported T211130564
|
||||
def test_subclass_nested_attr_access_submodule(self):
|
||||
class Bar(torch.nn.Module):
|
||||
def __init__(self):
|
||||
@ -2028,9 +2025,9 @@ graph():
|
||||
%add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul, %p_bar_p2), kwargs = {})
|
||||
%add_1 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%add, %b_bar_b1), kwargs = {})
|
||||
%sum_1 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%add_1,), kwargs = {})
|
||||
%getattr_65 : [num_users=1] = call_function[target=builtins.getattr](args = (%sum_1, a), kwargs = {})
|
||||
%getattr_70 : [num_users=1] = call_function[target=builtins.getattr](args = (%getattr_65, b), kwargs = {})
|
||||
%add_2 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%x, %getattr_70), kwargs = {})
|
||||
%access_subclass_inner_tensor_default_64 : [num_users=1] = call_function[target=torch.ops.export.access_subclass_inner_tensor.default](args = (%sum_1, a), kwargs = {})
|
||||
%access_subclass_inner_tensor_default_69 : [num_users=1] = call_function[target=torch.ops.export.access_subclass_inner_tensor.default](args = (%access_subclass_inner_tensor_default_64, b), kwargs = {})
|
||||
%add_2 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%x, %access_subclass_inner_tensor_default_69), kwargs = {})
|
||||
return (add_2,)""",
|
||||
)
|
||||
ep = export(m, (ref_x,))
|
||||
@ -2038,8 +2035,6 @@ graph():
|
||||
|
||||
@testing.expectedFailureLegacyExportNonStrict # Old export doesn't work with subclasses
|
||||
@testing.expectedFailureLegacyExportStrict # Old export doesn't work with subclasses
|
||||
@testing.expectedFailureSerDerNonStrict # builtins.getattr is not supported T211130564
|
||||
@testing.expectedFailureSerDer # builtins.getattr is not supported T211130564
|
||||
def test_subclass_nested_attr_access_const_metadata(self):
|
||||
class Foo(torch.nn.Module):
|
||||
def __init__(self):
|
||||
@ -2070,9 +2065,9 @@ graph():
|
||||
%mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%p_p1, 2), kwargs = {})
|
||||
%add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul, %p_p2), kwargs = {})
|
||||
%add_1 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%add, 4), kwargs = {})
|
||||
%getattr_22 : [num_users=1] = call_function[target=builtins.getattr](args = (%add_1, elem), kwargs = {})
|
||||
%getattr_27 : [num_users=1] = call_function[target=builtins.getattr](args = (%getattr_22, elem), kwargs = {})
|
||||
%add_2 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%x, %getattr_27), kwargs = {})
|
||||
%access_subclass_inner_tensor_default_10 : [num_users=1] = call_function[target=torch.ops.export.access_subclass_inner_tensor.default](args = (%add_1, elem), kwargs = {})
|
||||
%access_subclass_inner_tensor_default_13 : [num_users=1] = call_function[target=torch.ops.export.access_subclass_inner_tensor.default](args = (%access_subclass_inner_tensor_default_10, elem), kwargs = {})
|
||||
%add_2 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%x, %access_subclass_inner_tensor_default_13), kwargs = {})
|
||||
return (add_2,)""",
|
||||
)
|
||||
ep = export(m, (ref_x,))
|
||||
@ -2080,8 +2075,6 @@ graph():
|
||||
|
||||
@testing.expectedFailureLegacyExportNonStrict # Old export doesn't work with subclasses
|
||||
@testing.expectedFailureLegacyExportStrict # Old export doesn't work with subclasses
|
||||
@testing.expectedFailureSerDerNonStrict # builtins.getattr is not supported T211130564
|
||||
@testing.expectedFailureSerDer # builtins.getattr is not supported T211130564
|
||||
def test_subclass_nested_attr_access_const_metadata_not_top_level(self):
|
||||
class Foo(torch.nn.Module):
|
||||
def __init__(self):
|
||||
@ -2122,8 +2115,6 @@ graph():
|
||||
|
||||
@testing.expectedFailureLegacyExportNonStrict # Old export doesn't work with subclasses
|
||||
@testing.expectedFailureLegacyExportStrict # Old export doesn't work with subclasses
|
||||
@testing.expectedFailureSerDerNonStrict # builtins.getattr is not supported T211130564
|
||||
@testing.expectedFailureSerDer # builtins.getattr is not supported T211130564
|
||||
def test_subclass_nested_attr_access_const_metadata_not_top_level(self):
|
||||
class Foo(torch.nn.Module):
|
||||
def __init__(self):
|
||||
@ -2154,13 +2145,13 @@ graph():
|
||||
%x : [num_users=1] = placeholder[target=x]
|
||||
%mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%p_p1, 2), kwargs = {})
|
||||
%add : [num_users=2] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul, %p_p2), kwargs = {})
|
||||
%getattr_33 : [num_users=1] = call_function[target=builtins.getattr](args = (%add, a), kwargs = {})
|
||||
%getattr_38 : [num_users=1] = call_function[target=builtins.getattr](args = (%getattr_33, elem), kwargs = {})
|
||||
%add_1 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%add, %getattr_38), kwargs = {})
|
||||
%access_subclass_inner_tensor_default_18 : [num_users=1] = call_function[target=torch.ops.export.access_subclass_inner_tensor.default](args = (%add, a), kwargs = {})
|
||||
%access_subclass_inner_tensor_default_21 : [num_users=1] = call_function[target=torch.ops.export.access_subclass_inner_tensor.default](args = (%access_subclass_inner_tensor_default_18, elem), kwargs = {})
|
||||
%add_1 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%add, %access_subclass_inner_tensor_default_21), kwargs = {})
|
||||
%add_2 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_1, 4), kwargs = {})
|
||||
%getattr_45 : [num_users=1] = call_function[target=builtins.getattr](args = (%add_2, a), kwargs = {})
|
||||
%getattr_50 : [num_users=1] = call_function[target=builtins.getattr](args = (%getattr_45, elem), kwargs = {})
|
||||
%add_3 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%x, %getattr_50), kwargs = {})
|
||||
%access_subclass_inner_tensor_default_25 : [num_users=1] = call_function[target=torch.ops.export.access_subclass_inner_tensor.default](args = (%add_2, a), kwargs = {})
|
||||
%access_subclass_inner_tensor_default_28 : [num_users=1] = call_function[target=torch.ops.export.access_subclass_inner_tensor.default](args = (%access_subclass_inner_tensor_default_25, elem), kwargs = {})
|
||||
%add_3 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%x, %access_subclass_inner_tensor_default_28), kwargs = {})
|
||||
return (add_3,)""",
|
||||
)
|
||||
ep = export(m, (ref_x,))
|
||||
@ -2168,8 +2159,6 @@ graph():
|
||||
|
||||
@testing.expectedFailureLegacyExportNonStrict # Old export doesn't work with subclasses
|
||||
@testing.expectedFailureLegacyExportStrict # Old export doesn't work with subclasses
|
||||
@testing.expectedFailureSerDerNonStrict # builtins.getattr is not supported T211130564
|
||||
@testing.expectedFailureSerDer # builtins.getattr is not supported T211130564
|
||||
def test_subclass_nested_attr_access_complicated_metadata(self):
|
||||
class Foo(torch.nn.Module):
|
||||
def __init__(self):
|
||||
@ -2199,9 +2188,9 @@ graph():
|
||||
%mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%p_p1, 2), kwargs = {})
|
||||
%add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%x, %mul), kwargs = {})
|
||||
%add_1 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%add, %p_p2), kwargs = {})
|
||||
%getattr_21 : [num_users=1] = call_function[target=builtins.getattr](args = (%add_1, elem), kwargs = {})
|
||||
%getattr_26 : [num_users=1] = call_function[target=builtins.getattr](args = (%getattr_21, elem), kwargs = {})
|
||||
%add_2 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%getattr_26, 4), kwargs = {})
|
||||
%access_subclass_inner_tensor_default_10 : [num_users=1] = call_function[target=torch.ops.export.access_subclass_inner_tensor.default](args = (%add_1, elem), kwargs = {})
|
||||
%access_subclass_inner_tensor_default_13 : [num_users=1] = call_function[target=torch.ops.export.access_subclass_inner_tensor.default](args = (%access_subclass_inner_tensor_default_10, elem), kwargs = {})
|
||||
%add_2 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%access_subclass_inner_tensor_default_13, 4), kwargs = {})
|
||||
return (add_2,)""",
|
||||
)
|
||||
ep = export(m, (ref_x,))
|
||||
|
@ -1,5 +1,4 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import builtins
|
||||
import inspect
|
||||
import math
|
||||
import operator
|
||||
@ -140,7 +139,6 @@ class Verifier(metaclass=_VerifierMeta):
|
||||
math.floor,
|
||||
math.trunc,
|
||||
round,
|
||||
builtins.getattr,
|
||||
]
|
||||
|
||||
def allowed_op_types(self) -> tuple[type[Any], ...]:
|
||||
|
@ -64,6 +64,8 @@ __all__ = [
|
||||
"UnflattenedModule",
|
||||
]
|
||||
|
||||
# To make sure export specific custom ops are loaded
|
||||
import torch.export.custom_ops
|
||||
|
||||
from .decomp_utils import CustomDecompTable
|
||||
from .dynamic_shapes import Constraint, Dim, dims, ShapesCollection
|
||||
|
@ -1,13 +1,11 @@
|
||||
# mypy: allow-untyped-decorators
|
||||
# mypy: allow-untyped-defs
|
||||
import builtins
|
||||
import dataclasses
|
||||
import functools
|
||||
import inspect
|
||||
import logging
|
||||
import re
|
||||
import time
|
||||
import types
|
||||
import warnings
|
||||
from contextlib import contextmanager, nullcontext
|
||||
from typing import Any, Callable, Optional, Union
|
||||
@ -1514,8 +1512,7 @@ def _export_to_aten_ir_make_fx(
|
||||
out = original_getattr(self, attr)
|
||||
if attr in attrs_to_proxy:
|
||||
if torch._C._is_torch_function_mode_enabled():
|
||||
# If it is a static function or method, we should always inline
|
||||
if not isinstance(out, (types.FunctionType, types.MethodType)):
|
||||
if isinstance(out, torch.Tensor):
|
||||
# When we get here there is no guarantee that we will hit the
|
||||
# PreDispatchTorchFunctionMode, so we manually peak into the torch
|
||||
# function mode list and tweak the PreDispatchTorchFunctionMode.
|
||||
@ -1532,7 +1529,7 @@ def _export_to_aten_ir_make_fx(
|
||||
proxy = get_proxy_slot(self, tracer).proxy
|
||||
inner_proxy = tracer.create_proxy(
|
||||
"call_function",
|
||||
builtins.getattr,
|
||||
torch.ops.export.access_subclass_inner_tensor.default,
|
||||
(proxy, attr),
|
||||
{},
|
||||
)
|
||||
@ -1627,7 +1624,7 @@ def _export_to_aten_ir_make_fx(
|
||||
# from subclass tensors if we carefully rewrite track_tensor_tree
|
||||
# in a way that it doesn't do any tensor methods.
|
||||
torch.ops.aten.detach.default,
|
||||
builtins.getattr,
|
||||
torch.ops.export.access_subclass_inner_tensor.default,
|
||||
):
|
||||
return False
|
||||
return True
|
||||
|
23
torch/export/custom_ops.py
Normal file
23
torch/export/custom_ops.py
Normal file
@ -0,0 +1,23 @@
|
||||
import torch
|
||||
|
||||
|
||||
lib = torch.library.Library("export", "FRAGMENT") # noqa: TOR901
|
||||
|
||||
lib.define(
|
||||
"access_subclass_inner_tensor(Tensor src_subclass_tensor, str attr) -> Tensor"
|
||||
)
|
||||
|
||||
|
||||
@torch.library.impl(lib, "access_subclass_inner_tensor", "Autograd")
|
||||
def _access_subclass_inner_tensor(
|
||||
src_subclass_tensor: torch.Tensor, attr: str
|
||||
) -> torch.Tensor:
|
||||
from torch.utils._python_dispatch import is_traceable_wrapper_subclass
|
||||
|
||||
assert is_traceable_wrapper_subclass(src_subclass_tensor)
|
||||
val = getattr(src_subclass_tensor, attr, None)
|
||||
if val is None or not isinstance(val, torch.Tensor):
|
||||
raise RuntimeError(
|
||||
f"Attribute {attr} is not a tensor or doesn't exist in {src_subclass_tensor}"
|
||||
)
|
||||
return val
|
Reference in New Issue
Block a user