[ONNX][DORT] Lazy-import onnxruntime (#134662)

Currently, if installed, `onnxruntime` will be imported when importing `torch._inductor` (which will be imported by some other library, e.g. transformer-engine):

```
  /mnt/c.py(53)<module>()
-> from torch._inductor.utils import maybe_profile
  /usr/local/lib/python3.10/site-packages/torch/_inductor/utils.py(49)<module>()
-> import torch._export
  /usr/local/lib/python3.10/site-packages/torch/_export/__init__.py(25)<module>()
-> import torch._dynamo
  /usr/local/lib/python3.10/site-packages/torch/_dynamo/__init__.py(2)<module>()
-> from . import convert_frame, eval_frame, resume_execution
  /usr/local/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py(48)<module>()
-> from . import config, exc, trace_rules
  /usr/local/lib/python3.10/site-packages/torch/_dynamo/trace_rules.py(52)<module>()
-> from .variables import (
  /usr/local/lib/python3.10/site-packages/torch/_dynamo/variables/__init__.py(38)<module>()
-> from .higher_order_ops import (
  /usr/local/lib/python3.10/site-packages/torch/_dynamo/variables/higher_order_ops.py(14)<module>()
-> import torch.onnx.operators
  /usr/local/lib/python3.10/site-packages/torch/onnx/__init__.py(62)<module>()
-> from ._internal.onnxruntime import (
  /usr/local/lib/python3.10/site-packages/torch/onnx/_internal/onnxruntime.py(37)<module>()
-> import onnxruntime  # type: ignore[import]
```

This issue breaks generated triton kernel because it imported torch, and unexpected runtime libraries as well.

I've also added a test for this specific case under `test/onnx`, perhaps we should add more somewhere else?

Related issue: https://github.com/huggingface/accelerate/pull/3056
Pull Request resolved: https://github.com/pytorch/pytorch/pull/134662
Approved by: https://github.com/justinchuby
This commit is contained in:
Yichen Yan
2024-08-31 00:06:28 +00:00
committed by PyTorch MergeBot
parent 2384f77d76
commit 5dad6a5a84
3 changed files with 106 additions and 24 deletions

View File

@ -51,17 +51,19 @@ class TestDynamoWithONNXRuntime(onnx_test_common._TestONNXRuntime):
OrtBackend.clear_cached_instances()
def test_get_ort_device_type(self):
from onnxruntime.capi import _pybind_state as ORTC
self.assertEqual(
torch.onnx._internal.onnxruntime._get_ort_device_type("cuda"),
torch.onnx._internal.onnxruntime.ORTC.OrtDevice.cuda(),
ORTC.OrtDevice.cuda(),
)
self.assertEqual(
torch.onnx._internal.onnxruntime._get_ort_device_type("cpu"),
torch.onnx._internal.onnxruntime.ORTC.OrtDevice.cpu(),
ORTC.OrtDevice.cpu(),
)
self.assertEqual(
torch.onnx._internal.onnxruntime._get_ort_device_type("maia"),
torch.onnx._internal.onnxruntime.ORTC.OrtDevice.npu(),
ORTC.OrtDevice.npu(),
)
def test_torch_compile_backend_registration(self):

View File

@ -0,0 +1,37 @@
# Owner(s): ["module: onnx"]
import subprocess
import sys
import tempfile
import pytorch_test_common
from torch.testing._internal import common_utils
class TestLazyONNXPackages(pytorch_test_common.ExportTestCase):
def _test_package_is_lazily_imported(self, pkg, torch_pkg="torch.onnx"):
with tempfile.TemporaryDirectory() as wd:
r = subprocess.run(
[sys.executable, "-Ximporttime", "-c", "import torch.onnx"],
capture_output=True,
text=True,
cwd=wd,
check=True,
)
# The extra space makes sure we're checking the package, not any package containing its name.
self.assertTrue(
f" {pkg}" not in r.stderr,
f"`{pkg}` should not be imported, full importtime: {r.stderr}",
)
def test_onnxruntime_is_lazily_imported(self):
self._test_package_is_lazily_imported("onnxruntime")
def test_onnxscript_is_lazily_imported(self):
self._test_package_is_lazily_imported("onnxscript")
if __name__ == "__main__":
common_utils.run_tests()

View File

@ -34,32 +34,18 @@ from torch.utils import _pytree
if TYPE_CHECKING:
import onnx
try:
# Use try-except to initialize package-dependent global variables.
import onnxruntime # type: ignore[import]
from onnxruntime.capi import _pybind_state as ORTC # type: ignore[import]
# This is not use directly in DORT but needed by underlying exporter,
# so we still need to check if it exists.
importlib.import_module("onnxscript")
import onnxruntime
from onnxruntime.capi import _pybind_state as ORTC
import torch.onnx
import torch.onnx._internal
import torch.onnx._internal._exporter_legacy
import torch.onnx._internal.diagnostics
import torch.onnx._internal.fx.decomposition_table
import torch.onnx._internal.fx.passes
from torch.onnx._internal.fx import fx_onnx_interpreter
from torch.onnx._internal.fx.type_utils import (
_TORCH_DTYPE_TO_NUMPY_DTYPE,
_TORCH_DTYPE_TO_ONNX_TENSOR_ELEMENT_TYPE,
from_python_type_to_onnx_tensor_element_type,
)
import torch.onnx._internal.fx.passes # noqa: TCH004
_SUPPORT_ONNXRT = True
except ImportError:
_SUPPORT_ONNXRT = False
_SUPPORT_ONNXRT: Optional[bool] = None
__all__ = [
"is_onnxrt_backend_supported",
@ -87,6 +73,35 @@ def is_onnxrt_backend_supported() -> bool:
... print("pip install onnx onnxscript onnxruntime")
...
"""
global _SUPPORT_ONNXRT
if _SUPPORT_ONNXRT is None:
# `onnxruntime` might import a lot of other runtime packages,
# e.g. apex, deepspeed, transformers.
# So lazy-importing onnxruntime to avoid possible circular import.
try:
importlib.import_module("onnxruntime")
importlib.import_module("onnxruntime.capi._pybind_state")
# This is not use directly in DORT but needed by underlying exporter,
# so we still need to check if it exists.
importlib.import_module("onnxscript")
import torch.onnx # noqa: F401
import torch.onnx._internal # noqa: F401
import torch.onnx._internal._exporter_legacy # noqa: F401
import torch.onnx._internal.diagnostics # noqa: F401
from torch.onnx._internal.fx import ( # noqa: F401
decomposition_table,
fx_onnx_interpreter,
passes,
type_utils,
)
_SUPPORT_ONNXRT = True
except ImportError:
_SUPPORT_ONNXRT = False
return _SUPPORT_ONNXRT
@ -143,6 +158,8 @@ def _nvtx_range_pop():
def _get_ort_device_type(device_type: str):
from onnxruntime.capi import _pybind_state as ORTC
if device_type == "cuda":
return ORTC.OrtDevice.cuda()
if device_type == "cpu":
@ -305,6 +322,8 @@ def _get_onnx_devices(
...,
],
) -> Tuple["ORTC.OrtDevice", ...]:
from onnxruntime.capi import _pybind_state as ORTC
def _device_id_or_zero(device_id: int) -> int:
return device_id or 0
@ -338,6 +357,10 @@ def _get_onnx_devices(
def _get_ortvalues_from_torch_tensors(
tensors: Tuple[torch.Tensor, ...], devices: Tuple["ORTC.OrtDevice", ...]
) -> Tuple[torch.Tensor, ...]:
from onnxruntime.capi import _pybind_state as ORTC
from torch.onnx._internal.fx.type_utils import _TORCH_DTYPE_TO_NUMPY_DTYPE
ortvalues = ORTC.OrtValueVector()
ortvalues.reserve(len(tensors))
dtypes = []
@ -436,6 +459,9 @@ def _run_onnx_session_with_ortvaluevector(
...,
],
) -> Tuple[Union[torch.Tensor, int, float, bool], ...]:
import onnxruntime
from onnxruntime.capi import _pybind_state as ORTC
_nvtx_range_push("contiguous")
inputs = tuple(
_adjust_scalar_from_fx_to_onnx(arg, value_info)
@ -514,6 +540,8 @@ def _run_onnx_session_with_fetch(
...,
],
) -> Tuple[Union[torch.Tensor, int, float, bool], ...]:
import onnxruntime
inputs = tuple(
_adjust_scalar_from_fx_to_onnx(arg, value_info)
for arg, value_info in zip(inputs, input_value_infos)
@ -570,6 +598,11 @@ class OrtExecutionInfoPerSession:
)
def is_supported(self, *args):
from torch.onnx._internal.fx.type_utils import (
_TORCH_DTYPE_TO_ONNX_TENSOR_ELEMENT_TYPE,
from_python_type_to_onnx_tensor_element_type,
)
# Compare the args and the input schema in ONNX model and
# return the first match.
if len(args) != len(self.input_value_infos):
@ -728,6 +761,12 @@ class OrtBackend:
"""
def __init__(self, options: Optional[OrtBackendOptions] = None):
from onnxruntime.capi import _pybind_state as ORTC
import torch.onnx
import torch.onnx._internal._exporter_legacy
import torch.onnx._internal.fx.decomposition_table
self._options: Final = OrtBackendOptions() if options is None else options
# options.export_options contains information shared between exporter and DORT.
@ -849,6 +888,10 @@ class OrtBackend:
it means we delegate the computation to _ort_acclerated_call and therefore
onnxruntime.InferenceSession.
"""
import onnxruntime
from torch.onnx._internal.fx import fx_onnx_interpreter, passes
cached_execution_info_per_session = (
self._all_ort_execution_info.search_reusable_session_execution_info(
graph_module, *args
@ -867,7 +910,7 @@ class OrtBackend:
# It's first time seeing such as graph. Let's make a new session
# (type: onnxruntime.InferenceSession) for it.
graph_module = torch.onnx._internal.fx.passes.MovePlaceholderToFront(
graph_module = passes.MovePlaceholderToFront(
self._resolved_onnx_exporter_options.diagnostic_context,
graph_module,
).run()
@ -915,7 +958,7 @@ class OrtBackend:
# Cast FX variables if they will result schema-mismatch when searching
# for ONNX operator. E.g., add(double_tensor, int_tensor) is fine in PyTorch,
# but ONNX expects add(double_tensor, double_tensor).
graph_module = torch.onnx._internal.fx.passes.InsertTypePromotion(
graph_module = passes.InsertTypePromotion(
self._resolved_onnx_exporter_options.diagnostic_context, graph_module
).run()
# Start the per-node exporting process. It's conceptually a for loop