[export] Implement cpp deserializer. (#136398)

Differential Revision: D63206258

This diff introduces a mechanism to generate a json-compatible deserializer in cpp using nlohmann json (already being used by AOTI).

Why we need this? Because there will be a lot of cases where people don't want to use Python to load the graph (e.g. cpp runtime), and instead they can use this header to deserialize the JSON graph.

Every time we call update_schema.py to update the schema, the header will be auto generated and included into the source files.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/136398
Approved by: https://github.com/angelayi
This commit is contained in:
Zhengxu Chen
2024-11-14 16:34:56 +00:00
committed by PyTorch MergeBot
parent f98c601efe
commit 3ef2dfc1ba
15 changed files with 2656 additions and 62 deletions

1
.gitattributes vendored
View File

@ -5,3 +5,4 @@
.github/scripts/gql_mocks.json linguist-generated=true
third_party/LICENSES_BUNDLED.txt linguist-generated=true
tools/build/bazel/requirements.txt linguist-generated=true
torch/csrc/utils/generated_serialization_types.h linguist-generated=true

View File

@ -844,6 +844,7 @@ libtorch_python_core_sources = [
"torch/csrc/fx/node.cpp",
"torch/csrc/mps/Module.cpp",
"torch/csrc/mtia/Module.cpp",
"torch/csrc/export/pybind.cpp",
"torch/csrc/inductor/aoti_package/pybind.cpp",
"torch/csrc/inductor/aoti_runner/pybind.cpp",
"torch/csrc/inductor/aoti_eager/kernel_holder.cpp",

View File

@ -29,7 +29,7 @@ if __name__ == "__main__":
commit = schema_check.update_schema()
if os.path.exists(args.prefix + commit.path):
if os.path.exists(args.prefix + commit.yaml_path):
if commit.result["SCHEMA_VERSION"] < commit.base["SCHEMA_VERSION"]:
raise RuntimeError(
f"Schema version downgraded from {commit.base['SCHEMA_VERSION']} to {commit.result['SCHEMA_VERSION']}."
@ -55,17 +55,28 @@ if __name__ == "__main__":
+ f"Reason: {reason}"
)
header = (
"# @" + "generated by " + os.path.basename(__file__).rsplit(".", 1)[0] + ".py"
first_line = (
"@" + "generated by " + os.path.basename(__file__).rsplit(".", 1)[0] + ".py"
)
header += f"\n# checksum<<{commit.checksum_result}>>"
payload = dump(commit.result, Dumper=Dumper, sort_keys=False)
checksum = f"checksum<<{commit.checksum_result}>>"
yaml_header = "# " + first_line
yaml_header += "\n# " + checksum
yaml_payload = dump(commit.result, Dumper=Dumper, sort_keys=False)
content = header + "\n" + payload
cpp_header = "// " + first_line
cpp_header += "\n// " + checksum
cpp_header += "\n// clang-format off"
cpp_header += "\n" + commit.cpp_header
cpp_header += "\n// clang-format on"
cpp_header += "\n"
yaml_content = yaml_header + "\n" + yaml_payload
if args.dry_run:
print(content)
print("\nWill write the above schema to" + args.prefix + commit.path)
print(yaml_content)
print("\nWill write the above schema to" + args.prefix + commit.yaml_path)
else:
with open(args.prefix + commit.path, "w") as f:
f.write(content)
with open(args.prefix + commit.yaml_path, "w") as f:
f.write(yaml_content)
with open(args.prefix + commit.cpp_header_path, "w") as f:
f.write(cpp_header)

View File

@ -0,0 +1,60 @@
# Owner(s): ["oncall: export"]
import torch
from torch._export.serde.serialize import deserialize, serialize
try:
from . import test_export, testing
except ImportError:
import test_export # @manual=fbcode//caffe2/test:test_export-library
import testing # @manual=fbcode//caffe2/test:test_export-library
from torch.export import export
test_classes = {}
def mocked_cpp_serdes_export(*args, **kwargs):
ep = export(*args, **kwargs)
try:
payload = serialize(ep)
except Exception:
return ep
cpp_ep = torch._C._export.deserialize_exported_program(payload.exported_program)
loaded_json = torch._C._export.serialize_exported_program(cpp_ep)
payload.exported_program = loaded_json.encode()
loaded_ep = deserialize(payload)
return loaded_ep
def make_dynamic_cls(cls):
cls_prefix = "CppSerdes"
test_class = testing.make_test_cls_with_mocked_export(
cls,
cls_prefix,
"_cpp_serdes",
mocked_cpp_serdes_export,
xfail_prop="_expected_failure_cpp_serdes",
)
test_classes[test_class.__name__] = test_class
# REMOVING THIS LINE WILL STOP TESTS FROM RUNNING
globals()[test_class.__name__] = test_class
test_class.__module__ = __name__
tests = [
test_export.TestExport,
]
for test in tests:
make_dynamic_cls(test)
del test
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests
run_tests()

View File

@ -2951,6 +2951,7 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
export(N(), inputs, dynamic_shapes=dynamic_shapes)
@testing.expectedFailureSerDer # no unbacked bindings after deserialization?
@testing.expectedFailureCppSerDes # no unbacked bindings after deserialization?
@testing.expectedFailureSerDerNonStrict
def test_unbacked_bindings_for_divisible_u_symint(self):
with torch.library._scoped_library("mylib", "FRAGMENT") as lib:
@ -3673,6 +3674,7 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
self._test_export_same_as_eager(kw_func, args, kwargs)
@testing.expectedFailureSerDer # we don't save placeholder metadata
@testing.expectedFailureCppSerDes # we don't save placeholder metadata
@testing.expectedFailureSerDerNonStrict
@testing.expectedFailureNonStrict
@testing.expectedFailureTrainingIRToRunDecompNonStrict # source_fn_stack failure
@ -8078,6 +8080,7 @@ def forward(self, x, y):
export(f, (inputs,), dynamic_shapes=dynamic_shapes)
@testing.expectedFailureRetraceabilityNonStrict
@testing.expectedFailureCppSerDes # dynamic shape serialization
def test_disable_forced_specializations_ok(self):
# check that we don't force specialization, and defer to runtime asserts
# with allow_complex_guards_as_runtime_asserts=True to successfully export
@ -8198,6 +8201,7 @@ def forward(self, x, y):
# TODO requires_grad doesn't seem to work with serialization.
@testing.expectedFailureSerDer
@testing.expectedFailureCppSerDes
@testing.expectedFailureSerDerNonStrict
def test_preserve_requires_grad_placeholders(self):
class Module(torch.nn.Module):
@ -8536,6 +8540,7 @@ def forward(self, x, y):
ep.graph_module.code
)
@testing.expectedFailureCppSerDes
def test_slice_with_floordiv(self):
# slice operation emits runtime assert s0//2 <= s1
class M1(torch.nn.Module):
@ -9105,6 +9110,7 @@ def forward(self, x):
_load_dynamic_shapes(spec, from_dict=True)
@testing.expectedFailureSerDer # TODO(pianpwk): PowByNatural valuerange deserialization
@testing.expectedFailureCppSerDes # TODO(pianpwk): PowByNatural valuerange deserialization
@testing.expectedFailureSerDerNonStrict
@testing.expectedFailureRetraceabilityNonStrict
def test_dim_dynamic(self):

View File

@ -106,11 +106,13 @@ Example(s):
commit = _Commit(
result=src,
checksum_result="",
path="",
yaml_path="",
additions=additions,
subtractions=subtractions,
base=dst,
checksum_base="",
cpp_header="",
cpp_header_path="",
)
next_version, _ = check(commit)
self.assertEqual(next_version, [4, 1])
@ -138,11 +140,13 @@ Example(s):
commit = _Commit(
result=src,
checksum_result="",
path="",
yaml_path="",
additions=additions,
subtractions=subtractions,
base=dst,
checksum_base="",
cpp_header="",
cpp_header_path="",
)
next_version, _ = check(commit)
self.assertEqual(next_version, [4, 1])
@ -173,11 +177,13 @@ Example(s):
commit = _Commit(
result=src,
checksum_result="",
path="",
yaml_path="",
additions=additions,
subtractions=subtractions,
base=dst,
checksum_base="",
cpp_header="",
cpp_header_path="",
)
next_version, _ = check(commit)
self.assertEqual(next_version, [3, 3])
@ -231,11 +237,13 @@ Example(s):
commit = _Commit(
result=src,
checksum_result="",
path="",
yaml_path="",
additions=additions,
subtractions=subtractions,
base=dst,
checksum_base="",
cpp_header="",
cpp_header_path="",
)
next_version, _ = check(commit)
self.assertEqual(next_version, [3, 3])
@ -259,11 +267,13 @@ Example(s):
commit = _Commit(
result=src,
checksum_result="",
path="",
yaml_path="",
additions=additions,
subtractions=subtractions,
base=dst,
checksum_base="",
cpp_header="",
cpp_header_path="",
)
next_version, _ = check(commit)
self.assertEqual(next_version, [3, 3])
@ -294,11 +304,13 @@ Example(s):
commit = _Commit(
result=src,
checksum_result="",
path="",
yaml_path="",
additions=additions,
subtractions=subtractions,
base=dst,
checksum_base="",
cpp_header="",
cpp_header_path="",
)
next_version, _ = check(commit)
self.assertEqual(next_version, [3, 3])
@ -326,11 +338,13 @@ Example(s):
commit = _Commit(
result=src,
checksum_result="",
path="",
yaml_path="",
additions=additions,
subtractions=subtractions,
base=dst,
checksum_base="",
cpp_header="",
cpp_header_path="",
)
next_version, _ = check(commit)
self.assertEqual(next_version, [4, 1])

View File

@ -284,3 +284,8 @@ def expectedFailureSerDerPreDispatch(fn):
def expectedFailurePreDispatchRunDecomp(fn):
fn._expected_failure_pre_dispatch = True
return fn
def expectedFailureCppSerDes(fn):
fn._expected_failure_cpp_serdes = True
return fn

View File

@ -64,6 +64,7 @@ from torch.utils._python_dispatch import TorchDispatchMode
from . import (
_aoti,
_export,
_cpu,
_dynamo,
_functorch,

10
torch/_C/_export.pyi Normal file
View File

@ -0,0 +1,10 @@
# Defined in torch/csrc/export/pybind.cpp
class CppExportedProgram: ...
def deserialize_exported_program(
serialized_program: str,
) -> CppExportedProgram: ...
def serialize_exported_program(
cpp_exported_program: CppExportedProgram,
) -> str: ...

View File

@ -1,10 +1,11 @@
# mypy: allow-untyped-defs
import dataclasses
import hashlib
import inspect
import re
import typing
from enum import IntEnum
from typing import Any, Dict, Optional, Union
from typing import Any, Dict, List, Optional, Tuple, Union
from torch._export.serde import schema
from torch._export.serde.union import _Union
@ -20,43 +21,84 @@ def _check(x, msg):
def _staged_schema():
ret: Dict[str, Any] = {}
yaml_ret: Dict[str, Any] = {}
defs = {}
cpp_enum_defs: Dict[str, str] = {}
cpp_class_defs: Dict[str, str] = {}
cpp_type_decls: List[str] = []
cpp_json_defs: List[str] = []
def _handle_aggregate(ty):
def dump_type(t):
def _handle_aggregate(ty) -> Tuple[Dict[str, Any], Dict[str, Any]]:
def dump_type(t) -> Tuple[str, str]:
TYPE_MAP = {
str: "std::string",
int: "int64_t",
float: "double",
bool: "bool",
}
if isinstance(t, type):
return t.__name__
if t.__name__ in cpp_enum_defs:
return t.__name__, "int64_t"
else:
return t.__name__, TYPE_MAP.get(t, t.__name__)
elif isinstance(t, str):
assert t in defs
return t
assert t not in cpp_enum_defs
assert "[" not in t
return t, f"ForwardRef<{t}>"
elif o := typing.get_origin(t):
# Lemme know if there's a better way to do this.
if o == list:
head = "List"
yaml_head, cpp_head = "List", "std::vector"
elif o == dict:
head = "Dict"
yaml_head, cpp_head = "Dict", "std::unordered_map"
elif o == tuple:
if typing.get_args(t) == ():
return "Tuple[()]"
head = "Tuple"
return "Tuple[()]", "std::tuple<>"
yaml_head, cpp_head = "Tuple", "std::tuple"
elif o == Union:
args = typing.get_args(t)
assert len(args) == 2 and args[1] == type(None)
return f"Optional[{dump_type(args[0])}]"
yaml_type, cpp_type = dump_type(args[0])
return f"Optional[{yaml_type}]", f"std::optional<{cpp_type}>"
else:
raise AssertionError(f"Type {t} is not supported in export schema.")
return (
f"{head}[{', '.join([dump_type(x) for x in typing.get_args(t)])}]"
yaml_arg_types, cpp_arg_types = zip(
*[dump_type(x) for x in typing.get_args(t)]
)
return (f"{yaml_head}[{', '.join(yaml_arg_types)}]"), (
f"{cpp_head}<{', '.join(cpp_arg_types)}>"
)
elif t == ():
return "()"
return "()", ""
else:
raise AssertionError(f"Type {t} is not supported in export schema.")
def dump_field(f):
t = dump_type(f.type)
def dump_cpp_value(v) -> str:
if v is None:
return "std::nullopt"
elif v is True:
return "true"
elif v is False:
return "false"
elif v == {}:
return "{}"
elif v == []:
return "{}"
elif v == ():
return "{}"
elif isinstance(v, str):
return f'"{v}"'
else:
raise AssertionError(
f"Default value {v} is not supported yet in export schema."
)
def dump_field(f) -> Tuple[Dict[str, Any], str, Optional[str]]:
t, cpp = dump_type(f.type)
ret = {"type": t}
cpp_type = cpp
cpp_default: Optional[str] = None
value = dataclasses.MISSING
if f.default is not dataclasses.MISSING:
@ -67,24 +109,149 @@ def _staged_schema():
if value is not dataclasses.MISSING:
default = str(value)
ret["default"] = default
cpp_default = dump_cpp_value(value)
if t.startswith("Optional[") and value is not None:
raise AssertionError(
f"Optional field {ty.__name__}.{f.name} must have default value to be None."
)
return ret
return ret, cpp_type, cpp_default
return {f.name: dump_field(f) for f in dataclasses.fields(ty)}
yaml_ret = {}
cpp_ret = {}
for f in dataclasses.fields(ty):
yaml_res, cpp_type, cpp_default = dump_field(f)
yaml_ret[f.name] = yaml_res
cpp_ret[f.name] = {"cpp_type": cpp_type, "cpp_default": cpp_default}
return yaml_ret, cpp_ret
def _handle_int_enum(name, ty):
ret[name] = {"kind": "enum", "fields": {x.name: x.value for x in ty}}
yaml_ret[name] = {"kind": "enum", "fields": {x.name: x.value for x in ty}}
cpp_enum_defs[
name
] = f"""
enum class {name} {{
{chr(10).join([f" {x.name} = {x.value}," for x in ty])}
}};
"""
def _handle_struct(name, ty):
ret[name] = {"kind": "struct", "fields": _handle_aggregate(ty)}
fields, cpp_fields = _handle_aggregate(ty)
yaml_ret[name] = {"kind": "struct", "fields": fields}
field_decls = "\n".join(
f" {f['cpp_type']} {name}{' = ' + f['cpp_default'] if f['cpp_default'] is not None else ''};"
for name, f in cpp_fields.items()
)
def accessor(name, ty):
type_name = fields[name]["type"]
if type_name in cpp_enum_defs:
return f"""
{type_name} get_{name}() const {{
return static_cast<{type_name}>({name});
}}
"""
return f"""
const {ty}& get_{name}() const {{
return {name};
}}
"""
to_json_decl = f"void to_json(nlohmann::json& nlohmann_json_j, const {name}& nlohmann_json_t)"
to_json_def = f"""{{
{chr(10).join([f' nlohmann_json_j["{name}"] = nlohmann_json_t.{name};' for name, f in cpp_fields.items()])}
}}
"""
from_json_decl = f"void from_json(const nlohmann::json& nlohmann_json_j, {name}& nlohmann_json_t)"
from_json_def = f"""{{
{name} nlohmann_json_default_obj;
{chr(10).join(
[f' nlohmann_json_t.{name} = nlohmann_json_j.value("{name}", nlohmann_json_default_obj.{name});'
for name, f in cpp_fields.items()])}
}}
"""
cpp_class_defs[
name
] = f"""
class {name} {{
private:
{field_decls}
public:
{"".join([accessor(name, f["cpp_type"]) for name, f in cpp_fields.items()])}
friend {to_json_decl};
friend {from_json_decl};
}};
"""
cpp_json_defs.append(f"inline {to_json_decl} {to_json_def}")
cpp_json_defs.append(f"inline {from_json_decl} {from_json_def}")
cpp_type_decls.append(f"class {name};")
def _handle_union(name, ty):
ret[name] = {"kind": "union", "fields": _handle_aggregate(ty)}
fields, cpp_fields = _handle_aggregate(ty)
yaml_ret[name] = {"kind": "union", "fields": fields}
def accessor(name, ty, idx):
return f"""
const {ty}& get_{name}() const {{
return std::get<{idx + 1}>(variant_);
}}
"""
to_json_branches = "".join(
[
f"""
if (nlohmann_json_t.tag_ == Tag::{name.upper()}) {{
nlohmann_json_j["{name}"] = nlohmann_json_t.get_{name}();
return;
}}"""
for idx, (name, f) in enumerate(cpp_fields.items())
]
)
from_json_branches = "".join(
[
f"""
if (nlohmann_json_j.contains("{name}")) {{
nlohmann_json_t.variant_.emplace<{idx + 1}>(nlohmann_json_j.at("{name}").template get<{f["cpp_type"]}>());
nlohmann_json_t.tag_ = Tag::{name.upper()};
return;
}}"""
for idx, (name, f) in enumerate(cpp_fields.items())
]
)
cpp_class_defs[
name
] = f"""
class {name} {{
struct Void {{}};
public:
enum class Tag {{
{", ".join([name.upper() for name in cpp_fields])}
}};
private:
std::variant<Void, {", ".join(f["cpp_type"] for f in cpp_fields.values())}> variant_;
Tag tag_;
public:
Tag tag() const {{
return tag_;
}}
{"".join([accessor(name, f["cpp_type"], idx) for idx, (name, f) in enumerate(cpp_fields.items())])}
friend void to_json(nlohmann::json& nlohmann_json_j, const {name}& nlohmann_json_t) {{
{to_json_branches}
}}
friend void from_json(const nlohmann::json& nlohmann_json_j, {name}& nlohmann_json_t) {{
{from_json_branches}
}}
}};
"""
cpp_type_decls.append(f"class {name};")
for name in dir(schema):
if name.startswith("_"):
@ -97,11 +264,13 @@ def _staged_schema():
defs[name] = value
class_ordering = {}
for name, value in defs.items():
if isinstance(value, type):
if issubclass(value, IntEnum):
_handle_int_enum(name, value)
elif dataclasses.is_dataclass(value):
class_ordering[name] = inspect.findsource(value)[1]
if issubclass(value, _Union):
_handle_union(name, value)
else:
@ -113,11 +282,103 @@ def _staged_schema():
else:
raise AssertionError(f"Unknown variable {name}: {value}")
ret["SCHEMA_VERSION"] = list(defs["SCHEMA_VERSION"])
assert all(x > 0 for x in ret["SCHEMA_VERSION"])
ret["TREESPEC_VERSION"] = defs["TREESPEC_VERSION"]
assert ret["TREESPEC_VERSION"] > 0
return ret
yaml_ret["SCHEMA_VERSION"] = list(defs["SCHEMA_VERSION"])
assert all(x > 0 for x in yaml_ret["SCHEMA_VERSION"])
yaml_ret["TREESPEC_VERSION"] = defs["TREESPEC_VERSION"]
assert yaml_ret["TREESPEC_VERSION"] > 0
cpp_header = f"""
#pragma once
#include <optional>
#include <string>
#include <unordered_map>
#include <variant>
#include <vector>
#include <nlohmann/json.hpp>
#ifndef NLOHMANN_JSON_NAMESPACE_BEGIN
#define NLOHMANN_JSON_NAMESPACE_BEGIN namespace nlohmann {{
#endif
#ifndef NLOHMANN_JSON_NAMESPACE_END
#define NLOHMANN_JSON_NAMESPACE_END }}
#endif
// https://github.com/nlohmann/json/pull/2117
NLOHMANN_JSON_NAMESPACE_BEGIN
template <typename T>
struct adl_serializer<std::optional<T>> {{
static void to_json(json& j, const std::optional<T>& opt) {{
if (opt == std::nullopt) {{
j = nullptr;
}} else {{
j = *opt; // this will call adl_serializer<T>::to_json which will
// find the free function to_json in T's namespace!
}}
}}
static void from_json(const json& j, std::optional<T>& opt) {{
if (j.is_null()) {{
opt = std::nullopt;
}} else {{
opt = j.template get<T>(); // same as above, but with
// adl_serializer<T>::from_json
}}
}}
}};
NLOHMANN_JSON_NAMESPACE_END
namespace torch {{
namespace _export {{
template <typename T>
class ForwardRef {{
static_assert(!std::is_reference_v<T>, "ForwardRef cannot be a reference type");
public:
ForwardRef(): ptr_(std::make_unique<T>()) {{}}
ForwardRef(ForwardRef<T>&&) = default;
ForwardRef(const ForwardRef<T>& other): ptr_(std::make_unique<T>(*other.ptr_)) {{}}
ForwardRef<T>& operator=(ForwardRef<T>&&) = default;
ForwardRef<T>& operator=(const ForwardRef<T>& other) {{
ptr_ = std::make_unique<T>(*other.ptr_);
}}
const T& operator*() const {{
return *ptr_;
}}
const T* operator->() const {{
return ptr_.get();
}}
void emplace(T&& t) {{
ptr_ = std::make_unique<T>(std::move(t));
}}
private:
std::unique_ptr<T> ptr_;
}};
template <typename T>
void to_json(nlohmann::json& j, const ForwardRef<T>& p) {{
j = *p;
}}
template <typename T>
void from_json(const nlohmann::json& j, ForwardRef<T>& p) {{
p.emplace(j.template get<T>());
}}
{chr(10).join(cpp_type_decls)}
{"".join(cpp_enum_defs.values())}
{"".join(dict(sorted(cpp_class_defs.items(), key=lambda x: class_ordering[x[0]])).values())}
{chr(10).join(cpp_json_defs)}
}} // namespace _export
}} // namespace torch
"""
return yaml_ret, cpp_header
def _diff_schema(dst, src):
@ -193,11 +454,13 @@ def _hash_schema(s):
class _Commit:
result: Dict[str, Any]
checksum_result: str
path: str
yaml_path: str
additions: Dict[str, Any]
subtractions: Dict[str, Any]
base: Dict[str, Any]
checksum_base: Optional[str]
cpp_header: str
cpp_header_path: str
def update_schema():
@ -217,16 +480,22 @@ def update_schema():
checksum_base = None
dst = {"SCHEMA_VERSION": None, "TREESPEC_VERSION": None}
src = _staged_schema()
src, cpp_header = _staged_schema()
additions, subtractions = _diff_schema(dst, src)
yaml_path = __package__.replace(".", "/") + "/schema.yaml"
torch_prefix = "torch/"
assert yaml_path.startswith(torch_prefix) # sanity check
return _Commit(
result=src,
checksum_result=_hash_schema(src),
path=__package__.replace(".", "/") + "/schema.yaml",
yaml_path=yaml_path,
additions=additions,
subtractions=subtractions,
base=dst,
checksum_base=checksum_base,
cpp_header=cpp_header,
cpp_header_path=torch_prefix + "csrc/utils/generated_serialization_types.h",
)

View File

@ -697,7 +697,7 @@ class GraphModuleSerializer(metaclass=Final):
return inputs
def is_sym_int_arg(self, arg) -> bool:
return isinstance(arg, int) or (
return type(arg) is int or (
isinstance(arg, torch.fx.Node)
and arg.name in self.graph_state.sym_int_values
)
@ -770,13 +770,13 @@ class GraphModuleSerializer(metaclass=Final):
# For regular FX graph, SymInt arg should be a fx.Node with
# self.is_sym_int_arg(arg) being true
return Argument.create(as_sym_int=SymIntArgument.create(as_name=str(arg)))
elif isinstance(arg, bool):
elif type(arg) is bool:
return Argument.create(as_bool=arg)
elif isinstance(arg, str):
elif type(arg) is str:
return Argument.create(as_string=arg)
elif isinstance(arg, int):
elif type(arg) is int:
return Argument.create(as_int=arg)
elif isinstance(arg, float):
elif type(arg) is float:
return Argument.create(as_float=arg)
elif arg is None:
return Argument.create(as_none=())
@ -814,14 +814,13 @@ class GraphModuleSerializer(metaclass=Final):
)
return Argument.create(as_tensors=[])
# Must check bool first, as bool is also treated as int
if all(isinstance(a, bool) for a in arg):
if all(type(a) is bool for a in arg):
return Argument.create(as_bools=list(arg))
elif all(isinstance(a, int) for a in arg):
elif all(type(a) is int for a in arg):
return Argument.create(as_ints=list(arg))
elif all(isinstance(a, float) for a in arg):
elif all(type(a) is float for a in arg):
return Argument.create(as_floats=list(arg))
elif all(isinstance(a, str) for a in arg):
elif all(type(a) is str for a in arg):
return Argument.create(as_strings=list(arg))
elif all(isinstance(a, torch.SymInt) for a in arg):
# This is a special branch for handling SymInt args in inductor's
@ -837,7 +836,7 @@ class GraphModuleSerializer(metaclass=Final):
for a in arg:
if isinstance(a, torch.fx.Node):
values.append(SymIntArgument.create(as_name=a.name))
elif isinstance(a, int):
elif type(a) is int:
values.append(SymIntArgument.create(as_int=a))
return Argument.create(as_sym_ints=values)
elif all(self.is_sym_bool_arg(a) for a in arg):
@ -952,13 +951,13 @@ class GraphModuleSerializer(metaclass=Final):
def serialize_input_spec(self, spec: ep.InputSpec) -> InputSpec:
if spec.kind == ep.InputKind.USER_INPUT:
if isinstance(spec.arg, ep.ConstantArgument):
if isinstance(spec.arg.value, int):
if type(spec.arg.value) is int:
constant_spec = ConstantValue.create(as_int=spec.arg.value)
elif isinstance(spec.arg.value, bool):
elif type(spec.arg.value) is bool:
constant_spec = ConstantValue.create(as_bool=spec.arg.value)
elif isinstance(spec.arg.value, str):
elif type(spec.arg.value) is str:
constant_spec = ConstantValue.create(as_string=spec.arg.value)
elif isinstance(spec.arg.value, float):
elif type(spec.arg.value) is float:
constant_spec = ConstantValue.create(as_float=spec.arg.value)
elif spec.arg.value is None:
constant_spec = ConstantValue.create(as_none=())
@ -1548,7 +1547,7 @@ class GraphModuleDeserializer(metaclass=Final):
return self.shape_env.create_symintnode(sym, hint=hint)
elif s.type == "as_int":
assert isinstance(val, int)
assert type(val) is int
return val
else:
raise SerializeError(

View File

@ -68,6 +68,7 @@
#include <torch/csrc/autograd/python_variable.h>
#include <torch/csrc/cpu/Module.h>
#include <torch/csrc/dynamo/init.h>
#include <torch/csrc/export/pybind.h>
#include <torch/csrc/functorch/init.h>
#include <torch/csrc/fx/node.h>
#include <torch/csrc/inductor/aoti_package/pybind.h>
@ -1773,6 +1774,7 @@ PyObject* initModule() {
torch::profiler::initPythonBindings(module);
torch::python::init_bindings(module);
torch::lazy::initLazyBindings(module);
torch::_export::initExportBindings(module);
torch::inductor::initAOTIRunnerBindings(module);
torch::inductor::initAOTIPackageBindings(module);
#ifdef USE_ITT

View File

@ -0,0 +1,20 @@
#include <torch/csrc/utils/generated_serialization_types.h>
#include <torch/csrc/utils/pybind.h>
namespace torch::_export {
void initExportBindings(PyObject* module) {
auto rootModule = py::handle(module).cast<py::module>();
auto m = rootModule.def_submodule("_export");
py::class_<ExportedProgram>(m, "CppExportedProgram");
m.def("deserialize_exported_program", [](const std::string& serialized) {
return nlohmann::json::parse(serialized).get<ExportedProgram>();
});
m.def("serialize_exported_program", [](const ExportedProgram& ep) {
return nlohmann::json(ep).dump();
});
}
} // namespace torch::_export

View File

@ -0,0 +1,7 @@
#include <torch/csrc/python_headers.h>
namespace torch::_export {
void initExportBindings(PyObject* module);
} // namespace torch::_export

File diff suppressed because it is too large Load Diff