mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Simplify PrecompileContext to no longer be a CacheArtifactManager (#162886)
Summary: This diff does a big refactor of PrecompileContext to make it considerably simpler: instead of being a CacheArtifactManager and managing a bunch of bytes, it simply stores two things: dynamo cache entries and backend cache entries. When asked, it stitches them together into PrecompileCacheEntries, which are stored by DynamoCache. This structure then allows us to register DynamoCache to the regular Megacache API, instead of having two separate APIs that are confusing. It also lets us remove the autotune cache integration, since MegaCache API will automatically store autotune cache entries. The intent here is that users who want to use caching precompile will simply be able to use torch.compiler.save_cache_artifacts as before, just with `torch.dynamo.config.caching_precompile` set to True. They can also directly interact with PrecompileContext if they wish to specifically only load Precompile entries, using PrecompileContext.create_cache_entries(). Saving single entries and such with DynamoCache still works normally. Test Plan: All existing unit tests pass. Rollback Plan: Differential Revision: D82380307 Pull Request resolved: https://github.com/pytorch/pytorch/pull/162886 Approved by: https://github.com/zhxchen17
This commit is contained in:
committed by
PyTorch MergeBot
parent
8225a26835
commit
bfe9e60ffb
@ -3580,18 +3580,10 @@ def process_caching_precompile():
|
||||
)
|
||||
from torch._dynamo.precompile_context import PrecompileContext
|
||||
|
||||
# Serialize all callables, clear PrecompileContext
|
||||
# TODO: put this under torch.compiler API once ready
|
||||
serialized = PrecompileContext.serialize()
|
||||
PrecompileContext.clear()
|
||||
if serialized is not None:
|
||||
artifacts, info = serialized
|
||||
print(
|
||||
f"Saving {len(info.precompile_dynamo_artifacts)} Precompile Artifact(s)..."
|
||||
)
|
||||
results = PrecompileContext.deserialize(artifacts)
|
||||
assert results is not None
|
||||
PrecompileContext.populate_caches(results)
|
||||
debug_info = PrecompileContext.save_to_dynamo_cache()
|
||||
print(
|
||||
f"Saved {len(debug_info['dynamo'])} precompile artifacts with {len(debug_info['backends'])} backends"
|
||||
)
|
||||
|
||||
|
||||
def process_entry(rank, runner, original_dir, args):
|
||||
|
@ -16,13 +16,10 @@ from torch._dynamo.package import CompilePackage, DiskDynamoStore, DynamoCache
|
||||
from torch._dynamo.precompile_context import PrecompileContext
|
||||
from torch._dynamo.testing import reduce_to_scalar_loss
|
||||
from torch._functorch import config as functorch_config
|
||||
from torch._inductor.mock_cache import global_stats, PatchCaches, Stats
|
||||
from torch._inductor.runtime.runtime_utils import cache_dir
|
||||
from torch.testing._internal.common_utils import (
|
||||
instantiate_parametrized_tests,
|
||||
parametrize,
|
||||
skipIfRocm,
|
||||
skipIfXpu,
|
||||
)
|
||||
from torch.testing._internal.inductor_utils import (
|
||||
HAS_CUDA_AND_TRITON,
|
||||
@ -50,9 +47,7 @@ class TestPackage(torch._inductor.test_case.TestCase):
|
||||
DynamoCache.clear()
|
||||
PrecompileContext.clear()
|
||||
|
||||
def _save_and_reload(
|
||||
self, expected_backends, expected_dynamo, expected_autotune=None
|
||||
):
|
||||
def _save_and_reload(self, expected_backends, expected_dynamo):
|
||||
"""
|
||||
Serializes all artifacts, clears all caches, then reloads the serialized artifact
|
||||
Simulates a new process.
|
||||
@ -61,24 +56,12 @@ class TestPackage(torch._inductor.test_case.TestCase):
|
||||
expected_backends: Expected number of precompile_aot_autograd_artifacts
|
||||
expected_dynamo: Expected number of precompile_dynamo_artifacts
|
||||
"""
|
||||
serialized = PrecompileContext.serialize()
|
||||
assert serialized is not None
|
||||
(bytes_, cache_info) = serialized
|
||||
self.assertEqual(
|
||||
len(cache_info.precompile_aot_autograd_artifacts), expected_backends
|
||||
)
|
||||
self.assertEqual(len(cache_info.precompile_dynamo_artifacts), expected_dynamo)
|
||||
if expected_autotune is not None:
|
||||
self.assertEqual(len(cache_info.autotune_artifacts), expected_autotune)
|
||||
|
||||
debug_info = PrecompileContext.save_to_dynamo_cache()
|
||||
self.assertEqual(len(debug_info["dynamo"]), expected_dynamo)
|
||||
self.assertEqual(len(debug_info["backends"]), expected_backends)
|
||||
torch._dynamo.reset()
|
||||
DynamoCache.clear()
|
||||
PrecompileContext.clear()
|
||||
|
||||
deserialized = PrecompileContext.deserialize(bytes_)
|
||||
assert deserialized is not None
|
||||
PrecompileContext.populate_caches(deserialized)
|
||||
|
||||
@unittest.expectedFailure # FUNCTION_MATCH guard not serializable today
|
||||
def test_nn_module(self):
|
||||
class MyModule(torch.nn.Module):
|
||||
@ -440,41 +423,6 @@ def add(x, y):
|
||||
self.assertEqual(expected, [result1, result2])
|
||||
self.assertEqual(torch._dynamo.convert_frame.FRAME_COUNTER, total_frames)
|
||||
|
||||
@parametrize("device", ("cuda", "xpu"))
|
||||
@torch._dynamo.config.patch(caching_precompile=True)
|
||||
@skipIfXpu
|
||||
@skipIfRocm
|
||||
def test_automatic_dynamo_autotune_cache(self, device):
|
||||
if device == "cuda" and not HAS_CUDA_AND_TRITON:
|
||||
raise unittest.SkipTest("Requires CUDA/Triton")
|
||||
if device == "xpu" and not HAS_XPU_AND_TRITON:
|
||||
raise unittest.SkipTest("Requires XPU/Triton")
|
||||
|
||||
def fn(x, y):
|
||||
return x.sin() + y
|
||||
|
||||
arg1 = torch.randn(3, 3, device=device)
|
||||
arg2 = torch.randn(3, 3, device=device)
|
||||
expected = fn(arg1, arg2).clone()
|
||||
|
||||
with PatchCaches():
|
||||
compiled_fn1 = torch.compile(fn, mode="max-autotune")
|
||||
result = compiled_fn1(arg1, arg2).clone()
|
||||
self.assertEqual(expected, result)
|
||||
self.assertEqual(global_stats.autotune_local, Stats(1, 0, 1))
|
||||
DynamoCache.clear()
|
||||
|
||||
total_frames = torch._dynamo.convert_frame.FRAME_COUNTER
|
||||
self._save_and_reload(
|
||||
expected_backends=1, expected_dynamo=1, expected_autotune=1
|
||||
)
|
||||
compiled_fn1 = torch.compile(fn, mode="max-autotune")
|
||||
with torch.compiler.set_stance("fail_on_recompile"):
|
||||
result1 = compiled_fn1(arg1, arg2).clone()
|
||||
self.assertEqual(expected, result1)
|
||||
self.assertEqual(torch._dynamo.convert_frame.FRAME_COUNTER, total_frames)
|
||||
self.assertEqual(global_stats.autotune_local, Stats(2, 1, 1))
|
||||
|
||||
@parametrize("device", ("cpu", "cuda", "xpu"))
|
||||
@torch._dynamo.config.patch(caching_precompile=True)
|
||||
def test_automatic_dynamo_recompiles(self, device):
|
||||
|
@ -1,16 +1,9 @@
|
||||
# Owner(s): ["module: dynamo"]
|
||||
|
||||
import pickle
|
||||
|
||||
import torch
|
||||
import torch._dynamo
|
||||
import torch._dynamo.test_case
|
||||
import torch._functorch
|
||||
from torch._dynamo.precompile_context import (
|
||||
EditablePrecompileCacheArtifact,
|
||||
PrecompileCacheArtifact,
|
||||
PrecompileContext,
|
||||
)
|
||||
from torch._dynamo.precompile_context import BackendCacheArtifact, PrecompileContext
|
||||
from torch._functorch import config as functorch_config
|
||||
from torch._functorch._aot_autograd.autograd_cache import (
|
||||
BundledAOTAutogradCacheArtifact,
|
||||
@ -49,31 +42,11 @@ class PrecompileContextTests(InductorTestCase):
|
||||
result.sum().backward()
|
||||
self.assertEqual(len(PrecompileContext._dynamo_cache_entries), 1)
|
||||
self.assertEqual(len(PrecompileContext._backend_artifacts_by_key), 1)
|
||||
self.assertEqual(len(PrecompileContext._new_cache_artifacts), 0)
|
||||
|
||||
result = PrecompileContext.serialize()
|
||||
assert result is not None
|
||||
serialized, cache_info = result
|
||||
self.assertEqual(len(cache_info.precompile_aot_autograd_artifacts), 1)
|
||||
|
||||
artifacts = PrecompileContext.deserialize(serialized)
|
||||
assert artifacts is not None
|
||||
deserialized = artifacts["precompile_aot_autograd"]
|
||||
assert len(deserialized) == 1
|
||||
entry = deserialized[0]
|
||||
assert isinstance(entry, BundledAOTAutogradCacheArtifact)
|
||||
entry = entry.after_deserialization()
|
||||
# Now that we've serialized, there should be no new cache artifacts
|
||||
self.assertEqual(
|
||||
len(PrecompileContext._new_cache_artifacts["precompile_aot_autograd"]), 0
|
||||
)
|
||||
cache_entries, _ = PrecompileContext.create_cache_entries()
|
||||
self.assertEqual(len(cache_entries), 1)
|
||||
|
||||
@requires_triton()
|
||||
def test_serialize_by_key(self):
|
||||
"""
|
||||
Test that after torch.compile, PrecompileContext._new_cache_artifacts length is 1
|
||||
"""
|
||||
|
||||
def simple_function(x):
|
||||
return x.sin() + x.cos()
|
||||
|
||||
@ -87,14 +60,12 @@ class PrecompileContextTests(InductorTestCase):
|
||||
self.assertEqual(len(PrecompileContext._backend_artifacts_by_key), 1)
|
||||
for key in PrecompileContext._backend_artifacts_by_key.keys():
|
||||
result = PrecompileContext.serialize_artifact_by_key(key)
|
||||
assert isinstance(result, PrecompileCacheArtifact)
|
||||
assert isinstance(result, BackendCacheArtifact)
|
||||
self.assertEqual(result.key, key)
|
||||
|
||||
self.assertEqual(len(PrecompileContext._new_cache_artifacts), 0)
|
||||
result = PrecompileContext.serialize()
|
||||
assert result is not None
|
||||
_, cache_info = result
|
||||
self.assertEqual(len(cache_info.precompile_aot_autograd_artifacts), 1)
|
||||
# This should still work
|
||||
result, _ = PrecompileContext.create_cache_entries()
|
||||
assert len(result) == 1
|
||||
|
||||
@requires_triton()
|
||||
def test_editable(self):
|
||||
@ -114,11 +85,7 @@ class PrecompileContextTests(InductorTestCase):
|
||||
self.assertEqual(len(PrecompileContext._dynamo_cache_entries), 1)
|
||||
self.assertEqual(len(PrecompileContext._backend_artifacts_by_key), 1)
|
||||
# Find the key for the artifact of type "precompile_aot_autograd"
|
||||
key = next(
|
||||
k
|
||||
for k, v in PrecompileContext._backend_artifacts_by_key.items()
|
||||
if isinstance(v, EditablePrecompileCacheArtifact)
|
||||
)
|
||||
key = next(iter(PrecompileContext._backend_artifacts_by_key))
|
||||
|
||||
def edit_fn(x):
|
||||
x._my_private_field = 42
|
||||
@ -130,24 +97,12 @@ class PrecompileContextTests(InductorTestCase):
|
||||
assert isinstance(result, BundledAOTAutogradCacheArtifact)
|
||||
self.assertEqual(result.key, key)
|
||||
|
||||
self.assertEqual(len(PrecompileContext._new_cache_artifacts), 0)
|
||||
result = PrecompileContext.serialize()
|
||||
assert result is not None
|
||||
artifacts, cache_info = result
|
||||
self.assertEqual(len(cache_info.precompile_aot_autograd_artifacts), 1)
|
||||
|
||||
deserialized = PrecompileContext.deserialize(artifacts)
|
||||
assert deserialized is not None
|
||||
aot_autograd_artifacts = deserialized["precompile_aot_autograd"]
|
||||
result, _ = PrecompileContext.create_cache_entries()
|
||||
assert len(result) == 1
|
||||
aot_autograd_artifacts = next(iter(result.values())).backends
|
||||
assert len(aot_autograd_artifacts) == 1
|
||||
entry = aot_autograd_artifacts[0]
|
||||
assert isinstance(entry, BundledAOTAutogradCacheArtifact)
|
||||
raw_entry = pickle.loads(entry.content)
|
||||
self.assertEqual(raw_entry._my_private_field, 42)
|
||||
# Now that we've serialized, there should be no new cache artifacts
|
||||
self.assertEqual(
|
||||
len(PrecompileContext._new_cache_artifacts["precompile_aot_autograd"]), 0
|
||||
)
|
||||
entry = next(iter(aot_autograd_artifacts.values())).content
|
||||
self.assertEqual(entry._my_private_field, 42)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -182,7 +182,9 @@ class BundledAOTAutogradSerializableCallable(SerializableCallable):
|
||||
deserialize_bundled_cache_entry,
|
||||
)
|
||||
|
||||
compiled_fn = deserialize_bundled_cache_entry(data)
|
||||
entry = pickle.loads(data)
|
||||
|
||||
compiled_fn = deserialize_bundled_cache_entry(entry)
|
||||
return cls(compiled_fn)
|
||||
|
||||
def __call__(self, *args: Any, **kwargs: Any) -> Any:
|
||||
|
@ -747,12 +747,11 @@ class _TorchDynamoContext:
|
||||
# Create a fresh CompilePackage
|
||||
self._package.initialize(fn, None, ignore_inlined_sources=False)
|
||||
else:
|
||||
cache_entry, backends = result
|
||||
try:
|
||||
self._package.initialize(
|
||||
fn, cache_entry, ignore_inlined_sources=False
|
||||
fn, result.dynamo, ignore_inlined_sources=False
|
||||
)
|
||||
self._package.install(backends)
|
||||
self._package.install(result.backends)
|
||||
except RuntimeError as e:
|
||||
log.warning("Failed to load entry from dynamo cache: %s", e)
|
||||
self._package.initialize(fn, None, ignore_inlined_sources=False)
|
||||
|
@ -316,6 +316,7 @@ class SystemInfo:
|
||||
def current(cls) -> "SystemInfo":
|
||||
"""Create a SystemInfo instance with current system information."""
|
||||
# Get GPU name if CUDA or XPU is available
|
||||
gpu_name = None
|
||||
from torch.utils._triton import get_triton_version
|
||||
|
||||
gpu_name, toolkit_version = None, None
|
||||
@ -406,6 +407,19 @@ class _DynamoCacheEntry:
|
||||
}
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class PrecompileCacheEntry:
|
||||
"""
|
||||
A full cache entry for caching precompile, for a toplevel torch.compile.
|
||||
Consists of a _DynamoCacheEntry, which contains all the dynamo related contents,
|
||||
and a set of backends content. In general, the backend content here will always
|
||||
be of type precompile_context.BackendCacheArtifact
|
||||
"""
|
||||
|
||||
dynamo: _DynamoCacheEntry
|
||||
backends: dict[_BackendId, Any]
|
||||
|
||||
|
||||
def _hash_source(source: str) -> str:
|
||||
sha256_hash = hashlib.sha256()
|
||||
sha256_hash.update(source.encode())
|
||||
@ -817,10 +831,8 @@ class DynamoStore(abc.ABC):
|
||||
PrecompileContext,
|
||||
)
|
||||
|
||||
pickled_result = pickle.dumps(backend)
|
||||
PrecompileContext.record_artifact(
|
||||
EagerCacheArtifact.type(), key=backend_id, content=pickled_result
|
||||
)
|
||||
result = EagerCacheArtifact(key=backend_id, content=backend)
|
||||
PrecompileContext.record_artifact(result)
|
||||
|
||||
@abc.abstractmethod
|
||||
def clear(self) -> None: ...
|
||||
@ -828,8 +840,7 @@ class DynamoStore(abc.ABC):
|
||||
@abc.abstractmethod
|
||||
def write(
|
||||
self,
|
||||
dynamo: _DynamoCacheEntry,
|
||||
backends: _Backends,
|
||||
cache_entry: PrecompileCacheEntry,
|
||||
path: str,
|
||||
) -> None:
|
||||
"""
|
||||
@ -847,7 +858,7 @@ class DynamoStore(abc.ABC):
|
||||
Saves a package to a given path. Grabs backends from PrecompileContext.
|
||||
"""
|
||||
from torch._dynamo.precompile_context import (
|
||||
PrecompileCacheArtifact,
|
||||
BackendCacheArtifact,
|
||||
PrecompileContext,
|
||||
)
|
||||
|
||||
@ -858,10 +869,12 @@ class DynamoStore(abc.ABC):
|
||||
raise RuntimeError(
|
||||
f"Backend {backend_id} is not found in the given backends"
|
||||
)
|
||||
assert isinstance(serialized_backend, PrecompileCacheArtifact)
|
||||
assert isinstance(serialized_backend, BackendCacheArtifact)
|
||||
backend_content[backend_id] = serialized_backend
|
||||
|
||||
self.write(cache_entry, backend_content, key)
|
||||
entry = PrecompileCacheEntry(cache_entry, backend_content)
|
||||
|
||||
self.write(entry, key)
|
||||
|
||||
def save_package(self, package: CompilePackage, key: str) -> None:
|
||||
"""
|
||||
@ -872,7 +885,7 @@ class DynamoStore(abc.ABC):
|
||||
self.save_cache_entry(cache_entry, key)
|
||||
|
||||
@abc.abstractmethod
|
||||
def read(self, path: str) -> tuple[_DynamoCacheEntry, _Backends]:
|
||||
def read(self, path: str) -> PrecompileCacheEntry:
|
||||
"""
|
||||
Abstract method to read dynamo cache entry and backends from storage.
|
||||
|
||||
@ -884,19 +897,18 @@ class DynamoStore(abc.ABC):
|
||||
"""
|
||||
...
|
||||
|
||||
def load_cache_entry(
|
||||
self, key: str
|
||||
) -> tuple[_DynamoCacheEntry, dict[_BackendId, Any]]:
|
||||
from torch._dynamo.precompile_context import PrecompileContext
|
||||
def load_cache_entry(self, key: str) -> PrecompileCacheEntry:
|
||||
from torch._dynamo.precompile_context import (
|
||||
BackendCacheArtifact,
|
||||
PrecompileContext,
|
||||
)
|
||||
|
||||
cache_entry, backend_content = self.read(key)
|
||||
for backend_id, backend in backend_content.items():
|
||||
PrecompileContext.record_artifact(
|
||||
backend.type(), key=backend.key, content=backend.content
|
||||
)
|
||||
backend_content[backend_id] = backend
|
||||
precompile_entry = self.read(key)
|
||||
for backend in precompile_entry.backends.values():
|
||||
assert isinstance(backend, BackendCacheArtifact)
|
||||
PrecompileContext.record_artifact(backend)
|
||||
|
||||
return cache_entry, backend_content
|
||||
return precompile_entry
|
||||
|
||||
def load_package(
|
||||
self, fn: Any, key: str
|
||||
@ -904,9 +916,9 @@ class DynamoStore(abc.ABC):
|
||||
"""
|
||||
Loads a package from a given path and returns it plus a list of deserialized backends
|
||||
"""
|
||||
cache_entry, backend_content = self.load_cache_entry(key)
|
||||
package = CompilePackage(fn, cache_entry)
|
||||
return package, backend_content
|
||||
entry = self.load_cache_entry(key)
|
||||
package = CompilePackage(fn, entry.dynamo)
|
||||
return package, entry.backends
|
||||
|
||||
|
||||
class InMemoryDynamoStore(DynamoStore):
|
||||
@ -915,23 +927,22 @@ class InMemoryDynamoStore(DynamoStore):
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.packages: dict[str, tuple[_DynamoCacheEntry, _Backends]] = {}
|
||||
self.packages: dict[str, PrecompileCacheEntry] = {}
|
||||
|
||||
def clear(self) -> None:
|
||||
self.packages.clear()
|
||||
|
||||
def write(
|
||||
self,
|
||||
dynamo: _DynamoCacheEntry,
|
||||
backends: _Backends,
|
||||
entry: PrecompileCacheEntry,
|
||||
path: str,
|
||||
) -> None:
|
||||
"""
|
||||
Store the dynamo cache entry and backends in memory instead of writing to disk.
|
||||
"""
|
||||
self.packages[path] = (dynamo, backends)
|
||||
self.packages[path] = entry
|
||||
|
||||
def read(self, path: str) -> tuple[_DynamoCacheEntry, _Backends]:
|
||||
def read(self, path: str) -> PrecompileCacheEntry:
|
||||
"""
|
||||
Read dynamo cache entry and backends from memory.
|
||||
"""
|
||||
@ -964,34 +975,32 @@ class DiskDynamoStore(DynamoStore):
|
||||
|
||||
def write(
|
||||
self,
|
||||
dynamo: _DynamoCacheEntry,
|
||||
backends: _Backends,
|
||||
entry: PrecompileCacheEntry,
|
||||
path: str,
|
||||
) -> None:
|
||||
"""
|
||||
Write dynamo cache entry and backends to disk.
|
||||
"""
|
||||
from torch._inductor.codecache import write_atomic
|
||||
|
||||
path = os.path.join(self.path_prefix, path) if self.path_prefix else path
|
||||
try:
|
||||
os.makedirs(path, exist_ok=True)
|
||||
with open(os.path.join(path, "dynamo"), "wb") as dynamo_path:
|
||||
pickle.dump(dynamo, dynamo_path)
|
||||
with open(os.path.join(path, "backends"), "wb") as backend_path:
|
||||
pickle.dump(backends, backend_path)
|
||||
pickled_content: bytes = pickle.dumps(entry)
|
||||
write_atomic(os.path.join(path, "entry"), pickled_content)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to save package to {path}: {e}") from e
|
||||
|
||||
def read(self, path: str) -> tuple[_DynamoCacheEntry, _Backends]:
|
||||
def read(self, path: str) -> PrecompileCacheEntry:
|
||||
"""
|
||||
Read dynamo cache entry and backends from disk.
|
||||
"""
|
||||
path = os.path.join(self.path_prefix, path) if self.path_prefix else path
|
||||
try:
|
||||
with open(os.path.join(path, "dynamo"), "rb") as dynamo_path:
|
||||
cache_entry = pickle.load(dynamo_path)
|
||||
with open(os.path.join(path, "backends"), "rb") as backend_path:
|
||||
backend_content = pickle.load(backend_path)
|
||||
return cache_entry, backend_content
|
||||
with open(os.path.join(path, "entry"), "rb") as f:
|
||||
pickled_content = f.read()
|
||||
entry = pickle.loads(pickled_content)
|
||||
return entry
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to load package from path {path}: {e}") from e
|
||||
|
||||
@ -1010,9 +1019,7 @@ class DiskDynamoCache(DiskDynamoStore):
|
||||
logger.info("Saving CompilePackage for %s", package.source_id)
|
||||
super().save_package(package, key)
|
||||
|
||||
def load(
|
||||
self, fn: Callable[..., Any]
|
||||
) -> Optional[tuple[_DynamoCacheEntry, dict[_BackendId, Any]]]:
|
||||
def load(self, fn: Callable[..., Any]) -> Optional[PrecompileCacheEntry]:
|
||||
"""
|
||||
Loads a package from a given path and returns it plus a list of deserialized backends
|
||||
"""
|
||||
@ -1039,9 +1046,8 @@ class DiskDynamoCache(DiskDynamoStore):
|
||||
if results is None:
|
||||
return None
|
||||
else:
|
||||
(entry, backends) = results
|
||||
package = CompilePackage(fn, entry)
|
||||
package.install(backends)
|
||||
package = CompilePackage(fn, results.dynamo)
|
||||
package.install(results.backends)
|
||||
return package
|
||||
|
||||
|
||||
|
@ -1,25 +1,18 @@
|
||||
import copy
|
||||
import json
|
||||
import logging
|
||||
import pickle
|
||||
from abc import abstractmethod
|
||||
from collections import defaultdict
|
||||
from itertools import chain
|
||||
from typing import Any, Callable, Generic, Optional, TypeVar, Union
|
||||
from typing_extensions import override
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, Generic, Optional, TypeVar
|
||||
|
||||
import torch
|
||||
from torch._dynamo.package import _DynamoCacheEntry
|
||||
from torch.compiler._cache import (
|
||||
_serialize_single_cache,
|
||||
CacheArtifact,
|
||||
CacheArtifactFactory,
|
||||
CacheArtifactManager,
|
||||
CacheArtifactsResult,
|
||||
CacheInfo,
|
||||
from torch._dynamo.package import (
|
||||
_BackendId,
|
||||
_DynamoCacheEntry,
|
||||
DynamoCache,
|
||||
PrecompileCacheEntry,
|
||||
)
|
||||
from torch.utils._appending_byte_serializer import AppendingByteSerializer
|
||||
from torch.utils._ordered_set import OrderedSet
|
||||
|
||||
|
||||
"""
|
||||
@ -30,14 +23,12 @@ T = TypeVar("T")
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PrecompileCacheArtifact(CacheArtifact, Generic[T]):
|
||||
@dataclass
|
||||
class BackendCacheArtifact(Generic[T]):
|
||||
"""
|
||||
Data for each cache artifact that will be serialized and deserialized by
|
||||
PrecompileContext, rather than CacheArtifactManager.
|
||||
T represents the deserialized type of the artifact, i.e. the return type of after_deserialization
|
||||
|
||||
PrecompileCacheArtifact is a frozen dataclass - you can add new serializable fields and metadata specific to your own artifacts
|
||||
as needed, and use them in after_deserialization.
|
||||
Represents a single serializable backend artifact from a dynamo backend.
|
||||
Each BackendCacheArtifact has a key associated with it along with some
|
||||
serializable content.
|
||||
|
||||
Example implementation:
|
||||
|
||||
@ -51,13 +42,8 @@ class PrecompileCacheArtifact(CacheArtifact, Generic[T]):
|
||||
return result
|
||||
"""
|
||||
|
||||
@override
|
||||
def populate_cache(self) -> None:
|
||||
raise RuntimeError("Precompile cache artifacts do not populate caches")
|
||||
|
||||
@override
|
||||
def precompile_compatible(self) -> bool:
|
||||
return True
|
||||
key: str
|
||||
content: Any
|
||||
|
||||
@abstractmethod
|
||||
def after_deserialization(self) -> T:
|
||||
@ -67,63 +53,23 @@ class PrecompileCacheArtifact(CacheArtifact, Generic[T]):
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
@CacheArtifactFactory.register
|
||||
class EagerCacheArtifact(PrecompileCacheArtifact[Any]):
|
||||
@staticmethod
|
||||
def type() -> str:
|
||||
return "precompile_eager"
|
||||
|
||||
def after_deserialization(self) -> Any:
|
||||
return pickle.loads(self.content)
|
||||
|
||||
|
||||
class EditablePrecompileCacheArtifact(Generic[T]):
|
||||
"""
|
||||
A PrecompileCacheArtifact whose content isn't encoded until we call PrecompileContext.serialize()
|
||||
"""
|
||||
|
||||
def __init__(self, artifact_type: str, content: Any, key: str) -> None:
|
||||
# Deepcopy the content for now, but don't pickle it yet.
|
||||
# This allows us to make changes to self.content before true serialization
|
||||
self.content = copy.deepcopy(content)
|
||||
self.key = key
|
||||
self.artifact_type = artifact_type
|
||||
|
||||
def real_encode(self) -> PrecompileCacheArtifact[T]:
|
||||
"""
|
||||
Actually encode the object
|
||||
"""
|
||||
content = pickle.dumps(self.content)
|
||||
artifact = CacheArtifactFactory.encode_create(
|
||||
self.artifact_type, self.key, content
|
||||
)
|
||||
assert isinstance(artifact, PrecompileCacheArtifact)
|
||||
return artifact
|
||||
|
||||
def edit_contents(self, edit_fn: Callable[..., Any]) -> None:
|
||||
"""
|
||||
Edit the content of an existing artifact
|
||||
Edit the contents of the artifact.
|
||||
"""
|
||||
self.content = edit_fn(self.content)
|
||||
|
||||
|
||||
@CacheArtifactFactory.register
|
||||
class _DynamoCacheArtifact(PrecompileCacheArtifact[_DynamoCacheEntry]):
|
||||
@staticmethod
|
||||
def type() -> str:
|
||||
return "precompile_dynamo"
|
||||
|
||||
def after_deserialization(self) -> _DynamoCacheEntry:
|
||||
result = pickle.loads(self.content)
|
||||
return result
|
||||
class EagerCacheArtifact(BackendCacheArtifact[Any]):
|
||||
def after_deserialization(self) -> Any:
|
||||
return self.content
|
||||
|
||||
|
||||
class BypassDynamoCacheEntry(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class PrecompileContext(CacheArtifactManager):
|
||||
class PrecompileContext:
|
||||
"""
|
||||
PrecompileContext is a special CacheArtifactManager for handling precompilation
|
||||
It uses the same interface as CacheArtifactManager, but handles deserialization differently: instead
|
||||
@ -136,69 +82,32 @@ class PrecompileContext(CacheArtifactManager):
|
||||
|
||||
The following artifact types are supported by PrecompileContext:
|
||||
- BundledAOTAutogradCacheArtifact
|
||||
- AutotuneCacheArtifact (regular autotune results, same as Megacache)
|
||||
|
||||
"""
|
||||
|
||||
# Protected by the compile_lock
|
||||
# _backend_artifacts_by_key organizes results by the key of each artifact.
|
||||
# This allows us to implement serialize_by_key easily.
|
||||
# On call to `serialize()`, all cache artifacts in _backend_artifacts_by_key
|
||||
# are transferred to _new_cache_artifacts before serialization.
|
||||
_backend_artifacts_by_key: dict[
|
||||
str, Union[EditablePrecompileCacheArtifact[object], CacheArtifact]
|
||||
] = {}
|
||||
# Each object here must be serializable
|
||||
_backend_artifacts_by_key: dict[str, BackendCacheArtifact[Any]] = {}
|
||||
|
||||
# On call to `serialize()`, all cache artifacts in _dynamo_cache_entries are converted
|
||||
# into DynamoCacheArtifacts and added to _new_cache_artifacts for serialization
|
||||
_dynamo_cache_entries: dict[str, _DynamoCacheEntry] = {}
|
||||
|
||||
_new_cache_artifacts: CacheArtifactsResult = defaultdict(list)
|
||||
# Keep a separate seen artifacts list to make avoid unnecessary duplicates
|
||||
# This list will not be cleared between serialize() calls
|
||||
_seen_artifacts: OrderedSet[CacheArtifact] = OrderedSet()
|
||||
# When serialize() is called, artifacts are transferred from _cache_artifacts to
|
||||
# internal data structure of the _serializer
|
||||
# This allows us to only pay the cost of serialization if serialize() is called
|
||||
_serializer: AppendingByteSerializer[tuple[str, list[CacheArtifact]]] = (
|
||||
AppendingByteSerializer(serialize_fn=_serialize_single_cache)
|
||||
)
|
||||
_cache_info: CacheInfo = CacheInfo()
|
||||
|
||||
@classmethod
|
||||
def clear(cls) -> None:
|
||||
cls._backend_artifacts_by_key.clear()
|
||||
cls._dynamo_cache_entries.clear()
|
||||
super().clear()
|
||||
|
||||
@override
|
||||
@classmethod
|
||||
def record_artifact(
|
||||
cls,
|
||||
artifact_type: str,
|
||||
key: str,
|
||||
content: Any,
|
||||
editable: bool = False,
|
||||
artifact: BackendCacheArtifact[Any],
|
||||
) -> None:
|
||||
"""
|
||||
Called from each caching operation to record the artifact in this
|
||||
"mega" list
|
||||
Records a backend artifact to be used with dynamo cache entries
|
||||
"""
|
||||
artifact: Union[EditablePrecompileCacheArtifact[object], CacheArtifact]
|
||||
if editable:
|
||||
artifact = EditablePrecompileCacheArtifact(artifact_type, content, key)
|
||||
else:
|
||||
artifact = CacheArtifactFactory.encode_create(artifact_type, key, content)
|
||||
# TODO: although this covers completely same artifacts, it's possible
|
||||
# with AOTAutogradCacheEntries to have multiple artifacts whose keys
|
||||
# (i.e. backend_ids) are different, but whose contents are equal.
|
||||
# In those cases, it would be much better if we only serialize once instead
|
||||
# of N times.
|
||||
if artifact in cls._seen_artifacts:
|
||||
return
|
||||
cls._seen_artifacts.add(artifact)
|
||||
|
||||
cls._backend_artifacts_by_key[key] = artifact
|
||||
cls._backend_artifacts_by_key[artifact.key] = copy.deepcopy(artifact)
|
||||
|
||||
@classmethod
|
||||
def record_dynamo_cache_entry(
|
||||
@ -206,36 +115,6 @@ class PrecompileContext(CacheArtifactManager):
|
||||
) -> None:
|
||||
cls._dynamo_cache_entries[key] = cache_entry
|
||||
|
||||
@classmethod
|
||||
def _save_artifacts_by_type(cls) -> None:
|
||||
"""
|
||||
We normally record artifacts by key, but serialization expects them to be organized
|
||||
by artifact type. This function transfers artifacts from _backend_artifacts_by_key to _new_cache_artifacts
|
||||
"""
|
||||
for key, cache_entry in cls._dynamo_cache_entries.items():
|
||||
backends = cache_entry.backend_ids
|
||||
try:
|
||||
for id_ in backends:
|
||||
if id_ not in cls._backend_artifacts_by_key:
|
||||
logger.warning(
|
||||
"Bypassing %s because backend %s not found in artifacts"
|
||||
)
|
||||
raise BypassDynamoCacheEntry
|
||||
except BypassDynamoCacheEntry:
|
||||
continue
|
||||
pickled_result = pickle.dumps(cache_entry)
|
||||
dynamo_artifact = _DynamoCacheArtifact(key, pickled_result)
|
||||
cls._new_cache_artifacts[_DynamoCacheArtifact.type()].append(
|
||||
dynamo_artifact
|
||||
)
|
||||
|
||||
# Save all the backend artifacts
|
||||
for artifact in cls._backend_artifacts_by_key.values():
|
||||
if isinstance(artifact, EditablePrecompileCacheArtifact):
|
||||
artifact = artifact.real_encode()
|
||||
cls._new_cache_artifacts[artifact.__class__.type()].append(artifact)
|
||||
cls._backend_artifacts_by_key.clear()
|
||||
|
||||
@classmethod
|
||||
def edit_artifact(cls, key: str, edit_fn: Callable[..., Any]) -> None:
|
||||
"""
|
||||
@ -243,53 +122,19 @@ class PrecompileContext(CacheArtifactManager):
|
||||
"""
|
||||
assert key in cls._backend_artifacts_by_key, f"Key {key} not found in artifacts"
|
||||
artifact = cls._backend_artifacts_by_key[key]
|
||||
assert isinstance(artifact, EditablePrecompileCacheArtifact), (
|
||||
"Artifact is not editable"
|
||||
)
|
||||
artifact.edit_contents(edit_fn)
|
||||
|
||||
@classmethod
|
||||
def serialize_artifact_by_key(cls, key: str) -> Optional[CacheArtifact]:
|
||||
def serialize_artifact_by_key(cls, key: str) -> Optional[BackendCacheArtifact[Any]]:
|
||||
"""
|
||||
Serialize all backend artifacts with the given key returned in a list.
|
||||
Return the backend cache artifact with the associated key
|
||||
"""
|
||||
result = cls._backend_artifacts_by_key.get(key, None)
|
||||
if isinstance(result, EditablePrecompileCacheArtifact):
|
||||
result = result.real_encode()
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def serialize(cls) -> Optional[tuple[bytes, CacheInfo]]:
|
||||
if not cls._dynamo_cache_entries:
|
||||
return None
|
||||
|
||||
debug_info = cls.dump_debug_info(
|
||||
cls._dynamo_cache_entries, cls._backend_artifacts_by_key
|
||||
)
|
||||
artifacts = json.dumps({"artifacts": debug_info})
|
||||
torch._logging.trace_structured(
|
||||
"artifact",
|
||||
metadata_fn=lambda: {
|
||||
"name": "dynamo_cache_save_contents",
|
||||
"encoding": "json",
|
||||
},
|
||||
payload_fn=lambda: artifacts,
|
||||
expect_trace_id=False,
|
||||
)
|
||||
cls._save_artifacts_by_type()
|
||||
|
||||
result = super().serialize()
|
||||
assert result is not None
|
||||
data, info = result
|
||||
|
||||
return data, info
|
||||
return cls._backend_artifacts_by_key.get(key, None)
|
||||
|
||||
@staticmethod
|
||||
def dump_debug_info(
|
||||
dynamo_entries: dict[str, _DynamoCacheEntry],
|
||||
backend_artifacts: dict[
|
||||
str, Union[EditablePrecompileCacheArtifact[object], CacheArtifact]
|
||||
],
|
||||
backend_artifacts: dict[str, BackendCacheArtifact[Any]],
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Return a JSON serializable debug dump of all entries in the precompile context
|
||||
@ -300,36 +145,32 @@ class PrecompileContext(CacheArtifactManager):
|
||||
for key, cache_entry in dynamo_entries.items():
|
||||
info = cache_entry.debug_info()
|
||||
info["key"] = key
|
||||
debug_info["precompile_dynamo"].append(info)
|
||||
debug_info["dynamo"].append(info)
|
||||
|
||||
for artifact in backend_artifacts.values():
|
||||
if isinstance(artifact, EditablePrecompileCacheArtifact):
|
||||
debug_info[artifact.artifact_type].append(artifact.key)
|
||||
else:
|
||||
debug_info[artifact.__class__.type()].append(artifact.key)
|
||||
debug_info["backends"].append(artifact.key)
|
||||
|
||||
return debug_info
|
||||
|
||||
@staticmethod
|
||||
def populate_caches(artifacts: CacheArtifactsResult) -> CacheInfo:
|
||||
PrecompileContext._ensure_cache_artifacts_registered()
|
||||
@classmethod
|
||||
def save_to_dynamo_cache(cls) -> dict[str, Any]:
|
||||
precompile_cache_entries, debug_info = cls.create_cache_entries()
|
||||
for key, entry in precompile_cache_entries.items():
|
||||
DynamoCache.write(entry, key)
|
||||
return debug_info
|
||||
|
||||
backend_artifacts: dict[str, Any] = {}
|
||||
dynamo_entries: dict[str, _DynamoCacheEntry] = {}
|
||||
cache_info = CacheInfo()
|
||||
for artifact in chain(*artifacts.values()):
|
||||
if artifact.type() == "autotune":
|
||||
# Populate autotune cache artifacts
|
||||
artifact.populate_cache()
|
||||
elif artifact.type() == "precompile_dynamo":
|
||||
assert isinstance(artifact, _DynamoCacheArtifact)
|
||||
cache_entry: _DynamoCacheEntry = artifact.after_deserialization()
|
||||
dynamo_entries[artifact.key] = cache_entry
|
||||
else:
|
||||
backend_artifacts[artifact.key] = artifact
|
||||
cache_info.add(artifact)
|
||||
@classmethod
|
||||
def create_cache_entries(
|
||||
cls,
|
||||
) -> tuple[dict[str, PrecompileCacheEntry], dict[str, Any]]:
|
||||
"""
|
||||
Grabs all the cache entries in the precompile context and
|
||||
stitches them together into full PrecompileCacheEntries.
|
||||
"""
|
||||
dynamo_entries = cls._dynamo_cache_entries
|
||||
backend_artifacts = cls._backend_artifacts_by_key
|
||||
|
||||
num_artifacts = len(artifacts["precompile_dynamo"])
|
||||
num_artifacts = len(dynamo_entries)
|
||||
|
||||
debug_info = PrecompileContext.dump_debug_info(
|
||||
dynamo_entries, backend_artifacts
|
||||
@ -349,12 +190,13 @@ class PrecompileContext(CacheArtifactManager):
|
||||
payload_fn=lambda: debug_str,
|
||||
expect_trace_id=False,
|
||||
)
|
||||
from torch._dynamo.package import _BackendId, DynamoCache
|
||||
|
||||
precompile_cache_entries = {}
|
||||
|
||||
for key, cache_entry in dynamo_entries.items():
|
||||
try:
|
||||
backends = cache_entry.backend_ids
|
||||
backend_content: dict[_BackendId, PrecompileCacheArtifact[Any]] = {}
|
||||
backend_content: dict[_BackendId, BackendCacheArtifact[Any]] = {}
|
||||
for id_ in backends:
|
||||
if id_ not in backend_artifacts:
|
||||
debug_str = json.dumps(
|
||||
@ -375,11 +217,13 @@ class PrecompileContext(CacheArtifactManager):
|
||||
)
|
||||
continue
|
||||
artifact = backend_artifacts[id_]
|
||||
assert isinstance(artifact, PrecompileCacheArtifact)
|
||||
assert isinstance(artifact, BackendCacheArtifact)
|
||||
backend_content[id_] = artifact
|
||||
DynamoCache.write(cache_entry, backend_content, key)
|
||||
precompile_cache_entries[key] = PrecompileCacheEntry(
|
||||
dynamo=cache_entry, backends=backend_content
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to deserialize cache entry %s: %s", key, str(e))
|
||||
logger.warning("Failed to create cache entry %s: %s", key, str(e))
|
||||
|
||||
error = e
|
||||
data = json.dumps(
|
||||
@ -397,11 +241,4 @@ class PrecompileContext(CacheArtifactManager):
|
||||
payload_fn=lambda: data,
|
||||
)
|
||||
continue
|
||||
|
||||
return cache_info
|
||||
|
||||
@classmethod
|
||||
def _ensure_cache_artifacts_registered(cls) -> None:
|
||||
from torch._functorch._aot_autograd.autograd_cache import ( # noqa: F401
|
||||
BundledAOTAutogradCacheArtifact,
|
||||
)
|
||||
return precompile_cache_entries, debug_info
|
||||
|
@ -22,7 +22,7 @@ from typing import Any, Callable, Generic, Optional, TYPE_CHECKING, TypeVar, Uni
|
||||
from typing_extensions import override
|
||||
|
||||
import torch
|
||||
from torch._dynamo.precompile_context import PrecompileCacheArtifact, PrecompileContext
|
||||
from torch._dynamo.precompile_context import BackendCacheArtifact, PrecompileContext
|
||||
from torch._dynamo.trace_rules import torch_non_c_binding_in_graph_functions
|
||||
from torch._dynamo.utils import (
|
||||
chromium_event_log_active,
|
||||
@ -51,7 +51,7 @@ from torch._inductor.output_code import (
|
||||
OutputCode,
|
||||
)
|
||||
from torch._inductor.runtime.runtime_utils import cache_dir
|
||||
from torch._inductor.utils import should_use_remote_fx_graph_cache
|
||||
from torch._inductor.utils import BoxedBool, should_use_remote_fx_graph_cache
|
||||
from torch._logging import LazyString
|
||||
from torch._utils_internal import log_cache_bypass
|
||||
from torch.compiler._cache import (
|
||||
@ -81,7 +81,6 @@ from .utils import simple_wraps
|
||||
if TYPE_CHECKING:
|
||||
from torch._inductor.compile_fx import _CompileFxKwargs
|
||||
from torch._inductor.remote_cache import JsonDataTy, RemoteCache
|
||||
from torch._inductor.utils import BoxedBool
|
||||
from torch.fx.node import Node
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
@ -1059,15 +1058,14 @@ class AOTAutogradCacheArtifact(CacheArtifact):
|
||||
return "aot_autograd"
|
||||
|
||||
|
||||
def deserialize_bundled_cache_entry(data: bytes) -> Callable:
|
||||
entry = pickle.loads(data)
|
||||
def deserialize_bundled_cache_entry(entry: BundledAOTAutogradCacheEntry) -> Callable:
|
||||
# In the precompile use case, guards are already serialized
|
||||
# by dynamo, so we don't need to add them to the environment
|
||||
entry.guards_expr = None
|
||||
# TODO: this isn't exactly right, because cudagraphs needs to be a shared config
|
||||
# which is set by compile_fx. But in precompile, we never actually call compile_fx
|
||||
# so we don't have a place to track cudagraphs here.
|
||||
cudagraphs = torch._inductor.config.triton.cudagraphs
|
||||
cudagraphs = BoxedBool(torch._inductor.config.triton.cudagraphs)
|
||||
boxed_forward_device_index = BoxedDeviceIndex(None)
|
||||
compiled_fn = entry.wrap_post_compile(
|
||||
[],
|
||||
@ -1090,14 +1088,8 @@ def deserialize_bundled_cache_entry(data: bytes) -> Callable:
|
||||
return forward
|
||||
|
||||
|
||||
@CacheArtifactFactory.register
|
||||
class BundledAOTAutogradCacheArtifact(PrecompileCacheArtifact[Callable]):
|
||||
@override
|
||||
@staticmethod
|
||||
def type():
|
||||
return "precompile_aot_autograd"
|
||||
|
||||
@override
|
||||
@dataclass
|
||||
class BundledAOTAutogradCacheArtifact(BackendCacheArtifact[Callable]):
|
||||
def after_deserialization(self) -> Callable:
|
||||
return deserialize_bundled_cache_entry(self.content)
|
||||
|
||||
@ -1375,9 +1367,9 @@ class AOTAutogradCache(GuardedCache[GenericAOTAutogradCacheEntry]):
|
||||
# 1. because we set it to None on save 2. even if we didn't, this new run
|
||||
# that cache hit has a *new* backend id associated with it.
|
||||
PrecompileContext.record_artifact(
|
||||
BundledAOTAutogradCacheArtifact.type(),
|
||||
aot_config.precompile_backend_id,
|
||||
pickled_content,
|
||||
BundledAOTAutogradCacheArtifact(
|
||||
aot_config.precompile_backend_id, entry
|
||||
),
|
||||
)
|
||||
except Exception as e:
|
||||
log.info("AOTAutograd cache unable to load compiled graph: %s", e)
|
||||
@ -1413,15 +1405,11 @@ class AOTAutogradCache(GuardedCache[GenericAOTAutogradCacheEntry]):
|
||||
and entry.sanitized_aot_config.precompile_backend_id is not None
|
||||
):
|
||||
precompile_key = entry.sanitized_aot_config.precompile_backend_id
|
||||
artifact = BundledAOTAutogradCacheArtifact(precompile_key, entry)
|
||||
# Now that we're saving it, the precompile_backend_id field is no longer
|
||||
# useful, remove it from the entry.
|
||||
entry.sanitized_aot_config.precompile_backend_id = None
|
||||
PrecompileContext.record_artifact(
|
||||
BundledAOTAutogradCacheArtifact.type(),
|
||||
precompile_key,
|
||||
entry,
|
||||
editable=True,
|
||||
)
|
||||
PrecompileContext.record_artifact(artifact)
|
||||
AOTAutogradCache._write_to_local_cache(key, content)
|
||||
counters["aot_autograd"]["autograd_cache_saved"] += 1
|
||||
except BypassAOTAutogradCache as e:
|
||||
|
@ -35,7 +35,6 @@ from typing import Any, Optional, TYPE_CHECKING
|
||||
from typing_extensions import override
|
||||
|
||||
import torch
|
||||
from torch._dynamo.precompile_context import PrecompileContext
|
||||
from torch._inductor.runtime.runtime_utils import cache_dir
|
||||
from torch.compiler._cache import (
|
||||
CacheArtifact,
|
||||
@ -302,10 +301,6 @@ class AutotuneCache:
|
||||
CacheArtifactManager.record_artifact(
|
||||
AutotuneCacheArtifact.type(), autotune_artifact_key, data
|
||||
)
|
||||
if torch._dynamo.config.caching_precompile:
|
||||
PrecompileContext.record_artifact(
|
||||
AutotuneCacheArtifact.type(), autotune_artifact_key, data
|
||||
)
|
||||
|
||||
if log.isEnabledFor(logging.DEBUG):
|
||||
type_str = "coordesc" if found_by_coordesc else "heuristic"
|
||||
@ -631,10 +626,6 @@ class LocalAutotuneCache(RemoteCache[JsonDataTy]):
|
||||
CacheArtifactManager.record_artifact(
|
||||
AutotuneCacheArtifact.type(), autotune_artifact_key, result
|
||||
)
|
||||
if torch._dynamo.config.caching_precompile:
|
||||
PrecompileContext.record_artifact(
|
||||
AutotuneCacheArtifact.type(), autotune_artifact_key, result
|
||||
)
|
||||
return result
|
||||
|
||||
@override
|
||||
|
@ -48,9 +48,6 @@ class CacheArtifact(ABC):
|
||||
def populate_cache(self) -> None:
|
||||
pass
|
||||
|
||||
def precompile_compatible(self) -> bool:
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def type() -> str:
|
||||
"""
|
||||
@ -131,14 +128,6 @@ class CacheInfo:
|
||||
def pgo_artifacts(self) -> list[str]: # type: ignore[empty-body]
|
||||
...
|
||||
|
||||
@property
|
||||
def precompile_aot_autograd_artifacts(self) -> list[str]: # type: ignore[empty-body]
|
||||
...
|
||||
|
||||
@property
|
||||
def precompile_dynamo_artifacts(self) -> list[str]: # type: ignore[empty-body]
|
||||
...
|
||||
|
||||
def add(self, artifact: CacheArtifact) -> None:
|
||||
self.artifacts[artifact.type()].append(artifact.key)
|
||||
|
||||
|
Reference in New Issue
Block a user