[Export] Introduce class_fqn into CustomObjArgument (#118158)

Summary:
Class FQN is needed when unpacking CustomObj instance.
For all other Arguments, e.g. Tensor, TensorList, SymInt, we always know their exact type. However, CustomObjArgument had an opaque type.
Adding this field also helps unveiling the type of this opaque object.

Test Plan: CI

Differential Revision: D53029847

Pull Request resolved: https://github.com/pytorch/pytorch/pull/118158
Approved by: https://github.com/zhxchen17
This commit is contained in:
Sherlock Huang
2024-01-25 18:44:25 +00:00
committed by PyTorch MergeBot
parent fed0f2946f
commit bb3db079b1
7 changed files with 47 additions and 19 deletions

View File

@ -754,6 +754,11 @@ class TestSerializeCustomClass(TestCase):
node.args = (arg0, custom_node)
serialized_vals = serialize(ep)
ep_str = serialized_vals.exported_program.decode("utf-8")
assert "class_fqn" in ep_str
assert custom_obj._type().qualified_name() in ep_str
deserialized_ep = deserialize(serialized_vals)
for node in deserialized_ep.graph.nodes:
@ -763,6 +768,7 @@ class TestSerializeCustomClass(TestCase):
):
arg = node.args[0]
self.assertTrue(isinstance(arg, torch._C.ScriptObject))
self.assertEqual(arg._type(), custom_obj._type())
self.assertEqual(arg.__getstate__(), custom_obj.__getstate__())
self.assertEqual(arg.top(), 7)

View File

@ -5,6 +5,7 @@ from torch._export.verifier import SpecViolationError
from torch._guards import detect_fake_mode
from torch.export.custom_obj import ScriptObjectMeta
from torch.export.exported_program import (
ArgumentSpec,
CustomObjArgument,
ExportGraphSignature,
InputKind,
@ -54,12 +55,10 @@ def lift_constants_pass(
if isinstance(constant_val, torch.ScriptObject):
constant_name = f"_lifted_custom_obj{num_custom_obj}"
constant_kind = InputKind.CUSTOM_OBJ
constant_arg_cls = CustomObjArgument # type: ignore[assignment]
num_custom_obj += 1
elif isinstance(constant_val, torch.Tensor):
constant_name = f"_lifted_tensor_constant{num_tensor_constants}"
constant_kind = InputKind.CONSTANT_TENSOR
constant_arg_cls = TensorArgument # type: ignore[assignment]
num_tensor_constants += 1
elif isinstance(constant_val, torch.fx.GraphModule):
continue
@ -88,15 +87,25 @@ def lift_constants_pass(
else:
constant_fqn = constant_name
input_spec_arg: ArgumentSpec
if isinstance(constant_val, torch.Tensor):
const_placeholder_node.meta["val"] = fake_mode.from_tensor(
constant_val, static_shapes=True
)
const_placeholder_node.meta["val"].constant = constant_val
input_spec_arg = TensorArgument(name=const_placeholder_node.name)
elif isinstance(constant_val, torch._C.ScriptObject):
const_placeholder_node.meta["val"] = ScriptObjectMeta(constant_fqn)
class_fqn = constant_val._type().qualified_name() # type: ignore[attr-defined]
const_placeholder_node.meta["val"] = ScriptObjectMeta(
constant_fqn, class_fqn
)
input_spec_arg = CustomObjArgument(
name=const_placeholder_node.name, class_fqn=class_fqn
)
else:
const_placeholder_node.meta["val"] = constant_val
# TODO: use of TensorArgument doesn't look right, what's this branch for?
input_spec_arg = TensorArgument(name=const_placeholder_node.name)
node.replace_all_uses_with(const_placeholder_node)
gm.graph.erase_node(node)
@ -106,7 +115,7 @@ def lift_constants_pass(
first_user_input_loc,
InputSpec(
kind=constant_kind,
arg=constant_arg_cls(name=const_placeholder_node.name),
arg=input_spec_arg,
target=constant_fqn,
),
)
@ -133,7 +142,8 @@ def rewrite_script_object_meta(
continue
old_meta = node.meta["val"]
new_meta = ScriptObjectMeta(node.name)
class_fqn = old_meta._type().qualified_name() # type: ignore[attr-defined]
new_meta = ScriptObjectMeta(node.name, class_fqn)
constants[node.name] = old_meta
node.meta["val"] = new_meta

View File

@ -96,7 +96,8 @@ class TensorMeta:
@dataclass
class ScriptObjectMeta:
constant_name: Optional[str]
constant_name: str
class_fqn: str
# In most cases we will use the "as_name" field to store arguments which are
@ -147,6 +148,7 @@ class GraphArgument:
@dataclass
class CustomObjArgument:
name: str
class_fqn: str
# This is actually a union type

View File

@ -347,7 +347,8 @@ class GraphModuleSerializer:
elif isinstance(node.meta['val'], (int, bool, str, float, type(None))):
graph_input = self.serialize_input(node.meta['val'])
elif isinstance(node.meta['val'], export_ScriptObjectMeta):
graph_input = Argument.create(as_custom_obj=CustomObjArgument(name=node.name))
class_fqn = node.meta["val"].class_fqn
graph_input = Argument.create(as_custom_obj=CustomObjArgument(name=node.name, class_fqn=class_fqn))
self.graph_state.script_object_metas[node.name] = self.serialize_script_obj_meta(node.meta["val"])
else:
raise AssertionError(f"Unimplemented graph input type: {node.meta['val']}")
@ -487,7 +488,8 @@ class GraphModuleSerializer:
def serialize_script_obj_meta(self, script_obj_meta: export_ScriptObjectMeta) -> ScriptObjectMeta:
return ScriptObjectMeta(
constant_name=script_obj_meta.constant_name
constant_name=script_obj_meta.constant_name,
class_fqn=script_obj_meta.class_fqn,
)
def serialize_sym_op_inputs(self, op, args) -> List[NamedArgument]:
@ -564,7 +566,7 @@ class GraphModuleSerializer:
return Argument.create(as_sym_bool=SymBoolArgument.create(as_name=arg.name))
else:
if isinstance(arg.meta["val"], export_ScriptObjectMeta):
return Argument.create(as_custom_obj=CustomObjArgument(name=arg.name))
return Argument.create(as_custom_obj=CustomObjArgument(name=arg.name, class_fqn=arg.meta["val"].class_fqn))
return Argument.create(as_tensor=TensorArgument(name=arg.name))
elif isinstance(arg, inductor_tensor_buffers):
# Other branches are for arguments in fx node.
@ -688,7 +690,8 @@ class GraphModuleSerializer:
# serialize/deserialize function.
custom_obj_name = f"_custom_obj_{len(self.custom_objs)}"
self.custom_objs[custom_obj_name] = arg
return Argument.create(as_custom_obj=CustomObjArgument(custom_obj_name))
class_fqn = arg._type().qualified_name() # type: ignore[attr-defined]
return Argument.create(as_custom_obj=CustomObjArgument(custom_obj_name, class_fqn))
else:
raise SerializeError(f"Unsupported argument type: {type(arg)}")
@ -746,7 +749,7 @@ class GraphModuleSerializer:
assert isinstance(spec.arg, ep.CustomObjArgument)
return InputSpec.create(
custom_obj=InputToCustomObjSpec(
arg=CustomObjArgument(name=spec.arg.name),
arg=CustomObjArgument(name=spec.arg.name, class_fqn=spec.arg.class_fqn),
custom_obj_name=spec.target,
)
)
@ -811,7 +814,7 @@ class GraphModuleSerializer:
elif isinstance(x, ep.ConstantArgument):
return self.serialize_input(x.value)
elif isinstance(x, ep.CustomObjArgument):
return Argument.create(as_custom_obj=CustomObjArgument(name=x.name))
return Argument.create(as_custom_obj=CustomObjArgument(name=x.name, class_fqn=x.class_fqn))
else:
raise AssertionError("TODO")
@ -1154,7 +1157,8 @@ class GraphModuleDeserializer:
def deserialize_script_obj_meta(self, script_obj_meta: ScriptObjectMeta) -> export_ScriptObjectMeta:
return export_ScriptObjectMeta(
constant_name=script_obj_meta.constant_name
constant_name=script_obj_meta.constant_name,
class_fqn=script_obj_meta.class_fqn,
)
def deserialize_graph_output(self, output) -> torch.fx.Node:
@ -1296,10 +1300,10 @@ class GraphModuleDeserializer:
arg=ep.TensorArgument(name=i.tensor_constant.arg.name),
target=i.tensor_constant.tensor_constant_name,
)
elif i.custom_obj is not None:
elif i.type == "custom_obj":
return ep.InputSpec(
kind=ep.InputKind.CUSTOM_OBJ,
arg=ep.CustomObjArgument(name=i.custom_obj.arg.name),
arg=ep.CustomObjArgument(name=i.custom_obj.arg.name, class_fqn=i.custom_obj.arg.class_fqn),
target=i.custom_obj.custom_obj_name,
)
else:

View File

@ -1,5 +1,4 @@
from dataclasses import dataclass
from typing import Optional
__all__ = ["ScriptObjectMeta"]
@ -12,4 +11,6 @@ class ScriptObjectMeta:
"""
# Key into constants table to retrieve the real ScriptObject.
constant_name: Optional[str]
constant_name: str
class_fqn: str

View File

@ -516,7 +516,9 @@ class ExportedProgram:
old_input_spec = old_signature.input_specs[i]
arg = (
old_input_spec.arg
if isinstance(old_input_spec.arg, ConstantArgument)
if isinstance(
old_input_spec.arg, (ConstantArgument, CustomObjArgument)
)
else type(old_input_spec.arg)(node.name)
)
new_input_specs.append(
@ -534,7 +536,9 @@ class ExportedProgram:
old_output_spec = old_signature.output_specs[i]
arg = (
old_output_spec.arg
if isinstance(old_output_spec.arg, ConstantArgument)
if isinstance(
old_output_spec.arg, (ConstantArgument, CustomObjArgument)
)
else type(old_output_spec.arg)(node.name)
)
new_output_specs.append(

View File

@ -31,6 +31,7 @@ class SymIntArgument:
@dataclasses.dataclass
class CustomObjArgument:
name: str
class_fqn: str
@dataclasses.dataclass