mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Requested by Simon on a different PR Pull Request resolved: https://github.com/pytorch/pytorch/pull/156128 Approved by: https://github.com/xmfan
49 lines
1.3 KiB
Python
49 lines
1.3 KiB
Python
import contextlib
|
|
import os
|
|
from typing import Union
|
|
|
|
from torch._dynamo.test_case import (
|
|
run_tests as dynamo_run_tests,
|
|
TestCase as DynamoTestCase,
|
|
)
|
|
from torch._functorch import config as functorch_config
|
|
from torch._inductor import config
|
|
from torch._inductor.utils import fresh_cache
|
|
|
|
|
|
def run_tests(needs: Union[str, tuple[str, ...]] = ()) -> None:
|
|
dynamo_run_tests(needs)
|
|
|
|
|
|
class TestCase(DynamoTestCase):
|
|
"""
|
|
A base TestCase for inductor tests. Enables FX graph caching and isolates
|
|
the cache directory for each test.
|
|
"""
|
|
|
|
def setUp(self) -> None:
|
|
super().setUp()
|
|
self._inductor_test_stack = contextlib.ExitStack()
|
|
self._inductor_test_stack.enter_context(
|
|
functorch_config.patch(
|
|
{
|
|
"enable_autograd_cache": True,
|
|
}
|
|
)
|
|
)
|
|
|
|
if "TORCHINDUCTOR_FX_GRAPH_CACHE" not in os.environ:
|
|
self._inductor_test_stack.enter_context(
|
|
config.patch({"fx_graph_cache": True})
|
|
)
|
|
|
|
if (
|
|
os.environ.get("INDUCTOR_TEST_DISABLE_FRESH_CACHE") != "1"
|
|
and os.environ.get("TORCH_COMPILE_DEBUG") != "1"
|
|
):
|
|
self._inductor_test_stack.enter_context(fresh_cache())
|
|
|
|
def tearDown(self) -> None:
|
|
super().tearDown()
|
|
self._inductor_test_stack.close()
|