Add multi-cache autotune test (#133868)

Summary:
The existing tests didn't cover a case where we had multiple autotunes in a single graph.  Add a test to demonstrate that case.

Also added a test dependency on redis and removed the "fake redis" from the previous PR (#133579)

Test Plan: unit tests

Reviewed By: oulgen

Differential Revision: D61178861

Pull Request resolved: https://github.com/pytorch/pytorch/pull/133868
Approved by: https://github.com/oulgen
This commit is contained in:
Aaron Orenstein
2024-08-20 10:26:45 +00:00
committed by PyTorch MergeBot
parent 11af423eca
commit 241df7e7f8
4 changed files with 62 additions and 17 deletions

View File

@ -269,6 +269,10 @@ lintrunner==0.12.5
#Pinned versions: 0.12.5
#test that import:
redis>=4.0.0
#Description: redis database
#test that import: anything that tests OSS caching/mocking (inductor/test_codecache.py, inductor/test_max_autotune.py)
rockset==1.0.3
#Description: queries Rockset
#Pinned versions: 1.0.3

View File

@ -17,7 +17,6 @@ jinja2
fsspec
lintrunner
ninja
redis
# setuptools was removed from default python install
setuptools ; python_version >= "3.12"
packaging

View File

@ -132,12 +132,6 @@ _CACHE_CONFIG_EN = (
)
def _has_redis():
import importlib
return importlib.util.find_spec("redis") is not None
class PatchCaches(contextlib.AbstractContextManager):
num_init = 0
num_put = 0
@ -187,15 +181,6 @@ class PatchCaches(contextlib.AbstractContextManager):
@classmethod
def setUp(cls):
# If we don't have redis available then fake it since we'll be mocking it anyway.
if not _has_redis():
class FakeRedisModule:
class Redis:
pass
sys.modules["redis"] = FakeRedisModule()
# If this test is using PatchCaches then disable all the caches by
# default, letting the tests turn them on explicitly. This is because
# tests using PatchCaches will often want to check stats explicitly.

View File

@ -41,7 +41,7 @@ from torch.utils._triton import has_triton
try:
from .mock_cache import PatchCaches
from .mock_cache import patch_fbcode, PatchCaches
except ImportError:
from mock_cache import PatchCaches # @manual
@ -776,6 +776,63 @@ class TestFxGraphCacheHashing(TestCase):
assert "-DNDEBUG" in cmd_parts, cmd_parts
@instantiate_parametrized_tests
class TestAutotuneCache(TestCase):
device_type = GPU_TYPE
def setUp(self):
super().setUp()
counters.clear()
PatchCaches.setUp()
def tearDown(self):
super().tearDown()
PatchCaches.tearDown()
def reset(self):
torch._dynamo.reset()
clear_inductor_caches()
@config.patch({"fx_graph_cache": False})
@config.patch({"fx_graph_remote_cache": False})
@config.patch({"autotune_local_cache": False})
@config.patch({"autotune_remote_cache": True})
@config.patch({"max_autotune": True})
@parametrize("fbcode", (False,) + (True,) * config.is_fbcode())
def test_autotune_cache(self, fbcode: bool):
if not fbcode:
self.skipTest("Redis for autotune is currently broken")
class Model(torch.nn.Module):
def forward(self, x, y, a, b):
return x + y, a + b
def f(x, y, a, b):
return Model()(x, y, a, b)
x = torch.randn(100, 100).cuda()
y = torch.randn(100, 100).cuda()
a = torch.randn(1000, 100).cuda()
b = torch.randn(1000, 100).cuda()
f_compiled = torch.compile(f, fullgraph=True)
with PatchCaches(), patch_fbcode(fbcode):
f_compiled(x, y, a, b)
PatchCaches.update()
self.assertEqual(PatchCaches.num_get_hit, 0)
self.assertEqual(PatchCaches.num_get_miss, 2)
self.assertEqual(PatchCaches.num_put, 2)
self.reset()
f_compiled(x, y, a, b)
PatchCaches.report()
self.assertEqual(PatchCaches.num_get_hit, 2)
self.assertEqual(PatchCaches.num_get_miss, 2)
self.assertEqual(PatchCaches.num_put, 2)
class TestUtils(TestCase):
@config.patch({"fx_graph_remote_cache": False})
def test_fresh_inductor_cache(self):