Add test_imports (#77728)

That validates that every PyTorch submodule can be imported

Prevents regressions like the one described in https://github.com/pytorch/pytorch/issues/77441 from happening in the future

Pull Request resolved: https://github.com/pytorch/pytorch/pull/77728
Approved by: https://github.com/seemethere, https://github.com/janeyx99
This commit is contained in:
Nikita Shulga
2022-05-21 02:11:34 +00:00
committed by PyTorch MergeBot
parent 1d845253d8
commit 45eab670ac

View File

@ -3,10 +3,13 @@
import collections
import doctest
import functools
import importlib
import inspect
import itertools
import math
import os
import re
import sys
import unittest.mock
from typing import Any, Callable, Iterator, List, Tuple
@ -1762,5 +1765,43 @@ instantiate_parametrized_tests(TestTestParametrization)
instantiate_device_type_tests(TestTestParametrizationDeviceType, globals())
class TestImports(TestCase):
def test_circular_dependencies(self) -> None:
""" Checks that all modules inside torch can be imported
Prevents regression reported in https://github.com/pytorch/pytorch/issues/77441 """
ignored_modules = ["torch.utils.tensorboard", # deps on tensorboard
"torch.distributed.elastic.rendezvous", # depps on etcd
"torch.backends._coreml", # depends on pycoreml
"torch.contrib.", # something weird
"torch.testing._internal.common_fx2trt", # needs fx
"torch.testing._internal.distributed.", # just fails
]
# See https://github.com/pytorch/pytorch/issues/77801
if not sys.version_info >= (3, 9):
ignored_modules.append("torch.utils.benchmark")
if IS_WINDOWS:
# Distributed does not work on Windows
ignored_modules.append("torch.distributed.")
ignored_modules.append("torch.testing._internal.dist_utils")
torch_dir = os.path.dirname(torch.__file__)
for base, folders, files in os.walk(torch_dir):
prefix = os.path.relpath(base, os.path.dirname(torch_dir)).replace(os.path.sep, ".")
for f in files:
if not f.endswith(".py"):
continue
mod_name = f"{prefix}.{f[:-3]}" if f != "__init__.py" else prefix
# Do not attempt to import executable modules
if f == "__main__.py":
continue
if any(mod_name.startswith(x) for x in ignored_modules):
continue
try:
mod = importlib.import_module(mod_name)
except Exception as e:
raise RuntimeError(f"Failed to import {mod_name}: {e}") from e
self.assertTrue(inspect.ismodule(mod))
if __name__ == '__main__':
run_tests()