mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Make CompiledFxGraph portable between machines (#124438)
As we prepare FxGraphCache to move to remote, we need to make sure there's no data that is on the disk. Differential Revision: [D56363808](https://our.internmc.facebook.com/intern/diff/D56363808) Pull Request resolved: https://github.com/pytorch/pytorch/pull/124438 Approved by: https://github.com/jansel
This commit is contained in:
committed by
PyTorch MergeBot
parent
c5a4ba2257
commit
0d64b82f0b
@ -1,5 +1,6 @@
|
||||
# Owner(s): ["module: inductor"]
|
||||
import functools
|
||||
import os
|
||||
import pickle
|
||||
import unittest
|
||||
from typing import List
|
||||
@ -124,13 +125,18 @@ class TestFxGraphCache(TestCase):
|
||||
self.assertEqual(fn(a, b), compiled_fn(a, b))
|
||||
self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 1)
|
||||
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0)
|
||||
self.assertEqual(counters["inductor"]["fxgraph_lookup_write_file"], 0)
|
||||
|
||||
# A second call should hit. (First reset so in-memory guards
|
||||
# don't prevent compilation).
|
||||
torch._dynamo.reset()
|
||||
for m in torch._inductor.codecache.PyCodeCache.cache.values():
|
||||
os.remove(m.__file__)
|
||||
torch._inductor.codecache.PyCodeCache.cache_clear()
|
||||
self.assertEqual(fn(a, b), compiled_fn(a, b))
|
||||
self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 1)
|
||||
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 1)
|
||||
self.assertEqual(counters["inductor"]["fxgraph_lookup_write_file"], 1)
|
||||
|
||||
@requires_triton()
|
||||
@config.patch({"fx_graph_cache": True})
|
||||
|
@ -768,17 +768,21 @@ class FxGraphCache:
|
||||
|
||||
# See _save_graph(); we don't store the callable in the cache entry so
|
||||
# recreate it here from the PyCodeCache disk cache.
|
||||
artifact_path = get_path(graph.cache_key, "py")[2]
|
||||
if not os.path.exists(artifact_path):
|
||||
counters["inductor"]["fxgraph_lookup_write_file"] += 1
|
||||
write_atomic(artifact_path, graph.source_code)
|
||||
try:
|
||||
graph.current_callable = PyCodeCache.load_by_key_path(
|
||||
graph.cache_key,
|
||||
graph.artifact_path,
|
||||
artifact_path,
|
||||
graph.cache_linemap,
|
||||
graph.constants,
|
||||
).call
|
||||
except OSError:
|
||||
# Not expected, but in case the PyCodeCache entry is removed from
|
||||
# underneath us, treat it as a cache miss and recompile.
|
||||
log.error("Failed to load cached artifact: %s", graph.artifact_path)
|
||||
log.error("Failed to load cached artifact: %s", artifact_path)
|
||||
return None
|
||||
|
||||
# Now re-evaluate with the symints to add any guards to the current env.
|
||||
@ -914,7 +918,7 @@ class CompiledFxGraph:
|
||||
|
||||
current_callable: Optional[Callable[..., Any]]
|
||||
cache_key: str
|
||||
artifact_path: str
|
||||
source_code: str = dataclasses.field(repr=False) # Do not display source_code
|
||||
cache_linemap: Optional[List[Tuple[int, str]]]
|
||||
device_types: Set[str]
|
||||
device_idxs: Set[int]
|
||||
@ -943,7 +947,9 @@ class CompiledFxGraph:
|
||||
):
|
||||
self.current_callable = current_callable
|
||||
self.cache_key = graph.cache_key
|
||||
self.artifact_path = graph.cache_path
|
||||
if graph.cache_path:
|
||||
with open(graph.cache_path) as f:
|
||||
self.source_code = f.read()
|
||||
self.cache_linemap = graph.cache_linemap
|
||||
self.device_types = graph.device_types
|
||||
self.device_idxs = graph.device_idxs
|
||||
|
Reference in New Issue
Block a user