[ONNX] Merge 'initializers' into 'TorchScriptGraph' (#95676)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/95676
Approved by: https://github.com/titaiwangms, https://github.com/wschin
This commit is contained in:
BowenBao
2023-03-08 09:55:07 -08:00
committed by PyTorch MergeBot
parent e9e6b3b6c5
commit 0f4652f498
2 changed files with 7 additions and 19 deletions

View File

@ -64,7 +64,7 @@ if [[ "$BUILD_ENVIRONMENT" == *onnx* ]]; then
# TODO: change this when onnx 1.13.1 is released.
pip install --no-use-pep517 'onnx @ git+https://github.com/onnx/onnx@e192ba01e438d22ca2dedd7956e28e3551626c91'
# TODO: change this when onnx-script is on testPypi
pip install 'onnx-script @ git+https://github.com/microsoft/onnx-script@0298154caf6b46fc4e30abba034095c1290c26e3'
pip install 'onnx-script @ git+https://github.com/microsoft/onnx-script@29241e15f5182be1384f1cf6ba203d7e2e125196'
# numba requires numpy <= 1.20, onnxruntime requires numpy >= 1.21.
# We don't actually need it for our tests, but it's imported if it's present, so uninstall.
pip uninstall -q --yes numba

View File

@ -252,9 +252,6 @@ def _export_fx_node_to_onnxscript(
fx_name_to_onnxscipt_value: Dict[
str, Union[torch._C.Value, Tuple[torch._C.Value, ...]]
],
onnxscript_value_name_to_real_tensor: Dict[
str, Union[torch.Tensor, Tuple[torch._C.Value, ...]]
],
tracer: graph_building.TorchScriptTracingEvaluator,
fx_module_with_metadata: torch.fx.GraphModule,
options: options.ExportOptions,
@ -388,7 +385,9 @@ def _export_fx_node_to_onnxscript(
assert isinstance(input_, graph_building.TorchScriptTensor)
assert isinstance(input_, onnxscript.tensor.Tensor)
fx_name_to_onnxscipt_value[node.name] = input_
onnxscript_value_name_to_real_tensor[input_.name] = current_attr # type: ignore[assignment]
# FIXME: Refactor logic getting 'current_attr'.
assert isinstance(current_attr, torch.Tensor)
onnxscript_graph.add_initializer(input_.name, current_attr)
else:
# TODO(wechi): Support get_attr, call_module, call_method.
raise RuntimeError(f"Found node type not defined in torch.fx: {node.op}")
@ -413,18 +412,11 @@ def _export_fx_to_onnxscript(
fx_name_to_onnxscipt_value: Dict[
str, Union[torch._C.Value, Tuple[torch._C.Value, ...]]
] = {}
# Similar to fx_name_to_onnxscipt_value, we need a mapping fo real tensors (usually tensor parameters
# in nn.Module). Note that TorchScript's cannot store real tensors; TorchScript values are all
# symbolic. This is passed into ONNX ModelProto as the initializers.
onnxscript_value_name_to_real_tensor: Dict[
str, Union[torch.Tensor, Tuple[torch._C.Value, ...]]
] = {}
for node in fx_module_with_metadata.graph.nodes:
_export_fx_node_to_onnxscript(
node,
onnxscript_graph,
fx_name_to_onnxscipt_value,
onnxscript_value_name_to_real_tensor,
tracer,
fx_module_with_metadata,
options,
@ -439,7 +431,7 @@ def _export_fx_to_onnxscript(
opset_version=options.opset_version,
)
return onnxscript_graph, onnxscript_value_name_to_real_tensor
return onnxscript_graph
@_beartype.beartype
@ -531,13 +523,9 @@ def _export(
# ONNX exporter used in _ts_graph_to_onnx_model_in_protobuf is not compatible
# with FakeTensorMode.
with torch.utils._mode_utils.no_dispatch():
onnxscript_graph, initializers = _export_fx_to_onnxscript(
decomposed_module, export_options
)
onnxscript_graph = _export_fx_to_onnxscript(decomposed_module, export_options)
# Export TorchScript graph to ONNX ModelProto.
onnx_model = onnxscript_graph.to_model_proto(
initializers, export_options.opset_version
)
onnx_model = onnxscript_graph.to_model_proto(export_options.opset_version)
if export_options.use_binary_format:
# Return ModelProto in binary format.