From bb3db079b17ca6b18bfaccd634dc8009f2a31cc3 Mon Sep 17 00:00:00 2001 From: Sherlock Huang Date: Thu, 25 Jan 2024 18:44:25 +0000 Subject: [PATCH] [Export] Introduce class_fqn into CustomObjArgument (#118158) Summary: Class FQN is needed when unpacking CustomObj instance. For all other Arguments, e.g. Tensor, TensorList, SymInt, we always know their exact type. However, CustomObjArgument had an opaque type. Adding this field also helps unveiling the type of this opaque object. Test Plan: CI Differential Revision: D53029847 Pull Request resolved: https://github.com/pytorch/pytorch/pull/118158 Approved by: https://github.com/zhxchen17 --- test/export/test_serialize.py | 6 ++++++ torch/_export/passes/lift_constants_pass.py | 20 ++++++++++++++----- torch/_export/serde/schema.py | 4 +++- torch/_export/serde/serialize.py | 22 ++++++++++++--------- torch/export/custom_obj.py | 5 +++-- torch/export/exported_program.py | 8 ++++++-- torch/export/graph_signature.py | 1 + 7 files changed, 47 insertions(+), 19 deletions(-) diff --git a/test/export/test_serialize.py b/test/export/test_serialize.py index 2c06049872ec..e5aec6b410ea 100644 --- a/test/export/test_serialize.py +++ b/test/export/test_serialize.py @@ -754,6 +754,11 @@ class TestSerializeCustomClass(TestCase): node.args = (arg0, custom_node) serialized_vals = serialize(ep) + + ep_str = serialized_vals.exported_program.decode("utf-8") + assert "class_fqn" in ep_str + assert custom_obj._type().qualified_name() in ep_str + deserialized_ep = deserialize(serialized_vals) for node in deserialized_ep.graph.nodes: @@ -763,6 +768,7 @@ class TestSerializeCustomClass(TestCase): ): arg = node.args[0] self.assertTrue(isinstance(arg, torch._C.ScriptObject)) + self.assertEqual(arg._type(), custom_obj._type()) self.assertEqual(arg.__getstate__(), custom_obj.__getstate__()) self.assertEqual(arg.top(), 7) diff --git a/torch/_export/passes/lift_constants_pass.py b/torch/_export/passes/lift_constants_pass.py index 42aa1af79709..380e17917abc 100644 --- a/torch/_export/passes/lift_constants_pass.py +++ b/torch/_export/passes/lift_constants_pass.py @@ -5,6 +5,7 @@ from torch._export.verifier import SpecViolationError from torch._guards import detect_fake_mode from torch.export.custom_obj import ScriptObjectMeta from torch.export.exported_program import ( + ArgumentSpec, CustomObjArgument, ExportGraphSignature, InputKind, @@ -54,12 +55,10 @@ def lift_constants_pass( if isinstance(constant_val, torch.ScriptObject): constant_name = f"_lifted_custom_obj{num_custom_obj}" constant_kind = InputKind.CUSTOM_OBJ - constant_arg_cls = CustomObjArgument # type: ignore[assignment] num_custom_obj += 1 elif isinstance(constant_val, torch.Tensor): constant_name = f"_lifted_tensor_constant{num_tensor_constants}" constant_kind = InputKind.CONSTANT_TENSOR - constant_arg_cls = TensorArgument # type: ignore[assignment] num_tensor_constants += 1 elif isinstance(constant_val, torch.fx.GraphModule): continue @@ -88,15 +87,25 @@ def lift_constants_pass( else: constant_fqn = constant_name + input_spec_arg: ArgumentSpec if isinstance(constant_val, torch.Tensor): const_placeholder_node.meta["val"] = fake_mode.from_tensor( constant_val, static_shapes=True ) const_placeholder_node.meta["val"].constant = constant_val + input_spec_arg = TensorArgument(name=const_placeholder_node.name) elif isinstance(constant_val, torch._C.ScriptObject): - const_placeholder_node.meta["val"] = ScriptObjectMeta(constant_fqn) + class_fqn = constant_val._type().qualified_name() # type: ignore[attr-defined] + const_placeholder_node.meta["val"] = ScriptObjectMeta( + constant_fqn, class_fqn + ) + input_spec_arg = CustomObjArgument( + name=const_placeholder_node.name, class_fqn=class_fqn + ) else: const_placeholder_node.meta["val"] = constant_val + # TODO: use of TensorArgument doesn't look right, what's this branch for? + input_spec_arg = TensorArgument(name=const_placeholder_node.name) node.replace_all_uses_with(const_placeholder_node) gm.graph.erase_node(node) @@ -106,7 +115,7 @@ def lift_constants_pass( first_user_input_loc, InputSpec( kind=constant_kind, - arg=constant_arg_cls(name=const_placeholder_node.name), + arg=input_spec_arg, target=constant_fqn, ), ) @@ -133,7 +142,8 @@ def rewrite_script_object_meta( continue old_meta = node.meta["val"] - new_meta = ScriptObjectMeta(node.name) + class_fqn = old_meta._type().qualified_name() # type: ignore[attr-defined] + new_meta = ScriptObjectMeta(node.name, class_fqn) constants[node.name] = old_meta node.meta["val"] = new_meta diff --git a/torch/_export/serde/schema.py b/torch/_export/serde/schema.py index fd8fd26adfdf..2b42fa83485b 100644 --- a/torch/_export/serde/schema.py +++ b/torch/_export/serde/schema.py @@ -96,7 +96,8 @@ class TensorMeta: @dataclass class ScriptObjectMeta: - constant_name: Optional[str] + constant_name: str + class_fqn: str # In most cases we will use the "as_name" field to store arguments which are @@ -147,6 +148,7 @@ class GraphArgument: @dataclass class CustomObjArgument: name: str + class_fqn: str # This is actually a union type diff --git a/torch/_export/serde/serialize.py b/torch/_export/serde/serialize.py index 5ecce5beea58..9b189f6ac4f2 100644 --- a/torch/_export/serde/serialize.py +++ b/torch/_export/serde/serialize.py @@ -347,7 +347,8 @@ class GraphModuleSerializer: elif isinstance(node.meta['val'], (int, bool, str, float, type(None))): graph_input = self.serialize_input(node.meta['val']) elif isinstance(node.meta['val'], export_ScriptObjectMeta): - graph_input = Argument.create(as_custom_obj=CustomObjArgument(name=node.name)) + class_fqn = node.meta["val"].class_fqn + graph_input = Argument.create(as_custom_obj=CustomObjArgument(name=node.name, class_fqn=class_fqn)) self.graph_state.script_object_metas[node.name] = self.serialize_script_obj_meta(node.meta["val"]) else: raise AssertionError(f"Unimplemented graph input type: {node.meta['val']}") @@ -487,7 +488,8 @@ class GraphModuleSerializer: def serialize_script_obj_meta(self, script_obj_meta: export_ScriptObjectMeta) -> ScriptObjectMeta: return ScriptObjectMeta( - constant_name=script_obj_meta.constant_name + constant_name=script_obj_meta.constant_name, + class_fqn=script_obj_meta.class_fqn, ) def serialize_sym_op_inputs(self, op, args) -> List[NamedArgument]: @@ -564,7 +566,7 @@ class GraphModuleSerializer: return Argument.create(as_sym_bool=SymBoolArgument.create(as_name=arg.name)) else: if isinstance(arg.meta["val"], export_ScriptObjectMeta): - return Argument.create(as_custom_obj=CustomObjArgument(name=arg.name)) + return Argument.create(as_custom_obj=CustomObjArgument(name=arg.name, class_fqn=arg.meta["val"].class_fqn)) return Argument.create(as_tensor=TensorArgument(name=arg.name)) elif isinstance(arg, inductor_tensor_buffers): # Other branches are for arguments in fx node. @@ -688,7 +690,8 @@ class GraphModuleSerializer: # serialize/deserialize function. custom_obj_name = f"_custom_obj_{len(self.custom_objs)}" self.custom_objs[custom_obj_name] = arg - return Argument.create(as_custom_obj=CustomObjArgument(custom_obj_name)) + class_fqn = arg._type().qualified_name() # type: ignore[attr-defined] + return Argument.create(as_custom_obj=CustomObjArgument(custom_obj_name, class_fqn)) else: raise SerializeError(f"Unsupported argument type: {type(arg)}") @@ -746,7 +749,7 @@ class GraphModuleSerializer: assert isinstance(spec.arg, ep.CustomObjArgument) return InputSpec.create( custom_obj=InputToCustomObjSpec( - arg=CustomObjArgument(name=spec.arg.name), + arg=CustomObjArgument(name=spec.arg.name, class_fqn=spec.arg.class_fqn), custom_obj_name=spec.target, ) ) @@ -811,7 +814,7 @@ class GraphModuleSerializer: elif isinstance(x, ep.ConstantArgument): return self.serialize_input(x.value) elif isinstance(x, ep.CustomObjArgument): - return Argument.create(as_custom_obj=CustomObjArgument(name=x.name)) + return Argument.create(as_custom_obj=CustomObjArgument(name=x.name, class_fqn=x.class_fqn)) else: raise AssertionError("TODO") @@ -1154,7 +1157,8 @@ class GraphModuleDeserializer: def deserialize_script_obj_meta(self, script_obj_meta: ScriptObjectMeta) -> export_ScriptObjectMeta: return export_ScriptObjectMeta( - constant_name=script_obj_meta.constant_name + constant_name=script_obj_meta.constant_name, + class_fqn=script_obj_meta.class_fqn, ) def deserialize_graph_output(self, output) -> torch.fx.Node: @@ -1296,10 +1300,10 @@ class GraphModuleDeserializer: arg=ep.TensorArgument(name=i.tensor_constant.arg.name), target=i.tensor_constant.tensor_constant_name, ) - elif i.custom_obj is not None: + elif i.type == "custom_obj": return ep.InputSpec( kind=ep.InputKind.CUSTOM_OBJ, - arg=ep.CustomObjArgument(name=i.custom_obj.arg.name), + arg=ep.CustomObjArgument(name=i.custom_obj.arg.name, class_fqn=i.custom_obj.arg.class_fqn), target=i.custom_obj.custom_obj_name, ) else: diff --git a/torch/export/custom_obj.py b/torch/export/custom_obj.py index d0a7aedbee15..8e7f2080a4ee 100644 --- a/torch/export/custom_obj.py +++ b/torch/export/custom_obj.py @@ -1,5 +1,4 @@ from dataclasses import dataclass -from typing import Optional __all__ = ["ScriptObjectMeta"] @@ -12,4 +11,6 @@ class ScriptObjectMeta: """ # Key into constants table to retrieve the real ScriptObject. - constant_name: Optional[str] + constant_name: str + + class_fqn: str diff --git a/torch/export/exported_program.py b/torch/export/exported_program.py index 7d490f6c2a53..30fd13c734fc 100644 --- a/torch/export/exported_program.py +++ b/torch/export/exported_program.py @@ -516,7 +516,9 @@ class ExportedProgram: old_input_spec = old_signature.input_specs[i] arg = ( old_input_spec.arg - if isinstance(old_input_spec.arg, ConstantArgument) + if isinstance( + old_input_spec.arg, (ConstantArgument, CustomObjArgument) + ) else type(old_input_spec.arg)(node.name) ) new_input_specs.append( @@ -534,7 +536,9 @@ class ExportedProgram: old_output_spec = old_signature.output_specs[i] arg = ( old_output_spec.arg - if isinstance(old_output_spec.arg, ConstantArgument) + if isinstance( + old_output_spec.arg, (ConstantArgument, CustomObjArgument) + ) else type(old_output_spec.arg)(node.name) ) new_output_specs.append( diff --git a/torch/export/graph_signature.py b/torch/export/graph_signature.py index cf781849ad93..4feccce20f21 100644 --- a/torch/export/graph_signature.py +++ b/torch/export/graph_signature.py @@ -31,6 +31,7 @@ class SymIntArgument: @dataclasses.dataclass class CustomObjArgument: name: str + class_fqn: str @dataclasses.dataclass