Implement serializable getattr support for tensor subclasses (#145772)

builtins.getattr is not serializable, so we replace it with a custom op that has more refined schema.

Differential Revision: [D68899421](https://our.internmc.facebook.com/intern/diff/D68899421)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/145772
Approved by: https://github.com/bdhirsh
This commit is contained in:
Tugsbayasgalan Manlaibaatar
2025-02-09 17:36:40 -08:00
committed by PyTorch MergeBot
parent d5d3bdb55a
commit ebd992724f
6 changed files with 48 additions and 38 deletions

View File

@ -1141,6 +1141,7 @@ API Reference
.. autoclass:: torch.export.graph_signature.CustomObjArgument
.. py:module:: torch.export.dynamic_shapes
.. py:module:: torch.export.custom_ops
.. automodule:: torch.export.unflatten
:members:

View File

@ -1930,8 +1930,6 @@ graph():
@testing.expectedFailureLegacyExportNonStrict # Old export doesn't work with subclasses
@testing.expectedFailureLegacyExportStrict # Old export doesn't work with subclasses
@testing.expectedFailureSerDerNonStrict # builtins.getattr is not supported T211130564
@testing.expectedFailureSerDer # builtins.getattr is not supported T211130564
def test_subclass_nested_attr_access(self):
class Foo(torch.nn.Module):
def __init__(self):
@ -1958,6 +1956,7 @@ graph():
ref_x = torch.randn(3, 4)
ref_out = m(ref_x)
ep_training = torch.export.export_for_training(m, (ref_x,), strict=False)
self.assertTrue(torch.allclose(ep_training.module()(ref_x), ref_out))
self.assertExpectedInline(
str(ep_training.graph).strip(),
"""\
@ -1970,9 +1969,9 @@ graph():
%add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul, %p_p2), kwargs = {})
%add_1 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%add, %b_b1), kwargs = {})
%sum_1 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%add_1,), kwargs = {})
%getattr_65 : [num_users=1] = call_function[target=builtins.getattr](args = (%sum_1, a), kwargs = {})
%getattr_70 : [num_users=1] = call_function[target=builtins.getattr](args = (%getattr_65, b), kwargs = {})
%add_2 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%x, %getattr_70), kwargs = {})
%access_subclass_inner_tensor_default_64 : [num_users=1] = call_function[target=torch.ops.export.access_subclass_inner_tensor.default](args = (%sum_1, a), kwargs = {})
%access_subclass_inner_tensor_default_69 : [num_users=1] = call_function[target=torch.ops.export.access_subclass_inner_tensor.default](args = (%access_subclass_inner_tensor_default_64, b), kwargs = {})
%add_2 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%x, %access_subclass_inner_tensor_default_69), kwargs = {})
return (add_2,)""",
)
ep = export(m, (ref_x,))
@ -1980,8 +1979,6 @@ graph():
@testing.expectedFailureLegacyExportNonStrict # Old export doesn't work with subclasses
@testing.expectedFailureLegacyExportStrict # Old export doesn't work with subclasses
@testing.expectedFailureSerDerNonStrict # builtins.getattr is not supported T211130564
@testing.expectedFailureSerDer # builtins.getattr is not supported T211130564
def test_subclass_nested_attr_access_submodule(self):
class Bar(torch.nn.Module):
def __init__(self):
@ -2028,9 +2025,9 @@ graph():
%add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul, %p_bar_p2), kwargs = {})
%add_1 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%add, %b_bar_b1), kwargs = {})
%sum_1 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%add_1,), kwargs = {})
%getattr_65 : [num_users=1] = call_function[target=builtins.getattr](args = (%sum_1, a), kwargs = {})
%getattr_70 : [num_users=1] = call_function[target=builtins.getattr](args = (%getattr_65, b), kwargs = {})
%add_2 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%x, %getattr_70), kwargs = {})
%access_subclass_inner_tensor_default_64 : [num_users=1] = call_function[target=torch.ops.export.access_subclass_inner_tensor.default](args = (%sum_1, a), kwargs = {})
%access_subclass_inner_tensor_default_69 : [num_users=1] = call_function[target=torch.ops.export.access_subclass_inner_tensor.default](args = (%access_subclass_inner_tensor_default_64, b), kwargs = {})
%add_2 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%x, %access_subclass_inner_tensor_default_69), kwargs = {})
return (add_2,)""",
)
ep = export(m, (ref_x,))
@ -2038,8 +2035,6 @@ graph():
@testing.expectedFailureLegacyExportNonStrict # Old export doesn't work with subclasses
@testing.expectedFailureLegacyExportStrict # Old export doesn't work with subclasses
@testing.expectedFailureSerDerNonStrict # builtins.getattr is not supported T211130564
@testing.expectedFailureSerDer # builtins.getattr is not supported T211130564
def test_subclass_nested_attr_access_const_metadata(self):
class Foo(torch.nn.Module):
def __init__(self):
@ -2070,9 +2065,9 @@ graph():
%mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%p_p1, 2), kwargs = {})
%add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul, %p_p2), kwargs = {})
%add_1 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%add, 4), kwargs = {})
%getattr_22 : [num_users=1] = call_function[target=builtins.getattr](args = (%add_1, elem), kwargs = {})
%getattr_27 : [num_users=1] = call_function[target=builtins.getattr](args = (%getattr_22, elem), kwargs = {})
%add_2 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%x, %getattr_27), kwargs = {})
%access_subclass_inner_tensor_default_10 : [num_users=1] = call_function[target=torch.ops.export.access_subclass_inner_tensor.default](args = (%add_1, elem), kwargs = {})
%access_subclass_inner_tensor_default_13 : [num_users=1] = call_function[target=torch.ops.export.access_subclass_inner_tensor.default](args = (%access_subclass_inner_tensor_default_10, elem), kwargs = {})
%add_2 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%x, %access_subclass_inner_tensor_default_13), kwargs = {})
return (add_2,)""",
)
ep = export(m, (ref_x,))
@ -2080,8 +2075,6 @@ graph():
@testing.expectedFailureLegacyExportNonStrict # Old export doesn't work with subclasses
@testing.expectedFailureLegacyExportStrict # Old export doesn't work with subclasses
@testing.expectedFailureSerDerNonStrict # builtins.getattr is not supported T211130564
@testing.expectedFailureSerDer # builtins.getattr is not supported T211130564
def test_subclass_nested_attr_access_const_metadata_not_top_level(self):
class Foo(torch.nn.Module):
def __init__(self):
@ -2122,8 +2115,6 @@ graph():
@testing.expectedFailureLegacyExportNonStrict # Old export doesn't work with subclasses
@testing.expectedFailureLegacyExportStrict # Old export doesn't work with subclasses
@testing.expectedFailureSerDerNonStrict # builtins.getattr is not supported T211130564
@testing.expectedFailureSerDer # builtins.getattr is not supported T211130564
def test_subclass_nested_attr_access_const_metadata_not_top_level(self):
class Foo(torch.nn.Module):
def __init__(self):
@ -2154,13 +2145,13 @@ graph():
%x : [num_users=1] = placeholder[target=x]
%mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%p_p1, 2), kwargs = {})
%add : [num_users=2] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul, %p_p2), kwargs = {})
%getattr_33 : [num_users=1] = call_function[target=builtins.getattr](args = (%add, a), kwargs = {})
%getattr_38 : [num_users=1] = call_function[target=builtins.getattr](args = (%getattr_33, elem), kwargs = {})
%add_1 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%add, %getattr_38), kwargs = {})
%access_subclass_inner_tensor_default_18 : [num_users=1] = call_function[target=torch.ops.export.access_subclass_inner_tensor.default](args = (%add, a), kwargs = {})
%access_subclass_inner_tensor_default_21 : [num_users=1] = call_function[target=torch.ops.export.access_subclass_inner_tensor.default](args = (%access_subclass_inner_tensor_default_18, elem), kwargs = {})
%add_1 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%add, %access_subclass_inner_tensor_default_21), kwargs = {})
%add_2 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_1, 4), kwargs = {})
%getattr_45 : [num_users=1] = call_function[target=builtins.getattr](args = (%add_2, a), kwargs = {})
%getattr_50 : [num_users=1] = call_function[target=builtins.getattr](args = (%getattr_45, elem), kwargs = {})
%add_3 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%x, %getattr_50), kwargs = {})
%access_subclass_inner_tensor_default_25 : [num_users=1] = call_function[target=torch.ops.export.access_subclass_inner_tensor.default](args = (%add_2, a), kwargs = {})
%access_subclass_inner_tensor_default_28 : [num_users=1] = call_function[target=torch.ops.export.access_subclass_inner_tensor.default](args = (%access_subclass_inner_tensor_default_25, elem), kwargs = {})
%add_3 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%x, %access_subclass_inner_tensor_default_28), kwargs = {})
return (add_3,)""",
)
ep = export(m, (ref_x,))
@ -2168,8 +2159,6 @@ graph():
@testing.expectedFailureLegacyExportNonStrict # Old export doesn't work with subclasses
@testing.expectedFailureLegacyExportStrict # Old export doesn't work with subclasses
@testing.expectedFailureSerDerNonStrict # builtins.getattr is not supported T211130564
@testing.expectedFailureSerDer # builtins.getattr is not supported T211130564
def test_subclass_nested_attr_access_complicated_metadata(self):
class Foo(torch.nn.Module):
def __init__(self):
@ -2199,9 +2188,9 @@ graph():
%mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%p_p1, 2), kwargs = {})
%add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%x, %mul), kwargs = {})
%add_1 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%add, %p_p2), kwargs = {})
%getattr_21 : [num_users=1] = call_function[target=builtins.getattr](args = (%add_1, elem), kwargs = {})
%getattr_26 : [num_users=1] = call_function[target=builtins.getattr](args = (%getattr_21, elem), kwargs = {})
%add_2 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%getattr_26, 4), kwargs = {})
%access_subclass_inner_tensor_default_10 : [num_users=1] = call_function[target=torch.ops.export.access_subclass_inner_tensor.default](args = (%add_1, elem), kwargs = {})
%access_subclass_inner_tensor_default_13 : [num_users=1] = call_function[target=torch.ops.export.access_subclass_inner_tensor.default](args = (%access_subclass_inner_tensor_default_10, elem), kwargs = {})
%add_2 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%access_subclass_inner_tensor_default_13, 4), kwargs = {})
return (add_2,)""",
)
ep = export(m, (ref_x,))

View File

@ -1,5 +1,4 @@
# mypy: allow-untyped-defs
import builtins
import inspect
import math
import operator
@ -140,7 +139,6 @@ class Verifier(metaclass=_VerifierMeta):
math.floor,
math.trunc,
round,
builtins.getattr,
]
def allowed_op_types(self) -> tuple[type[Any], ...]:

View File

@ -64,6 +64,8 @@ __all__ = [
"UnflattenedModule",
]
# To make sure export specific custom ops are loaded
import torch.export.custom_ops
from .decomp_utils import CustomDecompTable
from .dynamic_shapes import Constraint, Dim, dims, ShapesCollection

View File

@ -1,13 +1,11 @@
# mypy: allow-untyped-decorators
# mypy: allow-untyped-defs
import builtins
import dataclasses
import functools
import inspect
import logging
import re
import time
import types
import warnings
from contextlib import contextmanager, nullcontext
from typing import Any, Callable, Optional, Union
@ -1514,8 +1512,7 @@ def _export_to_aten_ir_make_fx(
out = original_getattr(self, attr)
if attr in attrs_to_proxy:
if torch._C._is_torch_function_mode_enabled():
# If it is a static function or method, we should always inline
if not isinstance(out, (types.FunctionType, types.MethodType)):
if isinstance(out, torch.Tensor):
# When we get here there is no guarantee that we will hit the
# PreDispatchTorchFunctionMode, so we manually peak into the torch
# function mode list and tweak the PreDispatchTorchFunctionMode.
@ -1532,7 +1529,7 @@ def _export_to_aten_ir_make_fx(
proxy = get_proxy_slot(self, tracer).proxy
inner_proxy = tracer.create_proxy(
"call_function",
builtins.getattr,
torch.ops.export.access_subclass_inner_tensor.default,
(proxy, attr),
{},
)
@ -1627,7 +1624,7 @@ def _export_to_aten_ir_make_fx(
# from subclass tensors if we carefully rewrite track_tensor_tree
# in a way that it doesn't do any tensor methods.
torch.ops.aten.detach.default,
builtins.getattr,
torch.ops.export.access_subclass_inner_tensor.default,
):
return False
return True

View File

@ -0,0 +1,23 @@
import torch
lib = torch.library.Library("export", "FRAGMENT") # noqa: TOR901
lib.define(
"access_subclass_inner_tensor(Tensor src_subclass_tensor, str attr) -> Tensor"
)
@torch.library.impl(lib, "access_subclass_inner_tensor", "Autograd")
def _access_subclass_inner_tensor(
src_subclass_tensor: torch.Tensor, attr: str
) -> torch.Tensor:
from torch.utils._python_dispatch import is_traceable_wrapper_subclass
assert is_traceable_wrapper_subclass(src_subclass_tensor)
val = getattr(src_subclass_tensor, attr, None)
if val is None or not isinstance(val, torch.Tensor):
raise RuntimeError(
f"Attribute {attr} is not a tensor or doesn't exist in {src_subclass_tensor}"
)
return val