Still run TritonBundler with BundledAOTAutogradCache, save autotune results (#158048)

When running BundledAOTAutogradCache with precompile, we still need to run triton bundling so that the precompiled CompiledFxGraph has triton cuda kernels. We also pre save the autotune results in the precompile artifact.

It would be even better to pre trim the cuda kernels on save and apply them, which we can work on later.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/158048
Approved by: https://github.com/zhxchen17
This commit is contained in:
James Wu
2025-07-18 10:05:09 -07:00
committed by PyTorch MergeBot
parent d5a29fc58a
commit 8e57cdb746
4 changed files with 79 additions and 3 deletions

View File

@ -15,6 +15,7 @@ import torch.utils.cpp_extension
from torch._dynamo.package import CompilePackage, DiskDynamoStore, DynamoCache
from torch._dynamo.precompile_context import PrecompileContext
from torch._functorch import config as functorch_config
from torch._inductor.mock_cache import global_stats, PatchCaches, Stats
from torch._inductor.runtime.runtime_utils import cache_dir
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
@ -428,6 +429,39 @@ def add(x, y):
self.assertEqual(expected, [result1, result2])
self.assertEqual(torch._dynamo.convert_frame.FRAME_COUNTER, total_frames)
@parametrize("device", ("cuda", "xpu"))
@torch._dynamo.config.patch(caching_precompile=True)
def test_automatic_dynamo_autotune_cache(self, device):
if device == "cuda" and not HAS_CUDA:
raise unittest.SkipTest("Requires CUDA/Triton")
if device == "xpu" and not HAS_XPU:
raise unittest.SkipTest("Requires XPU/Triton")
def fn(x, y):
return x.sin() + y
arg1 = torch.randn(3, 3, device=device)
arg2 = torch.randn(3, 3, device=device)
expected = fn(arg1, arg2).clone()
with PatchCaches():
compiled_fn1 = torch.compile(fn, mode="max-autotune")
result = compiled_fn1(arg1, arg2).clone()
self.assertEqual(expected, result)
self.assertEqual(global_stats.autotune_local, Stats(1, 0, 1))
DynamoCache.clear()
total_frames = torch._dynamo.convert_frame.FRAME_COUNTER
self._save_and_reload(
expected_backends=1, expected_dynamo=1, expected_autotune=1
)
compiled_fn1 = torch.compile(fn, mode="max-autotune")
with torch.compiler.set_stance("fail_on_recompile"):
result1 = compiled_fn1(arg1, arg2).clone()
self.assertEqual(expected, result1)
self.assertEqual(torch._dynamo.convert_frame.FRAME_COUNTER, total_frames)
self.assertEqual(global_stats.autotune_local, Stats(2, 1, 1))
@parametrize("device", ("cpu", "cuda", "xpu"))
@torch._dynamo.config.patch(caching_precompile=True)
def test_automatic_dynamo_recompiles(self, device):

View File

@ -70,7 +70,8 @@ class PrecompileContext(CacheArtifactManager):
The following artifact types are supported by PrecompileContext:
- BundledAOTAutogradCacheArtifact
- CodeStateArtifact (from torch._dynamo.package once available)
- DynamoCodeStateArtifact
- AutotuneCacheArtifact (regular autotune results, same as Megacache)
"""
# Protected by the compile_lock
@ -149,8 +150,12 @@ class PrecompileContext(CacheArtifactManager):
artifacts_by_key = {}
cache_info = CacheInfo()
for artifact in chain(*artifacts.values()):
if artifact.type() == "autotune":
# Populate autotune cache artifacts
artifact.populate_cache()
else:
artifacts_by_key[artifact.key] = artifact
cache_info.add(artifact)
artifacts_by_key[artifact.key] = artifact
from torch._dynamo.package import _BackendId, DynamoCache

View File

@ -909,10 +909,37 @@ def _compile_fx_inner(
else:
log.debug("Failed to generate FX cache key")
if torch._functorch.config.bundled_autograd_cache:
assert mb_compiled_graph is None
assert cache_info is None
# When using bundled autograd cache, we still want
# to use the TritonBundler, but we don't want to save
# the results here. The results will get saved directly
# to AOTAutogradCache.
TritonBundler.begin_compile()
try:
mb_compiled_graph = fx_codegen_and_compile(
gm, example_inputs, inputs_to_check, **graph_kwargs
)
assert mb_compiled_graph is not None
(
triton_bundle,
triton_bundler_meta,
) = TritonBundler.collect()
mb_compiled_graph.set_triton_bundle(triton_bundle)
except (ShortenTraceback, SkipFrame):
raise
except Exception as e:
raise InductorError(e, currentframe()).with_traceback(
e.__traceback__
) from None
finally:
TritonBundler.end_compile()
# CACHE BYPASS: Compile the graph, don't save it to the cache
# (this can happen either because cache was disabled, or we
# determined the input is uncacheable)
if cache_info is None or cache_info["cache_state"] == "bypass":
elif cache_info is None or cache_info["cache_state"] == "bypass":
assert mb_compiled_graph is None
log.debug(
"FX cache bypass reason: %s",

View File

@ -35,6 +35,7 @@ from typing import Any, Optional, TYPE_CHECKING
from typing_extensions import override
import torch
from torch._dynamo.precompile_context import PrecompileContext
from torch._inductor.runtime.runtime_utils import cache_dir
from torch.compiler._cache import (
CacheArtifact,
@ -125,6 +126,7 @@ class AutotuneCache:
) -> Optional[AutotuneCache]:
cache = AutotuneCache(configs_hash)
key = AutotuneCache._prepare_key(filename)
cache._setup_local_cache(inductor_meta, os.path.dirname(filename), key)
cache._setup_remote_autotune_cache(inductor_meta, key)
if cache.local_cache or cache.remote_cache:
@ -300,6 +302,10 @@ class AutotuneCache:
CacheArtifactManager.record_artifact(
AutotuneCacheArtifact.type(), autotune_artifact_key, data
)
if torch._dynamo.config.caching_precompile:
PrecompileContext.record_artifact(
AutotuneCacheArtifact.type(), autotune_artifact_key, data
)
if log.isEnabledFor(logging.DEBUG):
type_str = "coordesc" if found_by_coordesc else "heuristic"
@ -625,6 +631,10 @@ class LocalAutotuneCache(RemoteCache[JsonDataTy]):
CacheArtifactManager.record_artifact(
AutotuneCacheArtifact.type(), autotune_artifact_key, result
)
if torch._dynamo.config.caching_precompile:
PrecompileContext.record_artifact(
AutotuneCacheArtifact.type(), autotune_artifact_key, result
)
return result
@override