mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[ONNX][DORT] Lazy-import onnxruntime
(#134662)
Currently, if installed, `onnxruntime` will be imported when importing `torch._inductor` (which will be imported by some other library, e.g. transformer-engine): ``` /mnt/c.py(53)<module>() -> from torch._inductor.utils import maybe_profile /usr/local/lib/python3.10/site-packages/torch/_inductor/utils.py(49)<module>() -> import torch._export /usr/local/lib/python3.10/site-packages/torch/_export/__init__.py(25)<module>() -> import torch._dynamo /usr/local/lib/python3.10/site-packages/torch/_dynamo/__init__.py(2)<module>() -> from . import convert_frame, eval_frame, resume_execution /usr/local/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py(48)<module>() -> from . import config, exc, trace_rules /usr/local/lib/python3.10/site-packages/torch/_dynamo/trace_rules.py(52)<module>() -> from .variables import ( /usr/local/lib/python3.10/site-packages/torch/_dynamo/variables/__init__.py(38)<module>() -> from .higher_order_ops import ( /usr/local/lib/python3.10/site-packages/torch/_dynamo/variables/higher_order_ops.py(14)<module>() -> import torch.onnx.operators /usr/local/lib/python3.10/site-packages/torch/onnx/__init__.py(62)<module>() -> from ._internal.onnxruntime import ( /usr/local/lib/python3.10/site-packages/torch/onnx/_internal/onnxruntime.py(37)<module>() -> import onnxruntime # type: ignore[import] ``` This issue breaks generated triton kernel because it imported torch, and unexpected runtime libraries as well. I've also added a test for this specific case under `test/onnx`, perhaps we should add more somewhere else? Related issue: https://github.com/huggingface/accelerate/pull/3056 Pull Request resolved: https://github.com/pytorch/pytorch/pull/134662 Approved by: https://github.com/justinchuby
This commit is contained in:
committed by
PyTorch MergeBot
parent
2384f77d76
commit
5dad6a5a84
@ -51,17 +51,19 @@ class TestDynamoWithONNXRuntime(onnx_test_common._TestONNXRuntime):
|
||||
OrtBackend.clear_cached_instances()
|
||||
|
||||
def test_get_ort_device_type(self):
|
||||
from onnxruntime.capi import _pybind_state as ORTC
|
||||
|
||||
self.assertEqual(
|
||||
torch.onnx._internal.onnxruntime._get_ort_device_type("cuda"),
|
||||
torch.onnx._internal.onnxruntime.ORTC.OrtDevice.cuda(),
|
||||
ORTC.OrtDevice.cuda(),
|
||||
)
|
||||
self.assertEqual(
|
||||
torch.onnx._internal.onnxruntime._get_ort_device_type("cpu"),
|
||||
torch.onnx._internal.onnxruntime.ORTC.OrtDevice.cpu(),
|
||||
ORTC.OrtDevice.cpu(),
|
||||
)
|
||||
self.assertEqual(
|
||||
torch.onnx._internal.onnxruntime._get_ort_device_type("maia"),
|
||||
torch.onnx._internal.onnxruntime.ORTC.OrtDevice.npu(),
|
||||
ORTC.OrtDevice.npu(),
|
||||
)
|
||||
|
||||
def test_torch_compile_backend_registration(self):
|
||||
|
37
test/onnx/test_lazy_import.py
Normal file
37
test/onnx/test_lazy_import.py
Normal file
@ -0,0 +1,37 @@
|
||||
# Owner(s): ["module: onnx"]
|
||||
|
||||
import subprocess
|
||||
import sys
|
||||
import tempfile
|
||||
|
||||
import pytorch_test_common
|
||||
|
||||
from torch.testing._internal import common_utils
|
||||
|
||||
|
||||
class TestLazyONNXPackages(pytorch_test_common.ExportTestCase):
|
||||
def _test_package_is_lazily_imported(self, pkg, torch_pkg="torch.onnx"):
|
||||
with tempfile.TemporaryDirectory() as wd:
|
||||
r = subprocess.run(
|
||||
[sys.executable, "-Ximporttime", "-c", "import torch.onnx"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
cwd=wd,
|
||||
check=True,
|
||||
)
|
||||
|
||||
# The extra space makes sure we're checking the package, not any package containing its name.
|
||||
self.assertTrue(
|
||||
f" {pkg}" not in r.stderr,
|
||||
f"`{pkg}` should not be imported, full importtime: {r.stderr}",
|
||||
)
|
||||
|
||||
def test_onnxruntime_is_lazily_imported(self):
|
||||
self._test_package_is_lazily_imported("onnxruntime")
|
||||
|
||||
def test_onnxscript_is_lazily_imported(self):
|
||||
self._test_package_is_lazily_imported("onnxscript")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
common_utils.run_tests()
|
@ -34,32 +34,18 @@ from torch.utils import _pytree
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import onnx
|
||||
|
||||
try:
|
||||
# Use try-except to initialize package-dependent global variables.
|
||||
import onnxruntime # type: ignore[import]
|
||||
from onnxruntime.capi import _pybind_state as ORTC # type: ignore[import]
|
||||
|
||||
# This is not use directly in DORT but needed by underlying exporter,
|
||||
# so we still need to check if it exists.
|
||||
importlib.import_module("onnxscript")
|
||||
import onnxruntime
|
||||
from onnxruntime.capi import _pybind_state as ORTC
|
||||
|
||||
import torch.onnx
|
||||
import torch.onnx._internal
|
||||
import torch.onnx._internal._exporter_legacy
|
||||
import torch.onnx._internal.diagnostics
|
||||
import torch.onnx._internal.fx.decomposition_table
|
||||
import torch.onnx._internal.fx.passes
|
||||
from torch.onnx._internal.fx import fx_onnx_interpreter
|
||||
from torch.onnx._internal.fx.type_utils import (
|
||||
_TORCH_DTYPE_TO_NUMPY_DTYPE,
|
||||
_TORCH_DTYPE_TO_ONNX_TENSOR_ELEMENT_TYPE,
|
||||
from_python_type_to_onnx_tensor_element_type,
|
||||
)
|
||||
import torch.onnx._internal.fx.passes # noqa: TCH004
|
||||
|
||||
_SUPPORT_ONNXRT = True
|
||||
except ImportError:
|
||||
_SUPPORT_ONNXRT = False
|
||||
|
||||
_SUPPORT_ONNXRT: Optional[bool] = None
|
||||
|
||||
__all__ = [
|
||||
"is_onnxrt_backend_supported",
|
||||
@ -87,6 +73,35 @@ def is_onnxrt_backend_supported() -> bool:
|
||||
... print("pip install onnx onnxscript onnxruntime")
|
||||
...
|
||||
"""
|
||||
global _SUPPORT_ONNXRT
|
||||
|
||||
if _SUPPORT_ONNXRT is None:
|
||||
# `onnxruntime` might import a lot of other runtime packages,
|
||||
# e.g. apex, deepspeed, transformers.
|
||||
# So lazy-importing onnxruntime to avoid possible circular import.
|
||||
try:
|
||||
importlib.import_module("onnxruntime")
|
||||
importlib.import_module("onnxruntime.capi._pybind_state")
|
||||
|
||||
# This is not use directly in DORT but needed by underlying exporter,
|
||||
# so we still need to check if it exists.
|
||||
importlib.import_module("onnxscript")
|
||||
|
||||
import torch.onnx # noqa: F401
|
||||
import torch.onnx._internal # noqa: F401
|
||||
import torch.onnx._internal._exporter_legacy # noqa: F401
|
||||
import torch.onnx._internal.diagnostics # noqa: F401
|
||||
from torch.onnx._internal.fx import ( # noqa: F401
|
||||
decomposition_table,
|
||||
fx_onnx_interpreter,
|
||||
passes,
|
||||
type_utils,
|
||||
)
|
||||
|
||||
_SUPPORT_ONNXRT = True
|
||||
except ImportError:
|
||||
_SUPPORT_ONNXRT = False
|
||||
|
||||
return _SUPPORT_ONNXRT
|
||||
|
||||
|
||||
@ -143,6 +158,8 @@ def _nvtx_range_pop():
|
||||
|
||||
|
||||
def _get_ort_device_type(device_type: str):
|
||||
from onnxruntime.capi import _pybind_state as ORTC
|
||||
|
||||
if device_type == "cuda":
|
||||
return ORTC.OrtDevice.cuda()
|
||||
if device_type == "cpu":
|
||||
@ -305,6 +322,8 @@ def _get_onnx_devices(
|
||||
...,
|
||||
],
|
||||
) -> Tuple["ORTC.OrtDevice", ...]:
|
||||
from onnxruntime.capi import _pybind_state as ORTC
|
||||
|
||||
def _device_id_or_zero(device_id: int) -> int:
|
||||
return device_id or 0
|
||||
|
||||
@ -338,6 +357,10 @@ def _get_onnx_devices(
|
||||
def _get_ortvalues_from_torch_tensors(
|
||||
tensors: Tuple[torch.Tensor, ...], devices: Tuple["ORTC.OrtDevice", ...]
|
||||
) -> Tuple[torch.Tensor, ...]:
|
||||
from onnxruntime.capi import _pybind_state as ORTC
|
||||
|
||||
from torch.onnx._internal.fx.type_utils import _TORCH_DTYPE_TO_NUMPY_DTYPE
|
||||
|
||||
ortvalues = ORTC.OrtValueVector()
|
||||
ortvalues.reserve(len(tensors))
|
||||
dtypes = []
|
||||
@ -436,6 +459,9 @@ def _run_onnx_session_with_ortvaluevector(
|
||||
...,
|
||||
],
|
||||
) -> Tuple[Union[torch.Tensor, int, float, bool], ...]:
|
||||
import onnxruntime
|
||||
from onnxruntime.capi import _pybind_state as ORTC
|
||||
|
||||
_nvtx_range_push("contiguous")
|
||||
inputs = tuple(
|
||||
_adjust_scalar_from_fx_to_onnx(arg, value_info)
|
||||
@ -514,6 +540,8 @@ def _run_onnx_session_with_fetch(
|
||||
...,
|
||||
],
|
||||
) -> Tuple[Union[torch.Tensor, int, float, bool], ...]:
|
||||
import onnxruntime
|
||||
|
||||
inputs = tuple(
|
||||
_adjust_scalar_from_fx_to_onnx(arg, value_info)
|
||||
for arg, value_info in zip(inputs, input_value_infos)
|
||||
@ -570,6 +598,11 @@ class OrtExecutionInfoPerSession:
|
||||
)
|
||||
|
||||
def is_supported(self, *args):
|
||||
from torch.onnx._internal.fx.type_utils import (
|
||||
_TORCH_DTYPE_TO_ONNX_TENSOR_ELEMENT_TYPE,
|
||||
from_python_type_to_onnx_tensor_element_type,
|
||||
)
|
||||
|
||||
# Compare the args and the input schema in ONNX model and
|
||||
# return the first match.
|
||||
if len(args) != len(self.input_value_infos):
|
||||
@ -728,6 +761,12 @@ class OrtBackend:
|
||||
"""
|
||||
|
||||
def __init__(self, options: Optional[OrtBackendOptions] = None):
|
||||
from onnxruntime.capi import _pybind_state as ORTC
|
||||
|
||||
import torch.onnx
|
||||
import torch.onnx._internal._exporter_legacy
|
||||
import torch.onnx._internal.fx.decomposition_table
|
||||
|
||||
self._options: Final = OrtBackendOptions() if options is None else options
|
||||
|
||||
# options.export_options contains information shared between exporter and DORT.
|
||||
@ -849,6 +888,10 @@ class OrtBackend:
|
||||
it means we delegate the computation to _ort_acclerated_call and therefore
|
||||
onnxruntime.InferenceSession.
|
||||
"""
|
||||
import onnxruntime
|
||||
|
||||
from torch.onnx._internal.fx import fx_onnx_interpreter, passes
|
||||
|
||||
cached_execution_info_per_session = (
|
||||
self._all_ort_execution_info.search_reusable_session_execution_info(
|
||||
graph_module, *args
|
||||
@ -867,7 +910,7 @@ class OrtBackend:
|
||||
# It's first time seeing such as graph. Let's make a new session
|
||||
# (type: onnxruntime.InferenceSession) for it.
|
||||
|
||||
graph_module = torch.onnx._internal.fx.passes.MovePlaceholderToFront(
|
||||
graph_module = passes.MovePlaceholderToFront(
|
||||
self._resolved_onnx_exporter_options.diagnostic_context,
|
||||
graph_module,
|
||||
).run()
|
||||
@ -915,7 +958,7 @@ class OrtBackend:
|
||||
# Cast FX variables if they will result schema-mismatch when searching
|
||||
# for ONNX operator. E.g., add(double_tensor, int_tensor) is fine in PyTorch,
|
||||
# but ONNX expects add(double_tensor, double_tensor).
|
||||
graph_module = torch.onnx._internal.fx.passes.InsertTypePromotion(
|
||||
graph_module = passes.InsertTypePromotion(
|
||||
self._resolved_onnx_exporter_options.diagnostic_context, graph_module
|
||||
).run()
|
||||
# Start the per-node exporting process. It's conceptually a for loop
|
||||
|
Reference in New Issue
Block a user