mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[Export] Introduce class_fqn into CustomObjArgument (#118158)
Summary: Class FQN is needed when unpacking CustomObj instance. For all other Arguments, e.g. Tensor, TensorList, SymInt, we always know their exact type. However, CustomObjArgument had an opaque type. Adding this field also helps unveiling the type of this opaque object. Test Plan: CI Differential Revision: D53029847 Pull Request resolved: https://github.com/pytorch/pytorch/pull/118158 Approved by: https://github.com/zhxchen17
This commit is contained in:
committed by
PyTorch MergeBot
parent
fed0f2946f
commit
bb3db079b1
@ -754,6 +754,11 @@ class TestSerializeCustomClass(TestCase):
|
||||
node.args = (arg0, custom_node)
|
||||
|
||||
serialized_vals = serialize(ep)
|
||||
|
||||
ep_str = serialized_vals.exported_program.decode("utf-8")
|
||||
assert "class_fqn" in ep_str
|
||||
assert custom_obj._type().qualified_name() in ep_str
|
||||
|
||||
deserialized_ep = deserialize(serialized_vals)
|
||||
|
||||
for node in deserialized_ep.graph.nodes:
|
||||
@ -763,6 +768,7 @@ class TestSerializeCustomClass(TestCase):
|
||||
):
|
||||
arg = node.args[0]
|
||||
self.assertTrue(isinstance(arg, torch._C.ScriptObject))
|
||||
self.assertEqual(arg._type(), custom_obj._type())
|
||||
self.assertEqual(arg.__getstate__(), custom_obj.__getstate__())
|
||||
self.assertEqual(arg.top(), 7)
|
||||
|
||||
|
@ -5,6 +5,7 @@ from torch._export.verifier import SpecViolationError
|
||||
from torch._guards import detect_fake_mode
|
||||
from torch.export.custom_obj import ScriptObjectMeta
|
||||
from torch.export.exported_program import (
|
||||
ArgumentSpec,
|
||||
CustomObjArgument,
|
||||
ExportGraphSignature,
|
||||
InputKind,
|
||||
@ -54,12 +55,10 @@ def lift_constants_pass(
|
||||
if isinstance(constant_val, torch.ScriptObject):
|
||||
constant_name = f"_lifted_custom_obj{num_custom_obj}"
|
||||
constant_kind = InputKind.CUSTOM_OBJ
|
||||
constant_arg_cls = CustomObjArgument # type: ignore[assignment]
|
||||
num_custom_obj += 1
|
||||
elif isinstance(constant_val, torch.Tensor):
|
||||
constant_name = f"_lifted_tensor_constant{num_tensor_constants}"
|
||||
constant_kind = InputKind.CONSTANT_TENSOR
|
||||
constant_arg_cls = TensorArgument # type: ignore[assignment]
|
||||
num_tensor_constants += 1
|
||||
elif isinstance(constant_val, torch.fx.GraphModule):
|
||||
continue
|
||||
@ -88,15 +87,25 @@ def lift_constants_pass(
|
||||
else:
|
||||
constant_fqn = constant_name
|
||||
|
||||
input_spec_arg: ArgumentSpec
|
||||
if isinstance(constant_val, torch.Tensor):
|
||||
const_placeholder_node.meta["val"] = fake_mode.from_tensor(
|
||||
constant_val, static_shapes=True
|
||||
)
|
||||
const_placeholder_node.meta["val"].constant = constant_val
|
||||
input_spec_arg = TensorArgument(name=const_placeholder_node.name)
|
||||
elif isinstance(constant_val, torch._C.ScriptObject):
|
||||
const_placeholder_node.meta["val"] = ScriptObjectMeta(constant_fqn)
|
||||
class_fqn = constant_val._type().qualified_name() # type: ignore[attr-defined]
|
||||
const_placeholder_node.meta["val"] = ScriptObjectMeta(
|
||||
constant_fqn, class_fqn
|
||||
)
|
||||
input_spec_arg = CustomObjArgument(
|
||||
name=const_placeholder_node.name, class_fqn=class_fqn
|
||||
)
|
||||
else:
|
||||
const_placeholder_node.meta["val"] = constant_val
|
||||
# TODO: use of TensorArgument doesn't look right, what's this branch for?
|
||||
input_spec_arg = TensorArgument(name=const_placeholder_node.name)
|
||||
|
||||
node.replace_all_uses_with(const_placeholder_node)
|
||||
gm.graph.erase_node(node)
|
||||
@ -106,7 +115,7 @@ def lift_constants_pass(
|
||||
first_user_input_loc,
|
||||
InputSpec(
|
||||
kind=constant_kind,
|
||||
arg=constant_arg_cls(name=const_placeholder_node.name),
|
||||
arg=input_spec_arg,
|
||||
target=constant_fqn,
|
||||
),
|
||||
)
|
||||
@ -133,7 +142,8 @@ def rewrite_script_object_meta(
|
||||
continue
|
||||
|
||||
old_meta = node.meta["val"]
|
||||
new_meta = ScriptObjectMeta(node.name)
|
||||
class_fqn = old_meta._type().qualified_name() # type: ignore[attr-defined]
|
||||
new_meta = ScriptObjectMeta(node.name, class_fqn)
|
||||
constants[node.name] = old_meta
|
||||
node.meta["val"] = new_meta
|
||||
|
||||
|
@ -96,7 +96,8 @@ class TensorMeta:
|
||||
|
||||
@dataclass
|
||||
class ScriptObjectMeta:
|
||||
constant_name: Optional[str]
|
||||
constant_name: str
|
||||
class_fqn: str
|
||||
|
||||
|
||||
# In most cases we will use the "as_name" field to store arguments which are
|
||||
@ -147,6 +148,7 @@ class GraphArgument:
|
||||
@dataclass
|
||||
class CustomObjArgument:
|
||||
name: str
|
||||
class_fqn: str
|
||||
|
||||
|
||||
# This is actually a union type
|
||||
|
@ -347,7 +347,8 @@ class GraphModuleSerializer:
|
||||
elif isinstance(node.meta['val'], (int, bool, str, float, type(None))):
|
||||
graph_input = self.serialize_input(node.meta['val'])
|
||||
elif isinstance(node.meta['val'], export_ScriptObjectMeta):
|
||||
graph_input = Argument.create(as_custom_obj=CustomObjArgument(name=node.name))
|
||||
class_fqn = node.meta["val"].class_fqn
|
||||
graph_input = Argument.create(as_custom_obj=CustomObjArgument(name=node.name, class_fqn=class_fqn))
|
||||
self.graph_state.script_object_metas[node.name] = self.serialize_script_obj_meta(node.meta["val"])
|
||||
else:
|
||||
raise AssertionError(f"Unimplemented graph input type: {node.meta['val']}")
|
||||
@ -487,7 +488,8 @@ class GraphModuleSerializer:
|
||||
|
||||
def serialize_script_obj_meta(self, script_obj_meta: export_ScriptObjectMeta) -> ScriptObjectMeta:
|
||||
return ScriptObjectMeta(
|
||||
constant_name=script_obj_meta.constant_name
|
||||
constant_name=script_obj_meta.constant_name,
|
||||
class_fqn=script_obj_meta.class_fqn,
|
||||
)
|
||||
|
||||
def serialize_sym_op_inputs(self, op, args) -> List[NamedArgument]:
|
||||
@ -564,7 +566,7 @@ class GraphModuleSerializer:
|
||||
return Argument.create(as_sym_bool=SymBoolArgument.create(as_name=arg.name))
|
||||
else:
|
||||
if isinstance(arg.meta["val"], export_ScriptObjectMeta):
|
||||
return Argument.create(as_custom_obj=CustomObjArgument(name=arg.name))
|
||||
return Argument.create(as_custom_obj=CustomObjArgument(name=arg.name, class_fqn=arg.meta["val"].class_fqn))
|
||||
return Argument.create(as_tensor=TensorArgument(name=arg.name))
|
||||
elif isinstance(arg, inductor_tensor_buffers):
|
||||
# Other branches are for arguments in fx node.
|
||||
@ -688,7 +690,8 @@ class GraphModuleSerializer:
|
||||
# serialize/deserialize function.
|
||||
custom_obj_name = f"_custom_obj_{len(self.custom_objs)}"
|
||||
self.custom_objs[custom_obj_name] = arg
|
||||
return Argument.create(as_custom_obj=CustomObjArgument(custom_obj_name))
|
||||
class_fqn = arg._type().qualified_name() # type: ignore[attr-defined]
|
||||
return Argument.create(as_custom_obj=CustomObjArgument(custom_obj_name, class_fqn))
|
||||
else:
|
||||
raise SerializeError(f"Unsupported argument type: {type(arg)}")
|
||||
|
||||
@ -746,7 +749,7 @@ class GraphModuleSerializer:
|
||||
assert isinstance(spec.arg, ep.CustomObjArgument)
|
||||
return InputSpec.create(
|
||||
custom_obj=InputToCustomObjSpec(
|
||||
arg=CustomObjArgument(name=spec.arg.name),
|
||||
arg=CustomObjArgument(name=spec.arg.name, class_fqn=spec.arg.class_fqn),
|
||||
custom_obj_name=spec.target,
|
||||
)
|
||||
)
|
||||
@ -811,7 +814,7 @@ class GraphModuleSerializer:
|
||||
elif isinstance(x, ep.ConstantArgument):
|
||||
return self.serialize_input(x.value)
|
||||
elif isinstance(x, ep.CustomObjArgument):
|
||||
return Argument.create(as_custom_obj=CustomObjArgument(name=x.name))
|
||||
return Argument.create(as_custom_obj=CustomObjArgument(name=x.name, class_fqn=x.class_fqn))
|
||||
else:
|
||||
raise AssertionError("TODO")
|
||||
|
||||
@ -1154,7 +1157,8 @@ class GraphModuleDeserializer:
|
||||
|
||||
def deserialize_script_obj_meta(self, script_obj_meta: ScriptObjectMeta) -> export_ScriptObjectMeta:
|
||||
return export_ScriptObjectMeta(
|
||||
constant_name=script_obj_meta.constant_name
|
||||
constant_name=script_obj_meta.constant_name,
|
||||
class_fqn=script_obj_meta.class_fqn,
|
||||
)
|
||||
|
||||
def deserialize_graph_output(self, output) -> torch.fx.Node:
|
||||
@ -1296,10 +1300,10 @@ class GraphModuleDeserializer:
|
||||
arg=ep.TensorArgument(name=i.tensor_constant.arg.name),
|
||||
target=i.tensor_constant.tensor_constant_name,
|
||||
)
|
||||
elif i.custom_obj is not None:
|
||||
elif i.type == "custom_obj":
|
||||
return ep.InputSpec(
|
||||
kind=ep.InputKind.CUSTOM_OBJ,
|
||||
arg=ep.CustomObjArgument(name=i.custom_obj.arg.name),
|
||||
arg=ep.CustomObjArgument(name=i.custom_obj.arg.name, class_fqn=i.custom_obj.arg.class_fqn),
|
||||
target=i.custom_obj.custom_obj_name,
|
||||
)
|
||||
else:
|
||||
|
@ -1,5 +1,4 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
|
||||
__all__ = ["ScriptObjectMeta"]
|
||||
@ -12,4 +11,6 @@ class ScriptObjectMeta:
|
||||
"""
|
||||
|
||||
# Key into constants table to retrieve the real ScriptObject.
|
||||
constant_name: Optional[str]
|
||||
constant_name: str
|
||||
|
||||
class_fqn: str
|
||||
|
@ -516,7 +516,9 @@ class ExportedProgram:
|
||||
old_input_spec = old_signature.input_specs[i]
|
||||
arg = (
|
||||
old_input_spec.arg
|
||||
if isinstance(old_input_spec.arg, ConstantArgument)
|
||||
if isinstance(
|
||||
old_input_spec.arg, (ConstantArgument, CustomObjArgument)
|
||||
)
|
||||
else type(old_input_spec.arg)(node.name)
|
||||
)
|
||||
new_input_specs.append(
|
||||
@ -534,7 +536,9 @@ class ExportedProgram:
|
||||
old_output_spec = old_signature.output_specs[i]
|
||||
arg = (
|
||||
old_output_spec.arg
|
||||
if isinstance(old_output_spec.arg, ConstantArgument)
|
||||
if isinstance(
|
||||
old_output_spec.arg, (ConstantArgument, CustomObjArgument)
|
||||
)
|
||||
else type(old_output_spec.arg)(node.name)
|
||||
)
|
||||
new_output_specs.append(
|
||||
|
@ -31,6 +31,7 @@ class SymIntArgument:
|
||||
@dataclasses.dataclass
|
||||
class CustomObjArgument:
|
||||
name: str
|
||||
class_fqn: str
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
|
Reference in New Issue
Block a user