From b54e466fd04e5e736662a6206d81ab0d5fe85d91 Mon Sep 17 00:00:00 2001 From: James Wu Date: Wed, 15 Oct 2025 10:42:44 -0700 Subject: [PATCH] Megacache integration (#163533) This diff adds megacache integration for DynamoCache. Because DynamoCache requires lazy serialization, i.e. it can only be serialized once all relevant backends have been compiled and we're ready for a save, we actually do the DynamoCache saving only on a call to `torch.compiler.save_cache_artifacts`. Differential Revision: [D82735763](https://our.internmc.facebook.com/intern/diff/D82735763/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/163533 Approved by: https://github.com/oulgen, https://github.com/zhxchen17 --- test/inductor/test_codecache.py | 111 ++++++++++++++++++++++++++++++++ torch/_dynamo/package.py | 51 ++++++++++++--- torch/compiler/__init__.py | 7 +- torch/compiler/_cache.py | 5 ++ 4 files changed, 165 insertions(+), 9 deletions(-) diff --git a/test/inductor/test_codecache.py b/test/inductor/test_codecache.py index 09570b98a2fb..78c2dd3de852 100644 --- a/test/inductor/test_codecache.py +++ b/test/inductor/test_codecache.py @@ -16,6 +16,8 @@ from unittest import mock import torch from torch._dynamo import reset +from torch._dynamo.package import DynamoCache +from torch._dynamo.precompile_context import PrecompileContext from torch._dynamo.utils import counters from torch._functorch import config as functorch_config from torch._functorch._aot_autograd.autograd_cache import AOTAutogradCache @@ -243,8 +245,12 @@ class TestFxGraphCache(TestCase): def setUp(self): super().setUp() counters.clear() + DynamoCache.clear() + PrecompileContext.clear() + AOTAutogradCache.clear() PatchCaches.setUp() CacheArtifactManager.clear() + torch._dynamo.reset() def tearDown(self): super().tearDown() @@ -252,6 +258,8 @@ class TestFxGraphCache(TestCase): def reset(self): AOTAutogradCache.clear() + DynamoCache.clear() + PrecompileContext.clear() PyCodeCache.cache_clear(purge=True) torch._dynamo.reset() clear_caches() @@ -595,6 +603,109 @@ class TestFxGraphCache(TestCase): self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 1) self.assertEqual(counters["inductor"]["fxgraph_lookup_write_file"], 1) + @requires_triton() + @config.patch( + { + "fx_graph_cache": True, + "fx_graph_remote_cache": False, + "autotune_local_cache": True, + } + ) + @torch._dynamo.config.patch( + { + "caching_precompile": True, + } + ) + @parametrize("dynamic", (False, True)) + @parametrize("device", (GPU_TYPE, "cpu")) + @parametrize("dtype", (torch.float32, torch.bfloat16)) + def test_cache_hot_load_caching_precompile(self, device, dtype, dynamic): + """ + Verify that we can populate and hot load functions from the cache. + """ + + if device == GPU_TYPE and not HAS_GPU: + raise unittest.SkipTest(f"requires {GPU_TYPE}") + if device == "cuda" and dtype == torch.bfloat16 and not SM80OrLater: + raise unittest.SkipTest("requires SM80 or later") + + def fn(x, y): + return x.sin() @ y + + a = torch.rand(100, 100, dtype=dtype, device=device, requires_grad=True) + b = torch.rand(100, 100, dtype=dtype, device=device, requires_grad=True) + + # Record artifacts + with fresh_cache(): + compiled_fn = torch.compile(fn, dynamic=dynamic) + + # A first call should miss in the cache. + eager_result = fn(a, b) + compiled_result = compiled_fn(a, b) + compiled_result.sum().backward() + self.assertEqual(eager_result, compiled_result) + self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 1) + self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 0) + self.assertEqual(counters["dynamo_cache"]["dynamo_cache_miss"], 1) + self.assertEqual(counters["dynamo_cache"]["dynamo_cache_hit"], 0) + + artifacts = torch.compiler.save_cache_artifacts() + + self.assertIsNotNone(artifacts) + + artifact_bytes, cache_info = artifacts + + autotune_expect = 2 if device == GPU_TYPE else 0 + self.assertEqual(len(cache_info.inductor_artifacts), 2) + self.assertEqual(len(cache_info.autotune_artifacts), autotune_expect) + self.assertEqual(len(cache_info.aot_autograd_artifacts), 1) + self.assertEqual(len(cache_info.pgo_artifacts), 0) + self.assertEqual(len(cache_info.precompile_artifacts), 1) + + self.reset() + + # Clean triton kernels + shutil.rmtree(os.path.join(cache_dir(), "triton"), ignore_errors=True) + + # We did not load anything so dont hit yet + with fresh_cache(): + eager_result = fn(a, b) + # With caching precompile, we have to re torch.compile the function + # to trigger cache lookup + compiled_fn = torch.compile(fn, dynamic=dynamic) + compiled_result = compiled_fn(a, b) + compiled_result.sum().backward() + self.assertEqual(eager_result, compiled_result) + self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 2) + self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 0) + self.assertEqual(counters["dynamo_cache"]["dynamo_cache_miss"], 2) + self.assertEqual(counters["dynamo_cache"]["dynamo_cache_hit"], 0) + self.reset() + # Clean triton kernels + shutil.rmtree(os.path.join(cache_dir(), "triton"), ignore_errors=True) + + # Hot load and hit + with fresh_cache(), torch.compiler.set_stance("fail_on_recompile"): + cache_info = torch.compiler.load_cache_artifacts(artifact_bytes) + self.assertEqual(len(cache_info.inductor_artifacts), 2) + self.assertEqual(len(cache_info.autotune_artifacts), autotune_expect) + self.assertEqual(len(cache_info.aot_autograd_artifacts), 1) + self.assertEqual(len(cache_info.pgo_artifacts), 0) + self.assertEqual(len(cache_info.precompile_artifacts), 1) + + # With caching precompile, we have to re torch.compile the function + # to trigger cache lookup + compiled_fn = torch.compile(fn, dynamic=dynamic) + + eager_result = fn(a, b) + compiled_result = compiled_fn(a, b) + compiled_result.sum().backward() + self.assertEqual(eager_result, compiled_result) + self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 2) + self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 0) + self.assertEqual(counters["dynamo_cache"]["dynamo_cache_miss"], 2) + self.assertEqual(counters["dynamo_cache"]["dynamo_cache_hit"], 1) + @config.patch( { "fx_graph_cache": True, diff --git a/torch/_dynamo/package.py b/torch/_dynamo/package.py index 6acc89fffac9..9c5dec0a98f9 100644 --- a/torch/_dynamo/package.py +++ b/torch/_dynamo/package.py @@ -34,7 +34,7 @@ from torch._dynamo.exc import PackageError from torch._dynamo.graph_utils import _graph_device_type from .bytecode_transformation import get_code_keys -from .utils import dynamo_timed, increment_frame +from .utils import counters, dynamo_timed, increment_frame logger = logging.getLogger(__name__) @@ -433,6 +433,23 @@ class _DynamoCacheEntry: } +from torch.compiler._cache import ( + CacheArtifact, + CacheArtifactFactory, + CacheArtifactManager, +) + + +@CacheArtifactFactory.register +class PrecompileCacheArtifact(CacheArtifact): + def populate_cache(self) -> None: + DynamoCache._write_to_local_cache(self.content, self.key) + + @staticmethod + def type() -> str: + return "precompile" + + @dataclasses.dataclass class PrecompileCacheEntry: """ @@ -1026,14 +1043,17 @@ class DiskDynamoStore(DynamoStore): Args: path_prefix: Prefix directory for where to put CompilePackages on disk """ - self.path_prefix = path_prefix + self._path_prefix = path_prefix + + def path_prefix(self) -> str: + return self._path_prefix def clear(self) -> None: """ Clear all CompilePackages from disk. """ - if self.path_prefix: - shutil.rmtree(self.path_prefix, ignore_errors=True) + if self.path_prefix(): + shutil.rmtree(self.path_prefix(), ignore_errors=True) def write( self, @@ -1043,12 +1063,21 @@ class DiskDynamoStore(DynamoStore): """ Write dynamo cache entry and backends to disk. """ + try: + pickled_content: bytes = pickle.dumps(entry) + CacheArtifactManager.record_artifact( + PrecompileCacheArtifact.type(), path, pickled_content + ) + self._write_to_local_cache(pickled_content, path) + except Exception as e: + raise RuntimeError(f"Failed to save package to {path}: {e}") from e + + def _write_to_local_cache(self, pickled_content: bytes, path: str) -> None: from torch._inductor.codecache import write_atomic - path = os.path.join(self.path_prefix, path) if self.path_prefix else path + path = os.path.join(self.path_prefix(), path) if self.path_prefix() else path try: os.makedirs(path, exist_ok=True) - pickled_content: bytes = pickle.dumps(entry) write_atomic(os.path.join(path, "entry"), pickled_content) except Exception as e: raise RuntimeError(f"Failed to save package to {path}: {e}") from e @@ -1057,7 +1086,7 @@ class DiskDynamoStore(DynamoStore): """ Read dynamo cache entry and backends from disk. """ - path = os.path.join(self.path_prefix, path) if self.path_prefix else path + path = os.path.join(self.path_prefix(), path) if self.path_prefix() else path try: with open(os.path.join(path, "entry"), "rb") as f: pickled_content = f.read() @@ -1087,15 +1116,18 @@ class DiskDynamoCache(DiskDynamoStore): """ key = CompilePackage.source_id_from_fn(fn) logger.info("Loading CompilePackage for %s", key) - path = os.path.join(self.path_prefix, key) + path = os.path.join(self.path_prefix(), key) if os.path.exists(path): try: result = super().load_cache_entry(key) + counters["dynamo_cache"]["dynamo_cache_hit"] += 1 return result except Exception as e: + counters["dynamo_cache"]["dynamo_cache_error"] += 1 logger.warning("Failed to load package from path %s: %s", path, str(e)) return None logger.info("No package found for %s", key) + counters["dynamo_cache"]["dynamo_cache_miss"] += 1 return None def load_and_install_package( @@ -1112,6 +1144,9 @@ class DiskDynamoCache(DiskDynamoStore): package.install(results.backends) return package + def path_prefix(self) -> str: + return os.path.join(cache_dir(), "dynamo") + def cache_dir() -> str: from torch._inductor.runtime.cache_dir_utils import cache_dir diff --git a/torch/compiler/__init__.py b/torch/compiler/__init__.py index 30881e06ff14..52d2645c4b71 100644 --- a/torch/compiler/__init__.py +++ b/torch/compiler/__init__.py @@ -501,7 +501,12 @@ def save_cache_artifacts() -> Optional[tuple[bytes, "CacheInfo"]]: - Execute torch.compile - Call torch.compiler.save_cache_artifacts() """ - from ._cache import CacheArtifactManager, CacheInfo + from ._cache import CacheArtifactManager + + if torch._dynamo.config.caching_precompile: + from torch._dynamo.precompile_context import PrecompileContext + + PrecompileContext.save_to_dynamo_cache() return CacheArtifactManager.serialize() diff --git a/torch/compiler/_cache.py b/torch/compiler/_cache.py index 8f978dd5690b..77cfb77d74df 100644 --- a/torch/compiler/_cache.py +++ b/torch/compiler/_cache.py @@ -130,6 +130,10 @@ class CacheInfo: def pgo_artifacts(self) -> list[str]: # type: ignore[empty-body] ... + @property + def precompile_artifacts(self) -> list[str]: # type: ignore[empty-body] + ... + def add(self, artifact: CacheArtifact) -> None: self.artifacts[artifact.type()].append(artifact.key) @@ -307,6 +311,7 @@ class CacheArtifactManager: cache artifacts are registered in the cache registry. This is done by simply importing all the cache artifacts already wrapped with register call. """ + from torch._dynamo.package import PrecompileCacheArtifact # noqa: F401 from torch._dynamo.pgo import PGOCacheArtifact # noqa: F401 from torch._functorch._aot_autograd.autograd_cache import ( # noqa: F401 AOTAutogradCacheArtifact,