mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Revert "Recheck Autotune cache on Precompile serialization to prune compilation results (#158656)"
This reverts commit 664005662ad8c9aa1942015397048aa9ca14fd6d. Reverted https://github.com/pytorch/pytorch/pull/158656 on behalf of https://github.com/seemethere due to failing internal tests, see D80486843 ([comment](https://github.com/pytorch/pytorch/pull/158656#issuecomment-3201491561))
This commit is contained in:
@ -16,7 +16,7 @@ from torch._dynamo.package import CompilePackage, DiskDynamoStore, DynamoCache
|
||||
from torch._dynamo.precompile_context import PrecompileContext
|
||||
from torch._dynamo.testing import reduce_to_scalar_loss
|
||||
from torch._functorch import config as functorch_config
|
||||
from torch._inductor.mock_cache import global_stats, PatchCaches
|
||||
from torch._inductor.mock_cache import global_stats, PatchCaches, Stats
|
||||
from torch._inductor.runtime.runtime_utils import cache_dir
|
||||
from torch.testing._internal.common_utils import (
|
||||
instantiate_parametrized_tests,
|
||||
@ -452,33 +452,27 @@ def add(x, y):
|
||||
def fn(x, y):
|
||||
return x.sin() + y
|
||||
|
||||
arg1 = torch.randn(32, 32, device=device)
|
||||
arg2 = torch.randn(32, 32, device=device)
|
||||
arg1 = torch.randn(3, 3, device=device)
|
||||
arg2 = torch.randn(3, 3, device=device)
|
||||
expected = fn(arg1, arg2).clone()
|
||||
|
||||
with PatchCaches():
|
||||
compiled_fn1 = torch.compile(fn, mode="max-autotune")
|
||||
result = compiled_fn1(arg1, arg2).clone()
|
||||
self.assertEqual(expected, result)
|
||||
self.assertEqual(global_stats.autotune_local.num_get_miss, 1)
|
||||
self.assertEqual(global_stats.autotune_local, Stats(1, 0, 1))
|
||||
DynamoCache.clear()
|
||||
|
||||
total_frames = torch._dynamo.convert_frame.FRAME_COUNTER
|
||||
self._save_and_reload(
|
||||
expected_backends=1, expected_dynamo=1, expected_autotune=1
|
||||
)
|
||||
# During save, we check the autotune cache another time, and now it should hit
|
||||
self.assertEqual(global_stats.autotune_local.num_get_hit, 1)
|
||||
compiled_fn1 = torch.compile(fn, mode="max-autotune")
|
||||
with torch.compiler.set_stance("fail_on_recompile"):
|
||||
result1 = compiled_fn1(arg1, arg2).clone()
|
||||
self.assertEqual(expected, result1)
|
||||
self.assertEqual(torch._dynamo.convert_frame.FRAME_COUNTER, total_frames)
|
||||
# No new hits or misses
|
||||
# Unfortunately, we don't *actually* know how many puts there will be, because
|
||||
# it's possible the best autotune config was found by coordesc.
|
||||
self.assertEqual(global_stats.autotune_local.num_get_hit, 1)
|
||||
self.assertEqual(global_stats.autotune_local.num_get_miss, 1)
|
||||
self.assertEqual(global_stats.autotune_local, Stats(2, 1, 1))
|
||||
|
||||
@parametrize("device", ("cpu", "cuda", "xpu"))
|
||||
@torch._dynamo.config.patch(caching_precompile=True)
|
||||
|
@ -169,16 +169,7 @@ class PrecompileContext(CacheArtifactManager):
|
||||
by artifact type. This function transfers artifacts from _new_cache_artifacts_by_key to _new_cache_artifacts
|
||||
"""
|
||||
for artifact in cls._new_cache_artifacts_by_key.values():
|
||||
from torch._functorch._aot_autograd.autograd_cache import (
|
||||
BundledAOTAutogradCacheEntry,
|
||||
)
|
||||
|
||||
if isinstance(artifact, EditablePrecompileCacheArtifact):
|
||||
if isinstance(artifact.content, BundledAOTAutogradCacheEntry):
|
||||
# BundledAOTAutogradCacheEntries should update their autotune results
|
||||
artifact.edit_contents(
|
||||
BundledAOTAutogradCacheEntry.update_autotune_results
|
||||
)
|
||||
artifact = artifact.real_encode()
|
||||
cls._new_cache_artifacts[artifact.__class__.type()].append(artifact)
|
||||
cls._new_cache_artifacts_by_key.clear()
|
||||
@ -204,15 +195,6 @@ class PrecompileContext(CacheArtifactManager):
|
||||
"""
|
||||
result = cls._new_cache_artifacts_by_key.get(key, None)
|
||||
if isinstance(result, EditablePrecompileCacheArtifact):
|
||||
from torch._functorch._aot_autograd.autograd_cache import (
|
||||
BundledAOTAutogradCacheEntry,
|
||||
)
|
||||
|
||||
if isinstance(result.content, BundledAOTAutogradCacheEntry):
|
||||
# BundledAOTAutogradCacheEntries should update their autotune results
|
||||
result.edit_contents(
|
||||
BundledAOTAutogradCacheEntry.update_autotune_results
|
||||
)
|
||||
result = result.real_encode()
|
||||
return result
|
||||
|
||||
|
@ -535,32 +535,6 @@ class CompiledFxGraphLoadable(InductorOutput[CompiledFxGraph]):
|
||||
|
||||
result: CompiledFxGraph
|
||||
|
||||
def recheck_autotune_results(self) -> None:
|
||||
"""
|
||||
Run during PrecompileContext.serialize(). We recheck the autotune cache
|
||||
again before saving results, to see if autotuning has completed for our generated
|
||||
triton kernels. If so, it edits the statically compiled triton kernel so that only
|
||||
the best config is preserved.
|
||||
"""
|
||||
triton_bundle = self.result._triton_bundle
|
||||
if triton_bundle is None:
|
||||
return
|
||||
static_autotuners = triton_bundle.static_autotuners
|
||||
for autotuner in static_autotuners:
|
||||
from torch._inductor.codecache import _load_triton_kernel_from_source
|
||||
|
||||
reload_kernel_from_src = functools.partial(
|
||||
_load_triton_kernel_from_source,
|
||||
autotuner.kernel_name,
|
||||
autotuner.source_code,
|
||||
)
|
||||
autotuner.kernel.recheck_autotune_cache(
|
||||
reload_kernel_from_src,
|
||||
)
|
||||
# Clear any extra state created by this check
|
||||
autotuner.kernel.prepare_for_pickle()
|
||||
autotuner.kernel.prepare_for_caching()
|
||||
|
||||
def pre_save(self) -> None:
|
||||
disk_compiled_graph = copy(self.result)
|
||||
disk_compiled_graph.prepare_for_serialization()
|
||||
@ -1024,18 +998,6 @@ class BundledAOTAutogradCacheEntry(
|
||||
of relying on cache keys from FxGraphCache
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def update_autotune_results(
|
||||
entry: BundledAOTAutogradCacheEntry,
|
||||
) -> BundledAOTAutogradCacheEntry:
|
||||
"""
|
||||
Update the autotune results in the cache entry.
|
||||
"""
|
||||
entry.compiled_fw.recheck_autotune_results()
|
||||
if entry.compiled_bw is not None:
|
||||
entry.compiled_bw.recheck_autotune_results()
|
||||
return entry
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def sanitize_gm_for_cache(gm: torch.fx.GraphModule):
|
||||
|
@ -401,9 +401,11 @@ class AsyncCompile:
|
||||
|
||||
if (future := CompiledTritonKernels.get(source_code)) is not None:
|
||||
counters["inductor"]["async_compile_cache_hit"] += 1
|
||||
# Set reload_kernel_from_src properly based on source_code
|
||||
if isinstance(future, StaticAutotunerFuture):
|
||||
# Remove the future now that we've cache hit
|
||||
CompiledTritonKernels.remove_future(source_code)
|
||||
future.reload_kernel_from_src = reload_kernel_in_parent
|
||||
if is_parallel:
|
||||
return future
|
||||
else:
|
||||
@ -457,7 +459,7 @@ class AsyncCompile:
|
||||
kernel.precompile(
|
||||
warm_cache_only=False,
|
||||
reload_kernel=reload_kernel_in_parent,
|
||||
source_code=source_code,
|
||||
static_triton_bundle_key=CompiledTritonKernels.key(source_code),
|
||||
)
|
||||
info = kernel.autotune_cache_info or {}
|
||||
info["compile_time_us"] = elapsed_us
|
||||
@ -486,7 +488,7 @@ class AsyncCompile:
|
||||
kernel.set_compile_info(compile_id, is_backward)
|
||||
kernel.precompile(
|
||||
warm_cache_only=False,
|
||||
source_code=source_code,
|
||||
static_triton_bundle_key=CompiledTritonKernels.key(source_code),
|
||||
)
|
||||
elapsed_us = (time_ns() - start_ns) // 1000
|
||||
get_metrics_context().add_top_n(
|
||||
|
@ -4213,28 +4213,24 @@ class StaticAutotunerFuture(CodeCacheFuture):
|
||||
A statically launchable CachingAutotuner, loaded from TritonBundler
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, static_autotuner: CachingAutotuner, kernel_name: str, source_code: str
|
||||
) -> None:
|
||||
def __init__(self, static_autotuner: CachingAutotuner) -> None:
|
||||
# Pickled version of CachingAutotuner
|
||||
self.static_autotuner = static_autotuner
|
||||
self.kernel_name = kernel_name
|
||||
# The python source code of the kernel is relatively small and stored by StaticallyLaunchedAutotuner.
|
||||
# We do not store the compiled cuda code here as it's very large,
|
||||
# it's stored via the regular TritonBundler
|
||||
self.source_code = source_code
|
||||
# This needs to be set in AsyncCompile.triton, in case
|
||||
# we need to reload the CachingAutotuner from its source code
|
||||
# We don't store the source code on the CachingAutotuner itself
|
||||
# since it can be very large.
|
||||
self.reload_kernel_from_src: Optional[Callable[[], Any]] = None
|
||||
|
||||
def result(self) -> CachingAutotuner:
|
||||
assert self.reload_kernel_from_src is not None
|
||||
with dynamo_timed("StaticAutotunerFuture.warm_precompile"):
|
||||
reload_kernel_from_src = functools.partial(
|
||||
_load_triton_kernel_from_source, self.kernel_name, self.source_code
|
||||
)
|
||||
self.static_autotuner.recheck_autotune_cache(
|
||||
reload_kernel_from_src=reload_kernel_from_src
|
||||
reload_kernel_from_src=self.reload_kernel_from_src
|
||||
)
|
||||
self.static_autotuner.precompile( # type: ignore[union-attr]
|
||||
warm_cache_only=False,
|
||||
reload_kernel=reload_kernel_from_src,
|
||||
source_code=None, # no need to save again
|
||||
reload_kernel=self.reload_kernel_from_src,
|
||||
static_triton_bundle_key=None, # no need to save again
|
||||
)
|
||||
return self.static_autotuner
|
||||
|
@ -386,14 +386,13 @@ class CachingAutotuner(KernelInterface):
|
||||
assert self.is_statically_launchable()
|
||||
|
||||
configs = [result.config for result in self.compile_results]
|
||||
if len(configs) <= 1:
|
||||
return
|
||||
|
||||
(cached_configs, _, autotune_cache_info) = check_autotune_cache(
|
||||
configs, self.filename, self.inductor_meta
|
||||
)
|
||||
self.autotune_cache_info = autotune_cache_info
|
||||
# I.e. there was an autotune cache hit
|
||||
if len(cached_configs) == 1:
|
||||
if len(cached_configs) == 1 and len(configs) > 1:
|
||||
best_config = cached_configs[0]
|
||||
# Grab the best compiled config, if it's in the list of available ones
|
||||
best_config_hash = triton_config_to_hashable(best_config)
|
||||
@ -422,7 +421,7 @@ class CachingAutotuner(KernelInterface):
|
||||
self,
|
||||
warm_cache_only=False,
|
||||
reload_kernel: Optional[Callable[[], CachingAutotuner]] = None,
|
||||
source_code: Optional[str] = None, # Used for static_triton_bundle_key
|
||||
static_triton_bundle_key: Optional[str] = None,
|
||||
):
|
||||
if warm_cache_only:
|
||||
self._precompile_worker()
|
||||
@ -435,9 +434,8 @@ class CachingAutotuner(KernelInterface):
|
||||
if reload_kernel is not None:
|
||||
self._reload_kernel = reload_kernel
|
||||
self._precompile_worker()
|
||||
|
||||
if source_code is not None and self.is_statically_launchable():
|
||||
TritonBundler.put_static_autotuner(source_code, self)
|
||||
if static_triton_bundle_key is not None and self.is_statically_launchable():
|
||||
TritonBundler.put_static_autotuner(static_triton_bundle_key, self)
|
||||
self._make_launchers()
|
||||
self._dynamic_scale_rblock()
|
||||
|
||||
|
@ -53,11 +53,7 @@ class StaticallyLaunchedAutotuner:
|
||||
Statically saved here have their cubin files saved by a corresponding TritonBundleEntry.
|
||||
"""
|
||||
|
||||
# We store the kernel's python source code here which we use for two things:
|
||||
# First, to calculate a cache key for CompiledTritonKernels
|
||||
# Second, in case we need to reload the kernel on load,
|
||||
# we can do so by reading the source code from the cache entry.
|
||||
source_code: str
|
||||
cache_key: str
|
||||
kernel_name: str
|
||||
kernel: "CachingAutotuner" # type: ignore[name-defined] # noqa: F821
|
||||
|
||||
@ -168,7 +164,7 @@ class TritonBundler:
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def put_static_autotuner(cls, source_code: str, kernel: "CachingAutotuner") -> None: # type: ignore[name-defined] # noqa: F821
|
||||
def put_static_autotuner(cls, key: str, kernel: "CachingAutotuner") -> None: # type: ignore[name-defined] # noqa: F821
|
||||
from torch._inductor import config
|
||||
|
||||
assert config.use_static_cuda_launcher
|
||||
@ -182,7 +178,7 @@ class TritonBundler:
|
||||
|
||||
entries.append(
|
||||
StaticallyLaunchedAutotuner(
|
||||
source_code,
|
||||
key,
|
||||
new_kernel.inductor_meta.get("kernel_name", "unknown_kernel"),
|
||||
new_kernel,
|
||||
)
|
||||
@ -244,9 +240,8 @@ class TritonBundler:
|
||||
# kernels that are not statically launchable (i.e. cache miss)
|
||||
# can launch a worker without waiting on the blocking step of
|
||||
# StaticAutotunerFuture.result().
|
||||
cache_key = CompiledTritonKernels.key(result.source_code)
|
||||
CompiledTritonKernels._cache[cache_key] = StaticAutotunerFuture(
|
||||
result.kernel, result.kernel_name, result.source_code
|
||||
CompiledTritonKernels._cache[result.cache_key] = StaticAutotunerFuture(
|
||||
result.kernel
|
||||
)
|
||||
counters["inductor"]["triton_bundler_load_static_autotuner"] += 1
|
||||
kernel_names.append(result.kernel_name)
|
||||
|
Reference in New Issue
Block a user