mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Currently, if installed, `onnxruntime` will be imported when importing `torch._inductor` (which will be imported by some other library, e.g. transformer-engine): ``` /mnt/c.py(53)<module>() -> from torch._inductor.utils import maybe_profile /usr/local/lib/python3.10/site-packages/torch/_inductor/utils.py(49)<module>() -> import torch._export /usr/local/lib/python3.10/site-packages/torch/_export/__init__.py(25)<module>() -> import torch._dynamo /usr/local/lib/python3.10/site-packages/torch/_dynamo/__init__.py(2)<module>() -> from . import convert_frame, eval_frame, resume_execution /usr/local/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py(48)<module>() -> from . import config, exc, trace_rules /usr/local/lib/python3.10/site-packages/torch/_dynamo/trace_rules.py(52)<module>() -> from .variables import ( /usr/local/lib/python3.10/site-packages/torch/_dynamo/variables/__init__.py(38)<module>() -> from .higher_order_ops import ( /usr/local/lib/python3.10/site-packages/torch/_dynamo/variables/higher_order_ops.py(14)<module>() -> import torch.onnx.operators /usr/local/lib/python3.10/site-packages/torch/onnx/__init__.py(62)<module>() -> from ._internal.onnxruntime import ( /usr/local/lib/python3.10/site-packages/torch/onnx/_internal/onnxruntime.py(37)<module>() -> import onnxruntime # type: ignore[import] ``` This issue breaks generated triton kernel because it imported torch, and unexpected runtime libraries as well. I've also added a test for this specific case under `test/onnx`, perhaps we should add more somewhere else? Related issue: https://github.com/huggingface/accelerate/pull/3056 Pull Request resolved: https://github.com/pytorch/pytorch/pull/134662 Approved by: https://github.com/justinchuby
38 lines
1.1 KiB
Python
38 lines
1.1 KiB
Python
# Owner(s): ["module: onnx"]
|
|
|
|
import subprocess
|
|
import sys
|
|
import tempfile
|
|
|
|
import pytorch_test_common
|
|
|
|
from torch.testing._internal import common_utils
|
|
|
|
|
|
class TestLazyONNXPackages(pytorch_test_common.ExportTestCase):
|
|
def _test_package_is_lazily_imported(self, pkg, torch_pkg="torch.onnx"):
|
|
with tempfile.TemporaryDirectory() as wd:
|
|
r = subprocess.run(
|
|
[sys.executable, "-Ximporttime", "-c", "import torch.onnx"],
|
|
capture_output=True,
|
|
text=True,
|
|
cwd=wd,
|
|
check=True,
|
|
)
|
|
|
|
# The extra space makes sure we're checking the package, not any package containing its name.
|
|
self.assertTrue(
|
|
f" {pkg}" not in r.stderr,
|
|
f"`{pkg}` should not be imported, full importtime: {r.stderr}",
|
|
)
|
|
|
|
def test_onnxruntime_is_lazily_imported(self):
|
|
self._test_package_is_lazily_imported("onnxruntime")
|
|
|
|
def test_onnxscript_is_lazily_imported(self):
|
|
self._test_package_is_lazily_imported("onnxscript")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
common_utils.run_tests()
|