mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
PR to enable default workflow PyTorch 2.0 unit tests for the ROCm stack. - Enables all the dynamo unit test suites - Enables some of the inductor unit test suites - `test_config` - `test_cpp_wrapper` (cpu only) - `test_minifier` - `test_standalone_compile` - `test_torchinductor_dynamic_shapes` - `test_torchinductor_opinfo` - `test_torchinductor` - `test_triton_wrapper` - Introduces TEST_WITH_ROCM conditions for unit test skip/fail dictionaries in test_torchinductor_dynamic_shapes.py and test_torchinductor_opinfo.py Note this PR follows on from the discussions for the previous UT enablement PR https://github.com/pytorch/pytorch/pull/97988, we have opted to only enable a few inductor suites at the moment to ease the upstreaming effort as these files are changing very quickly. Pull Request resolved: https://github.com/pytorch/pytorch/pull/100981 Approved by: https://github.com/jithunnair-amd, https://github.com/malfet
66 lines
1.5 KiB
Python
66 lines
1.5 KiB
Python
import contextlib
|
|
import importlib
|
|
import sys
|
|
|
|
import torch
|
|
import torch.testing
|
|
from torch.testing._internal.common_utils import (
|
|
IS_WINDOWS,
|
|
TEST_WITH_CROSSREF,
|
|
TEST_WITH_TORCHDYNAMO,
|
|
TestCase as TorchTestCase,
|
|
)
|
|
|
|
from . import config, reset, utils
|
|
|
|
|
|
def run_tests(needs=()):
|
|
from torch.testing._internal.common_utils import run_tests
|
|
|
|
if (
|
|
TEST_WITH_TORCHDYNAMO
|
|
or IS_WINDOWS
|
|
or TEST_WITH_CROSSREF
|
|
or sys.version_info >= (3, 12)
|
|
):
|
|
return # skip testing
|
|
|
|
if isinstance(needs, str):
|
|
needs = (needs,)
|
|
for need in needs:
|
|
if need == "cuda" and not torch.cuda.is_available():
|
|
return
|
|
else:
|
|
try:
|
|
importlib.import_module(need)
|
|
except ImportError:
|
|
return
|
|
run_tests()
|
|
|
|
|
|
class TestCase(TorchTestCase):
|
|
@classmethod
|
|
def tearDownClass(cls):
|
|
cls._exit_stack.close()
|
|
super().tearDownClass()
|
|
|
|
@classmethod
|
|
def setUpClass(cls):
|
|
super().setUpClass()
|
|
cls._exit_stack = contextlib.ExitStack()
|
|
cls._exit_stack.enter_context(
|
|
config.patch(raise_on_ctx_manager_usage=True, suppress_errors=False),
|
|
)
|
|
|
|
def setUp(self):
|
|
super().setUp()
|
|
reset()
|
|
utils.counters.clear()
|
|
|
|
def tearDown(self):
|
|
for k, v in utils.counters.items():
|
|
print(k, v.most_common())
|
|
reset()
|
|
utils.counters.clear()
|
|
super().tearDown()
|