mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[Dynamo, ONNX] Replace onnxrt
backend with new backend from ONNXRuntime team (#106929)
In https://github.com/pytorch/pytorch/pull/106589, a new ONNXRuntime-based Dynamo backend is introduced. As mentioned in that PR, we hope to replace legacy `onnxrt` with that new backend. This PR remove legacy `onnxrt` and register the new backend under the same name. Pull Request resolved: https://github.com/pytorch/pytorch/pull/106929 Approved by: https://github.com/thiagocrepaldi, https://github.com/BowenBao, https://github.com/abock, https://github.com/msaroufim, https://github.com/jansel
This commit is contained in:
committed by
PyTorch MergeBot
parent
d290511ecd
commit
22f5889753
1
.github/merge_rules.yaml
vendored
1
.github/merge_rules.yaml
vendored
@ -10,6 +10,7 @@
|
||||
- scripts/onnx/**
|
||||
- test/onnx/**
|
||||
- tools/onnx/**
|
||||
- torch/_dynamo/backends/onnxrt.py
|
||||
- torch/_C/__init__.pyi.in
|
||||
- torch/_C/_onnx.pyi
|
||||
- torch/csrc/jit/passes/onnx.*
|
||||
|
@ -57,11 +57,12 @@ nn/qat/ @jerryzh168
|
||||
/torch/testing/_internal/distributed @mrshenli @zhaojuanmao @rohan-varma @H-Huang @awgu @kwen2501 @wanchaol @fegin @fduwjj
|
||||
|
||||
# ONNX Export
|
||||
/torch/_dynamo/backends/onnxrt.py @bowenbao @abock @thiagocrepaldi @wschin
|
||||
/torch/csrc/jit/passes/onnx.h @bowenbao @abock @thiagocrepaldi
|
||||
/torch/csrc/jit/passes/onnx.cpp @bowenbao @abock @thiagocrepaldi
|
||||
/torch/csrc/jit/passes/onnx/ @bowenbao @abock @thiagocrepaldi
|
||||
/torch/onnx/ @bowenbao @abock @thiagocrepaldi
|
||||
/test/onnx/ @bowenbao @abock @thiagocrepaldi
|
||||
/torch/onnx/ @bowenbao @abock @thiagocrepaldi @wschin
|
||||
/test/onnx/ @bowenbao @abock @thiagocrepaldi @wschin
|
||||
|
||||
# Docker
|
||||
/.ci/docker/ @jeffdaily
|
||||
|
@ -8,6 +8,7 @@ from typing import Tuple
|
||||
|
||||
import torch
|
||||
import torch.onnx
|
||||
from parameterized import parameterized
|
||||
from torch import nn
|
||||
|
||||
from torch.onnx._internal.onnxruntime import make_aot_ort, OrtBackend
|
||||
@ -99,9 +100,13 @@ class TestDynamoWithONNXRuntime(onnx_test_common._TestONNXRuntime):
|
||||
):
|
||||
self.assertEqual(len(onnx_info), expected_number_of_onnx_models)
|
||||
|
||||
def test_elementwise_function_single_output(self):
|
||||
local_aot_ort, local_ort = make_aot_ort(dynamic=True)
|
||||
|
||||
@parameterized.expand(
|
||||
[
|
||||
(True,),
|
||||
(False,),
|
||||
]
|
||||
)
|
||||
def test_elementwise_function_single_output(self, test_local_backend: bool):
|
||||
example_args_collection = tuple(
|
||||
(torch.randn(batch, dtype=torch.float32),) for batch in (2, 4, 6, 8, 10)
|
||||
)
|
||||
@ -111,12 +116,24 @@ class TestDynamoWithONNXRuntime(onnx_test_common._TestONNXRuntime):
|
||||
z = y.sigmoid()
|
||||
return z
|
||||
|
||||
if test_local_backend:
|
||||
local_aot_ort, local_ort = make_aot_ort(dynamic=True)
|
||||
else:
|
||||
# This will use the global ONNXRuntime backend registered
|
||||
# in Dynamo to compile the tested model.
|
||||
local_aot_ort, local_ort = "onnxrt", None
|
||||
|
||||
self._test_model_numerically(
|
||||
elementwise_model,
|
||||
local_aot_ort,
|
||||
example_args_collection,
|
||||
)
|
||||
|
||||
# We can only check local backend's counting information
|
||||
# since global backend's counting information comes from
|
||||
# all compiled models.
|
||||
if test_local_backend:
|
||||
assert local_ort is not None
|
||||
self._assert_counting_information(
|
||||
local_ort,
|
||||
# OrtBackend._ort_acclerated_call should have been called 5 times because
|
||||
@ -130,9 +147,13 @@ class TestDynamoWithONNXRuntime(onnx_test_common._TestONNXRuntime):
|
||||
number_of_exported_onnx_models_for_all_graph_modules=(1,),
|
||||
)
|
||||
|
||||
def test_elementwise_function_multiple_output(self):
|
||||
local_aot_ort, local_ort = make_aot_ort(dynamic=True)
|
||||
|
||||
@parameterized.expand(
|
||||
[
|
||||
(True,),
|
||||
(False,),
|
||||
]
|
||||
)
|
||||
def test_elementwise_function_multiple_output(self, test_local_backend: bool):
|
||||
example_args_collection = tuple(
|
||||
(torch.randn(batch, dtype=torch.float32),) for batch in (2, 4, 8)
|
||||
)
|
||||
@ -143,12 +164,19 @@ class TestDynamoWithONNXRuntime(onnx_test_common._TestONNXRuntime):
|
||||
z = y * y
|
||||
return x, y, z
|
||||
|
||||
if test_local_backend:
|
||||
local_aot_ort, local_ort = make_aot_ort(dynamic=True)
|
||||
else:
|
||||
local_aot_ort, local_ort = "onnxrt", None
|
||||
|
||||
self._test_model_numerically(
|
||||
elementwise_model_with_multiple_outputs,
|
||||
local_aot_ort,
|
||||
example_args_collection,
|
||||
)
|
||||
|
||||
if test_local_backend:
|
||||
assert local_ort is not None
|
||||
self._assert_counting_information(
|
||||
local_ort,
|
||||
expected_execution_count=len(example_args_collection),
|
||||
@ -156,9 +184,13 @@ class TestDynamoWithONNXRuntime(onnx_test_common._TestONNXRuntime):
|
||||
number_of_exported_onnx_models_for_all_graph_modules=(1,),
|
||||
)
|
||||
|
||||
def test_mlp_with_local_backend(self):
|
||||
local_aot_ort, local_ort = make_aot_ort(dynamic=True)
|
||||
|
||||
@parameterized.expand(
|
||||
[
|
||||
(True,),
|
||||
(False,),
|
||||
]
|
||||
)
|
||||
def test_mlp_with_local_backend(self, test_local_backend: bool):
|
||||
example_args_collection = tuple(
|
||||
(torch.randn(batch, 2, dtype=torch.float32),) for batch in (1, 2, 4, 6, 8)
|
||||
)
|
||||
@ -176,12 +208,19 @@ class TestDynamoWithONNXRuntime(onnx_test_common._TestONNXRuntime):
|
||||
tensor_x = torch.sigmoid(tensor_x)
|
||||
return tensor_x
|
||||
|
||||
if test_local_backend:
|
||||
local_aot_ort, local_ort = make_aot_ort(dynamic=True)
|
||||
else:
|
||||
local_aot_ort, local_ort = "onnxrt", None
|
||||
|
||||
self._test_model_numerically(
|
||||
MLP(),
|
||||
local_aot_ort,
|
||||
example_args_collection,
|
||||
)
|
||||
|
||||
if test_local_backend:
|
||||
assert local_ort is not None
|
||||
self._assert_counting_information(
|
||||
local_ort,
|
||||
# OrtBackend._ort_acclerated_call should have been called 5 times because
|
||||
|
@ -1,131 +1,28 @@
|
||||
import importlib
|
||||
import logging
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
import torch
|
||||
from .common import device_from_inputs, fake_tensor_unsupported
|
||||
# This backend is maintained by ONNX team. To direct issues
|
||||
# to the right people, please tag related GitHub issues with `module: onnx`.
|
||||
#
|
||||
# Maintainers' Github IDs: wschin, thiagocrepaldi, BowenBao, abock
|
||||
from torch.onnx._internal.onnxruntime import has_onnxruntime, make_aot_ort
|
||||
from .registry import register_backend
|
||||
|
||||
try:
|
||||
import numpy as np
|
||||
if has_onnxruntime():
|
||||
aot_ort, ort = make_aot_ort(dynamic=True)
|
||||
register_backend(name="onnxrt", compiler_fn=aot_ort)
|
||||
else:
|
||||
|
||||
_np_dtype = {
|
||||
torch.float16: np.float16,
|
||||
torch.float32: np.float32,
|
||||
torch.float64: np.float64,
|
||||
torch.uint8: np.uint8,
|
||||
torch.int8: np.int8,
|
||||
torch.int16: np.int16,
|
||||
torch.int32: np.int32,
|
||||
torch.int64: np.longlong,
|
||||
torch.bool: np.bool_,
|
||||
}
|
||||
|
||||
except ImportError:
|
||||
_np_dtype = None
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def default_provider(device_type):
|
||||
if "ONNXRT_PROVIDER" in os.environ:
|
||||
return os.environ["ONNXRT_PROVIDER"]
|
||||
return {
|
||||
"cpu": "CPUExecutionProvider",
|
||||
"cuda": "CUDAExecutionProvider",
|
||||
# "TensorrtExecutionProvider" is another option
|
||||
}[device_type]
|
||||
|
||||
|
||||
def has_onnxruntime():
|
||||
try:
|
||||
importlib.import_module("onnxruntime")
|
||||
return True
|
||||
except ImportError:
|
||||
return False
|
||||
|
||||
|
||||
@register_backend
|
||||
@fake_tensor_unsupported
|
||||
def onnxrt(gm, example_inputs, *, filename=None, provider=None):
|
||||
if filename is None:
|
||||
with tempfile.NamedTemporaryFile(suffix=".onnx") as tmp:
|
||||
return onnxrt(gm, example_inputs, filename=tmp.name)
|
||||
|
||||
import onnxruntime # type: ignore[import]
|
||||
|
||||
assert _np_dtype, "requires numpy"
|
||||
|
||||
device_type = device_from_inputs(example_inputs).type
|
||||
example_outputs = gm(*example_inputs)
|
||||
if len(example_outputs) == 0:
|
||||
log.warning("Explicitly fall back to eager due to zero output")
|
||||
return gm.forward
|
||||
output_spec = [
|
||||
(o.shape, o.dtype, o.layout, o.device, o.requires_grad) for o in example_outputs
|
||||
]
|
||||
input_names = [f"i{i}" for i in range(len(example_inputs))]
|
||||
output_names = [f"o{x}" for x in range(len(example_outputs))]
|
||||
|
||||
torch.onnx.export(
|
||||
torch.jit.script(gm),
|
||||
example_inputs,
|
||||
filename,
|
||||
input_names=input_names,
|
||||
output_names=output_names,
|
||||
def information_displaying_backend(*args, **kwargs):
|
||||
raise ImportError(
|
||||
"onnxrt is not registered as a backend. "
|
||||
"Please make sure all dependencies such as "
|
||||
"numpy, onnx, onnxscript, and onnxruntime-training are installed. "
|
||||
"Suggested procedure to fix dependency problem: "
|
||||
"(1) pip or conda install numpy onnx onnxscript onnxruntime-training. "
|
||||
"(2) open a new python terminal "
|
||||
"(3) Run `from torch.onnx._internal.onnxruntime import has_onnxruntime`. "
|
||||
"(4) Run `has_onnxruntime()`. "
|
||||
"(5) If has_onnxruntime() returns True, then you can use `onnxrt` backend. "
|
||||
"(6) If has_onnxruntime() returns False, please execute the package importing section in "
|
||||
"torch/onnx/_internal/onnxruntime.py under pdb line-by-line to see which import fails."
|
||||
)
|
||||
del example_inputs, example_outputs
|
||||
|
||||
if provider is None:
|
||||
provider = default_provider(device_type)
|
||||
assert provider in onnxruntime.get_available_providers()
|
||||
session = onnxruntime.InferenceSession(filename, providers=[provider])
|
||||
|
||||
def _call(*initial_args):
|
||||
binding = session.io_binding()
|
||||
active_inputs = {inp.name for inp in session.get_inputs()}
|
||||
args = [a.contiguous() for a in initial_args]
|
||||
for name, value in zip(input_names, args):
|
||||
if name not in active_inputs:
|
||||
log.warning(
|
||||
"input %s skipped as not found in onnx inference session", name
|
||||
)
|
||||
continue
|
||||
dev = value.device
|
||||
binding.bind_input(
|
||||
name,
|
||||
dev.type,
|
||||
dev.index or 0,
|
||||
_np_dtype[value.dtype],
|
||||
value.size(),
|
||||
value.data_ptr(),
|
||||
)
|
||||
outputs = [
|
||||
torch.empty(
|
||||
shape,
|
||||
dtype=dtype,
|
||||
layout=layout,
|
||||
device=device,
|
||||
requires_grad=requires_grad,
|
||||
)
|
||||
for shape, dtype, layout, device, requires_grad in output_spec
|
||||
]
|
||||
|
||||
for name, value in zip(output_names, outputs):
|
||||
dev = value.device
|
||||
binding.bind_output(
|
||||
name,
|
||||
dev.type,
|
||||
dev.index or 0,
|
||||
_np_dtype[value.dtype],
|
||||
value.size(),
|
||||
value.data_ptr(),
|
||||
)
|
||||
session.run_with_iobinding(binding)
|
||||
if device_type == "cpu":
|
||||
binding.copy_outputs_to_cpu()
|
||||
return outputs
|
||||
|
||||
return _call
|
||||
register_backend(name="onnxrt", compiler_fn=information_displaying_backend)
|
||||
|
Reference in New Issue
Block a user