Compare commits

...

1 Commits

Author SHA1 Message Date
65d6fc58eb test 2025-01-28 22:13:11 -08:00
3 changed files with 79 additions and 9 deletions

View File

@ -3,7 +3,6 @@ PYTEST_DONT_REWRITE (prevents pytest from rewriting assertions, which interferes
with test_sym_bool)
"""
# Owner(s): ["oncall: export"]
import copy
import io
@ -900,6 +899,14 @@ class TestDeserialize(TestCase):
dynamic_shapes = {"x": (dim0_x, dim1_x)}
self.check_graph(Foo(), inputs, dynamic_shapes)
def test_pytree(self):
class Foo(torch.nn.Module):
def forward(self, x):
return x["a"] + x["b"]
inputs = ({"a": torch.ones(2, 3), "b": torch.ones(2, 3)},)
self.check_graph(Foo(), inputs)
def test_module(self):
class M(torch.nn.Module):
def __init__(self) -> None:

View File

@ -3,7 +3,7 @@
from dataclasses import dataclass, field
from enum import IntEnum
from typing import Annotated, Optional
from typing import Annotated, Optional, Self
from torch._export.serde.union import _Union
@ -364,6 +364,14 @@ class RangeConstraint:
max_val: Annotated[Optional[int], 20]
@dataclass
class TreeSpec:
type: Annotated[Optional[str], 10]
context: Annotated[str, 20]
children_spec: Annotated[list[Self], 30]
metadata: Annotated[dict[str, str], 40] = field(default_factory=dict)
@dataclass
class ModuleCallSignature:
inputs: Annotated[list[Argument], 10]
@ -371,8 +379,8 @@ class ModuleCallSignature:
# These are serialized by calling pytree.treespec_loads
# And deserialized by calling pytree.treespec_dumps
in_spec: Annotated[str, 30]
out_spec: Annotated[str, 40]
in_spec: Annotated[TreeSpec, 30]
out_spec: Annotated[TreeSpec, 40]
# This field is used to prettify the graph placeholders
# after we ser/der and retrace

View File

@ -34,13 +34,13 @@ import sympy
import torch
import torch.export.exported_program as ep
import torch.utils._pytree as pytree
from torch._export.verifier import load_verifier
from torch._export.non_strict_utils import _enable_graph_inputs_of_type_nn_module
from torch._library.fake_class_registry import FakeScriptObject
from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode
from torch.fx.experimental import symbolic_shapes
from torch.utils import _pytree as pytree
from torch.utils._pytree import treespec_dumps, treespec_loads
from torch.utils._pytree import treespec_loads
from torch.utils._sympy.numbers import int_oo
from torch.utils._sympy.symbol import symbol_is_type, SymT
from torch.utils._sympy.value_ranges import ValueRanges
@ -94,6 +94,7 @@ from .schema import ( # type: ignore[attr-defined]
TensorMeta,
TokenArgument,
TREESPEC_VERSION,
TreeSpec,
UserInputMutationSpec,
UserInputSpec,
UserOutputSpec,
@ -1176,6 +1177,32 @@ class GraphModuleSerializer(metaclass=Final):
else:
raise AssertionError("TODO")
def serialize_treespec(self, treespec: pytree.TreeSpec) -> TreeSpec:
# Use pytree's serialization mechanism to convert it to a python
# dataclass
serialized_pytree_old = pytree._SUPPORTED_PROTOCOLS[TREESPEC_VERSION].treespec_to_json(treespec)
# Convert the dataclass to export's pytree serialization schema. This
# allows us to store some additional metadata
def convert_treespec_schema_to_export_schema(serialized_pytree):
metadata = {}
if serialized_pytree.type is not None:
python_type = pytree.SERIALIZED_TYPE_TO_PYTHON_TYPE[serialized_pytree.type]
if issubclass(python_type, tuple) and hasattr(python_type, '_fields'):
# Check if the type is a NamedTuple, if so, we want to store the
# field names.
metadata["namedtuple_fields"] = list(python_type._fields)
return TreeSpec(
type=serialized_pytree.type,
context=serialized_pytree.context,
children_spec=[convert_treespec_schema_to_export_schema(child) for child in serialized_pytree.children_spec],
metadata=metadata
)
return convert_treespec_schema_to_export_schema(serialized_pytree_old)
def serialize_module_call_signature(
self, module_call_signature: ep.ModuleCallSignature
) -> ModuleCallSignature:
@ -1187,8 +1214,8 @@ class GraphModuleSerializer(metaclass=Final):
outputs=[
self.serialize_argument_spec(x) for x in module_call_signature.outputs
],
in_spec=treespec_dumps(module_call_signature.in_spec, TREESPEC_VERSION),
out_spec=treespec_dumps(module_call_signature.out_spec, TREESPEC_VERSION),
in_spec=self.serialize_treespec(module_call_signature.in_spec),
out_spec=self.serialize_treespec(module_call_signature.out_spec),
forward_arg_names=names if (names := module_call_signature.forward_arg_names) else None
)
@ -2496,10 +2523,36 @@ class GraphModuleDeserializer(metaclass=Final):
else:
return ep.ConstantArgument(name="", value=self.deserialize_input(x))
# def deserialize_treespec(self, serialized_treespec: TreeSpec) -> pytree.TreeSpec:
# # Use pytree's serialization mechanism to convert it to a python
# # dataclass
# serialized_pytree_old = pytree._SUPPORTED_PROTOCOLS[protocol].json_to_treespec(serialized_treespec)
# # Convert the dataclass to export's pytree serialization schema. This
# # allows us to store some additional metadata
# def convert_treespec_schema_to_export_schema(serialized_pytree):
# metadata = {}
# python_type = pytree.SERIALIZED_TYPE_TO_PYTHON_TYPE[serialized_pytree.type]
# if issubclass(python_type, tuple) and hasattr(python_type, '_fields'):
# # Check if the type is a NamedTuple, if so, we want to store the
# # field names.
# metadata["namedtuple_fields"] = list(python_type._fields)
# return TreeSpec(
# type=serialized_pytree.type,
# context=serialized_pytree.context
# children_spec=[convert_treespec_schema_to_export_schema(child) for child in serialized_pytree.children_spec]
# metdata=metadata
# )
# return convert_treespec_schema_to_export_schema(serialized_pytree_old)
def deserialize_module_call_signature(
self, module_call_signature: ModuleCallSignature
) -> ep.ModuleCallSignature:
return ep.ModuleCallSignature(
res = ep.ModuleCallSignature(
inputs=[
self.deserialize_argument_spec(x) for x in module_call_signature.inputs
],
@ -2510,6 +2563,8 @@ class GraphModuleDeserializer(metaclass=Final):
out_spec=treespec_loads(module_call_signature.out_spec),
forward_arg_names=names if (names := module_call_signature.forward_arg_names) else None,
)
breakpoint()
return res
def deserialize_module_call_graph(
self, module_call_graph: list[ModuleCallEntry]