mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[ONNX] Merge 'initializers' into 'TorchScriptGraph' (#95676)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/95676 Approved by: https://github.com/titaiwangms, https://github.com/wschin
This commit is contained in:
committed by
PyTorch MergeBot
parent
e9e6b3b6c5
commit
0f4652f498
@ -64,7 +64,7 @@ if [[ "$BUILD_ENVIRONMENT" == *onnx* ]]; then
|
||||
# TODO: change this when onnx 1.13.1 is released.
|
||||
pip install --no-use-pep517 'onnx @ git+https://github.com/onnx/onnx@e192ba01e438d22ca2dedd7956e28e3551626c91'
|
||||
# TODO: change this when onnx-script is on testPypi
|
||||
pip install 'onnx-script @ git+https://github.com/microsoft/onnx-script@0298154caf6b46fc4e30abba034095c1290c26e3'
|
||||
pip install 'onnx-script @ git+https://github.com/microsoft/onnx-script@29241e15f5182be1384f1cf6ba203d7e2e125196'
|
||||
# numba requires numpy <= 1.20, onnxruntime requires numpy >= 1.21.
|
||||
# We don't actually need it for our tests, but it's imported if it's present, so uninstall.
|
||||
pip uninstall -q --yes numba
|
||||
|
@ -252,9 +252,6 @@ def _export_fx_node_to_onnxscript(
|
||||
fx_name_to_onnxscipt_value: Dict[
|
||||
str, Union[torch._C.Value, Tuple[torch._C.Value, ...]]
|
||||
],
|
||||
onnxscript_value_name_to_real_tensor: Dict[
|
||||
str, Union[torch.Tensor, Tuple[torch._C.Value, ...]]
|
||||
],
|
||||
tracer: graph_building.TorchScriptTracingEvaluator,
|
||||
fx_module_with_metadata: torch.fx.GraphModule,
|
||||
options: options.ExportOptions,
|
||||
@ -388,7 +385,9 @@ def _export_fx_node_to_onnxscript(
|
||||
assert isinstance(input_, graph_building.TorchScriptTensor)
|
||||
assert isinstance(input_, onnxscript.tensor.Tensor)
|
||||
fx_name_to_onnxscipt_value[node.name] = input_
|
||||
onnxscript_value_name_to_real_tensor[input_.name] = current_attr # type: ignore[assignment]
|
||||
# FIXME: Refactor logic getting 'current_attr'.
|
||||
assert isinstance(current_attr, torch.Tensor)
|
||||
onnxscript_graph.add_initializer(input_.name, current_attr)
|
||||
else:
|
||||
# TODO(wechi): Support get_attr, call_module, call_method.
|
||||
raise RuntimeError(f"Found node type not defined in torch.fx: {node.op}")
|
||||
@ -413,18 +412,11 @@ def _export_fx_to_onnxscript(
|
||||
fx_name_to_onnxscipt_value: Dict[
|
||||
str, Union[torch._C.Value, Tuple[torch._C.Value, ...]]
|
||||
] = {}
|
||||
# Similar to fx_name_to_onnxscipt_value, we need a mapping fo real tensors (usually tensor parameters
|
||||
# in nn.Module). Note that TorchScript's cannot store real tensors; TorchScript values are all
|
||||
# symbolic. This is passed into ONNX ModelProto as the initializers.
|
||||
onnxscript_value_name_to_real_tensor: Dict[
|
||||
str, Union[torch.Tensor, Tuple[torch._C.Value, ...]]
|
||||
] = {}
|
||||
for node in fx_module_with_metadata.graph.nodes:
|
||||
_export_fx_node_to_onnxscript(
|
||||
node,
|
||||
onnxscript_graph,
|
||||
fx_name_to_onnxscipt_value,
|
||||
onnxscript_value_name_to_real_tensor,
|
||||
tracer,
|
||||
fx_module_with_metadata,
|
||||
options,
|
||||
@ -439,7 +431,7 @@ def _export_fx_to_onnxscript(
|
||||
opset_version=options.opset_version,
|
||||
)
|
||||
|
||||
return onnxscript_graph, onnxscript_value_name_to_real_tensor
|
||||
return onnxscript_graph
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
@ -531,13 +523,9 @@ def _export(
|
||||
# ONNX exporter used in _ts_graph_to_onnx_model_in_protobuf is not compatible
|
||||
# with FakeTensorMode.
|
||||
with torch.utils._mode_utils.no_dispatch():
|
||||
onnxscript_graph, initializers = _export_fx_to_onnxscript(
|
||||
decomposed_module, export_options
|
||||
)
|
||||
onnxscript_graph = _export_fx_to_onnxscript(decomposed_module, export_options)
|
||||
# Export TorchScript graph to ONNX ModelProto.
|
||||
onnx_model = onnxscript_graph.to_model_proto(
|
||||
initializers, export_options.opset_version
|
||||
)
|
||||
onnx_model = onnxscript_graph.to_model_proto(export_options.opset_version)
|
||||
|
||||
if export_options.use_binary_format:
|
||||
# Return ModelProto in binary format.
|
||||
|
Reference in New Issue
Block a user