mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
This reverts commit fab53dfdf1d89cecd5e82b12cced9b6dd217e87c. Reverted https://github.com/pytorch/pytorch/pull/155978 on behalf of https://github.com/guilhermeleobas due to failing in trunk ([comment](https://github.com/pytorch/pytorch/pull/155978#issuecomment-3019457531))
205 lines
7.1 KiB
Python
205 lines
7.1 KiB
Python
# mypy: allow-untyped-defs
|
|
|
|
"""Testing utilities for Dynamo, providing a specialized TestCase class and test running functionality.
|
|
|
|
This module extends PyTorch's testing framework with Dynamo-specific testing capabilities.
|
|
It includes:
|
|
- A custom TestCase class that handles Dynamo-specific setup/teardown
|
|
- Test running utilities with dependency checking
|
|
- Automatic reset of Dynamo state between tests
|
|
- Proper handling of gradient mode state
|
|
"""
|
|
|
|
import contextlib
|
|
import importlib
|
|
import inspect
|
|
import logging
|
|
import os
|
|
import re
|
|
import sys
|
|
import unittest
|
|
from typing import Union
|
|
|
|
import torch
|
|
import torch.testing
|
|
from torch._logging._internal import trace_log
|
|
from torch.testing._internal.common_utils import ( # type: ignore[attr-defined]
|
|
IS_WINDOWS,
|
|
TEST_WITH_CROSSREF,
|
|
TEST_WITH_TORCHDYNAMO,
|
|
TestCase as TorchTestCase,
|
|
)
|
|
|
|
from . import config, reset, utils
|
|
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
|
|
def run_tests(needs: Union[str, tuple[str, ...]] = ()) -> None:
|
|
from torch.testing._internal.common_utils import run_tests
|
|
|
|
if TEST_WITH_TORCHDYNAMO or TEST_WITH_CROSSREF:
|
|
return # skip testing
|
|
|
|
if (
|
|
not torch.xpu.is_available()
|
|
and IS_WINDOWS
|
|
and os.environ.get("TORCHINDUCTOR_WINDOWS_TESTS", "0") == "0"
|
|
):
|
|
return
|
|
|
|
if isinstance(needs, str):
|
|
needs = (needs,)
|
|
for need in needs:
|
|
if need == "cuda":
|
|
if not torch.cuda.is_available():
|
|
return
|
|
else:
|
|
try:
|
|
importlib.import_module(need)
|
|
except ImportError:
|
|
return
|
|
run_tests()
|
|
|
|
|
|
class TestCase(TorchTestCase):
|
|
_exit_stack: contextlib.ExitStack
|
|
|
|
@classmethod
|
|
def tearDownClass(cls) -> None:
|
|
cls._exit_stack.close()
|
|
super().tearDownClass()
|
|
|
|
@classmethod
|
|
def setUpClass(cls) -> None:
|
|
super().setUpClass()
|
|
cls._exit_stack = contextlib.ExitStack() # type: ignore[attr-defined]
|
|
cls._exit_stack.enter_context( # type: ignore[attr-defined]
|
|
config.patch(
|
|
raise_on_ctx_manager_usage=True,
|
|
suppress_errors=False,
|
|
log_compilation_metrics=False,
|
|
),
|
|
)
|
|
|
|
def setUp(self) -> None:
|
|
self._prior_is_grad_enabled = torch.is_grad_enabled()
|
|
super().setUp()
|
|
reset()
|
|
utils.counters.clear()
|
|
self.handler = logging.NullHandler()
|
|
trace_log.addHandler(self.handler)
|
|
|
|
def tearDown(self) -> None:
|
|
trace_log.removeHandler(self.handler)
|
|
for k, v in utils.counters.items():
|
|
print(k, v.most_common())
|
|
reset()
|
|
utils.counters.clear()
|
|
super().tearDown()
|
|
if self._prior_is_grad_enabled is not torch.is_grad_enabled():
|
|
log.warning("Running test changed grad mode")
|
|
torch.set_grad_enabled(self._prior_is_grad_enabled)
|
|
|
|
|
|
class CPythonTestCase(TestCase):
|
|
"""
|
|
Test class for CPython tests located in "test/dynamo/CPython/Py_version/*".
|
|
|
|
This class enables specific features that are disabled by default, such as
|
|
tracing through unittest methods.
|
|
"""
|
|
|
|
_stack: contextlib.ExitStack
|
|
dynamo_strict_nopython = True
|
|
|
|
# Restore original unittest methods to simplify tracing CPython test cases.
|
|
assertEqual = unittest.TestCase.assertEqual # type: ignore[assignment]
|
|
assertNotEqual = unittest.TestCase.assertNotEqual # type: ignore[assignment]
|
|
assertTrue = unittest.TestCase.assertTrue
|
|
assertFalse = unittest.TestCase.assertFalse
|
|
assertIs = unittest.TestCase.assertIs
|
|
assertIsNot = unittest.TestCase.assertIsNot
|
|
assertIsNone = unittest.TestCase.assertIsNone
|
|
assertIsNotNone = unittest.TestCase.assertIsNotNone
|
|
assertIn = unittest.TestCase.assertIn
|
|
assertNotIn = unittest.TestCase.assertNotIn
|
|
assertIsInstance = unittest.TestCase.assertIsInstance
|
|
assertNotIsInstance = unittest.TestCase.assertNotIsInstance
|
|
assertAlmostEqual = unittest.TestCase.assertAlmostEqual
|
|
assertNotAlmostEqual = unittest.TestCase.assertNotAlmostEqual
|
|
assertGreater = unittest.TestCase.assertGreater
|
|
assertGreaterEqual = unittest.TestCase.assertGreaterEqual
|
|
assertLess = unittest.TestCase.assertLess
|
|
assertLessEqual = unittest.TestCase.assertLessEqual
|
|
assertRegex = unittest.TestCase.assertRegex
|
|
assertNotRegex = unittest.TestCase.assertNotRegex
|
|
assertCountEqual = unittest.TestCase.assertCountEqual
|
|
assertMultiLineEqual = unittest.TestCase.assertMultiLineEqual
|
|
assertSequenceEqual = unittest.TestCase.assertSequenceEqual
|
|
assertListEqual = unittest.TestCase.assertListEqual
|
|
assertTupleEqual = unittest.TestCase.assertTupleEqual
|
|
assertSetEqual = unittest.TestCase.assertSetEqual
|
|
assertDictEqual = unittest.TestCase.assertDictEqual
|
|
assertRaises = unittest.TestCase.assertRaises
|
|
assertRaisesRegex = unittest.TestCase.assertRaisesRegex
|
|
assertWarns = unittest.TestCase.assertWarns
|
|
assertWarnsRegex = unittest.TestCase.assertWarnsRegex
|
|
assertLogs = unittest.TestCase.assertLogs
|
|
fail = unittest.TestCase.fail
|
|
failureException = unittest.TestCase.failureException
|
|
|
|
def compile_fn(self, fn, backend, nopython):
|
|
# We want to compile only the test function, excluding any setup code
|
|
# from unittest
|
|
method = getattr(self, self._testMethodName)
|
|
method = torch._dynamo.optimize(backend, nopython=nopython)(method)
|
|
setattr(self, self._testMethodName, method)
|
|
return fn
|
|
|
|
def _dynamo_test_key(self):
|
|
suffix = super()._dynamo_test_key()
|
|
test_cls = self.__class__
|
|
test_file = inspect.getfile(test_cls).split(os.sep)[-1].split(".")[0]
|
|
py_ver = re.search(r"/([\d_]+)/", inspect.getfile(test_cls))
|
|
if py_ver:
|
|
py_ver = py_ver.group().strip(os.sep).replace("_", "") # type: ignore[assignment]
|
|
else:
|
|
return suffix
|
|
return f"CPython{py_ver}-{test_file}-{suffix}"
|
|
|
|
@classmethod
|
|
def tearDownClass(cls) -> None:
|
|
cls._stack.close()
|
|
super().tearDownClass()
|
|
|
|
@classmethod
|
|
def setUpClass(cls) -> None:
|
|
# Skip test if python versions doesn't match
|
|
prefix = os.path.join("dynamo", "cpython") + os.path.sep
|
|
regex = re.escape(prefix) + r"\d_\d{2}"
|
|
search_path = inspect.getfile(cls)
|
|
m = re.search(regex, search_path)
|
|
if m:
|
|
test_py_ver = tuple(map(int, m.group().removeprefix(prefix).split("_")))
|
|
py_ver = sys.version_info[:2]
|
|
if py_ver < test_py_ver:
|
|
expected = ".".join(map(str, test_py_ver))
|
|
got = ".".join(map(str, py_ver))
|
|
raise unittest.SkipTest(
|
|
f"Test requires Python {expected} but got Python {got}"
|
|
)
|
|
else:
|
|
raise unittest.SkipTest(
|
|
f"Test requires a specific Python version but not found in path {inspect.getfile(cls)}"
|
|
)
|
|
|
|
super().setUpClass()
|
|
cls._stack = contextlib.ExitStack() # type: ignore[attr-defined]
|
|
cls._stack.enter_context( # type: ignore[attr-defined]
|
|
config.patch(
|
|
enable_trace_unittest=True,
|
|
),
|
|
)
|