mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[MegaCache] Rename the PGO artifact when used between different jobs (#151482)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/151482 Approved by: https://github.com/bobrenjc93, https://github.com/jamesjwu
This commit is contained in:
committed by
PyTorch MergeBot
parent
fe90a5c140
commit
8404c09b15
@ -534,6 +534,56 @@ class TestFxGraphCache(TestCase):
|
||||
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 1)
|
||||
self.assertEqual(counters["inductor"]["fxgraph_lookup_write_file"], 1)
|
||||
|
||||
@torch._dynamo.config.patch(automatic_dynamic_local_pgo=True)
|
||||
@torch._functorch.config.patch({"enable_autograd_cache": False})
|
||||
@config.patch({"fx_graph_cache": True, "fx_graph_remote_cache": False})
|
||||
def test_cache_hot_load_pgo_swap_file_names(self):
|
||||
"""
|
||||
Verify that we can populate and hot load functions from the cache with pgo
|
||||
with file name swapping
|
||||
"""
|
||||
|
||||
backend = torch._dynamo.testing.CompileCounterWithBackend("inductor")
|
||||
|
||||
@torch.compile(backend=backend, fullgraph=True)
|
||||
def f(x):
|
||||
return x * 2
|
||||
|
||||
# Record artifacts
|
||||
with mock.patch(
|
||||
"torch._utils_internal.get_mast_job_name_version", return_value=("foo", 5)
|
||||
):
|
||||
with fresh_inductor_cache():
|
||||
f(torch.randn(2, 3))
|
||||
f(torch.randn(2, 4))
|
||||
self.assertEqual(backend.frame_count, 2)
|
||||
|
||||
artifacts = torch.compiler.save_cache_artifacts()
|
||||
|
||||
self.assertIsNotNone(artifacts)
|
||||
|
||||
artifact_bytes, cache_info = artifacts
|
||||
|
||||
self.assertEqual(len(cache_info.pgo_artifacts), 2)
|
||||
|
||||
self.reset()
|
||||
backend.clear()
|
||||
|
||||
# Clean triton kernels
|
||||
shutil.rmtree(os.path.join(cache_dir(), "triton"), ignore_errors=True)
|
||||
|
||||
# Hot load and hit
|
||||
with mock.patch(
|
||||
"torch._utils_internal.get_mast_job_name_version", return_value=("bar", 10)
|
||||
), fresh_inductor_cache():
|
||||
cache_info = torch.compiler.load_cache_artifacts(artifact_bytes)
|
||||
|
||||
self.assertEqual(len(cache_info.pgo_artifacts), 2)
|
||||
|
||||
f(torch.randn(2, 5))
|
||||
f(torch.randn(2, 6))
|
||||
self.assertEqual(backend.frame_count, 1)
|
||||
|
||||
@requires_triton()
|
||||
@config.patch({"fx_graph_cache": True})
|
||||
@config.patch({"fx_graph_remote_cache": False})
|
||||
|
@ -493,6 +493,20 @@ def get_cache_key() -> Optional[str]:
|
||||
return None
|
||||
|
||||
|
||||
def rewrite_cache_key_for_mega_cache(original_key: str) -> str:
|
||||
"""
|
||||
The PGO cache artifact key for a MAST job contains the job name and the version.
|
||||
When we want to use the cache artifact on a different MAST job, we need to
|
||||
update the key to use the new MAST job's name and version.
|
||||
"""
|
||||
if not original_key.startswith("mast:"):
|
||||
# if original_key is overriden, then dont change it
|
||||
return original_key
|
||||
if (new_key := get_cache_key()) is not None:
|
||||
return new_key
|
||||
return original_key
|
||||
|
||||
|
||||
# This solely controls local PGO
|
||||
def code_state_path(cache_key: str) -> Optional[str]:
|
||||
if not torch._dynamo.config.automatic_dynamic_local_pgo:
|
||||
|
@ -206,7 +206,7 @@ class CacheArtifactManager:
|
||||
log.warning("Failed to un-pickle cache artifacts", exc_info=True)
|
||||
return None
|
||||
|
||||
from torch._dynamo.pgo import write_local_impl
|
||||
from torch._dynamo.pgo import rewrite_cache_key_for_mega_cache, write_local_impl
|
||||
from torch._functorch._aot_autograd.autograd_cache import AOTAutogradCache
|
||||
from torch._inductor.codecache import FxGraphCache
|
||||
from torch._inductor.runtime.autotune_cache import _LocalAutotuneCacheBackend
|
||||
@ -226,7 +226,9 @@ class CacheArtifactManager:
|
||||
elif artifact.type == CacheArtifactType.AOT_AUTOGRAD:
|
||||
AOTAutogradCache._write_to_local_cache(artifact.key, artifact.content)
|
||||
elif artifact.type == CacheArtifactType.PGO:
|
||||
meta = write_local_impl(artifact.key, artifact.content)
|
||||
meta = write_local_impl(
|
||||
rewrite_cache_key_for_mega_cache(artifact.key), artifact.content
|
||||
)
|
||||
assert meta is not None
|
||||
else:
|
||||
log.warning(f"Unsupported artifact type {artifact.type}") # noqa: G004
|
||||
|
Reference in New Issue
Block a user