[MegaCache] Rename the PGO artifact when used between different jobs (#151482)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/151482
Approved by: https://github.com/bobrenjc93, https://github.com/jamesjwu
This commit is contained in:
Oguz Ulgen
2025-04-16 22:05:19 -07:00
committed by PyTorch MergeBot
parent fe90a5c140
commit 8404c09b15
3 changed files with 68 additions and 2 deletions

View File

@ -534,6 +534,56 @@ class TestFxGraphCache(TestCase):
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 1)
self.assertEqual(counters["inductor"]["fxgraph_lookup_write_file"], 1)
@torch._dynamo.config.patch(automatic_dynamic_local_pgo=True)
@torch._functorch.config.patch({"enable_autograd_cache": False})
@config.patch({"fx_graph_cache": True, "fx_graph_remote_cache": False})
def test_cache_hot_load_pgo_swap_file_names(self):
"""
Verify that we can populate and hot load functions from the cache with pgo
with file name swapping
"""
backend = torch._dynamo.testing.CompileCounterWithBackend("inductor")
@torch.compile(backend=backend, fullgraph=True)
def f(x):
return x * 2
# Record artifacts
with mock.patch(
"torch._utils_internal.get_mast_job_name_version", return_value=("foo", 5)
):
with fresh_inductor_cache():
f(torch.randn(2, 3))
f(torch.randn(2, 4))
self.assertEqual(backend.frame_count, 2)
artifacts = torch.compiler.save_cache_artifacts()
self.assertIsNotNone(artifacts)
artifact_bytes, cache_info = artifacts
self.assertEqual(len(cache_info.pgo_artifacts), 2)
self.reset()
backend.clear()
# Clean triton kernels
shutil.rmtree(os.path.join(cache_dir(), "triton"), ignore_errors=True)
# Hot load and hit
with mock.patch(
"torch._utils_internal.get_mast_job_name_version", return_value=("bar", 10)
), fresh_inductor_cache():
cache_info = torch.compiler.load_cache_artifacts(artifact_bytes)
self.assertEqual(len(cache_info.pgo_artifacts), 2)
f(torch.randn(2, 5))
f(torch.randn(2, 6))
self.assertEqual(backend.frame_count, 1)
@requires_triton()
@config.patch({"fx_graph_cache": True})
@config.patch({"fx_graph_remote_cache": False})

View File

@ -493,6 +493,20 @@ def get_cache_key() -> Optional[str]:
return None
def rewrite_cache_key_for_mega_cache(original_key: str) -> str:
"""
The PGO cache artifact key for a MAST job contains the job name and the version.
When we want to use the cache artifact on a different MAST job, we need to
update the key to use the new MAST job's name and version.
"""
if not original_key.startswith("mast:"):
# if original_key is overriden, then dont change it
return original_key
if (new_key := get_cache_key()) is not None:
return new_key
return original_key
# This solely controls local PGO
def code_state_path(cache_key: str) -> Optional[str]:
if not torch._dynamo.config.automatic_dynamic_local_pgo:

View File

@ -206,7 +206,7 @@ class CacheArtifactManager:
log.warning("Failed to un-pickle cache artifacts", exc_info=True)
return None
from torch._dynamo.pgo import write_local_impl
from torch._dynamo.pgo import rewrite_cache_key_for_mega_cache, write_local_impl
from torch._functorch._aot_autograd.autograd_cache import AOTAutogradCache
from torch._inductor.codecache import FxGraphCache
from torch._inductor.runtime.autotune_cache import _LocalAutotuneCacheBackend
@ -226,7 +226,9 @@ class CacheArtifactManager:
elif artifact.type == CacheArtifactType.AOT_AUTOGRAD:
AOTAutogradCache._write_to_local_cache(artifact.key, artifact.content)
elif artifact.type == CacheArtifactType.PGO:
meta = write_local_impl(artifact.key, artifact.content)
meta = write_local_impl(
rewrite_cache_key_for_mega_cache(artifact.key), artifact.content
)
assert meta is not None
else:
log.warning(f"Unsupported artifact type {artifact.type}") # noqa: G004