[Dynamo, ONNX] Replace onnxrt backend with new backend from ONNXRuntime team (#106929)

In https://github.com/pytorch/pytorch/pull/106589, a new ONNXRuntime-based Dynamo backend is introduced. As mentioned in that PR, we hope to replace legacy `onnxrt` with that new backend. This PR remove legacy `onnxrt` and register the new backend under the same name.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/106929
Approved by: https://github.com/thiagocrepaldi, https://github.com/BowenBao, https://github.com/abock, https://github.com/msaroufim, https://github.com/jansel
This commit is contained in:
Wei-Sheng Chin
2023-08-15 22:50:42 +00:00
committed by PyTorch MergeBot
parent d290511ecd
commit 22f5889753
4 changed files with 107 additions and 169 deletions

View File

@ -10,6 +10,7 @@
- scripts/onnx/**
- test/onnx/**
- tools/onnx/**
- torch/_dynamo/backends/onnxrt.py
- torch/_C/__init__.pyi.in
- torch/_C/_onnx.pyi
- torch/csrc/jit/passes/onnx.*

View File

@ -57,11 +57,12 @@ nn/qat/ @jerryzh168
/torch/testing/_internal/distributed @mrshenli @zhaojuanmao @rohan-varma @H-Huang @awgu @kwen2501 @wanchaol @fegin @fduwjj
# ONNX Export
/torch/_dynamo/backends/onnxrt.py @bowenbao @abock @thiagocrepaldi @wschin
/torch/csrc/jit/passes/onnx.h @bowenbao @abock @thiagocrepaldi
/torch/csrc/jit/passes/onnx.cpp @bowenbao @abock @thiagocrepaldi
/torch/csrc/jit/passes/onnx/ @bowenbao @abock @thiagocrepaldi
/torch/onnx/ @bowenbao @abock @thiagocrepaldi
/test/onnx/ @bowenbao @abock @thiagocrepaldi
/torch/onnx/ @bowenbao @abock @thiagocrepaldi @wschin
/test/onnx/ @bowenbao @abock @thiagocrepaldi @wschin
# Docker
/.ci/docker/ @jeffdaily

View File

@ -8,6 +8,7 @@ from typing import Tuple
import torch
import torch.onnx
from parameterized import parameterized
from torch import nn
from torch.onnx._internal.onnxruntime import make_aot_ort, OrtBackend
@ -99,9 +100,13 @@ class TestDynamoWithONNXRuntime(onnx_test_common._TestONNXRuntime):
):
self.assertEqual(len(onnx_info), expected_number_of_onnx_models)
def test_elementwise_function_single_output(self):
local_aot_ort, local_ort = make_aot_ort(dynamic=True)
@parameterized.expand(
[
(True,),
(False,),
]
)
def test_elementwise_function_single_output(self, test_local_backend: bool):
example_args_collection = tuple(
(torch.randn(batch, dtype=torch.float32),) for batch in (2, 4, 6, 8, 10)
)
@ -111,12 +116,24 @@ class TestDynamoWithONNXRuntime(onnx_test_common._TestONNXRuntime):
z = y.sigmoid()
return z
if test_local_backend:
local_aot_ort, local_ort = make_aot_ort(dynamic=True)
else:
# This will use the global ONNXRuntime backend registered
# in Dynamo to compile the tested model.
local_aot_ort, local_ort = "onnxrt", None
self._test_model_numerically(
elementwise_model,
local_aot_ort,
example_args_collection,
)
# We can only check local backend's counting information
# since global backend's counting information comes from
# all compiled models.
if test_local_backend:
assert local_ort is not None
self._assert_counting_information(
local_ort,
# OrtBackend._ort_acclerated_call should have been called 5 times because
@ -130,9 +147,13 @@ class TestDynamoWithONNXRuntime(onnx_test_common._TestONNXRuntime):
number_of_exported_onnx_models_for_all_graph_modules=(1,),
)
def test_elementwise_function_multiple_output(self):
local_aot_ort, local_ort = make_aot_ort(dynamic=True)
@parameterized.expand(
[
(True,),
(False,),
]
)
def test_elementwise_function_multiple_output(self, test_local_backend: bool):
example_args_collection = tuple(
(torch.randn(batch, dtype=torch.float32),) for batch in (2, 4, 8)
)
@ -143,12 +164,19 @@ class TestDynamoWithONNXRuntime(onnx_test_common._TestONNXRuntime):
z = y * y
return x, y, z
if test_local_backend:
local_aot_ort, local_ort = make_aot_ort(dynamic=True)
else:
local_aot_ort, local_ort = "onnxrt", None
self._test_model_numerically(
elementwise_model_with_multiple_outputs,
local_aot_ort,
example_args_collection,
)
if test_local_backend:
assert local_ort is not None
self._assert_counting_information(
local_ort,
expected_execution_count=len(example_args_collection),
@ -156,9 +184,13 @@ class TestDynamoWithONNXRuntime(onnx_test_common._TestONNXRuntime):
number_of_exported_onnx_models_for_all_graph_modules=(1,),
)
def test_mlp_with_local_backend(self):
local_aot_ort, local_ort = make_aot_ort(dynamic=True)
@parameterized.expand(
[
(True,),
(False,),
]
)
def test_mlp_with_local_backend(self, test_local_backend: bool):
example_args_collection = tuple(
(torch.randn(batch, 2, dtype=torch.float32),) for batch in (1, 2, 4, 6, 8)
)
@ -176,12 +208,19 @@ class TestDynamoWithONNXRuntime(onnx_test_common._TestONNXRuntime):
tensor_x = torch.sigmoid(tensor_x)
return tensor_x
if test_local_backend:
local_aot_ort, local_ort = make_aot_ort(dynamic=True)
else:
local_aot_ort, local_ort = "onnxrt", None
self._test_model_numerically(
MLP(),
local_aot_ort,
example_args_collection,
)
if test_local_backend:
assert local_ort is not None
self._assert_counting_information(
local_ort,
# OrtBackend._ort_acclerated_call should have been called 5 times because

View File

@ -1,131 +1,28 @@
import importlib
import logging
import os
import tempfile
import torch
from .common import device_from_inputs, fake_tensor_unsupported
# This backend is maintained by ONNX team. To direct issues
# to the right people, please tag related GitHub issues with `module: onnx`.
#
# Maintainers' Github IDs: wschin, thiagocrepaldi, BowenBao, abock
from torch.onnx._internal.onnxruntime import has_onnxruntime, make_aot_ort
from .registry import register_backend
try:
import numpy as np
if has_onnxruntime():
aot_ort, ort = make_aot_ort(dynamic=True)
register_backend(name="onnxrt", compiler_fn=aot_ort)
else:
_np_dtype = {
torch.float16: np.float16,
torch.float32: np.float32,
torch.float64: np.float64,
torch.uint8: np.uint8,
torch.int8: np.int8,
torch.int16: np.int16,
torch.int32: np.int32,
torch.int64: np.longlong,
torch.bool: np.bool_,
}
except ImportError:
_np_dtype = None
log = logging.getLogger(__name__)
def default_provider(device_type):
if "ONNXRT_PROVIDER" in os.environ:
return os.environ["ONNXRT_PROVIDER"]
return {
"cpu": "CPUExecutionProvider",
"cuda": "CUDAExecutionProvider",
# "TensorrtExecutionProvider" is another option
}[device_type]
def has_onnxruntime():
try:
importlib.import_module("onnxruntime")
return True
except ImportError:
return False
@register_backend
@fake_tensor_unsupported
def onnxrt(gm, example_inputs, *, filename=None, provider=None):
if filename is None:
with tempfile.NamedTemporaryFile(suffix=".onnx") as tmp:
return onnxrt(gm, example_inputs, filename=tmp.name)
import onnxruntime # type: ignore[import]
assert _np_dtype, "requires numpy"
device_type = device_from_inputs(example_inputs).type
example_outputs = gm(*example_inputs)
if len(example_outputs) == 0:
log.warning("Explicitly fall back to eager due to zero output")
return gm.forward
output_spec = [
(o.shape, o.dtype, o.layout, o.device, o.requires_grad) for o in example_outputs
]
input_names = [f"i{i}" for i in range(len(example_inputs))]
output_names = [f"o{x}" for x in range(len(example_outputs))]
torch.onnx.export(
torch.jit.script(gm),
example_inputs,
filename,
input_names=input_names,
output_names=output_names,
def information_displaying_backend(*args, **kwargs):
raise ImportError(
"onnxrt is not registered as a backend. "
"Please make sure all dependencies such as "
"numpy, onnx, onnxscript, and onnxruntime-training are installed. "
"Suggested procedure to fix dependency problem: "
"(1) pip or conda install numpy onnx onnxscript onnxruntime-training. "
"(2) open a new python terminal "
"(3) Run `from torch.onnx._internal.onnxruntime import has_onnxruntime`. "
"(4) Run `has_onnxruntime()`. "
"(5) If has_onnxruntime() returns True, then you can use `onnxrt` backend. "
"(6) If has_onnxruntime() returns False, please execute the package importing section in "
"torch/onnx/_internal/onnxruntime.py under pdb line-by-line to see which import fails."
)
del example_inputs, example_outputs
if provider is None:
provider = default_provider(device_type)
assert provider in onnxruntime.get_available_providers()
session = onnxruntime.InferenceSession(filename, providers=[provider])
def _call(*initial_args):
binding = session.io_binding()
active_inputs = {inp.name for inp in session.get_inputs()}
args = [a.contiguous() for a in initial_args]
for name, value in zip(input_names, args):
if name not in active_inputs:
log.warning(
"input %s skipped as not found in onnx inference session", name
)
continue
dev = value.device
binding.bind_input(
name,
dev.type,
dev.index or 0,
_np_dtype[value.dtype],
value.size(),
value.data_ptr(),
)
outputs = [
torch.empty(
shape,
dtype=dtype,
layout=layout,
device=device,
requires_grad=requires_grad,
)
for shape, dtype, layout, device, requires_grad in output_spec
]
for name, value in zip(output_names, outputs):
dev = value.device
binding.bind_output(
name,
dev.type,
dev.index or 0,
_np_dtype[value.dtype],
value.size(),
value.data_ptr(),
)
session.run_with_iobinding(binding)
if device_type == "cpu":
binding.copy_outputs_to_cpu()
return outputs
return _call
register_backend(name="onnxrt", compiler_fn=information_displaying_backend)