mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Revert "Store statically launchable CachingAutotuners inside CompiledFXGraph.triton_bundle (#149054)"
This reverts commit c16af5d7984872b6ae81476d6cae64bddb7ce664. Reverted https://github.com/pytorch/pytorch/pull/149054 on behalf of https://github.com/jamesjwu due to Sorry I forgot to fix one last test ([comment](https://github.com/pytorch/pytorch/pull/149054#issuecomment-2761381443))
This commit is contained in:
@ -93,16 +93,12 @@ class TestFxGraphCache(TestCase):
|
||||
@requires_triton()
|
||||
@config.patch({"fx_graph_cache": True})
|
||||
@config.patch({"fx_graph_remote_cache": False})
|
||||
@config.patch({"compile_threads": 1})
|
||||
@parametrize("device", (GPU_TYPE, "cpu"))
|
||||
@parametrize("dtype", (torch.float32, torch.bfloat16))
|
||||
@parametrize("dynamic", (False, True))
|
||||
@parametrize("bundle_triton", (False, True))
|
||||
@parametrize("use_static_cuda_launcher", (False, True))
|
||||
@parametrize("grad", (False, True))
|
||||
def test_cache_load_function(
|
||||
self, device, dtype, dynamic, bundle_triton, use_static_cuda_launcher, grad
|
||||
):
|
||||
def test_cache_load_function(self, device, dtype, dynamic, bundle_triton, grad):
|
||||
"""
|
||||
Verify that we can populate and load functions from the cache.
|
||||
"""
|
||||
@ -110,10 +106,6 @@ class TestFxGraphCache(TestCase):
|
||||
raise unittest.SkipTest(f"requires {GPU_TYPE}")
|
||||
if device == "cuda" and dtype == torch.bfloat16 and not SM80OrLater:
|
||||
raise unittest.SkipTest("requires SM80 or later")
|
||||
if use_static_cuda_launcher and not (device == "cuda" and bundle_triton):
|
||||
raise unittest.SkipTest(
|
||||
"Static cuda launcher requires cuda and triton bundling"
|
||||
)
|
||||
|
||||
grad_multiplier = 2 if grad else 1
|
||||
|
||||
@ -124,10 +116,7 @@ class TestFxGraphCache(TestCase):
|
||||
a_orig = torch.rand(25, dtype=dtype, device=device)
|
||||
b_orig = torch.rand(5, 5, dtype=dtype, device=device)
|
||||
|
||||
with config.patch(
|
||||
bundle_triton_into_fx_graph_cache=bundle_triton,
|
||||
use_static_cuda_launcher=use_static_cuda_launcher,
|
||||
):
|
||||
with config.patch(bundle_triton_into_fx_graph_cache=bundle_triton):
|
||||
compiled_fn = torch.compile(fn, dynamic=dynamic)
|
||||
|
||||
a1 = a_orig.clone().requires_grad_(grad)
|
||||
@ -160,14 +149,6 @@ class TestFxGraphCache(TestCase):
|
||||
self.assertEqual(
|
||||
counters["inductor"]["triton_bundler_read_and_emit_kernel"], 0
|
||||
)
|
||||
if use_static_cuda_launcher:
|
||||
self.assertEqual(
|
||||
counters["inductor"]["triton_bundler_save_static_autotuner"],
|
||||
grad_multiplier if device == "cuda" else 0,
|
||||
)
|
||||
self.assertEqual(
|
||||
counters["inductor"]["triton_bundler_load_static_autotuner"], 0
|
||||
)
|
||||
|
||||
# A second call should hit. (First reset so in-memory guards
|
||||
# don't prevent compilation).
|
||||
@ -208,15 +189,6 @@ class TestFxGraphCache(TestCase):
|
||||
counters["inductor"]["triton_bundler_read_and_emit_kernel"],
|
||||
grad_multiplier * read_and_emit_kernel_count,
|
||||
)
|
||||
if use_static_cuda_launcher:
|
||||
self.assertEqual(
|
||||
counters["inductor"]["triton_bundler_save_static_autotuner"],
|
||||
grad_multiplier if device == "cuda" else 0,
|
||||
)
|
||||
self.assertEqual(
|
||||
counters["inductor"]["triton_bundler_load_static_autotuner"],
|
||||
grad_multiplier if device == "cuda" else 0,
|
||||
)
|
||||
|
||||
self.reset()
|
||||
|
||||
@ -256,15 +228,6 @@ class TestFxGraphCache(TestCase):
|
||||
counters["inductor"]["triton_bundler_read_and_emit_kernel"],
|
||||
grad_multiplier * read_and_emit_kernel_count,
|
||||
)
|
||||
if use_static_cuda_launcher:
|
||||
self.assertEqual(
|
||||
counters["inductor"]["triton_bundler_save_static_autotuner"],
|
||||
grad_multiplier * 2 if device == "cuda" else 0,
|
||||
)
|
||||
self.assertEqual(
|
||||
counters["inductor"]["triton_bundler_load_static_autotuner"],
|
||||
grad_multiplier if device == "cuda" else 0,
|
||||
)
|
||||
|
||||
@requires_triton()
|
||||
@config.patch({"fx_graph_remote_cache": True})
|
||||
@ -272,23 +235,13 @@ class TestFxGraphCache(TestCase):
|
||||
@parametrize("dtype", (torch.float32, torch.bfloat16))
|
||||
@parametrize("dynamic", (False, True))
|
||||
@parametrize("bundle_triton", (False, True))
|
||||
@parametrize("use_static_cuda_launcher", (False, True))
|
||||
@config.patch(
|
||||
{"compile_threads": 1}
|
||||
) # Can't check globalStats if there are workers
|
||||
def test_remote_cache_load_function(
|
||||
self, device, dtype, dynamic, bundle_triton, use_static_cuda_launcher
|
||||
):
|
||||
def test_remote_cache_load_function(self, device, dtype, dynamic, bundle_triton):
|
||||
from unittest.mock import patch
|
||||
|
||||
if device == GPU_TYPE and not HAS_GPU:
|
||||
raise unittest.SkipTest(f"requires {GPU_TYPE}")
|
||||
if device == "cuda" and dtype == torch.bfloat16 and not SM80OrLater:
|
||||
raise unittest.SkipTest("requires SM80 or later")
|
||||
if use_static_cuda_launcher and not (device == "cuda" and bundle_triton):
|
||||
raise unittest.SkipTest(
|
||||
"Static cuda launcher requires cuda and triton bundling"
|
||||
)
|
||||
|
||||
def fn(x, y):
|
||||
return (x * 2, y @ y)
|
||||
@ -300,7 +253,6 @@ class TestFxGraphCache(TestCase):
|
||||
{
|
||||
"fx_graph_remote_cache": True,
|
||||
"bundle_triton_into_fx_graph_cache": bundle_triton,
|
||||
"use_static_cuda_launcher": use_static_cuda_launcher,
|
||||
}
|
||||
), patch.dict(os.environ), PatchCaches():
|
||||
os.environ.pop("TRITON_CACHE_MANAGER", None)
|
||||
@ -816,9 +768,7 @@ class TestFxGraphCache(TestCase):
|
||||
|
||||
return torch.cond(x.shape[0], true_fn, false_fn, (x,))
|
||||
|
||||
with config.patch(
|
||||
bundle_triton_into_fx_graph_cache=bundle_triton,
|
||||
):
|
||||
with config.patch(bundle_triton_into_fx_graph_cache=bundle_triton):
|
||||
compiled_fn = torch.compile(fn, dynamic=True, fullgraph=True)
|
||||
|
||||
x = torch.randn(4, 4, device=GPU_TYPE)
|
||||
@ -983,10 +933,8 @@ class TestFxGraphCache(TestCase):
|
||||
@requires_triton()
|
||||
@config.patch({"fx_graph_cache": True})
|
||||
@config.patch({"fx_graph_remote_cache": False})
|
||||
@config.patch({"compile_threads": 1})
|
||||
@parametrize("bundle_triton", (False, True))
|
||||
@parametrize("use_static_cuda_launcher", (False, True))
|
||||
def test_triton_op(self, bundle_triton, use_static_cuda_launcher):
|
||||
def test_triton_op(self, bundle_triton):
|
||||
libname = "my_cool_namespace"
|
||||
opname = "my_triton_operator"
|
||||
|
||||
@ -1004,12 +952,7 @@ class TestFxGraphCache(TestCase):
|
||||
def f(x, y):
|
||||
return add(x, y)
|
||||
|
||||
compile_threads = 1 if use_static_cuda_launcher else config.compile_threads
|
||||
with config.patch(
|
||||
bundle_triton_into_fx_graph_cache=bundle_triton,
|
||||
use_static_cuda_launcher=use_static_cuda_launcher,
|
||||
compile_threads=compile_threads,
|
||||
):
|
||||
with config.patch(bundle_triton_into_fx_graph_cache=bundle_triton):
|
||||
compiled_fn = torch.compile(f, fullgraph=True)
|
||||
|
||||
x = torch.randn(4, device=GPU_TYPE)
|
||||
|
@ -32,7 +32,6 @@ from torch._inductor.codecache import (
|
||||
HalideCodeCache,
|
||||
LambdaFuture,
|
||||
ROCmCodeCache,
|
||||
StaticAutotunerFuture,
|
||||
torch_key,
|
||||
)
|
||||
from torch._inductor.compile_worker.subproc_pool import AnyPool, SubprocPool
|
||||
@ -149,7 +148,7 @@ class CompiledTritonKernels:
|
||||
Currently, the cache stores Future objects, but it should be generalizable for any kernels.
|
||||
"""
|
||||
|
||||
_cache: dict[str, CodeCacheFuture] = {}
|
||||
_cache: dict[str, LambdaFuture] = {}
|
||||
|
||||
@staticmethod
|
||||
def key(kernel_src: str):
|
||||
@ -162,7 +161,7 @@ class CompiledTritonKernels:
|
||||
return code_hash(kernel_src, extra=torch_key())
|
||||
|
||||
@staticmethod
|
||||
def save(kernel_src: str, future: CodeCacheFuture):
|
||||
def save(kernel_src: str, future: LambdaFuture):
|
||||
"""
|
||||
Saves a compiled triton kernel to the cache.
|
||||
TODO: We store a LambdaFuture as that's the callable returned by async_compile.triton,
|
||||
@ -175,9 +174,9 @@ class CompiledTritonKernels:
|
||||
CompiledTritonKernels._cache[key] = future
|
||||
|
||||
@staticmethod
|
||||
def get(kernel_src: str) -> Optional[CodeCacheFuture]:
|
||||
def get(kernel_src: str, default: Any) -> LambdaFuture:
|
||||
key = CompiledTritonKernels.key(kernel_src)
|
||||
return CompiledTritonKernels._cache.get(key, None)
|
||||
return CompiledTritonKernels._cache.get(key, default)
|
||||
|
||||
@staticmethod
|
||||
def cache_clear():
|
||||
@ -186,8 +185,6 @@ class CompiledTritonKernels:
|
||||
@staticmethod
|
||||
def remove_future(kernel_src: str) -> None:
|
||||
key = CompiledTritonKernels.key(kernel_src)
|
||||
|
||||
# Delete the LambdaFuture if there is one
|
||||
if key in CompiledTritonKernels._cache:
|
||||
del CompiledTritonKernels._cache[key]
|
||||
|
||||
@ -285,14 +282,9 @@ class AsyncCompile:
|
||||
- The AutotuneCache, if enabled, is constructed on each worker per triton config
|
||||
and pickled by to us via `CachingAutotuner.save_cache_hook`.
|
||||
"""
|
||||
load_kernel = functools.partial(
|
||||
_load_triton_kernel_from_source, kernel_name, source_code
|
||||
)
|
||||
|
||||
def reload_kernel_in_parent():
|
||||
# Benchmark how often this happens
|
||||
with dynamo_timed("reload_kernel_in_parent"):
|
||||
return load_kernel()
|
||||
if future := CompiledTritonKernels.get(source_code, None):
|
||||
counters["inductor"]["async_compile_cache_hit"] += 1
|
||||
return future
|
||||
|
||||
counters["inductor"]["async_compile_cache_miss"] += 1
|
||||
|
||||
@ -304,22 +296,15 @@ class AsyncCompile:
|
||||
torch._inductor.codecache.PyCodeCache.load(source_code), kernel_name
|
||||
)
|
||||
|
||||
load_kernel = functools.partial(
|
||||
_load_triton_kernel_from_source, kernel_name, source_code
|
||||
)
|
||||
is_parallel = self.use_process_pool()
|
||||
set_feature_use("parallel_compile_post_warmup", is_parallel)
|
||||
|
||||
compile_id = torch._guards.CompileContext.current_compile_id()
|
||||
is_backward = getattr(V.graph, "is_backward", False)
|
||||
|
||||
if (future := CompiledTritonKernels.get(source_code)) is not None:
|
||||
counters["inductor"]["async_compile_cache_hit"] += 1
|
||||
# Set reload_kernel_from_src properly based on source_code
|
||||
if isinstance(future, StaticAutotunerFuture):
|
||||
future.reload_kernel_from_src = reload_kernel_in_parent
|
||||
if is_parallel:
|
||||
return future
|
||||
else:
|
||||
return future.result()
|
||||
|
||||
if is_parallel:
|
||||
# We want to support changing these env vars after (and while) the
|
||||
# process pool is running, so pass them to the subprocess to reset.
|
||||
@ -332,16 +317,19 @@ class AsyncCompile:
|
||||
extra_env,
|
||||
)
|
||||
|
||||
def get_result() -> CachingAutotuner:
|
||||
def reload_kernel_in_parent():
|
||||
# Benchmark how often this happens
|
||||
with dynamo_timed("reload_kernel_in_parent"):
|
||||
return load_kernel()
|
||||
|
||||
def get_result() -> tuple[CachingAutotuner, int]:
|
||||
kernel, elapsed_us = task.result()
|
||||
# Now that we've compiled, we should clear the future
|
||||
# so it can't be used again
|
||||
kernel.set_compile_info(compile_id, is_backward)
|
||||
CompiledTritonKernels.remove_future(source_code)
|
||||
kernel.set_compile_info(compile_id, is_backward)
|
||||
kernel.precompile(
|
||||
warm_cache_only=False,
|
||||
reload_kernel=reload_kernel_in_parent,
|
||||
static_triton_bundle_key=CompiledTritonKernels.key(source_code),
|
||||
warm_cache_only=False, reload_kernel=reload_kernel_in_parent
|
||||
)
|
||||
get_metrics_context().add_top_n(
|
||||
"triton_kernel_compile_times_us", kernel_name, elapsed_us
|
||||
@ -362,10 +350,7 @@ class AsyncCompile:
|
||||
_set_triton_ptxas_path()
|
||||
kernel = load_kernel()
|
||||
kernel.set_compile_info(compile_id, is_backward)
|
||||
kernel.precompile(
|
||||
warm_cache_only=False,
|
||||
static_triton_bundle_key=CompiledTritonKernels.key(source_code),
|
||||
)
|
||||
kernel.precompile(warm_cache_only=False)
|
||||
elapsed_us = (time_ns() - start_ns) // 1000
|
||||
get_metrics_context().add_top_n(
|
||||
"triton_kernel_compile_times_us", kernel_name, elapsed_us
|
||||
@ -459,6 +444,7 @@ class AsyncCompile:
|
||||
disable=config.disable_progress,
|
||||
delay=0,
|
||||
)
|
||||
|
||||
for key, result in kernels.items():
|
||||
if config.verbose_progress and not isinstance(pbar, _Faketqdm):
|
||||
pbar.set_postfix_str(key)
|
||||
|
@ -3351,28 +3351,3 @@ class LambdaFuture(CodeCacheFuture):
|
||||
|
||||
def result(self) -> Callable[..., Any]: # type: ignore[override]
|
||||
return self.result_fn()
|
||||
|
||||
|
||||
class StaticAutotunerFuture(CodeCacheFuture):
|
||||
"""
|
||||
A statically launchable CachingAutotuner, loaded from TritonBundler
|
||||
"""
|
||||
|
||||
def __init__(self, static_autotuner: CachingAutotuner) -> None:
|
||||
# Pickled version of CachingAutotuner
|
||||
self.static_autotuner = static_autotuner
|
||||
# This needs to be set in AsyncCompile.triton, in case
|
||||
# we need to reload the CachingAutotuner from its source code
|
||||
# We don't store the source code on the CachingAutotuner itself
|
||||
# since it can be very large.
|
||||
self.reload_kernel_from_src: Optional[Callable[[], Any]] = None
|
||||
|
||||
def result(self) -> CachingAutotuner:
|
||||
assert self.reload_kernel_from_src is not None
|
||||
with dynamo_timed("StaticAutotunerFuture.warm_precompile"):
|
||||
self.static_autotuner.precompile( # type: ignore[union-attr]
|
||||
warm_cache_only=False,
|
||||
reload_kernel=self.reload_kernel_from_src,
|
||||
static_triton_bundle_key=None, # no need to save again
|
||||
)
|
||||
return self.static_autotuner
|
||||
|
@ -65,7 +65,7 @@ if TYPE_CHECKING:
|
||||
from torch._library.fake_class_registry import FakeScriptObject
|
||||
|
||||
from .compile_fx import _CompileFxKwargs
|
||||
from .triton_bundler import TritonBundle
|
||||
from .triton_bundler import TritonKernelArtifacts
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
@ -420,7 +420,7 @@ class CompiledFxGraph(OutputCode):
|
||||
inputs_to_check: Sequence[int]
|
||||
|
||||
_boxed_call: Optional[bool] = None
|
||||
_triton_bundle: Optional[TritonBundle] = None
|
||||
_triton_bundle: Optional[list[TritonKernelArtifacts]] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -85,17 +85,12 @@ class StaticallyLaunchedCudaKernel:
|
||||
def load_kernel(self, device: int) -> None:
|
||||
from torch._C import _StaticCudaLauncher
|
||||
|
||||
assert hasattr(self, "cubin_path")
|
||||
if self.function is not None:
|
||||
return
|
||||
|
||||
assert hasattr(self, "cubin_path")
|
||||
assert self.cubin_path is not None
|
||||
|
||||
(self.function, self.n_regs, self.n_spills) = _StaticCudaLauncher._load_kernel(
|
||||
self.cubin_path, self.name, self.shared, device
|
||||
)
|
||||
# Don't need the cubin path anymore now that we've loaded
|
||||
self.cubin_path = None
|
||||
|
||||
@staticmethod
|
||||
@functools.lru_cache
|
||||
@ -166,15 +161,6 @@ class StaticallyLaunchedCudaKernel:
|
||||
params.append(self.extract_type(ty))
|
||||
return "".join(params)
|
||||
|
||||
def __getstate__(self) -> dict[str, Any]:
|
||||
# Remove objects that are no longer valid for pickling
|
||||
state = self.__dict__.copy()
|
||||
state["function"] = None
|
||||
# Cubin paths aren't consistent across processes, so we clear
|
||||
# and reload them.
|
||||
state["cubin_path"] = None
|
||||
return state
|
||||
|
||||
def run(
|
||||
self,
|
||||
grid_x: int,
|
||||
@ -204,7 +190,6 @@ class StaticallyLaunchedCudaKernel:
|
||||
|
||||
# TODO: can handle grid functions here or in C++, so
|
||||
# that we don't need the grid handler above.
|
||||
|
||||
_StaticCudaLauncher._launch_kernel(
|
||||
self.function,
|
||||
grid_x,
|
||||
|
@ -275,17 +275,6 @@ class CachingAutotuner(KernelInterface):
|
||||
self.compile_id: Optional[CompileId] = None
|
||||
self.is_backward = False
|
||||
|
||||
def is_statically_launchable(self):
|
||||
"""
|
||||
Checks if every compiled kernel is statically launchable, which
|
||||
allows us to efficiently cache it in FXGraphCache
|
||||
"""
|
||||
if not self.compile_results:
|
||||
return False
|
||||
return all(
|
||||
isinstance(x, StaticTritonCompileResult) for x in self.compile_results
|
||||
)
|
||||
|
||||
def set_compile_info(
|
||||
self, compile_id: Optional[CompileId], is_backward: bool
|
||||
) -> None:
|
||||
@ -296,7 +285,6 @@ class CachingAutotuner(KernelInterface):
|
||||
self,
|
||||
warm_cache_only=False,
|
||||
reload_kernel: Optional[Callable[[], CachingAutotuner]] = None,
|
||||
static_triton_bundle_key: Optional[str] = None,
|
||||
):
|
||||
if warm_cache_only:
|
||||
self._precompile_worker()
|
||||
@ -309,8 +297,6 @@ class CachingAutotuner(KernelInterface):
|
||||
if reload_kernel is not None:
|
||||
self._reload_kernel = reload_kernel
|
||||
self._precompile_worker()
|
||||
if static_triton_bundle_key is not None and self.is_statically_launchable():
|
||||
TritonBundler.put_static_autotuner(static_triton_bundle_key, self)
|
||||
self._make_launchers()
|
||||
self._dynamic_scale_rblock()
|
||||
|
||||
@ -476,24 +462,15 @@ class CachingAutotuner(KernelInterface):
|
||||
raise RuntimeError(f"No valid triton configs. {type(exc).__name__}: {exc}")
|
||||
self.launchers = launchers
|
||||
|
||||
def prepare_for_pickle(self) -> tuple[Any, Any, Any, Any, Any]:
|
||||
def prepare_for_pickle(self):
|
||||
"""Drop stuff from triton.JITFunction that does not pickle.
|
||||
This must be called after precompile so that these things are no longer needed.
|
||||
Returns a tuple of old values
|
||||
"""
|
||||
old_values = (
|
||||
self.fn.fn,
|
||||
self.fn.__globals__,
|
||||
self.fn.used_global_vals,
|
||||
self.fn.repr,
|
||||
self.launchers,
|
||||
)
|
||||
self.fn.fn = None
|
||||
self.fn.__globals__ = None
|
||||
self.fn.used_global_vals = None
|
||||
self.fn.repr = _ConstRepr(self.fn.repr(self.fn))
|
||||
self.launchers = []
|
||||
return old_values
|
||||
|
||||
def __getstate__(self) -> dict[str, Any]:
|
||||
assert not self.launchers, (
|
||||
@ -1079,8 +1056,7 @@ class CompileResult(Generic[_T]):
|
||||
f" grid_2 = {grid.z_grid}",
|
||||
f" runner({', '.join(runner_args)})",
|
||||
]
|
||||
launcher_code = "\n".join(lines)
|
||||
exec(launcher_code, scope)
|
||||
exec("\n".join(lines), scope)
|
||||
return scope["launcher"]
|
||||
|
||||
def _get_arg_lists(
|
||||
@ -1222,26 +1198,8 @@ class StaticTritonCompileResult(CompileResult[StaticallyLaunchedCudaKernel]):
|
||||
raise e
|
||||
return None
|
||||
|
||||
def reload_cubin_path(self):
|
||||
"""
|
||||
When loading from cache on disk, we want to reload cubin
|
||||
files from their appropriate location on disc.
|
||||
"""
|
||||
cubin_location = os.path.join(
|
||||
triton_cache_dir(self.compile_meta.get("device", 0)),
|
||||
triton_hash_to_path_key(self.kernel.hash),
|
||||
f"{self.kernel.name}.cubin",
|
||||
)
|
||||
if not os.path.exists(cubin_location):
|
||||
raise RuntimeError(
|
||||
"Cubin file saved by TritonBundler not found at %s", cubin_location
|
||||
)
|
||||
self.kernel.cubin_path = cubin_location
|
||||
|
||||
def make_launcher(self) -> LauncherType:
|
||||
# Load the binary on the parent
|
||||
if not self.kernel.cubin_path:
|
||||
self.reload_cubin_path()
|
||||
self.kernel.load_kernel(self.compile_meta.get("device", 0))
|
||||
scope = {
|
||||
"runner": self.kernel.run,
|
||||
|
@ -1,4 +1,3 @@
|
||||
import copy
|
||||
import dataclasses
|
||||
import logging
|
||||
import os
|
||||
@ -43,21 +42,6 @@ class TritonKernelArtifact:
|
||||
payload: bytes = dataclasses.field(repr=False) # Do not display binary
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class StaticallyLaunchedAutotuner:
|
||||
"""
|
||||
Represents a statically compiled CachingAutotuner object that we can
|
||||
save directly in the cache. A CachingAutotuner is made up of a list of
|
||||
StaticTritonCompileResults, each of which uses the cubin from a TritonKernelArtifact.
|
||||
|
||||
Statically saved here have their cubin files saved by a corresponding TritonBundleEntry.
|
||||
"""
|
||||
|
||||
cache_key: str
|
||||
kernel_name: str
|
||||
kernel: "CachingAutotuner" # type: ignore[name-defined] # noqa: F821
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class TritonKernelArtifacts:
|
||||
"""
|
||||
@ -76,17 +60,6 @@ class TritonBundlerMetadata:
|
||||
"""
|
||||
|
||||
cached_kernel_names: list[str]
|
||||
statically_launched_kernel_names: list[str]
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class TritonBundle:
|
||||
"""
|
||||
Serializable bundle to save into FXGraphCache
|
||||
"""
|
||||
|
||||
kernel_artifacts: list[TritonKernelArtifacts]
|
||||
static_autotuners: list[StaticallyLaunchedAutotuner]
|
||||
|
||||
|
||||
class TritonBundler:
|
||||
@ -106,7 +79,6 @@ class TritonBundler:
|
||||
"""
|
||||
|
||||
_entries: Optional[list[TritonBundleEntry]] = None
|
||||
_static_autotuners: Optional[list[StaticallyLaunchedAutotuner]] = None
|
||||
|
||||
# __grp__kernel_name.json contains metadata with source code paths
|
||||
# we use this as sentinal value for search and replace
|
||||
@ -140,7 +112,6 @@ class TritonBundler:
|
||||
log.debug("TritonBundler.begin_compile is called")
|
||||
assert cls._entries is None
|
||||
cls._entries = []
|
||||
cls._static_autotuners = []
|
||||
|
||||
@classmethod
|
||||
def end_compile(cls) -> None:
|
||||
@ -150,7 +121,6 @@ class TritonBundler:
|
||||
"""
|
||||
log.debug("TritonBundler.end_compile is called")
|
||||
cls._entries = None
|
||||
cls._static_autotuners = None
|
||||
|
||||
@classmethod
|
||||
def put(cls, kernel_hash: str, device: int) -> None:
|
||||
@ -163,93 +133,20 @@ class TritonBundler:
|
||||
TritonBundleEntry(kernel_hash, device, triton_cache_dir(device))
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def put_static_autotuner(cls, key: str, kernel: "CachingAutotuner") -> None: # type: ignore[name-defined] # noqa: F821
|
||||
from torch._inductor import config
|
||||
|
||||
assert config.use_static_cuda_launcher
|
||||
if (entries := cls._static_autotuners) is not None:
|
||||
# Clear a bunch of unpicklable values and make a copy to save
|
||||
# for FXGraphCache
|
||||
old_values = kernel.prepare_for_pickle()
|
||||
new_kernel = copy.deepcopy(kernel)
|
||||
new_kernel._reload_kernel = None
|
||||
entries.append(
|
||||
StaticallyLaunchedAutotuner(
|
||||
key,
|
||||
new_kernel.inductor_meta.get("kernel_name", "unknown_kernel"),
|
||||
new_kernel,
|
||||
)
|
||||
)
|
||||
# Put the values back since we need it to use now
|
||||
(
|
||||
kernel.fn.fn,
|
||||
kernel.fn.__globals__,
|
||||
kernel.fn.used_global_vals,
|
||||
kernel.fn.repr,
|
||||
kernel.launchers,
|
||||
) = old_values
|
||||
|
||||
@classmethod
|
||||
def collect_static_autotuners(
|
||||
cls,
|
||||
) -> tuple[list[StaticallyLaunchedAutotuner], list[str]]:
|
||||
if not cls._static_autotuners:
|
||||
return [], []
|
||||
else:
|
||||
log.info(
|
||||
"Saving %d statically launchable CachingAutotuners",
|
||||
len(cls._static_autotuners),
|
||||
)
|
||||
static_autotuner_names = [i.kernel_name for i in cls._static_autotuners]
|
||||
counters["inductor"]["triton_bundler_save_static_autotuner"] += 1
|
||||
return cls._static_autotuners, static_autotuner_names
|
||||
|
||||
@classmethod
|
||||
def load_autotuners(
|
||||
cls, static_autotuners: Optional[list[StaticallyLaunchedAutotuner]]
|
||||
) -> list[str]:
|
||||
"""
|
||||
Load statically launchable CachingAutotuners into async_compile.CompiledTritonKernels
|
||||
cache.
|
||||
"""
|
||||
if not static_autotuners:
|
||||
return []
|
||||
|
||||
from torch._inductor.async_compile import CompiledTritonKernels
|
||||
from torch._inductor.codecache import StaticAutotunerFuture
|
||||
|
||||
log.info("Loading %d statically launchable autotuners", len(static_autotuners))
|
||||
kernel_names = []
|
||||
with dynamo_timed("TritonBundler.load_cached_static_autotuners"):
|
||||
for result in static_autotuners:
|
||||
# We make a future instead of returning the kernel here so that
|
||||
# kernels that are not statically launchable (i.e. cache miss)
|
||||
# can launch a worker without waiting on the blocking step of
|
||||
# StaticAutotunerFuture.result().
|
||||
CompiledTritonKernels._cache[result.cache_key] = StaticAutotunerFuture(
|
||||
result.kernel
|
||||
)
|
||||
counters["inductor"]["triton_bundler_load_static_autotuner"] += 1
|
||||
kernel_names.append(result.kernel_name)
|
||||
return kernel_names
|
||||
|
||||
@classmethod
|
||||
def collect(
|
||||
cls,
|
||||
) -> tuple[TritonBundle, Optional[TritonBundlerMetadata]]:
|
||||
) -> tuple[list[TritonKernelArtifacts], Optional[TritonBundlerMetadata]]:
|
||||
"""
|
||||
This is the main function called when a cache write happens. This function
|
||||
converts all the previously remembered kernels into bundled format so that
|
||||
it can be written into a cache entry.
|
||||
This function also finalizes the current bundle.
|
||||
"""
|
||||
from torch._inductor import config
|
||||
|
||||
if not TritonBundler.is_enabled():
|
||||
cls.end_compile()
|
||||
set_feature_use("triton_bundling", False)
|
||||
return TritonBundle([], []), None
|
||||
return [], None
|
||||
set_feature_use("triton_bundling", True)
|
||||
|
||||
with dynamo_timed(key="TritonBundler.collect", log_pt2_compile_event=True):
|
||||
@ -302,21 +199,14 @@ class TritonBundler:
|
||||
artifacts,
|
||||
)
|
||||
)
|
||||
if config.use_static_cuda_launcher:
|
||||
static_autotuners, static_kernel_names = (
|
||||
cls.collect_static_autotuners()
|
||||
)
|
||||
else:
|
||||
static_autotuners = []
|
||||
static_kernel_names = []
|
||||
cls.end_compile()
|
||||
return TritonBundle(result, static_autotuners), TritonBundlerMetadata(
|
||||
kernel_names, static_kernel_names
|
||||
)
|
||||
return TritonBundle([], []), None
|
||||
return result, TritonBundlerMetadata(kernel_names)
|
||||
return [], None
|
||||
|
||||
@staticmethod
|
||||
def read_and_emit(bundle: TritonBundle) -> Optional[TritonBundlerMetadata]:
|
||||
def read_and_emit(
|
||||
bundle: list[TritonKernelArtifacts],
|
||||
) -> Optional[TritonBundlerMetadata]:
|
||||
"""
|
||||
This is the main function called when a cache read happens. This function
|
||||
converts the bundled format back into individual files and writes them
|
||||
@ -329,8 +219,6 @@ class TritonBundler:
|
||||
Exclusive access means that no other process should be writing to
|
||||
or reading from the target directory.
|
||||
"""
|
||||
from torch._inductor import config
|
||||
|
||||
if not TritonBundler.is_enabled():
|
||||
return None
|
||||
|
||||
@ -339,7 +227,7 @@ class TritonBundler:
|
||||
):
|
||||
kernel_names: list[str] = []
|
||||
|
||||
for artifacts in bundle.kernel_artifacts:
|
||||
for artifacts in bundle:
|
||||
basedir = triton_cache_dir(artifacts.device)
|
||||
directory = os.path.join(basedir, artifacts.kernel_hash)
|
||||
|
||||
@ -384,10 +272,4 @@ class TritonBundler:
|
||||
# Atomic on POSIX systems
|
||||
os.replace(tmp_dir, directory)
|
||||
|
||||
if config.use_static_cuda_launcher:
|
||||
static_kernel_names = TritonBundler.load_autotuners(
|
||||
bundle.static_autotuners
|
||||
)
|
||||
else:
|
||||
static_kernel_names = []
|
||||
return TritonBundlerMetadata(kernel_names, static_kernel_names)
|
||||
return TritonBundlerMetadata(kernel_names)
|
||||
|
Reference in New Issue
Block a user