Revert "[WIP] Automatically load and save dynamo entries via caching_precompile (#155913)"

This reverts commit e466dab164d9236bfe5817ec8e4d24c7b9d3e392.

Reverted https://github.com/pytorch/pytorch/pull/155913 on behalf of https://github.com/huydhn due to Sorry for reverting your change but it seems to fail a test in trunk ([comment](https://github.com/pytorch/pytorch/pull/155913#issuecomment-3045914878))
This commit is contained in:
PyTorch MergeBot
2025-07-07 16:53:35 +00:00
parent eda0a9cc90
commit ae1094b72b
9 changed files with 39 additions and 414 deletions

View File

@ -12,8 +12,7 @@ import torch._inductor.config
import torch._inductor.test_case
import torch.onnx.operators
import torch.utils.cpp_extension
from torch._dynamo.package import CompilePackage, DiskDynamoStore, DynamoCache
from torch._dynamo.precompile_context import PrecompileContext
from torch._dynamo.package import CompilePackage, DiskDynamoStore
from torch._functorch import config as functorch_config
from torch._inductor.runtime.runtime_utils import cache_dir
from torch.testing._internal.common_utils import (
@ -31,38 +30,6 @@ class TestPackage(torch._inductor.test_case.TestCase):
os.makedirs(path, exist_ok=True)
return path
def setUp(self):
super().setUp()
torch._dynamo.reset()
torch._dynamo.utils.counters.clear()
DynamoCache.clear()
PrecompileContext.clear()
def _save_and_reload(self, expected_backends, expected_dynamo):
"""
Serializes all artifacts, clears all caches, then reloads the serialized artifact
Simulates a new process.
Args:
expected_backends: Expected number of precompile_aot_autograd_artifacts
expected_dynamo: Expected number of precompile_dynamo_artifacts
"""
serialized = PrecompileContext.serialize()
assert serialized is not None
(bytes_, cache_info) = serialized
self.assertEqual(
len(cache_info.precompile_aot_autograd_artifacts), expected_backends
)
self.assertEqual(len(cache_info.precompile_dynamo_artifacts), expected_dynamo)
torch._dynamo.reset()
DynamoCache.clear()
PrecompileContext.clear()
deserialized = PrecompileContext.deserialize(bytes_)
assert deserialized is not None
PrecompileContext.populate_caches(deserialized)
@parametrize("backend", ("eager", "inductor"))
@parametrize("device", ("cpu", "cuda", "xpu"))
def test_basic_fn(self, backend, device):
@ -338,173 +305,6 @@ def add(x, y):
)
ctx.load_package(fn, self.path())
@parametrize("device", ("cpu", "cuda", "xpu"))
def test_dynamo_cache_manual_load(self, device):
if device == "cuda" and not HAS_CUDA:
raise unittest.SkipTest("Requires CUDA/Triton")
if device == "xpu" and not HAS_XPU:
raise unittest.SkipTest("Requires XPU/Triton")
def fn(x):
return x.sin() + x.cos()
def fn2(x):
return x.cos() + x
package1 = CompilePackage(fn)
package2 = CompilePackage(fn2)
compiled_fn1 = torch._dynamo.optimize(backend="inductor", package=package1)(fn)
compiled_fn2 = torch._dynamo.optimize(backend="inductor", package=package2)(fn2)
arg1 = torch.randn(3, 2, device=device)
arg2 = torch.randn(5, 2, device=device)
expected = [compiled_fn1(arg1), compiled_fn2(arg2)]
DynamoCache.save(package1)
DynamoCache.save(package2)
self._save_and_reload(expected_backends=2, expected_dynamo=2)
# These should exist because of populate_caches
package1 = DynamoCache.load_and_install_package(fn)
package2 = DynamoCache.load_and_install_package(fn2)
with torch.compiler.set_stance("fail_on_recompile"):
result1 = compiled_fn1(arg1)
result2 = compiled_fn2(arg2)
self.assertEqual(expected, [result1, result2])
@parametrize("device", ("cpu", "cuda", "xpu"))
@torch._dynamo.config.patch(caching_precompile=True)
def test_automatic_dynamo_serialize(self, device):
if device == "cuda" and not HAS_CUDA:
raise unittest.SkipTest("Requires CUDA/Triton")
if device == "xpu" and not HAS_XPU:
raise unittest.SkipTest("Requires XPU/Triton")
def fn(x):
return x.sin() + x.cos()
def fn2(x):
return x.cos() + x
arg1 = torch.randn(3, 2, device=device)
arg2 = torch.randn(5, 2, device=device)
expected = [fn(arg1), fn2(arg2)]
compiled_fn1 = torch.compile(fn)
compiled_fn2 = torch.compile(fn2)
result = [compiled_fn1(arg1), compiled_fn2(arg2)]
self.assertEqual(expected, result)
DynamoCache.clear()
self._save_and_reload(expected_backends=2, expected_dynamo=2)
compiled_fn1 = torch.compile(fn)
compiled_fn2 = torch.compile(fn2)
with torch.compiler.set_stance("fail_on_recompile"):
result1 = compiled_fn1(arg1)
result2 = compiled_fn2(arg2)
self.assertEqual(expected, [result1, result2])
@parametrize("device", ("cpu", "cuda", "xpu"))
@torch._dynamo.config.patch(caching_precompile=True)
def test_automatic_dynamo_recompiles(self, device):
if device == "cuda" and not HAS_CUDA:
raise unittest.SkipTest("Requires CUDA/Triton")
if device == "xpu" and not HAS_XPU:
raise unittest.SkipTest("Requires XPU/Triton")
def fn(x):
return x.sin() + x.cos()
arg1 = torch.randn(3, 2, device=device)
arg2 = torch.randn(5, 2, device=device)
compiled_fn = torch.compile(fn)
expected1 = compiled_fn(arg1)
# Should cause a recompile
expected2 = compiled_fn(arg2)
self._save_and_reload(expected_backends=2, expected_dynamo=1)
compiled_fn = torch.compile(fn)
with torch.compiler.set_stance("fail_on_recompile"):
result1 = compiled_fn(arg1)
result2 = compiled_fn(arg2)
# Because of automatic dynamic, a third random shape should also not cause a recompile
arg3 = torch.randn(7, 2, device=device)
compiled_fn(arg3)
self.assertEqual(result1, expected1)
self.assertEqual(result2, expected2)
@parametrize("device", ("cpu", "cuda", "xpu"))
@torch._dynamo.config.patch(caching_precompile=True)
def test_automatic_dynamo_graph_breaks(self, device):
if device == "cuda" and not HAS_CUDA:
raise unittest.SkipTest("Requires CUDA/Triton")
if device == "xpu" and not HAS_XPU:
raise unittest.SkipTest("Requires XPU/Triton")
def fn(x, l, r):
if l > r:
return x.sum()
mid = (l + r) // 2
if x.sum() == mid:
return x.sum()
elif x.sum() < mid:
return fn(x, l, mid)
else:
return fn(x, mid + 1, r)
def guard_filter_fn(guards):
return [
guard.guard_type not in ("CLOSURE_MATCH", "FUNCTION_MATCH")
for guard in guards
]
# Saving
compiled_fn = torch._dynamo.optimize(
backend="inductor", guard_filter_fn=guard_filter_fn
)(fn)
N = 10
args_list = [(torch.tensor(x, device=device), 0, N - 1) for x in range(N)]
for args in args_list:
compiled_fn(*args)
self._save_and_reload(expected_backends=8, expected_dynamo=1)
compiled_fn = torch._dynamo.optimize(
backend="inductor", guard_filter_fn=guard_filter_fn
)(fn)
with torch.compiler.set_stance("fail_on_recompile"):
for args in args_list:
self.assertEqual(compiled_fn(*args), args[0].sum())
@parametrize("device", ("cpu", "cuda", "xpu"))
@torch._dynamo.config.patch(caching_precompile=True)
def test_automatic_dynamo_lazy_backward(self, device):
if device == "cuda" and not HAS_CUDA:
raise unittest.SkipTest("Requires CUDA/Triton")
if device == "xpu" and not HAS_XPU:
raise unittest.SkipTest("Requires XPU/Triton")
def fn(x):
return x.sin() + x.cos()
arg1 = torch.randn(3, 2, device=device, requires_grad=True)
arg2 = arg1.clone().detach_().requires_grad_(True)
compiled_fn = torch.compile(fn)
expected1 = compiled_fn(arg1)
expected1.sum().backward()
self._save_and_reload(expected_backends=1, expected_dynamo=1)
compiled_fn = torch.compile(fn)
# Run it again, no recompile needed
with torch.compiler.set_stance("fail_on_recompile"):
expected2 = compiled_fn(arg2)
expected2.sum().backward()
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests

View File

@ -542,10 +542,6 @@ fake_tensor_cache_crosscheck_enabled = (
# the inference_mode is still respected.
fake_tensor_disable_inference_mode = True
# Experimental feature for running automatic caching precompile.
# Enables automatic DynamoCache save/load
caching_precompile = False
# Enables the Compiled Autograd engine to trace autograd calls made under torch.compile().
# Note: AOTAutograd will still trace and partition an AOT backward graph local to that
# compiled region. But AOTAutograd traces without knowledge of backward hooks which are

View File

@ -531,6 +531,7 @@ class ConvertFrameAssert:
skip: int = 0,
) -> ConvertFrameReturn:
increment_frame()
code = frame.f_code
cache_size = compute_cache_size(frame, cache_entry)
@ -648,7 +649,7 @@ class ConvertFrameAssert:
dynamo_tls.traced_frame_infos.append(info)
with compile_context(CompileContext(compile_id)):
result = _compile(
return _compile(
frame.f_code,
frame.f_globals,
frame.f_locals,
@ -669,13 +670,6 @@ class ConvertFrameAssert:
convert_frame_box=self._box,
)
if config.caching_precompile and self._package is not None:
from .package import DynamoCache
# Record that the dynamo package has changed
DynamoCache.record_package(self._package)
return result
def convert_frame_assert(
compiler_fn: CompilerFn,

View File

@ -655,27 +655,6 @@ class _TorchDynamoContext:
def get_compiler_config():
return self.compiler_config
from .package import DynamoCache
# If self._package is lazily initialized, we should check the dynamo cache now
if config.caching_precompile:
assert self._package is not None
if not self._package.is_initialized():
result = DynamoCache.load(fn)
if result is None:
# Create a fresh CompilePackage
self._package.initialize(fn, None, ignore_inlined_sources=False)
else:
cache_entry, backends = result
try:
self._package.initialize(
fn, cache_entry, ignore_inlined_sources=False
)
self._package.install(backends)
except RuntimeError as e:
log.warning("Failed to load entry from dynamo cache: %s", e)
self._package.initialize(fn, None, ignore_inlined_sources=False)
fn = innermost_fn(fn)
# add context containing GraphModule to any GraphModule forward functions
@ -1173,20 +1152,8 @@ def _optimize(
# The backend function is stashed in the callable returned by
# _optimize_catch_errors in the field _torchdynamo_orig_backend. This can
# be used by eval_frame.c to insert a guard on the backend.
# With CachingPrecompile, instantiate an uninitialized CompilePackage
# which gets initialized by _optimize_catch_errors.__call__ once we have a function
if config.caching_precompile and package is None:
from .package import CompilePackage
package = CompilePackage(fn=None, dynamo=None, ignore_inlined_sources=False)
return _optimize_catch_errors(
convert_frame.convert_frame(
backend,
hooks,
package=package,
),
convert_frame.convert_frame(backend, hooks, package=package),
hooks,
backend_ctx_ctor,
error_on_graph_break=nopython,
@ -2102,16 +2069,6 @@ def _optimize_assert(
# Find if backend has any extra context manager
backend_ctx_ctor = getattr(backend, "backend_ctx_ctor", null_context)
if config.caching_precompile and package is None:
# Create an uninitialized package that will be set/filled by
# _OptimizeContext.__call__
# We need to instantiate the object here because the same CompilePackage
# needs to be shared between convert_frame_assert
# and OptimizeContext.
from .package import CompilePackage
package = CompilePackage(fn=None, dynamo=None, ignore_inlined_sources=False)
return _optimize_catch_errors(
convert_frame.convert_frame_assert(
backend,

View File

@ -19,16 +19,14 @@ import logging
import os
import pickle
import platform
import shutil
import sys
import types
from collections.abc import Generator
from typing import Any, Callable, NewType, Optional
from typing import Any, NewType, Optional
import torch
import torch._inductor.package
from torch._dynamo.precompile_context import PrecompileCacheArtifact, PrecompileContext
from torch._inductor.runtime.cache_dir_utils import cache_dir
from torch.compiler._cache import CacheArtifactFactory
from .bytecode_transformation import get_code_keys
@ -185,7 +183,7 @@ class CompilePackage:
def __init__(
self,
fn: Optional[Callable[..., Any]],
fn: Any,
dynamo: Optional[_DynamoCacheEntry] = None,
ignore_inlined_sources: bool = False,
) -> None:
@ -199,16 +197,12 @@ class CompilePackage:
self._cached_backends: dict[_BackendId, Any] = {}
self._inlined_sources: set[InlinedSource] = set()
self._resume_codes: set[types.CodeType] = set()
self._initialized = False
if fn is not None:
self.initialize(fn, dynamo, ignore_inlined_sources)
self.uninstall()
self.validate()
def is_initialized(self) -> bool:
return self._initialized
self._initialize(fn, dynamo, ignore_inlined_sources)
self.uninstall()
self.validate()
def initialize(
def _initialize(
self,
fn: Any,
dynamo: Optional[_DynamoCacheEntry] = None,
@ -216,7 +210,6 @@ class CompilePackage:
) -> None:
from .eval_frame import innermost_fn
assert not self._initialized
self._inlined_sources = set()
self._innermost_fn = innermost_fn(fn)
assert self._innermost_fn is not None
@ -249,7 +242,6 @@ class CompilePackage:
self._add_function(
self._innermost_fn.__code__, self._innermost_fn.__module__
)
self._initialized = True
def _add_function(
self,
@ -281,7 +273,10 @@ class CompilePackage:
@functools.cached_property
def source_id(self) -> str:
assert self._innermost_fn is not None
return CompilePackage.source_id_from_fn(self._innermost_fn)
sha256_hash = hashlib.sha256()
sha256_hash.update(self._innermost_fn.__qualname__.encode())
sha256_hash.update(str(self._innermost_fn.__code__.co_firstlineno).encode())
return sha256_hash.hexdigest()
@contextlib.contextmanager
def code_context(self, code: types.CodeType) -> Generator[None, None, None]:
@ -440,17 +435,6 @@ class CompilePackage:
codes=list(self._codes.values()), inlined_sources=self._inlined_sources
)
@staticmethod
def source_id_from_fn(fn: Callable[..., Any]) -> str:
from .eval_frame import innermost_fn
innermost_fn_ = innermost_fn(fn)
sha256_hash = hashlib.sha256()
sha256_hash.update(innermost_fn_.__qualname__.encode())
sha256_hash.update(str(innermost_fn_.__code__.co_firstlineno).encode())
return sha256_hash.hexdigest()
@CacheArtifactFactory.register
class EagerCacheArtifact(PrecompileCacheArtifact[Any]):
@ -491,9 +475,6 @@ class DynamoStore(abc.ABC):
EagerCacheArtifact.type(), key=backend_id, content=pickled_result
)
@abc.abstractmethod
def clear(self) -> None: ...
@abc.abstractmethod
def write(
self,
@ -511,11 +492,12 @@ class DynamoStore(abc.ABC):
"""
...
def save_cache_entry(self, cache_entry: _DynamoCacheEntry, key: str) -> None:
def save_package(self, package: CompilePackage, key: str) -> None:
"""
Saves a package to a given path. Grabs backends from PrecompileContext.
"""
backend_content: _Backends = {}
cache_entry = package.cache_entry()
for backend_id in cache_entry.backend_ids:
serialized_backend = PrecompileContext.serialize_artifact_by_key(backend_id)
if serialized_backend is None:
@ -527,14 +509,6 @@ class DynamoStore(abc.ABC):
self.write(cache_entry, backend_content, key)
def save_package(self, package: CompilePackage, key: str) -> None:
"""
Saves a package to a given path. Grabs backends from PrecompileContext.
"""
self.record_package(package)
cache_entry = package.cache_entry()
self.save_cache_entry(cache_entry, key)
@abc.abstractmethod
def read(self, path: str) -> tuple[_DynamoCacheEntry, _Backends]:
"""
@ -548,25 +522,17 @@ class DynamoStore(abc.ABC):
"""
...
def load_cache_entry(
self, key: str
) -> tuple[_DynamoCacheEntry, dict[_BackendId, Any]]:
cache_entry, backend_content = self.read(key)
for backend_id, backend in backend_content.items():
PrecompileContext.record_artifact(
backend.type(), key=backend.key, content=backend.content
)
backend_content[backend_id] = backend.after_deserialization()
return cache_entry, backend_content
def load_package(
self, fn: Any, key: str
) -> tuple[CompilePackage, dict[_BackendId, Any]]:
"""
Loads a package from a given path and returns it plus a list of deserialized backends
"""
cache_entry, backend_content = self.load_cache_entry(key)
cache_entry, backend_content = self.read(key)
for backend_id, backend in backend_content.items():
backend_content[backend_id] = backend.after_deserialization()
package = CompilePackage(fn, cache_entry)
return package, backend_content
@ -579,9 +545,6 @@ class InMemoryDynamoStore(DynamoStore):
def __init__(self) -> None:
self.packages: dict[str, tuple[_DynamoCacheEntry, _Backends]] = {}
def clear(self) -> None:
self.packages.clear()
def write(
self,
dynamo: _DynamoCacheEntry,
@ -617,13 +580,6 @@ class DiskDynamoStore(DynamoStore):
"""
self.path_prefix = path_prefix
def clear(self) -> None:
"""
Clear all CompilePackages from disk.
"""
if self.path_prefix:
shutil.rmtree(self.path_prefix, ignore_errors=True)
def write(
self,
dynamo: _DynamoCacheEntry,
@ -633,9 +589,7 @@ class DiskDynamoStore(DynamoStore):
"""
Write dynamo cache entry and backends to disk.
"""
path = os.path.join(self.path_prefix, path) if self.path_prefix else path
try:
os.makedirs(path, exist_ok=True)
with open(os.path.join(path, "dynamo"), "wb") as dynamo_path:
pickle.dump(dynamo, dynamo_path)
with open(os.path.join(path, "backends"), "wb") as backend_path:
@ -647,7 +601,6 @@ class DiskDynamoStore(DynamoStore):
"""
Read dynamo cache entry and backends from disk.
"""
path = os.path.join(self.path_prefix, path) if self.path_prefix else path
try:
with open(os.path.join(path, "dynamo"), "rb") as dynamo_path:
cache_entry = pickle.load(dynamo_path)
@ -657,53 +610,18 @@ class DiskDynamoStore(DynamoStore):
except Exception as e:
raise RuntimeError(f"Failed to load package from path {path}: {e}") from e
class DiskDynamoCache(DiskDynamoStore):
"""
Special DiskDynamoStore which adds some helper functions for automatically
tracking paths of packages
"""
def save(self, package: CompilePackage) -> None:
def save_package(self, package: CompilePackage, key: str) -> None:
"""
Saves a package to a given path. Grabs backends from PrecompileContext.
Save a package to disk using the path_prefix + key as the file path.
"""
key = package.source_id
logger.info("Saving CompilePackage for %s", package.source_id)
super().save_package(package, key)
full_path = os.path.join(self.path_prefix, key) if self.path_prefix else key
super().save_package(package, full_path)
def load(
self, fn: Callable[..., Any]
) -> Optional[tuple[_DynamoCacheEntry, dict[_BackendId, Any]]]:
def load_package(
self, fn: Any, key: str
) -> tuple[CompilePackage, dict[_BackendId, Any]]:
"""
Loads a package from a given path and returns it plus a list of deserialized backends
Load a package from disk using the path_prefix + key as the file path.
"""
key = CompilePackage.source_id_from_fn(fn)
logger.info("Loading CompilePackage for %s", key)
path = os.path.join(self.path_prefix, key)
if os.path.exists(path):
try:
return super().load_cache_entry(key)
except Exception as e:
logger.warning("Failed to load package from path %s: %s", path, str(e))
return None
logger.info("No package found for %s", key)
return None
def load_and_install_package(
self, fn: Callable[..., Any]
) -> Optional[CompilePackage]:
"""
Load directly into a package and install backends
"""
results = self.load(fn)
if results is None:
return None
else:
(entry, backends) = results
package = CompilePackage(fn, entry)
package.install(backends)
return package
DynamoCache = DiskDynamoCache(os.path.join(cache_dir(), "dynamo"))
full_path = os.path.join(self.path_prefix, key) if self.path_prefix else key
return super().load_package(fn, full_path)

View File

@ -1,6 +1,5 @@
from abc import abstractmethod
from collections import defaultdict
from itertools import chain
from typing import Any, Generic, Optional, TypeVar
from typing_extensions import override
@ -144,34 +143,10 @@ class PrecompileContext(CacheArtifactManager):
@staticmethod
def populate_caches(artifacts: CacheArtifactsResult) -> CacheInfo:
PrecompileContext._ensure_cache_artifacts_registered()
artifacts_by_key = {}
cache_info = CacheInfo()
for artifact in chain(*artifacts.values()):
cache_info.add(artifact)
artifacts_by_key[artifact.key] = artifact
from torch._dynamo.package import _BackendId, DynamoCache
for dynamo_entry in artifacts["precompile_dynamo"]:
assert isinstance(dynamo_entry, PrecompileCacheArtifact)
cache_entry = dynamo_entry.after_deserialization()
# Grab backends from the dynamo cache entry
backends = cache_entry.backend_ids
backend_content: dict[_BackendId, PrecompileCacheArtifact[Any]] = {}
for id_ in backends:
assert id_ in artifacts_by_key, f"Backend {id_} not found in artifacts"
artifact = artifacts_by_key[id_]
assert isinstance(artifact, PrecompileCacheArtifact)
backend_content[id_] = artifact
DynamoCache.write(cache_entry, backend_content, dynamo_entry.key)
return cache_info
raise NotImplementedError("TODO")
@classmethod
def _ensure_cache_artifacts_registered(cls) -> None:
from torch._dynamo.package import _DynamoCacheArtifact # noqa: F401
from torch._functorch._aot_autograd.autograd_cache import ( # noqa: F401
BundledAOTAutogradCacheArtifact,
)

View File

@ -121,10 +121,6 @@ def should_use_local_autograd_cache():
return config.enable_autograd_cache
def should_bundle_autograd_cache():
return config.bundled_autograd_cache or torch._dynamo.config.caching_precompile
def check_node_safe(node: Node):
"""
Checks that the node only uses supported operators. We are starting with very
@ -1100,10 +1096,7 @@ class AOTAutogradCache(GuardedCache[GenericAOTAutogradCacheEntry]):
except FXGraphCacheMiss as e:
counters["aot_autograd"]["autograd_cache_miss"] += 1
cache_state = "miss"
if (
config.strict_autograd_cache
or torch._dynamo.config.caching_precompile
):
if config.strict_autograd_cache:
raise e
# Most often this is BypassAOTAutogradCache, but
# if there's ever different reason we can't cache,
@ -1133,10 +1126,7 @@ class AOTAutogradCache(GuardedCache[GenericAOTAutogradCacheEntry]):
)
if remote:
log_cache_bypass("bypass_aot_autograd", str(e))
if (
config.strict_autograd_cache
or torch._dynamo.config.caching_precompile
):
if config.strict_autograd_cache:
raise e
if compiled_fn is None:
# Set the cache key so we can save a cache result later
@ -1250,7 +1240,7 @@ class AOTAutogradCache(GuardedCache[GenericAOTAutogradCacheEntry]):
AOTAutogradCacheArtifact.type(), key, pickled_content
)
if (
should_bundle_autograd_cache()
config.bundled_autograd_cache
and aot_config is not None
and aot_config.precompile_backend_id is not None
):
@ -1292,7 +1282,7 @@ class AOTAutogradCache(GuardedCache[GenericAOTAutogradCacheEntry]):
AOTAutogradCacheArtifact.type(), key, content
)
if (
should_bundle_autograd_cache()
config.bundled_autograd_cache
and entry.sanitized_aot_config.precompile_backend_id is not None
):
precompile_key = entry.sanitized_aot_config.precompile_backend_id
@ -1368,7 +1358,7 @@ class AOTAutogradCache(GuardedCache[GenericAOTAutogradCacheEntry]):
num_symints_saved_for_bw: Optional[int],
serialized_bw_module: Optional[SerializedGraphModule],
) -> GenericAOTAutogradCacheEntry:
if should_bundle_autograd_cache():
if config.bundled_autograd_cache:
# Helper function to unwrap all the wrappers we added during aotdispatch
# They get reapplied on cache load
def unwrap_compiled_fx_graph(obj):

View File

@ -43,7 +43,6 @@ from .. import config
from .autograd_cache import (
AOTAutogradCache,
serialize_graph_module,
should_bundle_autograd_cache,
should_use_remote_autograd_cache,
)
from .dispatch_and_compile_graph import (
@ -263,7 +262,7 @@ def aot_dispatch_base(
cache_info = aot_config.cache_info
def should_save_cache():
if should_bundle_autograd_cache():
if torch._functorch.config.bundled_autograd_cache:
return True
else:
return hasattr(compiled_fw, "_fx_graph_cache_key")
@ -1782,7 +1781,7 @@ def aot_dispatch_autograd(
cache_info = aot_config.cache_info
def should_save_cache():
if should_bundle_autograd_cache():
if torch._functorch.config.bundled_autograd_cache:
return True
else:
return hasattr(compiled_fw_func, "_fx_graph_cache_key") and hasattr(

View File

@ -135,10 +135,6 @@ class CacheInfo:
def precompile_aot_autograd_artifacts(self) -> list[str]: # type: ignore[empty-body]
...
@property
def precompile_dynamo_artifacts(self) -> list[str]: # type: ignore[empty-body]
...
def add(self, artifact: CacheArtifact) -> None:
self.artifacts[artifact.type()].append(artifact.key)