mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add test_imports (#77728)
That validates that every PyTorch submodule can be imported Prevents regressions like the one described in https://github.com/pytorch/pytorch/issues/77441 from happening in the future Pull Request resolved: https://github.com/pytorch/pytorch/pull/77728 Approved by: https://github.com/seemethere, https://github.com/janeyx99
This commit is contained in:
committed by
PyTorch MergeBot
parent
1d845253d8
commit
45eab670ac
@ -3,10 +3,13 @@
|
||||
import collections
|
||||
import doctest
|
||||
import functools
|
||||
import importlib
|
||||
import inspect
|
||||
import itertools
|
||||
import math
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import unittest.mock
|
||||
from typing import Any, Callable, Iterator, List, Tuple
|
||||
|
||||
@ -1762,5 +1765,43 @@ instantiate_parametrized_tests(TestTestParametrization)
|
||||
instantiate_device_type_tests(TestTestParametrizationDeviceType, globals())
|
||||
|
||||
|
||||
class TestImports(TestCase):
|
||||
def test_circular_dependencies(self) -> None:
|
||||
""" Checks that all modules inside torch can be imported
|
||||
Prevents regression reported in https://github.com/pytorch/pytorch/issues/77441 """
|
||||
ignored_modules = ["torch.utils.tensorboard", # deps on tensorboard
|
||||
"torch.distributed.elastic.rendezvous", # depps on etcd
|
||||
"torch.backends._coreml", # depends on pycoreml
|
||||
"torch.contrib.", # something weird
|
||||
"torch.testing._internal.common_fx2trt", # needs fx
|
||||
"torch.testing._internal.distributed.", # just fails
|
||||
]
|
||||
# See https://github.com/pytorch/pytorch/issues/77801
|
||||
if not sys.version_info >= (3, 9):
|
||||
ignored_modules.append("torch.utils.benchmark")
|
||||
if IS_WINDOWS:
|
||||
# Distributed does not work on Windows
|
||||
ignored_modules.append("torch.distributed.")
|
||||
ignored_modules.append("torch.testing._internal.dist_utils")
|
||||
|
||||
torch_dir = os.path.dirname(torch.__file__)
|
||||
for base, folders, files in os.walk(torch_dir):
|
||||
prefix = os.path.relpath(base, os.path.dirname(torch_dir)).replace(os.path.sep, ".")
|
||||
for f in files:
|
||||
if not f.endswith(".py"):
|
||||
continue
|
||||
mod_name = f"{prefix}.{f[:-3]}" if f != "__init__.py" else prefix
|
||||
# Do not attempt to import executable modules
|
||||
if f == "__main__.py":
|
||||
continue
|
||||
if any(mod_name.startswith(x) for x in ignored_modules):
|
||||
continue
|
||||
try:
|
||||
mod = importlib.import_module(mod_name)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to import {mod_name}: {e}") from e
|
||||
self.assertTrue(inspect.ismodule(mod))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
run_tests()
|
||||
|
Reference in New Issue
Block a user