mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add multi-cache autotune test (#133868)
Summary: The existing tests didn't cover a case where we had multiple autotunes in a single graph. Add a test to demonstrate that case. Also added a test dependency on redis and removed the "fake redis" from the previous PR (#133579) Test Plan: unit tests Reviewed By: oulgen Differential Revision: D61178861 Pull Request resolved: https://github.com/pytorch/pytorch/pull/133868 Approved by: https://github.com/oulgen
This commit is contained in:
committed by
PyTorch MergeBot
parent
11af423eca
commit
241df7e7f8
@ -269,6 +269,10 @@ lintrunner==0.12.5
|
||||
#Pinned versions: 0.12.5
|
||||
#test that import:
|
||||
|
||||
redis>=4.0.0
|
||||
#Description: redis database
|
||||
#test that import: anything that tests OSS caching/mocking (inductor/test_codecache.py, inductor/test_max_autotune.py)
|
||||
|
||||
rockset==1.0.3
|
||||
#Description: queries Rockset
|
||||
#Pinned versions: 1.0.3
|
||||
|
@ -17,7 +17,6 @@ jinja2
|
||||
fsspec
|
||||
lintrunner
|
||||
ninja
|
||||
redis
|
||||
# setuptools was removed from default python install
|
||||
setuptools ; python_version >= "3.12"
|
||||
packaging
|
||||
|
@ -132,12 +132,6 @@ _CACHE_CONFIG_EN = (
|
||||
)
|
||||
|
||||
|
||||
def _has_redis():
|
||||
import importlib
|
||||
|
||||
return importlib.util.find_spec("redis") is not None
|
||||
|
||||
|
||||
class PatchCaches(contextlib.AbstractContextManager):
|
||||
num_init = 0
|
||||
num_put = 0
|
||||
@ -187,15 +181,6 @@ class PatchCaches(contextlib.AbstractContextManager):
|
||||
|
||||
@classmethod
|
||||
def setUp(cls):
|
||||
# If we don't have redis available then fake it since we'll be mocking it anyway.
|
||||
if not _has_redis():
|
||||
|
||||
class FakeRedisModule:
|
||||
class Redis:
|
||||
pass
|
||||
|
||||
sys.modules["redis"] = FakeRedisModule()
|
||||
|
||||
# If this test is using PatchCaches then disable all the caches by
|
||||
# default, letting the tests turn them on explicitly. This is because
|
||||
# tests using PatchCaches will often want to check stats explicitly.
|
||||
|
@ -41,7 +41,7 @@ from torch.utils._triton import has_triton
|
||||
|
||||
|
||||
try:
|
||||
from .mock_cache import PatchCaches
|
||||
from .mock_cache import patch_fbcode, PatchCaches
|
||||
except ImportError:
|
||||
from mock_cache import PatchCaches # @manual
|
||||
|
||||
@ -776,6 +776,63 @@ class TestFxGraphCacheHashing(TestCase):
|
||||
assert "-DNDEBUG" in cmd_parts, cmd_parts
|
||||
|
||||
|
||||
@instantiate_parametrized_tests
|
||||
class TestAutotuneCache(TestCase):
|
||||
device_type = GPU_TYPE
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
counters.clear()
|
||||
PatchCaches.setUp()
|
||||
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
PatchCaches.tearDown()
|
||||
|
||||
def reset(self):
|
||||
torch._dynamo.reset()
|
||||
clear_inductor_caches()
|
||||
|
||||
@config.patch({"fx_graph_cache": False})
|
||||
@config.patch({"fx_graph_remote_cache": False})
|
||||
@config.patch({"autotune_local_cache": False})
|
||||
@config.patch({"autotune_remote_cache": True})
|
||||
@config.patch({"max_autotune": True})
|
||||
@parametrize("fbcode", (False,) + (True,) * config.is_fbcode())
|
||||
def test_autotune_cache(self, fbcode: bool):
|
||||
if not fbcode:
|
||||
self.skipTest("Redis for autotune is currently broken")
|
||||
|
||||
class Model(torch.nn.Module):
|
||||
def forward(self, x, y, a, b):
|
||||
return x + y, a + b
|
||||
|
||||
def f(x, y, a, b):
|
||||
return Model()(x, y, a, b)
|
||||
|
||||
x = torch.randn(100, 100).cuda()
|
||||
y = torch.randn(100, 100).cuda()
|
||||
a = torch.randn(1000, 100).cuda()
|
||||
b = torch.randn(1000, 100).cuda()
|
||||
f_compiled = torch.compile(f, fullgraph=True)
|
||||
|
||||
with PatchCaches(), patch_fbcode(fbcode):
|
||||
f_compiled(x, y, a, b)
|
||||
|
||||
PatchCaches.update()
|
||||
self.assertEqual(PatchCaches.num_get_hit, 0)
|
||||
self.assertEqual(PatchCaches.num_get_miss, 2)
|
||||
self.assertEqual(PatchCaches.num_put, 2)
|
||||
|
||||
self.reset()
|
||||
f_compiled(x, y, a, b)
|
||||
|
||||
PatchCaches.report()
|
||||
self.assertEqual(PatchCaches.num_get_hit, 2)
|
||||
self.assertEqual(PatchCaches.num_get_miss, 2)
|
||||
self.assertEqual(PatchCaches.num_put, 2)
|
||||
|
||||
|
||||
class TestUtils(TestCase):
|
||||
@config.patch({"fx_graph_remote_cache": False})
|
||||
def test_fresh_inductor_cache(self):
|
||||
|
Reference in New Issue
Block a user