mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Megacache integration (#163533)
This diff adds megacache integration for DynamoCache. Because DynamoCache requires lazy serialization, i.e. it can only be serialized once all relevant backends have been compiled and we're ready for a save, we actually do the DynamoCache saving only on a call to `torch.compiler.save_cache_artifacts`. Differential Revision: [D82735763](https://our.internmc.facebook.com/intern/diff/D82735763/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/163533 Approved by: https://github.com/oulgen, https://github.com/zhxchen17
This commit is contained in:
committed by
PyTorch MergeBot
parent
53f9ae0e50
commit
b54e466fd0
@ -16,6 +16,8 @@ from unittest import mock
|
||||
|
||||
import torch
|
||||
from torch._dynamo import reset
|
||||
from torch._dynamo.package import DynamoCache
|
||||
from torch._dynamo.precompile_context import PrecompileContext
|
||||
from torch._dynamo.utils import counters
|
||||
from torch._functorch import config as functorch_config
|
||||
from torch._functorch._aot_autograd.autograd_cache import AOTAutogradCache
|
||||
@ -243,8 +245,12 @@ class TestFxGraphCache(TestCase):
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
counters.clear()
|
||||
DynamoCache.clear()
|
||||
PrecompileContext.clear()
|
||||
AOTAutogradCache.clear()
|
||||
PatchCaches.setUp()
|
||||
CacheArtifactManager.clear()
|
||||
torch._dynamo.reset()
|
||||
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
@ -252,6 +258,8 @@ class TestFxGraphCache(TestCase):
|
||||
|
||||
def reset(self):
|
||||
AOTAutogradCache.clear()
|
||||
DynamoCache.clear()
|
||||
PrecompileContext.clear()
|
||||
PyCodeCache.cache_clear(purge=True)
|
||||
torch._dynamo.reset()
|
||||
clear_caches()
|
||||
@ -595,6 +603,109 @@ class TestFxGraphCache(TestCase):
|
||||
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 1)
|
||||
self.assertEqual(counters["inductor"]["fxgraph_lookup_write_file"], 1)
|
||||
|
||||
@requires_triton()
|
||||
@config.patch(
|
||||
{
|
||||
"fx_graph_cache": True,
|
||||
"fx_graph_remote_cache": False,
|
||||
"autotune_local_cache": True,
|
||||
}
|
||||
)
|
||||
@torch._dynamo.config.patch(
|
||||
{
|
||||
"caching_precompile": True,
|
||||
}
|
||||
)
|
||||
@parametrize("dynamic", (False, True))
|
||||
@parametrize("device", (GPU_TYPE, "cpu"))
|
||||
@parametrize("dtype", (torch.float32, torch.bfloat16))
|
||||
def test_cache_hot_load_caching_precompile(self, device, dtype, dynamic):
|
||||
"""
|
||||
Verify that we can populate and hot load functions from the cache.
|
||||
"""
|
||||
|
||||
if device == GPU_TYPE and not HAS_GPU:
|
||||
raise unittest.SkipTest(f"requires {GPU_TYPE}")
|
||||
if device == "cuda" and dtype == torch.bfloat16 and not SM80OrLater:
|
||||
raise unittest.SkipTest("requires SM80 or later")
|
||||
|
||||
def fn(x, y):
|
||||
return x.sin() @ y
|
||||
|
||||
a = torch.rand(100, 100, dtype=dtype, device=device, requires_grad=True)
|
||||
b = torch.rand(100, 100, dtype=dtype, device=device, requires_grad=True)
|
||||
|
||||
# Record artifacts
|
||||
with fresh_cache():
|
||||
compiled_fn = torch.compile(fn, dynamic=dynamic)
|
||||
|
||||
# A first call should miss in the cache.
|
||||
eager_result = fn(a, b)
|
||||
compiled_result = compiled_fn(a, b)
|
||||
compiled_result.sum().backward()
|
||||
self.assertEqual(eager_result, compiled_result)
|
||||
self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 1)
|
||||
self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 0)
|
||||
self.assertEqual(counters["dynamo_cache"]["dynamo_cache_miss"], 1)
|
||||
self.assertEqual(counters["dynamo_cache"]["dynamo_cache_hit"], 0)
|
||||
|
||||
artifacts = torch.compiler.save_cache_artifacts()
|
||||
|
||||
self.assertIsNotNone(artifacts)
|
||||
|
||||
artifact_bytes, cache_info = artifacts
|
||||
|
||||
autotune_expect = 2 if device == GPU_TYPE else 0
|
||||
self.assertEqual(len(cache_info.inductor_artifacts), 2)
|
||||
self.assertEqual(len(cache_info.autotune_artifacts), autotune_expect)
|
||||
self.assertEqual(len(cache_info.aot_autograd_artifacts), 1)
|
||||
self.assertEqual(len(cache_info.pgo_artifacts), 0)
|
||||
self.assertEqual(len(cache_info.precompile_artifacts), 1)
|
||||
|
||||
self.reset()
|
||||
|
||||
# Clean triton kernels
|
||||
shutil.rmtree(os.path.join(cache_dir(), "triton"), ignore_errors=True)
|
||||
|
||||
# We did not load anything so dont hit yet
|
||||
with fresh_cache():
|
||||
eager_result = fn(a, b)
|
||||
# With caching precompile, we have to re torch.compile the function
|
||||
# to trigger cache lookup
|
||||
compiled_fn = torch.compile(fn, dynamic=dynamic)
|
||||
compiled_result = compiled_fn(a, b)
|
||||
compiled_result.sum().backward()
|
||||
self.assertEqual(eager_result, compiled_result)
|
||||
self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 2)
|
||||
self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 0)
|
||||
self.assertEqual(counters["dynamo_cache"]["dynamo_cache_miss"], 2)
|
||||
self.assertEqual(counters["dynamo_cache"]["dynamo_cache_hit"], 0)
|
||||
self.reset()
|
||||
# Clean triton kernels
|
||||
shutil.rmtree(os.path.join(cache_dir(), "triton"), ignore_errors=True)
|
||||
|
||||
# Hot load and hit
|
||||
with fresh_cache(), torch.compiler.set_stance("fail_on_recompile"):
|
||||
cache_info = torch.compiler.load_cache_artifacts(artifact_bytes)
|
||||
self.assertEqual(len(cache_info.inductor_artifacts), 2)
|
||||
self.assertEqual(len(cache_info.autotune_artifacts), autotune_expect)
|
||||
self.assertEqual(len(cache_info.aot_autograd_artifacts), 1)
|
||||
self.assertEqual(len(cache_info.pgo_artifacts), 0)
|
||||
self.assertEqual(len(cache_info.precompile_artifacts), 1)
|
||||
|
||||
# With caching precompile, we have to re torch.compile the function
|
||||
# to trigger cache lookup
|
||||
compiled_fn = torch.compile(fn, dynamic=dynamic)
|
||||
|
||||
eager_result = fn(a, b)
|
||||
compiled_result = compiled_fn(a, b)
|
||||
compiled_result.sum().backward()
|
||||
self.assertEqual(eager_result, compiled_result)
|
||||
self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 2)
|
||||
self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 0)
|
||||
self.assertEqual(counters["dynamo_cache"]["dynamo_cache_miss"], 2)
|
||||
self.assertEqual(counters["dynamo_cache"]["dynamo_cache_hit"], 1)
|
||||
|
||||
@config.patch(
|
||||
{
|
||||
"fx_graph_cache": True,
|
||||
|
@ -34,7 +34,7 @@ from torch._dynamo.exc import PackageError
|
||||
from torch._dynamo.graph_utils import _graph_device_type
|
||||
|
||||
from .bytecode_transformation import get_code_keys
|
||||
from .utils import dynamo_timed, increment_frame
|
||||
from .utils import counters, dynamo_timed, increment_frame
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -433,6 +433,23 @@ class _DynamoCacheEntry:
|
||||
}
|
||||
|
||||
|
||||
from torch.compiler._cache import (
|
||||
CacheArtifact,
|
||||
CacheArtifactFactory,
|
||||
CacheArtifactManager,
|
||||
)
|
||||
|
||||
|
||||
@CacheArtifactFactory.register
|
||||
class PrecompileCacheArtifact(CacheArtifact):
|
||||
def populate_cache(self) -> None:
|
||||
DynamoCache._write_to_local_cache(self.content, self.key)
|
||||
|
||||
@staticmethod
|
||||
def type() -> str:
|
||||
return "precompile"
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class PrecompileCacheEntry:
|
||||
"""
|
||||
@ -1026,14 +1043,17 @@ class DiskDynamoStore(DynamoStore):
|
||||
Args:
|
||||
path_prefix: Prefix directory for where to put CompilePackages on disk
|
||||
"""
|
||||
self.path_prefix = path_prefix
|
||||
self._path_prefix = path_prefix
|
||||
|
||||
def path_prefix(self) -> str:
|
||||
return self._path_prefix
|
||||
|
||||
def clear(self) -> None:
|
||||
"""
|
||||
Clear all CompilePackages from disk.
|
||||
"""
|
||||
if self.path_prefix:
|
||||
shutil.rmtree(self.path_prefix, ignore_errors=True)
|
||||
if self.path_prefix():
|
||||
shutil.rmtree(self.path_prefix(), ignore_errors=True)
|
||||
|
||||
def write(
|
||||
self,
|
||||
@ -1043,12 +1063,21 @@ class DiskDynamoStore(DynamoStore):
|
||||
"""
|
||||
Write dynamo cache entry and backends to disk.
|
||||
"""
|
||||
try:
|
||||
pickled_content: bytes = pickle.dumps(entry)
|
||||
CacheArtifactManager.record_artifact(
|
||||
PrecompileCacheArtifact.type(), path, pickled_content
|
||||
)
|
||||
self._write_to_local_cache(pickled_content, path)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to save package to {path}: {e}") from e
|
||||
|
||||
def _write_to_local_cache(self, pickled_content: bytes, path: str) -> None:
|
||||
from torch._inductor.codecache import write_atomic
|
||||
|
||||
path = os.path.join(self.path_prefix, path) if self.path_prefix else path
|
||||
path = os.path.join(self.path_prefix(), path) if self.path_prefix() else path
|
||||
try:
|
||||
os.makedirs(path, exist_ok=True)
|
||||
pickled_content: bytes = pickle.dumps(entry)
|
||||
write_atomic(os.path.join(path, "entry"), pickled_content)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to save package to {path}: {e}") from e
|
||||
@ -1057,7 +1086,7 @@ class DiskDynamoStore(DynamoStore):
|
||||
"""
|
||||
Read dynamo cache entry and backends from disk.
|
||||
"""
|
||||
path = os.path.join(self.path_prefix, path) if self.path_prefix else path
|
||||
path = os.path.join(self.path_prefix(), path) if self.path_prefix() else path
|
||||
try:
|
||||
with open(os.path.join(path, "entry"), "rb") as f:
|
||||
pickled_content = f.read()
|
||||
@ -1087,15 +1116,18 @@ class DiskDynamoCache(DiskDynamoStore):
|
||||
"""
|
||||
key = CompilePackage.source_id_from_fn(fn)
|
||||
logger.info("Loading CompilePackage for %s", key)
|
||||
path = os.path.join(self.path_prefix, key)
|
||||
path = os.path.join(self.path_prefix(), key)
|
||||
if os.path.exists(path):
|
||||
try:
|
||||
result = super().load_cache_entry(key)
|
||||
counters["dynamo_cache"]["dynamo_cache_hit"] += 1
|
||||
return result
|
||||
except Exception as e:
|
||||
counters["dynamo_cache"]["dynamo_cache_error"] += 1
|
||||
logger.warning("Failed to load package from path %s: %s", path, str(e))
|
||||
return None
|
||||
logger.info("No package found for %s", key)
|
||||
counters["dynamo_cache"]["dynamo_cache_miss"] += 1
|
||||
return None
|
||||
|
||||
def load_and_install_package(
|
||||
@ -1112,6 +1144,9 @@ class DiskDynamoCache(DiskDynamoStore):
|
||||
package.install(results.backends)
|
||||
return package
|
||||
|
||||
def path_prefix(self) -> str:
|
||||
return os.path.join(cache_dir(), "dynamo")
|
||||
|
||||
|
||||
def cache_dir() -> str:
|
||||
from torch._inductor.runtime.cache_dir_utils import cache_dir
|
||||
|
@ -501,7 +501,12 @@ def save_cache_artifacts() -> Optional[tuple[bytes, "CacheInfo"]]:
|
||||
- Execute torch.compile
|
||||
- Call torch.compiler.save_cache_artifacts()
|
||||
"""
|
||||
from ._cache import CacheArtifactManager, CacheInfo
|
||||
from ._cache import CacheArtifactManager
|
||||
|
||||
if torch._dynamo.config.caching_precompile:
|
||||
from torch._dynamo.precompile_context import PrecompileContext
|
||||
|
||||
PrecompileContext.save_to_dynamo_cache()
|
||||
|
||||
return CacheArtifactManager.serialize()
|
||||
|
||||
|
@ -130,6 +130,10 @@ class CacheInfo:
|
||||
def pgo_artifacts(self) -> list[str]: # type: ignore[empty-body]
|
||||
...
|
||||
|
||||
@property
|
||||
def precompile_artifacts(self) -> list[str]: # type: ignore[empty-body]
|
||||
...
|
||||
|
||||
def add(self, artifact: CacheArtifact) -> None:
|
||||
self.artifacts[artifact.type()].append(artifact.key)
|
||||
|
||||
@ -307,6 +311,7 @@ class CacheArtifactManager:
|
||||
cache artifacts are registered in the cache registry. This is done by
|
||||
simply importing all the cache artifacts already wrapped with register call.
|
||||
"""
|
||||
from torch._dynamo.package import PrecompileCacheArtifact # noqa: F401
|
||||
from torch._dynamo.pgo import PGOCacheArtifact # noqa: F401
|
||||
from torch._functorch._aot_autograd.autograd_cache import ( # noqa: F401
|
||||
AOTAutogradCacheArtifact,
|
||||
|
Reference in New Issue
Block a user