Revert "Recheck Autotune cache on Precompile serialization to prune compilation results (#158656)"

This reverts commit 664005662ad8c9aa1942015397048aa9ca14fd6d.

Reverted https://github.com/pytorch/pytorch/pull/158656 on behalf of https://github.com/seemethere due to failing internal tests, see D80486843 ([comment](https://github.com/pytorch/pytorch/pull/158656#issuecomment-3201491561))
This commit is contained in:
PyTorch MergeBot
2025-08-19 16:53:20 +00:00
parent fecc5f6001
commit eddaaa6c2a
7 changed files with 29 additions and 100 deletions

View File

@ -16,7 +16,7 @@ from torch._dynamo.package import CompilePackage, DiskDynamoStore, DynamoCache
from torch._dynamo.precompile_context import PrecompileContext
from torch._dynamo.testing import reduce_to_scalar_loss
from torch._functorch import config as functorch_config
from torch._inductor.mock_cache import global_stats, PatchCaches
from torch._inductor.mock_cache import global_stats, PatchCaches, Stats
from torch._inductor.runtime.runtime_utils import cache_dir
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
@ -452,33 +452,27 @@ def add(x, y):
def fn(x, y):
return x.sin() + y
arg1 = torch.randn(32, 32, device=device)
arg2 = torch.randn(32, 32, device=device)
arg1 = torch.randn(3, 3, device=device)
arg2 = torch.randn(3, 3, device=device)
expected = fn(arg1, arg2).clone()
with PatchCaches():
compiled_fn1 = torch.compile(fn, mode="max-autotune")
result = compiled_fn1(arg1, arg2).clone()
self.assertEqual(expected, result)
self.assertEqual(global_stats.autotune_local.num_get_miss, 1)
self.assertEqual(global_stats.autotune_local, Stats(1, 0, 1))
DynamoCache.clear()
total_frames = torch._dynamo.convert_frame.FRAME_COUNTER
self._save_and_reload(
expected_backends=1, expected_dynamo=1, expected_autotune=1
)
# During save, we check the autotune cache another time, and now it should hit
self.assertEqual(global_stats.autotune_local.num_get_hit, 1)
compiled_fn1 = torch.compile(fn, mode="max-autotune")
with torch.compiler.set_stance("fail_on_recompile"):
result1 = compiled_fn1(arg1, arg2).clone()
self.assertEqual(expected, result1)
self.assertEqual(torch._dynamo.convert_frame.FRAME_COUNTER, total_frames)
# No new hits or misses
# Unfortunately, we don't *actually* know how many puts there will be, because
# it's possible the best autotune config was found by coordesc.
self.assertEqual(global_stats.autotune_local.num_get_hit, 1)
self.assertEqual(global_stats.autotune_local.num_get_miss, 1)
self.assertEqual(global_stats.autotune_local, Stats(2, 1, 1))
@parametrize("device", ("cpu", "cuda", "xpu"))
@torch._dynamo.config.patch(caching_precompile=True)

View File

@ -169,16 +169,7 @@ class PrecompileContext(CacheArtifactManager):
by artifact type. This function transfers artifacts from _new_cache_artifacts_by_key to _new_cache_artifacts
"""
for artifact in cls._new_cache_artifacts_by_key.values():
from torch._functorch._aot_autograd.autograd_cache import (
BundledAOTAutogradCacheEntry,
)
if isinstance(artifact, EditablePrecompileCacheArtifact):
if isinstance(artifact.content, BundledAOTAutogradCacheEntry):
# BundledAOTAutogradCacheEntries should update their autotune results
artifact.edit_contents(
BundledAOTAutogradCacheEntry.update_autotune_results
)
artifact = artifact.real_encode()
cls._new_cache_artifacts[artifact.__class__.type()].append(artifact)
cls._new_cache_artifacts_by_key.clear()
@ -204,15 +195,6 @@ class PrecompileContext(CacheArtifactManager):
"""
result = cls._new_cache_artifacts_by_key.get(key, None)
if isinstance(result, EditablePrecompileCacheArtifact):
from torch._functorch._aot_autograd.autograd_cache import (
BundledAOTAutogradCacheEntry,
)
if isinstance(result.content, BundledAOTAutogradCacheEntry):
# BundledAOTAutogradCacheEntries should update their autotune results
result.edit_contents(
BundledAOTAutogradCacheEntry.update_autotune_results
)
result = result.real_encode()
return result

View File

@ -535,32 +535,6 @@ class CompiledFxGraphLoadable(InductorOutput[CompiledFxGraph]):
result: CompiledFxGraph
def recheck_autotune_results(self) -> None:
"""
Run during PrecompileContext.serialize(). We recheck the autotune cache
again before saving results, to see if autotuning has completed for our generated
triton kernels. If so, it edits the statically compiled triton kernel so that only
the best config is preserved.
"""
triton_bundle = self.result._triton_bundle
if triton_bundle is None:
return
static_autotuners = triton_bundle.static_autotuners
for autotuner in static_autotuners:
from torch._inductor.codecache import _load_triton_kernel_from_source
reload_kernel_from_src = functools.partial(
_load_triton_kernel_from_source,
autotuner.kernel_name,
autotuner.source_code,
)
autotuner.kernel.recheck_autotune_cache(
reload_kernel_from_src,
)
# Clear any extra state created by this check
autotuner.kernel.prepare_for_pickle()
autotuner.kernel.prepare_for_caching()
def pre_save(self) -> None:
disk_compiled_graph = copy(self.result)
disk_compiled_graph.prepare_for_serialization()
@ -1024,18 +998,6 @@ class BundledAOTAutogradCacheEntry(
of relying on cache keys from FxGraphCache
"""
@staticmethod
def update_autotune_results(
entry: BundledAOTAutogradCacheEntry,
) -> BundledAOTAutogradCacheEntry:
"""
Update the autotune results in the cache entry.
"""
entry.compiled_fw.recheck_autotune_results()
if entry.compiled_bw is not None:
entry.compiled_bw.recheck_autotune_results()
return entry
@contextlib.contextmanager
def sanitize_gm_for_cache(gm: torch.fx.GraphModule):

View File

@ -401,9 +401,11 @@ class AsyncCompile:
if (future := CompiledTritonKernels.get(source_code)) is not None:
counters["inductor"]["async_compile_cache_hit"] += 1
# Set reload_kernel_from_src properly based on source_code
if isinstance(future, StaticAutotunerFuture):
# Remove the future now that we've cache hit
CompiledTritonKernels.remove_future(source_code)
future.reload_kernel_from_src = reload_kernel_in_parent
if is_parallel:
return future
else:
@ -457,7 +459,7 @@ class AsyncCompile:
kernel.precompile(
warm_cache_only=False,
reload_kernel=reload_kernel_in_parent,
source_code=source_code,
static_triton_bundle_key=CompiledTritonKernels.key(source_code),
)
info = kernel.autotune_cache_info or {}
info["compile_time_us"] = elapsed_us
@ -486,7 +488,7 @@ class AsyncCompile:
kernel.set_compile_info(compile_id, is_backward)
kernel.precompile(
warm_cache_only=False,
source_code=source_code,
static_triton_bundle_key=CompiledTritonKernels.key(source_code),
)
elapsed_us = (time_ns() - start_ns) // 1000
get_metrics_context().add_top_n(

View File

@ -4213,28 +4213,24 @@ class StaticAutotunerFuture(CodeCacheFuture):
A statically launchable CachingAutotuner, loaded from TritonBundler
"""
def __init__(
self, static_autotuner: CachingAutotuner, kernel_name: str, source_code: str
) -> None:
def __init__(self, static_autotuner: CachingAutotuner) -> None:
# Pickled version of CachingAutotuner
self.static_autotuner = static_autotuner
self.kernel_name = kernel_name
# The python source code of the kernel is relatively small and stored by StaticallyLaunchedAutotuner.
# We do not store the compiled cuda code here as it's very large,
# it's stored via the regular TritonBundler
self.source_code = source_code
# This needs to be set in AsyncCompile.triton, in case
# we need to reload the CachingAutotuner from its source code
# We don't store the source code on the CachingAutotuner itself
# since it can be very large.
self.reload_kernel_from_src: Optional[Callable[[], Any]] = None
def result(self) -> CachingAutotuner:
assert self.reload_kernel_from_src is not None
with dynamo_timed("StaticAutotunerFuture.warm_precompile"):
reload_kernel_from_src = functools.partial(
_load_triton_kernel_from_source, self.kernel_name, self.source_code
)
self.static_autotuner.recheck_autotune_cache(
reload_kernel_from_src=reload_kernel_from_src
reload_kernel_from_src=self.reload_kernel_from_src
)
self.static_autotuner.precompile( # type: ignore[union-attr]
warm_cache_only=False,
reload_kernel=reload_kernel_from_src,
source_code=None, # no need to save again
reload_kernel=self.reload_kernel_from_src,
static_triton_bundle_key=None, # no need to save again
)
return self.static_autotuner

View File

@ -386,14 +386,13 @@ class CachingAutotuner(KernelInterface):
assert self.is_statically_launchable()
configs = [result.config for result in self.compile_results]
if len(configs) <= 1:
return
(cached_configs, _, autotune_cache_info) = check_autotune_cache(
configs, self.filename, self.inductor_meta
)
self.autotune_cache_info = autotune_cache_info
# I.e. there was an autotune cache hit
if len(cached_configs) == 1:
if len(cached_configs) == 1 and len(configs) > 1:
best_config = cached_configs[0]
# Grab the best compiled config, if it's in the list of available ones
best_config_hash = triton_config_to_hashable(best_config)
@ -422,7 +421,7 @@ class CachingAutotuner(KernelInterface):
self,
warm_cache_only=False,
reload_kernel: Optional[Callable[[], CachingAutotuner]] = None,
source_code: Optional[str] = None, # Used for static_triton_bundle_key
static_triton_bundle_key: Optional[str] = None,
):
if warm_cache_only:
self._precompile_worker()
@ -435,9 +434,8 @@ class CachingAutotuner(KernelInterface):
if reload_kernel is not None:
self._reload_kernel = reload_kernel
self._precompile_worker()
if source_code is not None and self.is_statically_launchable():
TritonBundler.put_static_autotuner(source_code, self)
if static_triton_bundle_key is not None and self.is_statically_launchable():
TritonBundler.put_static_autotuner(static_triton_bundle_key, self)
self._make_launchers()
self._dynamic_scale_rblock()

View File

@ -53,11 +53,7 @@ class StaticallyLaunchedAutotuner:
Statically saved here have their cubin files saved by a corresponding TritonBundleEntry.
"""
# We store the kernel's python source code here which we use for two things:
# First, to calculate a cache key for CompiledTritonKernels
# Second, in case we need to reload the kernel on load,
# we can do so by reading the source code from the cache entry.
source_code: str
cache_key: str
kernel_name: str
kernel: "CachingAutotuner" # type: ignore[name-defined] # noqa: F821
@ -168,7 +164,7 @@ class TritonBundler:
)
@classmethod
def put_static_autotuner(cls, source_code: str, kernel: "CachingAutotuner") -> None: # type: ignore[name-defined] # noqa: F821
def put_static_autotuner(cls, key: str, kernel: "CachingAutotuner") -> None: # type: ignore[name-defined] # noqa: F821
from torch._inductor import config
assert config.use_static_cuda_launcher
@ -182,7 +178,7 @@ class TritonBundler:
entries.append(
StaticallyLaunchedAutotuner(
source_code,
key,
new_kernel.inductor_meta.get("kernel_name", "unknown_kernel"),
new_kernel,
)
@ -244,9 +240,8 @@ class TritonBundler:
# kernels that are not statically launchable (i.e. cache miss)
# can launch a worker without waiting on the blocking step of
# StaticAutotunerFuture.result().
cache_key = CompiledTritonKernels.key(result.source_code)
CompiledTritonKernels._cache[cache_key] = StaticAutotunerFuture(
result.kernel, result.kernel_name, result.source_code
CompiledTritonKernels._cache[result.cache_key] = StaticAutotunerFuture(
result.kernel
)
counters["inductor"]["triton_bundler_load_static_autotuner"] += 1
kernel_names.append(result.kernel_name)