Megacache integration (#163533)

This diff adds megacache integration for DynamoCache.

Because DynamoCache requires lazy serialization, i.e. it can only be serialized once all relevant backends have been compiled and we're ready for a save, we actually do the DynamoCache saving only on a call to `torch.compiler.save_cache_artifacts`.

Differential Revision: [D82735763](https://our.internmc.facebook.com/intern/diff/D82735763/)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163533
Approved by: https://github.com/oulgen, https://github.com/zhxchen17
This commit is contained in:
James Wu
2025-10-15 10:42:44 -07:00
committed by PyTorch MergeBot
parent 53f9ae0e50
commit b54e466fd0
4 changed files with 165 additions and 9 deletions

View File

@ -16,6 +16,8 @@ from unittest import mock
import torch
from torch._dynamo import reset
from torch._dynamo.package import DynamoCache
from torch._dynamo.precompile_context import PrecompileContext
from torch._dynamo.utils import counters
from torch._functorch import config as functorch_config
from torch._functorch._aot_autograd.autograd_cache import AOTAutogradCache
@ -243,8 +245,12 @@ class TestFxGraphCache(TestCase):
def setUp(self):
super().setUp()
counters.clear()
DynamoCache.clear()
PrecompileContext.clear()
AOTAutogradCache.clear()
PatchCaches.setUp()
CacheArtifactManager.clear()
torch._dynamo.reset()
def tearDown(self):
super().tearDown()
@ -252,6 +258,8 @@ class TestFxGraphCache(TestCase):
def reset(self):
AOTAutogradCache.clear()
DynamoCache.clear()
PrecompileContext.clear()
PyCodeCache.cache_clear(purge=True)
torch._dynamo.reset()
clear_caches()
@ -595,6 +603,109 @@ class TestFxGraphCache(TestCase):
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 1)
self.assertEqual(counters["inductor"]["fxgraph_lookup_write_file"], 1)
@requires_triton()
@config.patch(
{
"fx_graph_cache": True,
"fx_graph_remote_cache": False,
"autotune_local_cache": True,
}
)
@torch._dynamo.config.patch(
{
"caching_precompile": True,
}
)
@parametrize("dynamic", (False, True))
@parametrize("device", (GPU_TYPE, "cpu"))
@parametrize("dtype", (torch.float32, torch.bfloat16))
def test_cache_hot_load_caching_precompile(self, device, dtype, dynamic):
"""
Verify that we can populate and hot load functions from the cache.
"""
if device == GPU_TYPE and not HAS_GPU:
raise unittest.SkipTest(f"requires {GPU_TYPE}")
if device == "cuda" and dtype == torch.bfloat16 and not SM80OrLater:
raise unittest.SkipTest("requires SM80 or later")
def fn(x, y):
return x.sin() @ y
a = torch.rand(100, 100, dtype=dtype, device=device, requires_grad=True)
b = torch.rand(100, 100, dtype=dtype, device=device, requires_grad=True)
# Record artifacts
with fresh_cache():
compiled_fn = torch.compile(fn, dynamic=dynamic)
# A first call should miss in the cache.
eager_result = fn(a, b)
compiled_result = compiled_fn(a, b)
compiled_result.sum().backward()
self.assertEqual(eager_result, compiled_result)
self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 1)
self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 0)
self.assertEqual(counters["dynamo_cache"]["dynamo_cache_miss"], 1)
self.assertEqual(counters["dynamo_cache"]["dynamo_cache_hit"], 0)
artifacts = torch.compiler.save_cache_artifacts()
self.assertIsNotNone(artifacts)
artifact_bytes, cache_info = artifacts
autotune_expect = 2 if device == GPU_TYPE else 0
self.assertEqual(len(cache_info.inductor_artifacts), 2)
self.assertEqual(len(cache_info.autotune_artifacts), autotune_expect)
self.assertEqual(len(cache_info.aot_autograd_artifacts), 1)
self.assertEqual(len(cache_info.pgo_artifacts), 0)
self.assertEqual(len(cache_info.precompile_artifacts), 1)
self.reset()
# Clean triton kernels
shutil.rmtree(os.path.join(cache_dir(), "triton"), ignore_errors=True)
# We did not load anything so dont hit yet
with fresh_cache():
eager_result = fn(a, b)
# With caching precompile, we have to re torch.compile the function
# to trigger cache lookup
compiled_fn = torch.compile(fn, dynamic=dynamic)
compiled_result = compiled_fn(a, b)
compiled_result.sum().backward()
self.assertEqual(eager_result, compiled_result)
self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 2)
self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 0)
self.assertEqual(counters["dynamo_cache"]["dynamo_cache_miss"], 2)
self.assertEqual(counters["dynamo_cache"]["dynamo_cache_hit"], 0)
self.reset()
# Clean triton kernels
shutil.rmtree(os.path.join(cache_dir(), "triton"), ignore_errors=True)
# Hot load and hit
with fresh_cache(), torch.compiler.set_stance("fail_on_recompile"):
cache_info = torch.compiler.load_cache_artifacts(artifact_bytes)
self.assertEqual(len(cache_info.inductor_artifacts), 2)
self.assertEqual(len(cache_info.autotune_artifacts), autotune_expect)
self.assertEqual(len(cache_info.aot_autograd_artifacts), 1)
self.assertEqual(len(cache_info.pgo_artifacts), 0)
self.assertEqual(len(cache_info.precompile_artifacts), 1)
# With caching precompile, we have to re torch.compile the function
# to trigger cache lookup
compiled_fn = torch.compile(fn, dynamic=dynamic)
eager_result = fn(a, b)
compiled_result = compiled_fn(a, b)
compiled_result.sum().backward()
self.assertEqual(eager_result, compiled_result)
self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 2)
self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 0)
self.assertEqual(counters["dynamo_cache"]["dynamo_cache_miss"], 2)
self.assertEqual(counters["dynamo_cache"]["dynamo_cache_hit"], 1)
@config.patch(
{
"fx_graph_cache": True,

View File

@ -34,7 +34,7 @@ from torch._dynamo.exc import PackageError
from torch._dynamo.graph_utils import _graph_device_type
from .bytecode_transformation import get_code_keys
from .utils import dynamo_timed, increment_frame
from .utils import counters, dynamo_timed, increment_frame
logger = logging.getLogger(__name__)
@ -433,6 +433,23 @@ class _DynamoCacheEntry:
}
from torch.compiler._cache import (
CacheArtifact,
CacheArtifactFactory,
CacheArtifactManager,
)
@CacheArtifactFactory.register
class PrecompileCacheArtifact(CacheArtifact):
def populate_cache(self) -> None:
DynamoCache._write_to_local_cache(self.content, self.key)
@staticmethod
def type() -> str:
return "precompile"
@dataclasses.dataclass
class PrecompileCacheEntry:
"""
@ -1026,14 +1043,17 @@ class DiskDynamoStore(DynamoStore):
Args:
path_prefix: Prefix directory for where to put CompilePackages on disk
"""
self.path_prefix = path_prefix
self._path_prefix = path_prefix
def path_prefix(self) -> str:
return self._path_prefix
def clear(self) -> None:
"""
Clear all CompilePackages from disk.
"""
if self.path_prefix:
shutil.rmtree(self.path_prefix, ignore_errors=True)
if self.path_prefix():
shutil.rmtree(self.path_prefix(), ignore_errors=True)
def write(
self,
@ -1043,12 +1063,21 @@ class DiskDynamoStore(DynamoStore):
"""
Write dynamo cache entry and backends to disk.
"""
try:
pickled_content: bytes = pickle.dumps(entry)
CacheArtifactManager.record_artifact(
PrecompileCacheArtifact.type(), path, pickled_content
)
self._write_to_local_cache(pickled_content, path)
except Exception as e:
raise RuntimeError(f"Failed to save package to {path}: {e}") from e
def _write_to_local_cache(self, pickled_content: bytes, path: str) -> None:
from torch._inductor.codecache import write_atomic
path = os.path.join(self.path_prefix, path) if self.path_prefix else path
path = os.path.join(self.path_prefix(), path) if self.path_prefix() else path
try:
os.makedirs(path, exist_ok=True)
pickled_content: bytes = pickle.dumps(entry)
write_atomic(os.path.join(path, "entry"), pickled_content)
except Exception as e:
raise RuntimeError(f"Failed to save package to {path}: {e}") from e
@ -1057,7 +1086,7 @@ class DiskDynamoStore(DynamoStore):
"""
Read dynamo cache entry and backends from disk.
"""
path = os.path.join(self.path_prefix, path) if self.path_prefix else path
path = os.path.join(self.path_prefix(), path) if self.path_prefix() else path
try:
with open(os.path.join(path, "entry"), "rb") as f:
pickled_content = f.read()
@ -1087,15 +1116,18 @@ class DiskDynamoCache(DiskDynamoStore):
"""
key = CompilePackage.source_id_from_fn(fn)
logger.info("Loading CompilePackage for %s", key)
path = os.path.join(self.path_prefix, key)
path = os.path.join(self.path_prefix(), key)
if os.path.exists(path):
try:
result = super().load_cache_entry(key)
counters["dynamo_cache"]["dynamo_cache_hit"] += 1
return result
except Exception as e:
counters["dynamo_cache"]["dynamo_cache_error"] += 1
logger.warning("Failed to load package from path %s: %s", path, str(e))
return None
logger.info("No package found for %s", key)
counters["dynamo_cache"]["dynamo_cache_miss"] += 1
return None
def load_and_install_package(
@ -1112,6 +1144,9 @@ class DiskDynamoCache(DiskDynamoStore):
package.install(results.backends)
return package
def path_prefix(self) -> str:
return os.path.join(cache_dir(), "dynamo")
def cache_dir() -> str:
from torch._inductor.runtime.cache_dir_utils import cache_dir

View File

@ -501,7 +501,12 @@ def save_cache_artifacts() -> Optional[tuple[bytes, "CacheInfo"]]:
- Execute torch.compile
- Call torch.compiler.save_cache_artifacts()
"""
from ._cache import CacheArtifactManager, CacheInfo
from ._cache import CacheArtifactManager
if torch._dynamo.config.caching_precompile:
from torch._dynamo.precompile_context import PrecompileContext
PrecompileContext.save_to_dynamo_cache()
return CacheArtifactManager.serialize()

View File

@ -130,6 +130,10 @@ class CacheInfo:
def pgo_artifacts(self) -> list[str]: # type: ignore[empty-body]
...
@property
def precompile_artifacts(self) -> list[str]: # type: ignore[empty-body]
...
def add(self, artifact: CacheArtifact) -> None:
self.artifacts[artifact.type()].append(artifact.key)
@ -307,6 +311,7 @@ class CacheArtifactManager:
cache artifacts are registered in the cache registry. This is done by
simply importing all the cache artifacts already wrapped with register call.
"""
from torch._dynamo.package import PrecompileCacheArtifact # noqa: F401
from torch._dynamo.pgo import PGOCacheArtifact # noqa: F401
from torch._functorch._aot_autograd.autograd_cache import ( # noqa: F401
AOTAutogradCacheArtifact,