Make CompiledFxGraph portable between machines (#124438)

As we prepare FxGraphCache to move to remote, we need to make sure there's no data that is on the disk.

Differential Revision: [D56363808](https://our.internmc.facebook.com/intern/diff/D56363808)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/124438
Approved by: https://github.com/jansel
This commit is contained in:
Oguz Ulgen
2024-04-18 18:33:22 -07:00
committed by PyTorch MergeBot
parent c5a4ba2257
commit 0d64b82f0b
2 changed files with 16 additions and 4 deletions

View File

@ -1,5 +1,6 @@
# Owner(s): ["module: inductor"]
import functools
import os
import pickle
import unittest
from typing import List
@ -124,13 +125,18 @@ class TestFxGraphCache(TestCase):
self.assertEqual(fn(a, b), compiled_fn(a, b))
self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 1)
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0)
self.assertEqual(counters["inductor"]["fxgraph_lookup_write_file"], 0)
# A second call should hit. (First reset so in-memory guards
# don't prevent compilation).
torch._dynamo.reset()
for m in torch._inductor.codecache.PyCodeCache.cache.values():
os.remove(m.__file__)
torch._inductor.codecache.PyCodeCache.cache_clear()
self.assertEqual(fn(a, b), compiled_fn(a, b))
self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 1)
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 1)
self.assertEqual(counters["inductor"]["fxgraph_lookup_write_file"], 1)
@requires_triton()
@config.patch({"fx_graph_cache": True})

View File

@ -768,17 +768,21 @@ class FxGraphCache:
# See _save_graph(); we don't store the callable in the cache entry so
# recreate it here from the PyCodeCache disk cache.
artifact_path = get_path(graph.cache_key, "py")[2]
if not os.path.exists(artifact_path):
counters["inductor"]["fxgraph_lookup_write_file"] += 1
write_atomic(artifact_path, graph.source_code)
try:
graph.current_callable = PyCodeCache.load_by_key_path(
graph.cache_key,
graph.artifact_path,
artifact_path,
graph.cache_linemap,
graph.constants,
).call
except OSError:
# Not expected, but in case the PyCodeCache entry is removed from
# underneath us, treat it as a cache miss and recompile.
log.error("Failed to load cached artifact: %s", graph.artifact_path)
log.error("Failed to load cached artifact: %s", artifact_path)
return None
# Now re-evaluate with the symints to add any guards to the current env.
@ -914,7 +918,7 @@ class CompiledFxGraph:
current_callable: Optional[Callable[..., Any]]
cache_key: str
artifact_path: str
source_code: str = dataclasses.field(repr=False) # Do not display source_code
cache_linemap: Optional[List[Tuple[int, str]]]
device_types: Set[str]
device_idxs: Set[int]
@ -943,7 +947,9 @@ class CompiledFxGraph:
):
self.current_callable = current_callable
self.cache_key = graph.cache_key
self.artifact_path = graph.cache_path
if graph.cache_path:
with open(graph.cache_path) as f:
self.source_code = f.read()
self.cache_linemap = graph.cache_linemap
self.device_types = graph.device_types
self.device_idxs = graph.device_idxs