integrate kernacle into inductor (#160121)

This adds integration into inductor in two parts

1) It kicks off the best config lookup at lowering time within mm.py
2) It awaits the future at scheduling time in select_algorithm.py

Notably this does not do the following

1) Support for enumerating between mm, addmm and bmm
2) Support for enumerating between exhaustive/max
3) Enumerating different hardware SKUs eg. H100, A100, etc.

those will come in the next diffs

Differential Revision: [D79824921](https://our.internmc.facebook.com/intern/diff/D79824921/)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/160121
Approved by: https://github.com/izaitsevfb
This commit is contained in:
bobrenjc93
2025-08-07 11:24:21 -07:00
committed by PyTorch MergeBot
parent ba4ccf5d67
commit 05c417715f
6 changed files with 250 additions and 3 deletions

View File

@ -1,7 +1,7 @@
#include <torch/csrc/api/include/torch/types.h> // @manual=fbcode//caffe2:libtorch
#include <torch/csrc/inductor/aoti_torch/c/shim.h>
#include <torch/csrc/inductor/aoti_torch/utils.h>
#include <torch/csrc/inductor/aoti_torch/c/shim.h> // @manual
#include <torch/csrc/inductor/aoti_torch/utils.h> // @manual
#include <cstdint>
#include <iostream>

View File

@ -0,0 +1,176 @@
import asyncio
import sys
import weakref
from asyncio import AbstractEventLoop, Future
from collections.abc import Awaitable, Coroutine, Generator, Iterator
from contextlib import contextmanager, ExitStack
from contextvars import Context
from typing import Any, Callable, Optional, Protocol, TypeVar
from torch.utils._ordered_set import OrderedSet
T = TypeVar("T")
TCoro = Generator[Any, None, T]
if sys.version_info >= (3, 11):
class TaskFactory(Protocol):
def __call__(
self,
__loop: AbstractEventLoop,
__factory: Coroutine[None, None, object] | Generator[None, None, object],
__context: Context | None = None,
/,
) -> asyncio.futures.Future[object]: ...
TaskFactoryType = TaskFactory
else:
TaskFactoryType = Callable[[AbstractEventLoop, Generator[TCoro, None, T]], Future] # type: ignore[valid-type]
def await_sync(awaitable: Awaitable[T]) -> T:
with get_loop() as loop:
return loop.run_until_complete(awaitable)
@contextmanager
def get_loop(
always_create_new_loop: bool = False,
) -> Iterator[AbstractEventLoop]:
try:
loop = asyncio.get_event_loop()
except RuntimeError as re:
if "There is no current event loop in thread" in str(re):
with _new_loop() as loop:
yield loop
return
else:
raise
@contextmanager
def _restore_loop(
loop: asyncio.AbstractEventLoop,
) -> Iterator[None]:
try:
yield
finally:
asyncio.set_event_loop(loop)
@contextmanager
def _restore_running_loop() -> Iterator[None]:
loop_from_events = asyncio.events._get_running_loop()
asyncio.events._set_running_loop(None)
try:
yield
finally:
asyncio.events._set_running_loop(loop_from_events)
with ExitStack() as stack:
if loop.is_running():
stack.enter_context(_restore_running_loop())
stack.enter_context(_restore_loop(loop=loop))
loop = stack.enter_context(_new_loop(loop.get_task_factory())) # type: ignore[arg-type]
elif loop.is_closed():
loop = stack.enter_context(_new_loop()) # type: ignore[arg-type]
elif always_create_new_loop:
stack.enter_context(_restore_loop(loop=loop))
loop = stack.enter_context(_new_loop()) # type: ignore[arg-type]
yield loop
@contextmanager
def _new_loop(
task_factory: Optional[TaskFactoryType] = None,
) -> Iterator[asyncio.AbstractEventLoop]:
loop = asyncio.new_event_loop()
tasks = _patch_loop(loop)
if task_factory:
# pyre-ignore[6]
loop.set_task_factory(task_factory) # type: ignore[arg-type]
asyncio.set_event_loop(loop)
try:
yield loop
finally:
try:
_cancel_all_tasks(loop, tasks)
finally:
asyncio.set_event_loop(None)
loop.close()
def _cancel_all_tasks(
loop: AbstractEventLoop,
tasks: OrderedSet[Future], # type: ignore[type-arg]
) -> None:
to_cancel = [task for task in tasks if not task.done()]
if not to_cancel:
return
# pyre-fixme[1001]: Awaitable assigned to `task` is never awaited.
for task in to_cancel:
task.cancel()
loop.run_until_complete(asyncio.gather(*to_cancel, return_exceptions=True))
for task in to_cancel:
if task.cancelled():
continue
if task.exception() is not None:
loop.call_exception_handler(
{
"message": "unhandled exception during asyncio.run() shutdown",
"exception": task.exception(),
"task": task,
}
)
def _patch_loop(loop: AbstractEventLoop) -> OrderedSet[Future]: # type: ignore[type-arg]
tasks: weakref.WeakSet[Future] = weakref.WeakSet() # type: ignore[type-arg]
task_factories: list[Optional[TaskFactoryType]] = [None]
def _set_task_factory(factory: Optional[TaskFactoryType]) -> None:
task_factories[0] = factory
def _get_task_factory() -> Optional[TaskFactoryType]:
return task_factories[0]
def _safe_task_factory(
loop: AbstractEventLoop,
coro: TCoro, # type: ignore[type-arg]
*,
context: Context | None = None,
) -> asyncio.Future: # type: ignore[valid-type, type-arg]
task_factory = task_factories[0]
if task_factory is None:
if sys.version_info >= (3, 11):
task = asyncio.Task(coro, loop=loop, context=context)
else:
task = asyncio.Task(coro, loop=loop)
# pyre-ignore[16]: `Task` has no attribute `_source_traceback`.
if task._source_traceback: # type: ignore[attr-defined]
del task._source_traceback[ # type: ignore[attr-defined]
-1
] # pragma: no cover # type: ignore[attr-defined]
else:
if sys.version_info >= (3, 11):
task = task_factory(loop, coro, context=context) # type: ignore[arg-type, call-arg, assignment]
else:
task = task_factory(loop, coro) # type: ignore[arg-type]
# `Union[Task[Any], Future[Any]]`.
tasks.add(task)
return task
# pyre-ignore[6]
loop.set_task_factory(_safe_task_factory) # type: ignore[method-assign, arg-type]
# pyre-ignore[8]
loop.set_task_factory = _set_task_factory # type: ignore[method-assign, assignment]
# pyre-ignore[8]
loop.get_task_factory = _get_task_factory # type: ignore[method-assign, assignment]
return tasks # type: ignore[return-value]

View File

@ -81,6 +81,11 @@ disable_progress = True
# Whether to enable printing the source code for each future
verbose_progress = False
# Configurable compile worker logging path for subproc_pool
worker_log_path = (
"/logs/dedicated_log_torch_compile_worker_rank" if is_fbcode() else None
)
# precompilation timeout
precompilation_timeout_seconds: int = 60 * 60
@ -91,6 +96,8 @@ fx_graph_cache: bool = Config(
default=True,
)
remote_gemm_autotune_cache: bool = False
# use remote fx aot graph codegen cache
# False: Disables the cache
# True: Enables the cache

View File

@ -15,6 +15,7 @@ from torch._inductor.autoheuristic.autoheuristic_utils import (
mm_operations,
)
from torch._inductor.codegen.cpp_gemm_template import CppGemmTemplate
from torch._inductor.remote_gemm_autotune_cache import gen_best_config
from torch._inductor.virtualized import V
from torch.fx.experimental.proxy_tensor import make_fx
from torch.torch_version import TorchVersion
@ -836,7 +837,19 @@ def tuned_mm(mat1, mat2, *, layout=None):
lazy_register_extern_choice(k).bind(kernel_inputs.nodes(), layout)
)
return autotune_select_algorithm(name, choices, kernel_inputs.nodes(), layout)
best_config_future = None
# Purposely not awaiting the future here - this kicks off the best config lookup at lowering time
# The future will be awaited at scheduling time in select_algorithm.py
if torch._inductor.config.remote_gemm_autotune_cache:
best_config_future = gen_best_config(mat1, mat2)
return autotune_select_algorithm(
name,
choices,
kernel_inputs.nodes(),
layout,
best_config_future=best_config_future,
)
@register_lowering(aten._int_mm, type_promotion_kind=None)

View File

@ -0,0 +1,20 @@
import asyncio
from typing import TypeVar
import torch._inductor.config as config
from torch._inductor import ir
_T = TypeVar("_T")
def gen_best_config(mat1: ir.StorageBox, mat2: ir.StorageBox) -> asyncio.Task[_T]:
"""
Generate the best GEMM autotune config for the given matrices.
"""
if config.is_fbcode():
from torch._inductor.fb.remote_gemm_autotune_cache import gen_best_config
return gen_best_config(mat1, mat2)
else:
raise NotImplementedError("Function gen_best_config is not yet implemented")

View File

@ -34,6 +34,7 @@ from torch._dynamo.utils import (
identity,
preserve_rng_state,
)
from torch._inductor.await_utils import await_sync
from torch._inductor.utils import clear_on_fresh_cache
from torch.utils._filelock import FileLock
from torch.utils._ordered_set import OrderedSet
@ -2280,6 +2281,7 @@ class AlgorithmSelectorCache(PersistentCache):
input_gen_fns: Optional[dict[int, Callable[[ir.Buffer], torch.Tensor]]] = None,
precompilation_timeout_seconds: int = 60 * 60,
return_multi_template=False,
best_config_future=None,
):
from .codegen.cuda.cuda_kernel import CUDATemplateCaller
@ -2387,6 +2389,35 @@ class AlgorithmSelectorCache(PersistentCache):
log.debug("Prescreening elapsed time: %.02fs", prescreening_elapse)
autotune_start_ts = time.time()
if best_config_future is not None:
best_config = await_sync(best_config_future)
important_keys = [
"ACC_TYPE",
"ALLOW_TF32",
"BLOCK_K",
"BLOCK_M",
"BLOCK_N",
"EVEN_K",
"GROUP_M",
"USE_FAST_ACCUM",
"num_stages",
"num_warps",
"num_consumer_groups",
"num_buffers_warp_spec",
]
choices = [
choice
for choice in choices
if all(
f"{k}={best_config[k]}" in choice.description
for k in important_keys
)
for k in important_keys
]
log.info("Filtered to %d choices based on best_config", len(choices))
timings = self.lookup(
choices,
name,