Files
pytorch/test/onnx/test_lazy_import.py
Yichen Yan 5dad6a5a84 [ONNX][DORT] Lazy-import onnxruntime (#134662)
Currently, if installed, `onnxruntime` will be imported when importing `torch._inductor` (which will be imported by some other library, e.g. transformer-engine):

```
  /mnt/c.py(53)<module>()
-> from torch._inductor.utils import maybe_profile
  /usr/local/lib/python3.10/site-packages/torch/_inductor/utils.py(49)<module>()
-> import torch._export
  /usr/local/lib/python3.10/site-packages/torch/_export/__init__.py(25)<module>()
-> import torch._dynamo
  /usr/local/lib/python3.10/site-packages/torch/_dynamo/__init__.py(2)<module>()
-> from . import convert_frame, eval_frame, resume_execution
  /usr/local/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py(48)<module>()
-> from . import config, exc, trace_rules
  /usr/local/lib/python3.10/site-packages/torch/_dynamo/trace_rules.py(52)<module>()
-> from .variables import (
  /usr/local/lib/python3.10/site-packages/torch/_dynamo/variables/__init__.py(38)<module>()
-> from .higher_order_ops import (
  /usr/local/lib/python3.10/site-packages/torch/_dynamo/variables/higher_order_ops.py(14)<module>()
-> import torch.onnx.operators
  /usr/local/lib/python3.10/site-packages/torch/onnx/__init__.py(62)<module>()
-> from ._internal.onnxruntime import (
  /usr/local/lib/python3.10/site-packages/torch/onnx/_internal/onnxruntime.py(37)<module>()
-> import onnxruntime  # type: ignore[import]
```

This issue breaks generated triton kernel because it imported torch, and unexpected runtime libraries as well.

I've also added a test for this specific case under `test/onnx`, perhaps we should add more somewhere else?

Related issue: https://github.com/huggingface/accelerate/pull/3056
Pull Request resolved: https://github.com/pytorch/pytorch/pull/134662
Approved by: https://github.com/justinchuby
2024-08-31 00:06:28 +00:00

38 lines
1.1 KiB
Python

# Owner(s): ["module: onnx"]
import subprocess
import sys
import tempfile
import pytorch_test_common
from torch.testing._internal import common_utils
class TestLazyONNXPackages(pytorch_test_common.ExportTestCase):
def _test_package_is_lazily_imported(self, pkg, torch_pkg="torch.onnx"):
with tempfile.TemporaryDirectory() as wd:
r = subprocess.run(
[sys.executable, "-Ximporttime", "-c", "import torch.onnx"],
capture_output=True,
text=True,
cwd=wd,
check=True,
)
# The extra space makes sure we're checking the package, not any package containing its name.
self.assertTrue(
f" {pkg}" not in r.stderr,
f"`{pkg}` should not be imported, full importtime: {r.stderr}",
)
def test_onnxruntime_is_lazily_imported(self):
self._test_package_is_lazily_imported("onnxruntime")
def test_onnxscript_is_lazily_imported(self):
self._test_package_is_lazily_imported("onnxscript")
if __name__ == "__main__":
common_utils.run_tests()