From 05c417715f791875fbf28cfc3fc86142de1a3206 Mon Sep 17 00:00:00 2001 From: bobrenjc93 Date: Thu, 7 Aug 2025 11:24:21 -0700 Subject: [PATCH] integrate kernacle into inductor (#160121) This adds integration into inductor in two parts 1) It kicks off the best config lookup at lowering time within mm.py 2) It awaits the future at scheduling time in select_algorithm.py Notably this does not do the following 1) Support for enumerating between mm, addmm and bmm 2) Support for enumerating between exhaustive/max 3) Enumerating different hardware SKUs eg. H100, A100, etc. those will come in the next diffs Differential Revision: [D79824921](https://our.internmc.facebook.com/intern/diff/D79824921/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/160121 Approved by: https://github.com/izaitsevfb --- test/inductor/custom_ops.cpp | 4 +- torch/_inductor/await_utils.py | 176 ++++++++++++++++++ torch/_inductor/config.py | 7 + torch/_inductor/kernel/mm.py | 15 +- torch/_inductor/remote_gemm_autotune_cache.py | 20 ++ torch/_inductor/select_algorithm.py | 31 +++ 6 files changed, 250 insertions(+), 3 deletions(-) create mode 100644 torch/_inductor/await_utils.py create mode 100644 torch/_inductor/remote_gemm_autotune_cache.py diff --git a/test/inductor/custom_ops.cpp b/test/inductor/custom_ops.cpp index ae1d00c5b634..ade7695a10d0 100644 --- a/test/inductor/custom_ops.cpp +++ b/test/inductor/custom_ops.cpp @@ -1,7 +1,7 @@ #include // @manual=fbcode//caffe2:libtorch -#include -#include +#include // @manual +#include // @manual #include #include diff --git a/torch/_inductor/await_utils.py b/torch/_inductor/await_utils.py new file mode 100644 index 000000000000..a549674d5cd7 --- /dev/null +++ b/torch/_inductor/await_utils.py @@ -0,0 +1,176 @@ +import asyncio +import sys +import weakref +from asyncio import AbstractEventLoop, Future +from collections.abc import Awaitable, Coroutine, Generator, Iterator +from contextlib import contextmanager, ExitStack +from contextvars import Context +from typing import Any, Callable, Optional, Protocol, TypeVar + +from torch.utils._ordered_set import OrderedSet + + +T = TypeVar("T") +TCoro = Generator[Any, None, T] + +if sys.version_info >= (3, 11): + + class TaskFactory(Protocol): + def __call__( + self, + __loop: AbstractEventLoop, + __factory: Coroutine[None, None, object] | Generator[None, None, object], + __context: Context | None = None, + /, + ) -> asyncio.futures.Future[object]: ... + + TaskFactoryType = TaskFactory +else: + TaskFactoryType = Callable[[AbstractEventLoop, Generator[TCoro, None, T]], Future] # type: ignore[valid-type] + + +def await_sync(awaitable: Awaitable[T]) -> T: + with get_loop() as loop: + return loop.run_until_complete(awaitable) + + +@contextmanager +def get_loop( + always_create_new_loop: bool = False, +) -> Iterator[AbstractEventLoop]: + try: + loop = asyncio.get_event_loop() + except RuntimeError as re: + if "There is no current event loop in thread" in str(re): + with _new_loop() as loop: + yield loop + return + else: + raise + + @contextmanager + def _restore_loop( + loop: asyncio.AbstractEventLoop, + ) -> Iterator[None]: + try: + yield + finally: + asyncio.set_event_loop(loop) + + @contextmanager + def _restore_running_loop() -> Iterator[None]: + loop_from_events = asyncio.events._get_running_loop() + asyncio.events._set_running_loop(None) + try: + yield + finally: + asyncio.events._set_running_loop(loop_from_events) + + with ExitStack() as stack: + if loop.is_running(): + stack.enter_context(_restore_running_loop()) + stack.enter_context(_restore_loop(loop=loop)) + loop = stack.enter_context(_new_loop(loop.get_task_factory())) # type: ignore[arg-type] + elif loop.is_closed(): + loop = stack.enter_context(_new_loop()) # type: ignore[arg-type] + elif always_create_new_loop: + stack.enter_context(_restore_loop(loop=loop)) + loop = stack.enter_context(_new_loop()) # type: ignore[arg-type] + yield loop + + +@contextmanager +def _new_loop( + task_factory: Optional[TaskFactoryType] = None, +) -> Iterator[asyncio.AbstractEventLoop]: + loop = asyncio.new_event_loop() + tasks = _patch_loop(loop) + + if task_factory: + # pyre-ignore[6] + loop.set_task_factory(task_factory) # type: ignore[arg-type] + + asyncio.set_event_loop(loop) + try: + yield loop + finally: + try: + _cancel_all_tasks(loop, tasks) + finally: + asyncio.set_event_loop(None) + loop.close() + + +def _cancel_all_tasks( + loop: AbstractEventLoop, + tasks: OrderedSet[Future], # type: ignore[type-arg] +) -> None: + to_cancel = [task for task in tasks if not task.done()] + + if not to_cancel: + return + + # pyre-fixme[1001]: Awaitable assigned to `task` is never awaited. + for task in to_cancel: + task.cancel() + + loop.run_until_complete(asyncio.gather(*to_cancel, return_exceptions=True)) + + for task in to_cancel: + if task.cancelled(): + continue + if task.exception() is not None: + loop.call_exception_handler( + { + "message": "unhandled exception during asyncio.run() shutdown", + "exception": task.exception(), + "task": task, + } + ) + + +def _patch_loop(loop: AbstractEventLoop) -> OrderedSet[Future]: # type: ignore[type-arg] + tasks: weakref.WeakSet[Future] = weakref.WeakSet() # type: ignore[type-arg] + + task_factories: list[Optional[TaskFactoryType]] = [None] + + def _set_task_factory(factory: Optional[TaskFactoryType]) -> None: + task_factories[0] = factory + + def _get_task_factory() -> Optional[TaskFactoryType]: + return task_factories[0] + + def _safe_task_factory( + loop: AbstractEventLoop, + coro: TCoro, # type: ignore[type-arg] + *, + context: Context | None = None, + ) -> asyncio.Future: # type: ignore[valid-type, type-arg] + task_factory = task_factories[0] + if task_factory is None: + if sys.version_info >= (3, 11): + task = asyncio.Task(coro, loop=loop, context=context) + else: + task = asyncio.Task(coro, loop=loop) + # pyre-ignore[16]: `Task` has no attribute `_source_traceback`. + if task._source_traceback: # type: ignore[attr-defined] + del task._source_traceback[ # type: ignore[attr-defined] + -1 + ] # pragma: no cover # type: ignore[attr-defined] + else: + if sys.version_info >= (3, 11): + task = task_factory(loop, coro, context=context) # type: ignore[arg-type, call-arg, assignment] + else: + task = task_factory(loop, coro) # type: ignore[arg-type] + # `Union[Task[Any], Future[Any]]`. + tasks.add(task) + return task + + # pyre-ignore[6] + loop.set_task_factory(_safe_task_factory) # type: ignore[method-assign, arg-type] + # pyre-ignore[8] + loop.set_task_factory = _set_task_factory # type: ignore[method-assign, assignment] + # pyre-ignore[8] + loop.get_task_factory = _get_task_factory # type: ignore[method-assign, assignment] + + return tasks # type: ignore[return-value] diff --git a/torch/_inductor/config.py b/torch/_inductor/config.py index 51a438840b04..8d3b4cd7ed49 100644 --- a/torch/_inductor/config.py +++ b/torch/_inductor/config.py @@ -81,6 +81,11 @@ disable_progress = True # Whether to enable printing the source code for each future verbose_progress = False +# Configurable compile worker logging path for subproc_pool +worker_log_path = ( + "/logs/dedicated_log_torch_compile_worker_rank" if is_fbcode() else None +) + # precompilation timeout precompilation_timeout_seconds: int = 60 * 60 @@ -91,6 +96,8 @@ fx_graph_cache: bool = Config( default=True, ) +remote_gemm_autotune_cache: bool = False + # use remote fx aot graph codegen cache # False: Disables the cache # True: Enables the cache diff --git a/torch/_inductor/kernel/mm.py b/torch/_inductor/kernel/mm.py index 6e741430f36d..e68a76174c73 100644 --- a/torch/_inductor/kernel/mm.py +++ b/torch/_inductor/kernel/mm.py @@ -15,6 +15,7 @@ from torch._inductor.autoheuristic.autoheuristic_utils import ( mm_operations, ) from torch._inductor.codegen.cpp_gemm_template import CppGemmTemplate +from torch._inductor.remote_gemm_autotune_cache import gen_best_config from torch._inductor.virtualized import V from torch.fx.experimental.proxy_tensor import make_fx from torch.torch_version import TorchVersion @@ -836,7 +837,19 @@ def tuned_mm(mat1, mat2, *, layout=None): lazy_register_extern_choice(k).bind(kernel_inputs.nodes(), layout) ) - return autotune_select_algorithm(name, choices, kernel_inputs.nodes(), layout) + best_config_future = None + # Purposely not awaiting the future here - this kicks off the best config lookup at lowering time + # The future will be awaited at scheduling time in select_algorithm.py + if torch._inductor.config.remote_gemm_autotune_cache: + best_config_future = gen_best_config(mat1, mat2) + + return autotune_select_algorithm( + name, + choices, + kernel_inputs.nodes(), + layout, + best_config_future=best_config_future, + ) @register_lowering(aten._int_mm, type_promotion_kind=None) diff --git a/torch/_inductor/remote_gemm_autotune_cache.py b/torch/_inductor/remote_gemm_autotune_cache.py new file mode 100644 index 000000000000..0ef026269b10 --- /dev/null +++ b/torch/_inductor/remote_gemm_autotune_cache.py @@ -0,0 +1,20 @@ +import asyncio +from typing import TypeVar + +import torch._inductor.config as config +from torch._inductor import ir + + +_T = TypeVar("_T") + + +def gen_best_config(mat1: ir.StorageBox, mat2: ir.StorageBox) -> asyncio.Task[_T]: + """ + Generate the best GEMM autotune config for the given matrices. + """ + if config.is_fbcode(): + from torch._inductor.fb.remote_gemm_autotune_cache import gen_best_config + + return gen_best_config(mat1, mat2) + else: + raise NotImplementedError("Function gen_best_config is not yet implemented") diff --git a/torch/_inductor/select_algorithm.py b/torch/_inductor/select_algorithm.py index 4faa251953d6..01337fc0d30b 100644 --- a/torch/_inductor/select_algorithm.py +++ b/torch/_inductor/select_algorithm.py @@ -34,6 +34,7 @@ from torch._dynamo.utils import ( identity, preserve_rng_state, ) +from torch._inductor.await_utils import await_sync from torch._inductor.utils import clear_on_fresh_cache from torch.utils._filelock import FileLock from torch.utils._ordered_set import OrderedSet @@ -2280,6 +2281,7 @@ class AlgorithmSelectorCache(PersistentCache): input_gen_fns: Optional[dict[int, Callable[[ir.Buffer], torch.Tensor]]] = None, precompilation_timeout_seconds: int = 60 * 60, return_multi_template=False, + best_config_future=None, ): from .codegen.cuda.cuda_kernel import CUDATemplateCaller @@ -2387,6 +2389,35 @@ class AlgorithmSelectorCache(PersistentCache): log.debug("Prescreening elapsed time: %.02fs", prescreening_elapse) autotune_start_ts = time.time() + + if best_config_future is not None: + best_config = await_sync(best_config_future) + + important_keys = [ + "ACC_TYPE", + "ALLOW_TF32", + "BLOCK_K", + "BLOCK_M", + "BLOCK_N", + "EVEN_K", + "GROUP_M", + "USE_FAST_ACCUM", + "num_stages", + "num_warps", + "num_consumer_groups", + "num_buffers_warp_spec", + ] + choices = [ + choice + for choice in choices + if all( + f"{k}={best_config[k]}" in choice.description + for k in important_keys + ) + for k in important_keys + ] + log.info("Filtered to %d choices based on best_config", len(choices)) + timings = self.lookup( choices, name,