mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[Inductor] Record Triton’s Base32 Cache Key in .best_config for Debugging (#154618)
This is a follow-up PR of the reverted one https://github.com/pytorch/pytorch/pull/148981 re-opening for visibility : Modified TorchInductor’s autotuning flow so that each best_config JSON file also includes the Triton “base32” (or base64) cache key. Motivation Debugging & Analysis: With this change, we can quickly identify which compiled binary and IRs belongs to a given best config. The impact is minimal since it is only an extra field in .best_config. It can help advanced performance tuning or kernel-level debugging. Also, since Triton already stores cubin/hsaco in its cache, developers/researchers can avoid to set store_cubin = True since they can get the cubin/hsaco in the Triton cache and with the code provided in this PR, they can easily match the best_config with the right Triton cache directory for the "best" kernel. Pull Request resolved: https://github.com/pytorch/pytorch/pull/154618 Approved by: https://github.com/jansel
This commit is contained in:
committed by
PyTorch MergeBot
parent
d6edefefbf
commit
f57754e815
96
test/inductor/test_best_config.py
Normal file
96
test/inductor/test_best_config.py
Normal file
@ -0,0 +1,96 @@
|
||||
# Owner(s): ["module: inductor"]
|
||||
|
||||
import glob
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
from torch._inductor import config
|
||||
from torch.testing._internal.common_utils import IS_LINUX, skipIfXpu
|
||||
from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU
|
||||
|
||||
|
||||
try:
|
||||
import triton # noqa: F401
|
||||
except ImportError as e:
|
||||
if __name__ == "__main__":
|
||||
sys.exit(0)
|
||||
raise unittest.SkipTest("requires triton") from e
|
||||
|
||||
from torch._inductor.test_case import run_tests, TestCase
|
||||
|
||||
|
||||
def trivial_kernel(x):
|
||||
return torch.sin(x) + torch.cos(x)
|
||||
|
||||
|
||||
class TestKernelBestConfig(TestCase):
|
||||
device_type = GPU_TYPE
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
# Save the original configuration and environment variables.
|
||||
cls.original_compile_threads = config.compile_threads
|
||||
cls.original_max_autotune = config.max_autotune
|
||||
cls.original_inductor_env = os.environ.get("TORCHINDUCTOR_CACHE_DIR", "")
|
||||
cls.original_triton_env = os.environ.get("TRITON_CACHE_DIR", "")
|
||||
super().setUpClass()
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
# Restore the original configuration and environment variables.
|
||||
os.environ["TORCHINDUCTOR_CACHE_DIR"] = cls.original_inductor_env
|
||||
os.environ["TRITON_CACHE_DIR"] = cls.original_triton_env
|
||||
config.compile_threads = cls.original_compile_threads
|
||||
config.max_autotune = cls.original_max_autotune
|
||||
super().tearDownClass()
|
||||
|
||||
@skipIfXpu
|
||||
def test_best_config_has_triton_cache_key(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
os.environ["TORCHINDUCTOR_CACHE_DIR"] = tmpdir
|
||||
triton_cache_dir = os.path.join(tmpdir, "triton_cache")
|
||||
os.environ["TRITON_CACHE_DIR"] = triton_cache_dir
|
||||
|
||||
config.compile_threads = 0
|
||||
config.max_autotune = True
|
||||
|
||||
compiled_fn = torch.compile(trivial_kernel)
|
||||
|
||||
x = torch.randn(32, 10, device=GPU_TYPE)
|
||||
compiled_fn(x)
|
||||
|
||||
# Search for .best_config files in the inductor cache directory.
|
||||
best_config_files = glob.glob(
|
||||
os.path.join(tmpdir, "**", "*.best_config"), recursive=True
|
||||
)
|
||||
self.assertGreater(
|
||||
len(best_config_files),
|
||||
0,
|
||||
f"No best_config files found in {tmpdir}. Directory contents: {os.listdir(tmpdir)}",
|
||||
)
|
||||
|
||||
# Validate that each best_config file contains a real triton_cache_hash,
|
||||
# and that a corresponding Triton cache directory exists.
|
||||
for file_path in best_config_files:
|
||||
with open(file_path) as f:
|
||||
data = json.load(f)
|
||||
self.assertIn(
|
||||
"triton_cache_hash",
|
||||
data,
|
||||
f"Missing triton_cache_hash in {os.path.basename(file_path)}",
|
||||
)
|
||||
cache_hash = data["triton_cache_hash"]
|
||||
expected_path = os.path.join(triton_cache_dir, cache_hash)
|
||||
self.assertTrue(
|
||||
os.path.exists(expected_path),
|
||||
f"Triton cache directory missing: {expected_path}",
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if IS_LINUX and HAS_GPU:
|
||||
run_tests()
|
@ -213,6 +213,7 @@ S390X_BLOCKLIST = [
|
||||
"test_unary_ufuncs",
|
||||
# these tests fail when cuda is not available
|
||||
"inductor/test_aot_inductor",
|
||||
"inductor/test_best_config",
|
||||
"inductor/test_cudacodecache",
|
||||
"inductor/test_inductor_utils",
|
||||
"inductor/test_inplacing_pass",
|
||||
|
@ -242,7 +242,11 @@ class AutotuneCache:
|
||||
|
||||
# Save the config in the caches
|
||||
def save(
|
||||
self, config: Config, time_taken_ns: int, found_by_coordesc: bool = False
|
||||
self,
|
||||
config: Config,
|
||||
time_taken_ns: int,
|
||||
found_by_coordesc: bool = False,
|
||||
triton_cache_hash: Optional[str] = None,
|
||||
) -> None:
|
||||
data = {
|
||||
**config.kwargs,
|
||||
@ -251,6 +255,7 @@ class AutotuneCache:
|
||||
"configs_hash": self.configs_hash,
|
||||
"found_by_coordesc": found_by_coordesc,
|
||||
"time_taken_ms": time_taken_ns // 1000000, # Convert from NS to MS
|
||||
"triton_cache_hash": triton_cache_hash,
|
||||
}
|
||||
if HAS_WARP_SPEC:
|
||||
data.update(
|
||||
@ -514,6 +519,8 @@ def _load_cached_autotuning(
|
||||
# Remove time taken for comparison
|
||||
best_config.pop("time_taken_ms", None)
|
||||
|
||||
best_config.pop("triton_cache_hash", None)
|
||||
|
||||
if inductor_meta.get("coordinate_descent_tuning") and best_config.pop(
|
||||
"found_by_coordesc", False
|
||||
):
|
||||
|
@ -969,7 +969,11 @@ class CachingAutotuner(KernelInterface):
|
||||
)
|
||||
|
||||
if self.save_cache_hook:
|
||||
self.save_cache_hook(launcher.config, self.autotune_time_taken_ns)
|
||||
self.save_cache_hook(
|
||||
launcher.config,
|
||||
self.autotune_time_taken_ns,
|
||||
triton_cache_hash=launcher.cache_hash,
|
||||
)
|
||||
|
||||
def save_gpu_kernel(self, stream, launcher):
|
||||
key = self.inductor_meta.get("kernel_name", None) # unique kernel name
|
||||
@ -1432,6 +1436,7 @@ class StaticTritonCompileResult(CompileResult[StaticallyLaunchedCudaKernel]):
|
||||
launcher.n_regs = self.kernel.n_regs # type: ignore[attr-defined]
|
||||
launcher.n_spills = self.kernel.n_spills # type: ignore[attr-defined]
|
||||
launcher.shared = self.kernel.shared # type: ignore[attr-defined]
|
||||
launcher.cache_hash = triton_hash_to_path_key(self.kernel.hash) # type: ignore[attr-defined]
|
||||
launcher.store_cubin = False # type: ignore[attr-defined]
|
||||
launcher._is_static = True # type: ignore[attr-defined]
|
||||
return launcher
|
||||
@ -1620,6 +1625,7 @@ class TritonCompileResult(CompileResult[CompiledKernel]):
|
||||
launcher.n_regs = getattr(binary, "n_regs", None)
|
||||
launcher.n_spills = getattr(binary, "n_spills", None)
|
||||
launcher.shared = binary_shared
|
||||
launcher.cache_hash = triton_hash_to_path_key(binary.hash)
|
||||
launcher.store_cubin = self.inductor_meta.get("store_cubin", False)
|
||||
# store this global variable to avoid the high overhead of reading it when calling run
|
||||
if launcher.store_cubin:
|
||||
|
Reference in New Issue
Block a user