mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Revert "[WIP] Automatically load and save dynamo entries via caching_precompile (#155913)"
This reverts commit e466dab164d9236bfe5817ec8e4d24c7b9d3e392. Reverted https://github.com/pytorch/pytorch/pull/155913 on behalf of https://github.com/huydhn due to Sorry for reverting your change but it seems to fail a test in trunk ([comment](https://github.com/pytorch/pytorch/pull/155913#issuecomment-3045914878))
This commit is contained in:
@ -12,8 +12,7 @@ import torch._inductor.config
|
||||
import torch._inductor.test_case
|
||||
import torch.onnx.operators
|
||||
import torch.utils.cpp_extension
|
||||
from torch._dynamo.package import CompilePackage, DiskDynamoStore, DynamoCache
|
||||
from torch._dynamo.precompile_context import PrecompileContext
|
||||
from torch._dynamo.package import CompilePackage, DiskDynamoStore
|
||||
from torch._functorch import config as functorch_config
|
||||
from torch._inductor.runtime.runtime_utils import cache_dir
|
||||
from torch.testing._internal.common_utils import (
|
||||
@ -31,38 +30,6 @@ class TestPackage(torch._inductor.test_case.TestCase):
|
||||
os.makedirs(path, exist_ok=True)
|
||||
return path
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
torch._dynamo.reset()
|
||||
torch._dynamo.utils.counters.clear()
|
||||
DynamoCache.clear()
|
||||
PrecompileContext.clear()
|
||||
|
||||
def _save_and_reload(self, expected_backends, expected_dynamo):
|
||||
"""
|
||||
Serializes all artifacts, clears all caches, then reloads the serialized artifact
|
||||
Simulates a new process.
|
||||
|
||||
Args:
|
||||
expected_backends: Expected number of precompile_aot_autograd_artifacts
|
||||
expected_dynamo: Expected number of precompile_dynamo_artifacts
|
||||
"""
|
||||
serialized = PrecompileContext.serialize()
|
||||
assert serialized is not None
|
||||
(bytes_, cache_info) = serialized
|
||||
self.assertEqual(
|
||||
len(cache_info.precompile_aot_autograd_artifacts), expected_backends
|
||||
)
|
||||
self.assertEqual(len(cache_info.precompile_dynamo_artifacts), expected_dynamo)
|
||||
|
||||
torch._dynamo.reset()
|
||||
DynamoCache.clear()
|
||||
PrecompileContext.clear()
|
||||
|
||||
deserialized = PrecompileContext.deserialize(bytes_)
|
||||
assert deserialized is not None
|
||||
PrecompileContext.populate_caches(deserialized)
|
||||
|
||||
@parametrize("backend", ("eager", "inductor"))
|
||||
@parametrize("device", ("cpu", "cuda", "xpu"))
|
||||
def test_basic_fn(self, backend, device):
|
||||
@ -338,173 +305,6 @@ def add(x, y):
|
||||
)
|
||||
ctx.load_package(fn, self.path())
|
||||
|
||||
@parametrize("device", ("cpu", "cuda", "xpu"))
|
||||
def test_dynamo_cache_manual_load(self, device):
|
||||
if device == "cuda" and not HAS_CUDA:
|
||||
raise unittest.SkipTest("Requires CUDA/Triton")
|
||||
if device == "xpu" and not HAS_XPU:
|
||||
raise unittest.SkipTest("Requires XPU/Triton")
|
||||
|
||||
def fn(x):
|
||||
return x.sin() + x.cos()
|
||||
|
||||
def fn2(x):
|
||||
return x.cos() + x
|
||||
|
||||
package1 = CompilePackage(fn)
|
||||
package2 = CompilePackage(fn2)
|
||||
compiled_fn1 = torch._dynamo.optimize(backend="inductor", package=package1)(fn)
|
||||
compiled_fn2 = torch._dynamo.optimize(backend="inductor", package=package2)(fn2)
|
||||
arg1 = torch.randn(3, 2, device=device)
|
||||
arg2 = torch.randn(5, 2, device=device)
|
||||
expected = [compiled_fn1(arg1), compiled_fn2(arg2)]
|
||||
|
||||
DynamoCache.save(package1)
|
||||
DynamoCache.save(package2)
|
||||
|
||||
self._save_and_reload(expected_backends=2, expected_dynamo=2)
|
||||
|
||||
# These should exist because of populate_caches
|
||||
package1 = DynamoCache.load_and_install_package(fn)
|
||||
package2 = DynamoCache.load_and_install_package(fn2)
|
||||
|
||||
with torch.compiler.set_stance("fail_on_recompile"):
|
||||
result1 = compiled_fn1(arg1)
|
||||
result2 = compiled_fn2(arg2)
|
||||
self.assertEqual(expected, [result1, result2])
|
||||
|
||||
@parametrize("device", ("cpu", "cuda", "xpu"))
|
||||
@torch._dynamo.config.patch(caching_precompile=True)
|
||||
def test_automatic_dynamo_serialize(self, device):
|
||||
if device == "cuda" and not HAS_CUDA:
|
||||
raise unittest.SkipTest("Requires CUDA/Triton")
|
||||
if device == "xpu" and not HAS_XPU:
|
||||
raise unittest.SkipTest("Requires XPU/Triton")
|
||||
|
||||
def fn(x):
|
||||
return x.sin() + x.cos()
|
||||
|
||||
def fn2(x):
|
||||
return x.cos() + x
|
||||
|
||||
arg1 = torch.randn(3, 2, device=device)
|
||||
arg2 = torch.randn(5, 2, device=device)
|
||||
expected = [fn(arg1), fn2(arg2)]
|
||||
compiled_fn1 = torch.compile(fn)
|
||||
compiled_fn2 = torch.compile(fn2)
|
||||
result = [compiled_fn1(arg1), compiled_fn2(arg2)]
|
||||
self.assertEqual(expected, result)
|
||||
DynamoCache.clear()
|
||||
|
||||
self._save_and_reload(expected_backends=2, expected_dynamo=2)
|
||||
|
||||
compiled_fn1 = torch.compile(fn)
|
||||
compiled_fn2 = torch.compile(fn2)
|
||||
with torch.compiler.set_stance("fail_on_recompile"):
|
||||
result1 = compiled_fn1(arg1)
|
||||
result2 = compiled_fn2(arg2)
|
||||
self.assertEqual(expected, [result1, result2])
|
||||
|
||||
@parametrize("device", ("cpu", "cuda", "xpu"))
|
||||
@torch._dynamo.config.patch(caching_precompile=True)
|
||||
def test_automatic_dynamo_recompiles(self, device):
|
||||
if device == "cuda" and not HAS_CUDA:
|
||||
raise unittest.SkipTest("Requires CUDA/Triton")
|
||||
if device == "xpu" and not HAS_XPU:
|
||||
raise unittest.SkipTest("Requires XPU/Triton")
|
||||
|
||||
def fn(x):
|
||||
return x.sin() + x.cos()
|
||||
|
||||
arg1 = torch.randn(3, 2, device=device)
|
||||
arg2 = torch.randn(5, 2, device=device)
|
||||
compiled_fn = torch.compile(fn)
|
||||
expected1 = compiled_fn(arg1)
|
||||
|
||||
# Should cause a recompile
|
||||
expected2 = compiled_fn(arg2)
|
||||
|
||||
self._save_and_reload(expected_backends=2, expected_dynamo=1)
|
||||
|
||||
compiled_fn = torch.compile(fn)
|
||||
with torch.compiler.set_stance("fail_on_recompile"):
|
||||
result1 = compiled_fn(arg1)
|
||||
result2 = compiled_fn(arg2)
|
||||
# Because of automatic dynamic, a third random shape should also not cause a recompile
|
||||
arg3 = torch.randn(7, 2, device=device)
|
||||
compiled_fn(arg3)
|
||||
self.assertEqual(result1, expected1)
|
||||
self.assertEqual(result2, expected2)
|
||||
|
||||
@parametrize("device", ("cpu", "cuda", "xpu"))
|
||||
@torch._dynamo.config.patch(caching_precompile=True)
|
||||
def test_automatic_dynamo_graph_breaks(self, device):
|
||||
if device == "cuda" and not HAS_CUDA:
|
||||
raise unittest.SkipTest("Requires CUDA/Triton")
|
||||
if device == "xpu" and not HAS_XPU:
|
||||
raise unittest.SkipTest("Requires XPU/Triton")
|
||||
|
||||
def fn(x, l, r):
|
||||
if l > r:
|
||||
return x.sum()
|
||||
mid = (l + r) // 2
|
||||
if x.sum() == mid:
|
||||
return x.sum()
|
||||
elif x.sum() < mid:
|
||||
return fn(x, l, mid)
|
||||
else:
|
||||
return fn(x, mid + 1, r)
|
||||
|
||||
def guard_filter_fn(guards):
|
||||
return [
|
||||
guard.guard_type not in ("CLOSURE_MATCH", "FUNCTION_MATCH")
|
||||
for guard in guards
|
||||
]
|
||||
|
||||
# Saving
|
||||
compiled_fn = torch._dynamo.optimize(
|
||||
backend="inductor", guard_filter_fn=guard_filter_fn
|
||||
)(fn)
|
||||
N = 10
|
||||
args_list = [(torch.tensor(x, device=device), 0, N - 1) for x in range(N)]
|
||||
for args in args_list:
|
||||
compiled_fn(*args)
|
||||
|
||||
self._save_and_reload(expected_backends=8, expected_dynamo=1)
|
||||
|
||||
compiled_fn = torch._dynamo.optimize(
|
||||
backend="inductor", guard_filter_fn=guard_filter_fn
|
||||
)(fn)
|
||||
with torch.compiler.set_stance("fail_on_recompile"):
|
||||
for args in args_list:
|
||||
self.assertEqual(compiled_fn(*args), args[0].sum())
|
||||
|
||||
@parametrize("device", ("cpu", "cuda", "xpu"))
|
||||
@torch._dynamo.config.patch(caching_precompile=True)
|
||||
def test_automatic_dynamo_lazy_backward(self, device):
|
||||
if device == "cuda" and not HAS_CUDA:
|
||||
raise unittest.SkipTest("Requires CUDA/Triton")
|
||||
if device == "xpu" and not HAS_XPU:
|
||||
raise unittest.SkipTest("Requires XPU/Triton")
|
||||
|
||||
def fn(x):
|
||||
return x.sin() + x.cos()
|
||||
|
||||
arg1 = torch.randn(3, 2, device=device, requires_grad=True)
|
||||
arg2 = arg1.clone().detach_().requires_grad_(True)
|
||||
|
||||
compiled_fn = torch.compile(fn)
|
||||
expected1 = compiled_fn(arg1)
|
||||
expected1.sum().backward()
|
||||
|
||||
self._save_and_reload(expected_backends=1, expected_dynamo=1)
|
||||
|
||||
compiled_fn = torch.compile(fn)
|
||||
# Run it again, no recompile needed
|
||||
with torch.compiler.set_stance("fail_on_recompile"):
|
||||
expected2 = compiled_fn(arg2)
|
||||
expected2.sum().backward()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._dynamo.test_case import run_tests
|
||||
|
@ -542,10 +542,6 @@ fake_tensor_cache_crosscheck_enabled = (
|
||||
# the inference_mode is still respected.
|
||||
fake_tensor_disable_inference_mode = True
|
||||
|
||||
# Experimental feature for running automatic caching precompile.
|
||||
# Enables automatic DynamoCache save/load
|
||||
caching_precompile = False
|
||||
|
||||
# Enables the Compiled Autograd engine to trace autograd calls made under torch.compile().
|
||||
# Note: AOTAutograd will still trace and partition an AOT backward graph local to that
|
||||
# compiled region. But AOTAutograd traces without knowledge of backward hooks which are
|
||||
|
@ -531,6 +531,7 @@ class ConvertFrameAssert:
|
||||
skip: int = 0,
|
||||
) -> ConvertFrameReturn:
|
||||
increment_frame()
|
||||
|
||||
code = frame.f_code
|
||||
|
||||
cache_size = compute_cache_size(frame, cache_entry)
|
||||
@ -648,7 +649,7 @@ class ConvertFrameAssert:
|
||||
dynamo_tls.traced_frame_infos.append(info)
|
||||
|
||||
with compile_context(CompileContext(compile_id)):
|
||||
result = _compile(
|
||||
return _compile(
|
||||
frame.f_code,
|
||||
frame.f_globals,
|
||||
frame.f_locals,
|
||||
@ -669,13 +670,6 @@ class ConvertFrameAssert:
|
||||
convert_frame_box=self._box,
|
||||
)
|
||||
|
||||
if config.caching_precompile and self._package is not None:
|
||||
from .package import DynamoCache
|
||||
|
||||
# Record that the dynamo package has changed
|
||||
DynamoCache.record_package(self._package)
|
||||
return result
|
||||
|
||||
|
||||
def convert_frame_assert(
|
||||
compiler_fn: CompilerFn,
|
||||
|
@ -655,27 +655,6 @@ class _TorchDynamoContext:
|
||||
def get_compiler_config():
|
||||
return self.compiler_config
|
||||
|
||||
from .package import DynamoCache
|
||||
|
||||
# If self._package is lazily initialized, we should check the dynamo cache now
|
||||
if config.caching_precompile:
|
||||
assert self._package is not None
|
||||
if not self._package.is_initialized():
|
||||
result = DynamoCache.load(fn)
|
||||
if result is None:
|
||||
# Create a fresh CompilePackage
|
||||
self._package.initialize(fn, None, ignore_inlined_sources=False)
|
||||
else:
|
||||
cache_entry, backends = result
|
||||
try:
|
||||
self._package.initialize(
|
||||
fn, cache_entry, ignore_inlined_sources=False
|
||||
)
|
||||
self._package.install(backends)
|
||||
except RuntimeError as e:
|
||||
log.warning("Failed to load entry from dynamo cache: %s", e)
|
||||
self._package.initialize(fn, None, ignore_inlined_sources=False)
|
||||
|
||||
fn = innermost_fn(fn)
|
||||
|
||||
# add context containing GraphModule to any GraphModule forward functions
|
||||
@ -1173,20 +1152,8 @@ def _optimize(
|
||||
# The backend function is stashed in the callable returned by
|
||||
# _optimize_catch_errors in the field _torchdynamo_orig_backend. This can
|
||||
# be used by eval_frame.c to insert a guard on the backend.
|
||||
|
||||
# With CachingPrecompile, instantiate an uninitialized CompilePackage
|
||||
# which gets initialized by _optimize_catch_errors.__call__ once we have a function
|
||||
if config.caching_precompile and package is None:
|
||||
from .package import CompilePackage
|
||||
|
||||
package = CompilePackage(fn=None, dynamo=None, ignore_inlined_sources=False)
|
||||
|
||||
return _optimize_catch_errors(
|
||||
convert_frame.convert_frame(
|
||||
backend,
|
||||
hooks,
|
||||
package=package,
|
||||
),
|
||||
convert_frame.convert_frame(backend, hooks, package=package),
|
||||
hooks,
|
||||
backend_ctx_ctor,
|
||||
error_on_graph_break=nopython,
|
||||
@ -2102,16 +2069,6 @@ def _optimize_assert(
|
||||
# Find if backend has any extra context manager
|
||||
backend_ctx_ctor = getattr(backend, "backend_ctx_ctor", null_context)
|
||||
|
||||
if config.caching_precompile and package is None:
|
||||
# Create an uninitialized package that will be set/filled by
|
||||
# _OptimizeContext.__call__
|
||||
# We need to instantiate the object here because the same CompilePackage
|
||||
# needs to be shared between convert_frame_assert
|
||||
# and OptimizeContext.
|
||||
from .package import CompilePackage
|
||||
|
||||
package = CompilePackage(fn=None, dynamo=None, ignore_inlined_sources=False)
|
||||
|
||||
return _optimize_catch_errors(
|
||||
convert_frame.convert_frame_assert(
|
||||
backend,
|
||||
|
@ -19,16 +19,14 @@ import logging
|
||||
import os
|
||||
import pickle
|
||||
import platform
|
||||
import shutil
|
||||
import sys
|
||||
import types
|
||||
from collections.abc import Generator
|
||||
from typing import Any, Callable, NewType, Optional
|
||||
from typing import Any, NewType, Optional
|
||||
|
||||
import torch
|
||||
import torch._inductor.package
|
||||
from torch._dynamo.precompile_context import PrecompileCacheArtifact, PrecompileContext
|
||||
from torch._inductor.runtime.cache_dir_utils import cache_dir
|
||||
from torch.compiler._cache import CacheArtifactFactory
|
||||
|
||||
from .bytecode_transformation import get_code_keys
|
||||
@ -185,7 +183,7 @@ class CompilePackage:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
fn: Optional[Callable[..., Any]],
|
||||
fn: Any,
|
||||
dynamo: Optional[_DynamoCacheEntry] = None,
|
||||
ignore_inlined_sources: bool = False,
|
||||
) -> None:
|
||||
@ -199,16 +197,12 @@ class CompilePackage:
|
||||
self._cached_backends: dict[_BackendId, Any] = {}
|
||||
self._inlined_sources: set[InlinedSource] = set()
|
||||
self._resume_codes: set[types.CodeType] = set()
|
||||
self._initialized = False
|
||||
if fn is not None:
|
||||
self.initialize(fn, dynamo, ignore_inlined_sources)
|
||||
self.uninstall()
|
||||
self.validate()
|
||||
|
||||
def is_initialized(self) -> bool:
|
||||
return self._initialized
|
||||
self._initialize(fn, dynamo, ignore_inlined_sources)
|
||||
self.uninstall()
|
||||
self.validate()
|
||||
|
||||
def initialize(
|
||||
def _initialize(
|
||||
self,
|
||||
fn: Any,
|
||||
dynamo: Optional[_DynamoCacheEntry] = None,
|
||||
@ -216,7 +210,6 @@ class CompilePackage:
|
||||
) -> None:
|
||||
from .eval_frame import innermost_fn
|
||||
|
||||
assert not self._initialized
|
||||
self._inlined_sources = set()
|
||||
self._innermost_fn = innermost_fn(fn)
|
||||
assert self._innermost_fn is not None
|
||||
@ -249,7 +242,6 @@ class CompilePackage:
|
||||
self._add_function(
|
||||
self._innermost_fn.__code__, self._innermost_fn.__module__
|
||||
)
|
||||
self._initialized = True
|
||||
|
||||
def _add_function(
|
||||
self,
|
||||
@ -281,7 +273,10 @@ class CompilePackage:
|
||||
@functools.cached_property
|
||||
def source_id(self) -> str:
|
||||
assert self._innermost_fn is not None
|
||||
return CompilePackage.source_id_from_fn(self._innermost_fn)
|
||||
sha256_hash = hashlib.sha256()
|
||||
sha256_hash.update(self._innermost_fn.__qualname__.encode())
|
||||
sha256_hash.update(str(self._innermost_fn.__code__.co_firstlineno).encode())
|
||||
return sha256_hash.hexdigest()
|
||||
|
||||
@contextlib.contextmanager
|
||||
def code_context(self, code: types.CodeType) -> Generator[None, None, None]:
|
||||
@ -440,17 +435,6 @@ class CompilePackage:
|
||||
codes=list(self._codes.values()), inlined_sources=self._inlined_sources
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def source_id_from_fn(fn: Callable[..., Any]) -> str:
|
||||
from .eval_frame import innermost_fn
|
||||
|
||||
innermost_fn_ = innermost_fn(fn)
|
||||
|
||||
sha256_hash = hashlib.sha256()
|
||||
sha256_hash.update(innermost_fn_.__qualname__.encode())
|
||||
sha256_hash.update(str(innermost_fn_.__code__.co_firstlineno).encode())
|
||||
return sha256_hash.hexdigest()
|
||||
|
||||
|
||||
@CacheArtifactFactory.register
|
||||
class EagerCacheArtifact(PrecompileCacheArtifact[Any]):
|
||||
@ -491,9 +475,6 @@ class DynamoStore(abc.ABC):
|
||||
EagerCacheArtifact.type(), key=backend_id, content=pickled_result
|
||||
)
|
||||
|
||||
@abc.abstractmethod
|
||||
def clear(self) -> None: ...
|
||||
|
||||
@abc.abstractmethod
|
||||
def write(
|
||||
self,
|
||||
@ -511,11 +492,12 @@ class DynamoStore(abc.ABC):
|
||||
"""
|
||||
...
|
||||
|
||||
def save_cache_entry(self, cache_entry: _DynamoCacheEntry, key: str) -> None:
|
||||
def save_package(self, package: CompilePackage, key: str) -> None:
|
||||
"""
|
||||
Saves a package to a given path. Grabs backends from PrecompileContext.
|
||||
"""
|
||||
backend_content: _Backends = {}
|
||||
cache_entry = package.cache_entry()
|
||||
for backend_id in cache_entry.backend_ids:
|
||||
serialized_backend = PrecompileContext.serialize_artifact_by_key(backend_id)
|
||||
if serialized_backend is None:
|
||||
@ -527,14 +509,6 @@ class DynamoStore(abc.ABC):
|
||||
|
||||
self.write(cache_entry, backend_content, key)
|
||||
|
||||
def save_package(self, package: CompilePackage, key: str) -> None:
|
||||
"""
|
||||
Saves a package to a given path. Grabs backends from PrecompileContext.
|
||||
"""
|
||||
self.record_package(package)
|
||||
cache_entry = package.cache_entry()
|
||||
self.save_cache_entry(cache_entry, key)
|
||||
|
||||
@abc.abstractmethod
|
||||
def read(self, path: str) -> tuple[_DynamoCacheEntry, _Backends]:
|
||||
"""
|
||||
@ -548,25 +522,17 @@ class DynamoStore(abc.ABC):
|
||||
"""
|
||||
...
|
||||
|
||||
def load_cache_entry(
|
||||
self, key: str
|
||||
) -> tuple[_DynamoCacheEntry, dict[_BackendId, Any]]:
|
||||
cache_entry, backend_content = self.read(key)
|
||||
for backend_id, backend in backend_content.items():
|
||||
PrecompileContext.record_artifact(
|
||||
backend.type(), key=backend.key, content=backend.content
|
||||
)
|
||||
backend_content[backend_id] = backend.after_deserialization()
|
||||
|
||||
return cache_entry, backend_content
|
||||
|
||||
def load_package(
|
||||
self, fn: Any, key: str
|
||||
) -> tuple[CompilePackage, dict[_BackendId, Any]]:
|
||||
"""
|
||||
Loads a package from a given path and returns it plus a list of deserialized backends
|
||||
"""
|
||||
cache_entry, backend_content = self.load_cache_entry(key)
|
||||
cache_entry, backend_content = self.read(key)
|
||||
|
||||
for backend_id, backend in backend_content.items():
|
||||
backend_content[backend_id] = backend.after_deserialization()
|
||||
|
||||
package = CompilePackage(fn, cache_entry)
|
||||
return package, backend_content
|
||||
|
||||
@ -579,9 +545,6 @@ class InMemoryDynamoStore(DynamoStore):
|
||||
def __init__(self) -> None:
|
||||
self.packages: dict[str, tuple[_DynamoCacheEntry, _Backends]] = {}
|
||||
|
||||
def clear(self) -> None:
|
||||
self.packages.clear()
|
||||
|
||||
def write(
|
||||
self,
|
||||
dynamo: _DynamoCacheEntry,
|
||||
@ -617,13 +580,6 @@ class DiskDynamoStore(DynamoStore):
|
||||
"""
|
||||
self.path_prefix = path_prefix
|
||||
|
||||
def clear(self) -> None:
|
||||
"""
|
||||
Clear all CompilePackages from disk.
|
||||
"""
|
||||
if self.path_prefix:
|
||||
shutil.rmtree(self.path_prefix, ignore_errors=True)
|
||||
|
||||
def write(
|
||||
self,
|
||||
dynamo: _DynamoCacheEntry,
|
||||
@ -633,9 +589,7 @@ class DiskDynamoStore(DynamoStore):
|
||||
"""
|
||||
Write dynamo cache entry and backends to disk.
|
||||
"""
|
||||
path = os.path.join(self.path_prefix, path) if self.path_prefix else path
|
||||
try:
|
||||
os.makedirs(path, exist_ok=True)
|
||||
with open(os.path.join(path, "dynamo"), "wb") as dynamo_path:
|
||||
pickle.dump(dynamo, dynamo_path)
|
||||
with open(os.path.join(path, "backends"), "wb") as backend_path:
|
||||
@ -647,7 +601,6 @@ class DiskDynamoStore(DynamoStore):
|
||||
"""
|
||||
Read dynamo cache entry and backends from disk.
|
||||
"""
|
||||
path = os.path.join(self.path_prefix, path) if self.path_prefix else path
|
||||
try:
|
||||
with open(os.path.join(path, "dynamo"), "rb") as dynamo_path:
|
||||
cache_entry = pickle.load(dynamo_path)
|
||||
@ -657,53 +610,18 @@ class DiskDynamoStore(DynamoStore):
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to load package from path {path}: {e}") from e
|
||||
|
||||
|
||||
class DiskDynamoCache(DiskDynamoStore):
|
||||
"""
|
||||
Special DiskDynamoStore which adds some helper functions for automatically
|
||||
tracking paths of packages
|
||||
"""
|
||||
|
||||
def save(self, package: CompilePackage) -> None:
|
||||
def save_package(self, package: CompilePackage, key: str) -> None:
|
||||
"""
|
||||
Saves a package to a given path. Grabs backends from PrecompileContext.
|
||||
Save a package to disk using the path_prefix + key as the file path.
|
||||
"""
|
||||
key = package.source_id
|
||||
logger.info("Saving CompilePackage for %s", package.source_id)
|
||||
super().save_package(package, key)
|
||||
full_path = os.path.join(self.path_prefix, key) if self.path_prefix else key
|
||||
super().save_package(package, full_path)
|
||||
|
||||
def load(
|
||||
self, fn: Callable[..., Any]
|
||||
) -> Optional[tuple[_DynamoCacheEntry, dict[_BackendId, Any]]]:
|
||||
def load_package(
|
||||
self, fn: Any, key: str
|
||||
) -> tuple[CompilePackage, dict[_BackendId, Any]]:
|
||||
"""
|
||||
Loads a package from a given path and returns it plus a list of deserialized backends
|
||||
Load a package from disk using the path_prefix + key as the file path.
|
||||
"""
|
||||
key = CompilePackage.source_id_from_fn(fn)
|
||||
logger.info("Loading CompilePackage for %s", key)
|
||||
path = os.path.join(self.path_prefix, key)
|
||||
if os.path.exists(path):
|
||||
try:
|
||||
return super().load_cache_entry(key)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to load package from path %s: %s", path, str(e))
|
||||
return None
|
||||
logger.info("No package found for %s", key)
|
||||
return None
|
||||
|
||||
def load_and_install_package(
|
||||
self, fn: Callable[..., Any]
|
||||
) -> Optional[CompilePackage]:
|
||||
"""
|
||||
Load directly into a package and install backends
|
||||
"""
|
||||
results = self.load(fn)
|
||||
if results is None:
|
||||
return None
|
||||
else:
|
||||
(entry, backends) = results
|
||||
package = CompilePackage(fn, entry)
|
||||
package.install(backends)
|
||||
return package
|
||||
|
||||
|
||||
DynamoCache = DiskDynamoCache(os.path.join(cache_dir(), "dynamo"))
|
||||
full_path = os.path.join(self.path_prefix, key) if self.path_prefix else key
|
||||
return super().load_package(fn, full_path)
|
||||
|
@ -1,6 +1,5 @@
|
||||
from abc import abstractmethod
|
||||
from collections import defaultdict
|
||||
from itertools import chain
|
||||
from typing import Any, Generic, Optional, TypeVar
|
||||
from typing_extensions import override
|
||||
|
||||
@ -144,34 +143,10 @@ class PrecompileContext(CacheArtifactManager):
|
||||
|
||||
@staticmethod
|
||||
def populate_caches(artifacts: CacheArtifactsResult) -> CacheInfo:
|
||||
PrecompileContext._ensure_cache_artifacts_registered()
|
||||
|
||||
artifacts_by_key = {}
|
||||
cache_info = CacheInfo()
|
||||
for artifact in chain(*artifacts.values()):
|
||||
cache_info.add(artifact)
|
||||
artifacts_by_key[artifact.key] = artifact
|
||||
|
||||
from torch._dynamo.package import _BackendId, DynamoCache
|
||||
|
||||
for dynamo_entry in artifacts["precompile_dynamo"]:
|
||||
assert isinstance(dynamo_entry, PrecompileCacheArtifact)
|
||||
cache_entry = dynamo_entry.after_deserialization()
|
||||
# Grab backends from the dynamo cache entry
|
||||
backends = cache_entry.backend_ids
|
||||
backend_content: dict[_BackendId, PrecompileCacheArtifact[Any]] = {}
|
||||
for id_ in backends:
|
||||
assert id_ in artifacts_by_key, f"Backend {id_} not found in artifacts"
|
||||
artifact = artifacts_by_key[id_]
|
||||
assert isinstance(artifact, PrecompileCacheArtifact)
|
||||
backend_content[id_] = artifact
|
||||
DynamoCache.write(cache_entry, backend_content, dynamo_entry.key)
|
||||
|
||||
return cache_info
|
||||
raise NotImplementedError("TODO")
|
||||
|
||||
@classmethod
|
||||
def _ensure_cache_artifacts_registered(cls) -> None:
|
||||
from torch._dynamo.package import _DynamoCacheArtifact # noqa: F401
|
||||
from torch._functorch._aot_autograd.autograd_cache import ( # noqa: F401
|
||||
BundledAOTAutogradCacheArtifact,
|
||||
)
|
||||
|
@ -121,10 +121,6 @@ def should_use_local_autograd_cache():
|
||||
return config.enable_autograd_cache
|
||||
|
||||
|
||||
def should_bundle_autograd_cache():
|
||||
return config.bundled_autograd_cache or torch._dynamo.config.caching_precompile
|
||||
|
||||
|
||||
def check_node_safe(node: Node):
|
||||
"""
|
||||
Checks that the node only uses supported operators. We are starting with very
|
||||
@ -1100,10 +1096,7 @@ class AOTAutogradCache(GuardedCache[GenericAOTAutogradCacheEntry]):
|
||||
except FXGraphCacheMiss as e:
|
||||
counters["aot_autograd"]["autograd_cache_miss"] += 1
|
||||
cache_state = "miss"
|
||||
if (
|
||||
config.strict_autograd_cache
|
||||
or torch._dynamo.config.caching_precompile
|
||||
):
|
||||
if config.strict_autograd_cache:
|
||||
raise e
|
||||
# Most often this is BypassAOTAutogradCache, but
|
||||
# if there's ever different reason we can't cache,
|
||||
@ -1133,10 +1126,7 @@ class AOTAutogradCache(GuardedCache[GenericAOTAutogradCacheEntry]):
|
||||
)
|
||||
if remote:
|
||||
log_cache_bypass("bypass_aot_autograd", str(e))
|
||||
if (
|
||||
config.strict_autograd_cache
|
||||
or torch._dynamo.config.caching_precompile
|
||||
):
|
||||
if config.strict_autograd_cache:
|
||||
raise e
|
||||
if compiled_fn is None:
|
||||
# Set the cache key so we can save a cache result later
|
||||
@ -1250,7 +1240,7 @@ class AOTAutogradCache(GuardedCache[GenericAOTAutogradCacheEntry]):
|
||||
AOTAutogradCacheArtifact.type(), key, pickled_content
|
||||
)
|
||||
if (
|
||||
should_bundle_autograd_cache()
|
||||
config.bundled_autograd_cache
|
||||
and aot_config is not None
|
||||
and aot_config.precompile_backend_id is not None
|
||||
):
|
||||
@ -1292,7 +1282,7 @@ class AOTAutogradCache(GuardedCache[GenericAOTAutogradCacheEntry]):
|
||||
AOTAutogradCacheArtifact.type(), key, content
|
||||
)
|
||||
if (
|
||||
should_bundle_autograd_cache()
|
||||
config.bundled_autograd_cache
|
||||
and entry.sanitized_aot_config.precompile_backend_id is not None
|
||||
):
|
||||
precompile_key = entry.sanitized_aot_config.precompile_backend_id
|
||||
@ -1368,7 +1358,7 @@ class AOTAutogradCache(GuardedCache[GenericAOTAutogradCacheEntry]):
|
||||
num_symints_saved_for_bw: Optional[int],
|
||||
serialized_bw_module: Optional[SerializedGraphModule],
|
||||
) -> GenericAOTAutogradCacheEntry:
|
||||
if should_bundle_autograd_cache():
|
||||
if config.bundled_autograd_cache:
|
||||
# Helper function to unwrap all the wrappers we added during aotdispatch
|
||||
# They get reapplied on cache load
|
||||
def unwrap_compiled_fx_graph(obj):
|
||||
|
@ -43,7 +43,6 @@ from .. import config
|
||||
from .autograd_cache import (
|
||||
AOTAutogradCache,
|
||||
serialize_graph_module,
|
||||
should_bundle_autograd_cache,
|
||||
should_use_remote_autograd_cache,
|
||||
)
|
||||
from .dispatch_and_compile_graph import (
|
||||
@ -263,7 +262,7 @@ def aot_dispatch_base(
|
||||
cache_info = aot_config.cache_info
|
||||
|
||||
def should_save_cache():
|
||||
if should_bundle_autograd_cache():
|
||||
if torch._functorch.config.bundled_autograd_cache:
|
||||
return True
|
||||
else:
|
||||
return hasattr(compiled_fw, "_fx_graph_cache_key")
|
||||
@ -1782,7 +1781,7 @@ def aot_dispatch_autograd(
|
||||
cache_info = aot_config.cache_info
|
||||
|
||||
def should_save_cache():
|
||||
if should_bundle_autograd_cache():
|
||||
if torch._functorch.config.bundled_autograd_cache:
|
||||
return True
|
||||
else:
|
||||
return hasattr(compiled_fw_func, "_fx_graph_cache_key") and hasattr(
|
||||
|
@ -135,10 +135,6 @@ class CacheInfo:
|
||||
def precompile_aot_autograd_artifacts(self) -> list[str]: # type: ignore[empty-body]
|
||||
...
|
||||
|
||||
@property
|
||||
def precompile_dynamo_artifacts(self) -> list[str]: # type: ignore[empty-body]
|
||||
...
|
||||
|
||||
def add(self, artifact: CacheArtifact) -> None:
|
||||
self.artifacts[artifact.type()].append(artifact.key)
|
||||
|
||||
|
Reference in New Issue
Block a user