diff --git a/test/dynamo/test_package.py b/test/dynamo/test_package.py index eace2e3cdc42..fdd01135ea2f 100644 --- a/test/dynamo/test_package.py +++ b/test/dynamo/test_package.py @@ -16,7 +16,7 @@ from torch._dynamo.package import CompilePackage, DiskDynamoStore, DynamoCache from torch._dynamo.precompile_context import PrecompileContext from torch._dynamo.testing import reduce_to_scalar_loss from torch._functorch import config as functorch_config -from torch._inductor.mock_cache import global_stats, PatchCaches +from torch._inductor.mock_cache import global_stats, PatchCaches, Stats from torch._inductor.runtime.runtime_utils import cache_dir from torch.testing._internal.common_utils import ( instantiate_parametrized_tests, @@ -452,33 +452,27 @@ def add(x, y): def fn(x, y): return x.sin() + y - arg1 = torch.randn(32, 32, device=device) - arg2 = torch.randn(32, 32, device=device) + arg1 = torch.randn(3, 3, device=device) + arg2 = torch.randn(3, 3, device=device) expected = fn(arg1, arg2).clone() with PatchCaches(): compiled_fn1 = torch.compile(fn, mode="max-autotune") result = compiled_fn1(arg1, arg2).clone() self.assertEqual(expected, result) - self.assertEqual(global_stats.autotune_local.num_get_miss, 1) + self.assertEqual(global_stats.autotune_local, Stats(1, 0, 1)) DynamoCache.clear() total_frames = torch._dynamo.convert_frame.FRAME_COUNTER self._save_and_reload( expected_backends=1, expected_dynamo=1, expected_autotune=1 ) - # During save, we check the autotune cache another time, and now it should hit - self.assertEqual(global_stats.autotune_local.num_get_hit, 1) compiled_fn1 = torch.compile(fn, mode="max-autotune") with torch.compiler.set_stance("fail_on_recompile"): result1 = compiled_fn1(arg1, arg2).clone() self.assertEqual(expected, result1) self.assertEqual(torch._dynamo.convert_frame.FRAME_COUNTER, total_frames) - # No new hits or misses - # Unfortunately, we don't *actually* know how many puts there will be, because - # it's possible the best autotune config was found by coordesc. - self.assertEqual(global_stats.autotune_local.num_get_hit, 1) - self.assertEqual(global_stats.autotune_local.num_get_miss, 1) + self.assertEqual(global_stats.autotune_local, Stats(2, 1, 1)) @parametrize("device", ("cpu", "cuda", "xpu")) @torch._dynamo.config.patch(caching_precompile=True) diff --git a/torch/_dynamo/precompile_context.py b/torch/_dynamo/precompile_context.py index 55fb5dbbda06..38f97e583375 100644 --- a/torch/_dynamo/precompile_context.py +++ b/torch/_dynamo/precompile_context.py @@ -169,16 +169,7 @@ class PrecompileContext(CacheArtifactManager): by artifact type. This function transfers artifacts from _new_cache_artifacts_by_key to _new_cache_artifacts """ for artifact in cls._new_cache_artifacts_by_key.values(): - from torch._functorch._aot_autograd.autograd_cache import ( - BundledAOTAutogradCacheEntry, - ) - if isinstance(artifact, EditablePrecompileCacheArtifact): - if isinstance(artifact.content, BundledAOTAutogradCacheEntry): - # BundledAOTAutogradCacheEntries should update their autotune results - artifact.edit_contents( - BundledAOTAutogradCacheEntry.update_autotune_results - ) artifact = artifact.real_encode() cls._new_cache_artifacts[artifact.__class__.type()].append(artifact) cls._new_cache_artifacts_by_key.clear() @@ -204,15 +195,6 @@ class PrecompileContext(CacheArtifactManager): """ result = cls._new_cache_artifacts_by_key.get(key, None) if isinstance(result, EditablePrecompileCacheArtifact): - from torch._functorch._aot_autograd.autograd_cache import ( - BundledAOTAutogradCacheEntry, - ) - - if isinstance(result.content, BundledAOTAutogradCacheEntry): - # BundledAOTAutogradCacheEntries should update their autotune results - result.edit_contents( - BundledAOTAutogradCacheEntry.update_autotune_results - ) result = result.real_encode() return result diff --git a/torch/_functorch/_aot_autograd/autograd_cache.py b/torch/_functorch/_aot_autograd/autograd_cache.py index c1726bc5dd6b..248c3a0ae673 100644 --- a/torch/_functorch/_aot_autograd/autograd_cache.py +++ b/torch/_functorch/_aot_autograd/autograd_cache.py @@ -535,32 +535,6 @@ class CompiledFxGraphLoadable(InductorOutput[CompiledFxGraph]): result: CompiledFxGraph - def recheck_autotune_results(self) -> None: - """ - Run during PrecompileContext.serialize(). We recheck the autotune cache - again before saving results, to see if autotuning has completed for our generated - triton kernels. If so, it edits the statically compiled triton kernel so that only - the best config is preserved. - """ - triton_bundle = self.result._triton_bundle - if triton_bundle is None: - return - static_autotuners = triton_bundle.static_autotuners - for autotuner in static_autotuners: - from torch._inductor.codecache import _load_triton_kernel_from_source - - reload_kernel_from_src = functools.partial( - _load_triton_kernel_from_source, - autotuner.kernel_name, - autotuner.source_code, - ) - autotuner.kernel.recheck_autotune_cache( - reload_kernel_from_src, - ) - # Clear any extra state created by this check - autotuner.kernel.prepare_for_pickle() - autotuner.kernel.prepare_for_caching() - def pre_save(self) -> None: disk_compiled_graph = copy(self.result) disk_compiled_graph.prepare_for_serialization() @@ -1024,18 +998,6 @@ class BundledAOTAutogradCacheEntry( of relying on cache keys from FxGraphCache """ - @staticmethod - def update_autotune_results( - entry: BundledAOTAutogradCacheEntry, - ) -> BundledAOTAutogradCacheEntry: - """ - Update the autotune results in the cache entry. - """ - entry.compiled_fw.recheck_autotune_results() - if entry.compiled_bw is not None: - entry.compiled_bw.recheck_autotune_results() - return entry - @contextlib.contextmanager def sanitize_gm_for_cache(gm: torch.fx.GraphModule): diff --git a/torch/_inductor/async_compile.py b/torch/_inductor/async_compile.py index ee365b35c5e6..09bf4b1c9e28 100644 --- a/torch/_inductor/async_compile.py +++ b/torch/_inductor/async_compile.py @@ -401,9 +401,11 @@ class AsyncCompile: if (future := CompiledTritonKernels.get(source_code)) is not None: counters["inductor"]["async_compile_cache_hit"] += 1 + # Set reload_kernel_from_src properly based on source_code if isinstance(future, StaticAutotunerFuture): # Remove the future now that we've cache hit CompiledTritonKernels.remove_future(source_code) + future.reload_kernel_from_src = reload_kernel_in_parent if is_parallel: return future else: @@ -457,7 +459,7 @@ class AsyncCompile: kernel.precompile( warm_cache_only=False, reload_kernel=reload_kernel_in_parent, - source_code=source_code, + static_triton_bundle_key=CompiledTritonKernels.key(source_code), ) info = kernel.autotune_cache_info or {} info["compile_time_us"] = elapsed_us @@ -486,7 +488,7 @@ class AsyncCompile: kernel.set_compile_info(compile_id, is_backward) kernel.precompile( warm_cache_only=False, - source_code=source_code, + static_triton_bundle_key=CompiledTritonKernels.key(source_code), ) elapsed_us = (time_ns() - start_ns) // 1000 get_metrics_context().add_top_n( diff --git a/torch/_inductor/codecache.py b/torch/_inductor/codecache.py index 6d829b95cc3c..312dc2aaeb0c 100644 --- a/torch/_inductor/codecache.py +++ b/torch/_inductor/codecache.py @@ -4213,28 +4213,24 @@ class StaticAutotunerFuture(CodeCacheFuture): A statically launchable CachingAutotuner, loaded from TritonBundler """ - def __init__( - self, static_autotuner: CachingAutotuner, kernel_name: str, source_code: str - ) -> None: + def __init__(self, static_autotuner: CachingAutotuner) -> None: # Pickled version of CachingAutotuner self.static_autotuner = static_autotuner - self.kernel_name = kernel_name - # The python source code of the kernel is relatively small and stored by StaticallyLaunchedAutotuner. - # We do not store the compiled cuda code here as it's very large, - # it's stored via the regular TritonBundler - self.source_code = source_code + # This needs to be set in AsyncCompile.triton, in case + # we need to reload the CachingAutotuner from its source code + # We don't store the source code on the CachingAutotuner itself + # since it can be very large. + self.reload_kernel_from_src: Optional[Callable[[], Any]] = None def result(self) -> CachingAutotuner: + assert self.reload_kernel_from_src is not None with dynamo_timed("StaticAutotunerFuture.warm_precompile"): - reload_kernel_from_src = functools.partial( - _load_triton_kernel_from_source, self.kernel_name, self.source_code - ) self.static_autotuner.recheck_autotune_cache( - reload_kernel_from_src=reload_kernel_from_src + reload_kernel_from_src=self.reload_kernel_from_src ) self.static_autotuner.precompile( # type: ignore[union-attr] warm_cache_only=False, - reload_kernel=reload_kernel_from_src, - source_code=None, # no need to save again + reload_kernel=self.reload_kernel_from_src, + static_triton_bundle_key=None, # no need to save again ) return self.static_autotuner diff --git a/torch/_inductor/runtime/triton_heuristics.py b/torch/_inductor/runtime/triton_heuristics.py index 76a811b72d36..d9e3d6734449 100644 --- a/torch/_inductor/runtime/triton_heuristics.py +++ b/torch/_inductor/runtime/triton_heuristics.py @@ -386,14 +386,13 @@ class CachingAutotuner(KernelInterface): assert self.is_statically_launchable() configs = [result.config for result in self.compile_results] - if len(configs) <= 1: - return + (cached_configs, _, autotune_cache_info) = check_autotune_cache( configs, self.filename, self.inductor_meta ) self.autotune_cache_info = autotune_cache_info # I.e. there was an autotune cache hit - if len(cached_configs) == 1: + if len(cached_configs) == 1 and len(configs) > 1: best_config = cached_configs[0] # Grab the best compiled config, if it's in the list of available ones best_config_hash = triton_config_to_hashable(best_config) @@ -422,7 +421,7 @@ class CachingAutotuner(KernelInterface): self, warm_cache_only=False, reload_kernel: Optional[Callable[[], CachingAutotuner]] = None, - source_code: Optional[str] = None, # Used for static_triton_bundle_key + static_triton_bundle_key: Optional[str] = None, ): if warm_cache_only: self._precompile_worker() @@ -435,9 +434,8 @@ class CachingAutotuner(KernelInterface): if reload_kernel is not None: self._reload_kernel = reload_kernel self._precompile_worker() - - if source_code is not None and self.is_statically_launchable(): - TritonBundler.put_static_autotuner(source_code, self) + if static_triton_bundle_key is not None and self.is_statically_launchable(): + TritonBundler.put_static_autotuner(static_triton_bundle_key, self) self._make_launchers() self._dynamic_scale_rblock() diff --git a/torch/_inductor/triton_bundler.py b/torch/_inductor/triton_bundler.py index 79962b60ca1c..b5ccb873e33f 100644 --- a/torch/_inductor/triton_bundler.py +++ b/torch/_inductor/triton_bundler.py @@ -53,11 +53,7 @@ class StaticallyLaunchedAutotuner: Statically saved here have their cubin files saved by a corresponding TritonBundleEntry. """ - # We store the kernel's python source code here which we use for two things: - # First, to calculate a cache key for CompiledTritonKernels - # Second, in case we need to reload the kernel on load, - # we can do so by reading the source code from the cache entry. - source_code: str + cache_key: str kernel_name: str kernel: "CachingAutotuner" # type: ignore[name-defined] # noqa: F821 @@ -168,7 +164,7 @@ class TritonBundler: ) @classmethod - def put_static_autotuner(cls, source_code: str, kernel: "CachingAutotuner") -> None: # type: ignore[name-defined] # noqa: F821 + def put_static_autotuner(cls, key: str, kernel: "CachingAutotuner") -> None: # type: ignore[name-defined] # noqa: F821 from torch._inductor import config assert config.use_static_cuda_launcher @@ -182,7 +178,7 @@ class TritonBundler: entries.append( StaticallyLaunchedAutotuner( - source_code, + key, new_kernel.inductor_meta.get("kernel_name", "unknown_kernel"), new_kernel, ) @@ -244,9 +240,8 @@ class TritonBundler: # kernels that are not statically launchable (i.e. cache miss) # can launch a worker without waiting on the blocking step of # StaticAutotunerFuture.result(). - cache_key = CompiledTritonKernels.key(result.source_code) - CompiledTritonKernels._cache[cache_key] = StaticAutotunerFuture( - result.kernel, result.kernel_name, result.source_code + CompiledTritonKernels._cache[result.cache_key] = StaticAutotunerFuture( + result.kernel ) counters["inductor"]["triton_bundler_load_static_autotuner"] += 1 kernel_names.append(result.kernel_name)