Files
pytorch/torch/_inductor/await_utils.py
bobrenjc93 05c417715f integrate kernacle into inductor (#160121)
This adds integration into inductor in two parts

1) It kicks off the best config lookup at lowering time within mm.py
2) It awaits the future at scheduling time in select_algorithm.py

Notably this does not do the following

1) Support for enumerating between mm, addmm and bmm
2) Support for enumerating between exhaustive/max
3) Enumerating different hardware SKUs eg. H100, A100, etc.

those will come in the next diffs

Differential Revision: [D79824921](https://our.internmc.facebook.com/intern/diff/D79824921/)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/160121
Approved by: https://github.com/izaitsevfb
2025-08-08 02:14:44 +00:00

177 lines
5.6 KiB
Python

import asyncio
import sys
import weakref
from asyncio import AbstractEventLoop, Future
from collections.abc import Awaitable, Coroutine, Generator, Iterator
from contextlib import contextmanager, ExitStack
from contextvars import Context
from typing import Any, Callable, Optional, Protocol, TypeVar
from torch.utils._ordered_set import OrderedSet
T = TypeVar("T")
TCoro = Generator[Any, None, T]
if sys.version_info >= (3, 11):
class TaskFactory(Protocol):
def __call__(
self,
__loop: AbstractEventLoop,
__factory: Coroutine[None, None, object] | Generator[None, None, object],
__context: Context | None = None,
/,
) -> asyncio.futures.Future[object]: ...
TaskFactoryType = TaskFactory
else:
TaskFactoryType = Callable[[AbstractEventLoop, Generator[TCoro, None, T]], Future] # type: ignore[valid-type]
def await_sync(awaitable: Awaitable[T]) -> T:
with get_loop() as loop:
return loop.run_until_complete(awaitable)
@contextmanager
def get_loop(
always_create_new_loop: bool = False,
) -> Iterator[AbstractEventLoop]:
try:
loop = asyncio.get_event_loop()
except RuntimeError as re:
if "There is no current event loop in thread" in str(re):
with _new_loop() as loop:
yield loop
return
else:
raise
@contextmanager
def _restore_loop(
loop: asyncio.AbstractEventLoop,
) -> Iterator[None]:
try:
yield
finally:
asyncio.set_event_loop(loop)
@contextmanager
def _restore_running_loop() -> Iterator[None]:
loop_from_events = asyncio.events._get_running_loop()
asyncio.events._set_running_loop(None)
try:
yield
finally:
asyncio.events._set_running_loop(loop_from_events)
with ExitStack() as stack:
if loop.is_running():
stack.enter_context(_restore_running_loop())
stack.enter_context(_restore_loop(loop=loop))
loop = stack.enter_context(_new_loop(loop.get_task_factory())) # type: ignore[arg-type]
elif loop.is_closed():
loop = stack.enter_context(_new_loop()) # type: ignore[arg-type]
elif always_create_new_loop:
stack.enter_context(_restore_loop(loop=loop))
loop = stack.enter_context(_new_loop()) # type: ignore[arg-type]
yield loop
@contextmanager
def _new_loop(
task_factory: Optional[TaskFactoryType] = None,
) -> Iterator[asyncio.AbstractEventLoop]:
loop = asyncio.new_event_loop()
tasks = _patch_loop(loop)
if task_factory:
# pyre-ignore[6]
loop.set_task_factory(task_factory) # type: ignore[arg-type]
asyncio.set_event_loop(loop)
try:
yield loop
finally:
try:
_cancel_all_tasks(loop, tasks)
finally:
asyncio.set_event_loop(None)
loop.close()
def _cancel_all_tasks(
loop: AbstractEventLoop,
tasks: OrderedSet[Future], # type: ignore[type-arg]
) -> None:
to_cancel = [task for task in tasks if not task.done()]
if not to_cancel:
return
# pyre-fixme[1001]: Awaitable assigned to `task` is never awaited.
for task in to_cancel:
task.cancel()
loop.run_until_complete(asyncio.gather(*to_cancel, return_exceptions=True))
for task in to_cancel:
if task.cancelled():
continue
if task.exception() is not None:
loop.call_exception_handler(
{
"message": "unhandled exception during asyncio.run() shutdown",
"exception": task.exception(),
"task": task,
}
)
def _patch_loop(loop: AbstractEventLoop) -> OrderedSet[Future]: # type: ignore[type-arg]
tasks: weakref.WeakSet[Future] = weakref.WeakSet() # type: ignore[type-arg]
task_factories: list[Optional[TaskFactoryType]] = [None]
def _set_task_factory(factory: Optional[TaskFactoryType]) -> None:
task_factories[0] = factory
def _get_task_factory() -> Optional[TaskFactoryType]:
return task_factories[0]
def _safe_task_factory(
loop: AbstractEventLoop,
coro: TCoro, # type: ignore[type-arg]
*,
context: Context | None = None,
) -> asyncio.Future: # type: ignore[valid-type, type-arg]
task_factory = task_factories[0]
if task_factory is None:
if sys.version_info >= (3, 11):
task = asyncio.Task(coro, loop=loop, context=context)
else:
task = asyncio.Task(coro, loop=loop)
# pyre-ignore[16]: `Task` has no attribute `_source_traceback`.
if task._source_traceback: # type: ignore[attr-defined]
del task._source_traceback[ # type: ignore[attr-defined]
-1
] # pragma: no cover # type: ignore[attr-defined]
else:
if sys.version_info >= (3, 11):
task = task_factory(loop, coro, context=context) # type: ignore[arg-type, call-arg, assignment]
else:
task = task_factory(loop, coro) # type: ignore[arg-type]
# `Union[Task[Any], Future[Any]]`.
tasks.add(task)
return task
# pyre-ignore[6]
loop.set_task_factory(_safe_task_factory) # type: ignore[method-assign, arg-type]
# pyre-ignore[8]
loop.set_task_factory = _set_task_factory # type: ignore[method-assign, assignment]
# pyre-ignore[8]
loop.get_task_factory = _get_task_factory # type: ignore[method-assign, assignment]
return tasks # type: ignore[return-value]