mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Still run TritonBundler with BundledAOTAutogradCache, save autotune results (#158048)
When running BundledAOTAutogradCache with precompile, we still need to run triton bundling so that the precompiled CompiledFxGraph has triton cuda kernels. We also pre save the autotune results in the precompile artifact. It would be even better to pre trim the cuda kernels on save and apply them, which we can work on later. Pull Request resolved: https://github.com/pytorch/pytorch/pull/158048 Approved by: https://github.com/zhxchen17
This commit is contained in:
committed by
PyTorch MergeBot
parent
d5a29fc58a
commit
8e57cdb746
@ -15,6 +15,7 @@ import torch.utils.cpp_extension
|
||||
from torch._dynamo.package import CompilePackage, DiskDynamoStore, DynamoCache
|
||||
from torch._dynamo.precompile_context import PrecompileContext
|
||||
from torch._functorch import config as functorch_config
|
||||
from torch._inductor.mock_cache import global_stats, PatchCaches, Stats
|
||||
from torch._inductor.runtime.runtime_utils import cache_dir
|
||||
from torch.testing._internal.common_utils import (
|
||||
instantiate_parametrized_tests,
|
||||
@ -428,6 +429,39 @@ def add(x, y):
|
||||
self.assertEqual(expected, [result1, result2])
|
||||
self.assertEqual(torch._dynamo.convert_frame.FRAME_COUNTER, total_frames)
|
||||
|
||||
@parametrize("device", ("cuda", "xpu"))
|
||||
@torch._dynamo.config.patch(caching_precompile=True)
|
||||
def test_automatic_dynamo_autotune_cache(self, device):
|
||||
if device == "cuda" and not HAS_CUDA:
|
||||
raise unittest.SkipTest("Requires CUDA/Triton")
|
||||
if device == "xpu" and not HAS_XPU:
|
||||
raise unittest.SkipTest("Requires XPU/Triton")
|
||||
|
||||
def fn(x, y):
|
||||
return x.sin() + y
|
||||
|
||||
arg1 = torch.randn(3, 3, device=device)
|
||||
arg2 = torch.randn(3, 3, device=device)
|
||||
expected = fn(arg1, arg2).clone()
|
||||
|
||||
with PatchCaches():
|
||||
compiled_fn1 = torch.compile(fn, mode="max-autotune")
|
||||
result = compiled_fn1(arg1, arg2).clone()
|
||||
self.assertEqual(expected, result)
|
||||
self.assertEqual(global_stats.autotune_local, Stats(1, 0, 1))
|
||||
DynamoCache.clear()
|
||||
|
||||
total_frames = torch._dynamo.convert_frame.FRAME_COUNTER
|
||||
self._save_and_reload(
|
||||
expected_backends=1, expected_dynamo=1, expected_autotune=1
|
||||
)
|
||||
compiled_fn1 = torch.compile(fn, mode="max-autotune")
|
||||
with torch.compiler.set_stance("fail_on_recompile"):
|
||||
result1 = compiled_fn1(arg1, arg2).clone()
|
||||
self.assertEqual(expected, result1)
|
||||
self.assertEqual(torch._dynamo.convert_frame.FRAME_COUNTER, total_frames)
|
||||
self.assertEqual(global_stats.autotune_local, Stats(2, 1, 1))
|
||||
|
||||
@parametrize("device", ("cpu", "cuda", "xpu"))
|
||||
@torch._dynamo.config.patch(caching_precompile=True)
|
||||
def test_automatic_dynamo_recompiles(self, device):
|
||||
|
@ -70,7 +70,8 @@ class PrecompileContext(CacheArtifactManager):
|
||||
|
||||
The following artifact types are supported by PrecompileContext:
|
||||
- BundledAOTAutogradCacheArtifact
|
||||
- CodeStateArtifact (from torch._dynamo.package once available)
|
||||
- DynamoCodeStateArtifact
|
||||
- AutotuneCacheArtifact (regular autotune results, same as Megacache)
|
||||
"""
|
||||
|
||||
# Protected by the compile_lock
|
||||
@ -149,8 +150,12 @@ class PrecompileContext(CacheArtifactManager):
|
||||
artifacts_by_key = {}
|
||||
cache_info = CacheInfo()
|
||||
for artifact in chain(*artifacts.values()):
|
||||
if artifact.type() == "autotune":
|
||||
# Populate autotune cache artifacts
|
||||
artifact.populate_cache()
|
||||
else:
|
||||
artifacts_by_key[artifact.key] = artifact
|
||||
cache_info.add(artifact)
|
||||
artifacts_by_key[artifact.key] = artifact
|
||||
|
||||
from torch._dynamo.package import _BackendId, DynamoCache
|
||||
|
||||
|
@ -909,10 +909,37 @@ def _compile_fx_inner(
|
||||
else:
|
||||
log.debug("Failed to generate FX cache key")
|
||||
|
||||
if torch._functorch.config.bundled_autograd_cache:
|
||||
assert mb_compiled_graph is None
|
||||
assert cache_info is None
|
||||
# When using bundled autograd cache, we still want
|
||||
# to use the TritonBundler, but we don't want to save
|
||||
# the results here. The results will get saved directly
|
||||
# to AOTAutogradCache.
|
||||
TritonBundler.begin_compile()
|
||||
try:
|
||||
mb_compiled_graph = fx_codegen_and_compile(
|
||||
gm, example_inputs, inputs_to_check, **graph_kwargs
|
||||
)
|
||||
assert mb_compiled_graph is not None
|
||||
(
|
||||
triton_bundle,
|
||||
triton_bundler_meta,
|
||||
) = TritonBundler.collect()
|
||||
mb_compiled_graph.set_triton_bundle(triton_bundle)
|
||||
except (ShortenTraceback, SkipFrame):
|
||||
raise
|
||||
except Exception as e:
|
||||
raise InductorError(e, currentframe()).with_traceback(
|
||||
e.__traceback__
|
||||
) from None
|
||||
finally:
|
||||
TritonBundler.end_compile()
|
||||
|
||||
# CACHE BYPASS: Compile the graph, don't save it to the cache
|
||||
# (this can happen either because cache was disabled, or we
|
||||
# determined the input is uncacheable)
|
||||
if cache_info is None or cache_info["cache_state"] == "bypass":
|
||||
elif cache_info is None or cache_info["cache_state"] == "bypass":
|
||||
assert mb_compiled_graph is None
|
||||
log.debug(
|
||||
"FX cache bypass reason: %s",
|
||||
|
@ -35,6 +35,7 @@ from typing import Any, Optional, TYPE_CHECKING
|
||||
from typing_extensions import override
|
||||
|
||||
import torch
|
||||
from torch._dynamo.precompile_context import PrecompileContext
|
||||
from torch._inductor.runtime.runtime_utils import cache_dir
|
||||
from torch.compiler._cache import (
|
||||
CacheArtifact,
|
||||
@ -125,6 +126,7 @@ class AutotuneCache:
|
||||
) -> Optional[AutotuneCache]:
|
||||
cache = AutotuneCache(configs_hash)
|
||||
key = AutotuneCache._prepare_key(filename)
|
||||
|
||||
cache._setup_local_cache(inductor_meta, os.path.dirname(filename), key)
|
||||
cache._setup_remote_autotune_cache(inductor_meta, key)
|
||||
if cache.local_cache or cache.remote_cache:
|
||||
@ -300,6 +302,10 @@ class AutotuneCache:
|
||||
CacheArtifactManager.record_artifact(
|
||||
AutotuneCacheArtifact.type(), autotune_artifact_key, data
|
||||
)
|
||||
if torch._dynamo.config.caching_precompile:
|
||||
PrecompileContext.record_artifact(
|
||||
AutotuneCacheArtifact.type(), autotune_artifact_key, data
|
||||
)
|
||||
|
||||
if log.isEnabledFor(logging.DEBUG):
|
||||
type_str = "coordesc" if found_by_coordesc else "heuristic"
|
||||
@ -625,6 +631,10 @@ class LocalAutotuneCache(RemoteCache[JsonDataTy]):
|
||||
CacheArtifactManager.record_artifact(
|
||||
AutotuneCacheArtifact.type(), autotune_artifact_key, result
|
||||
)
|
||||
if torch._dynamo.config.caching_precompile:
|
||||
PrecompileContext.record_artifact(
|
||||
AutotuneCacheArtifact.type(), autotune_artifact_key, result
|
||||
)
|
||||
return result
|
||||
|
||||
@override
|
||||
|
Reference in New Issue
Block a user