mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
integrate kernacle into inductor (#160121)
This adds integration into inductor in two parts 1) It kicks off the best config lookup at lowering time within mm.py 2) It awaits the future at scheduling time in select_algorithm.py Notably this does not do the following 1) Support for enumerating between mm, addmm and bmm 2) Support for enumerating between exhaustive/max 3) Enumerating different hardware SKUs eg. H100, A100, etc. those will come in the next diffs Differential Revision: [D79824921](https://our.internmc.facebook.com/intern/diff/D79824921/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/160121 Approved by: https://github.com/izaitsevfb
This commit is contained in:
committed by
PyTorch MergeBot
parent
ba4ccf5d67
commit
05c417715f
@ -1,7 +1,7 @@
|
||||
#include <torch/csrc/api/include/torch/types.h> // @manual=fbcode//caffe2:libtorch
|
||||
|
||||
#include <torch/csrc/inductor/aoti_torch/c/shim.h>
|
||||
#include <torch/csrc/inductor/aoti_torch/utils.h>
|
||||
#include <torch/csrc/inductor/aoti_torch/c/shim.h> // @manual
|
||||
#include <torch/csrc/inductor/aoti_torch/utils.h> // @manual
|
||||
|
||||
#include <cstdint>
|
||||
#include <iostream>
|
||||
|
176
torch/_inductor/await_utils.py
Normal file
176
torch/_inductor/await_utils.py
Normal file
@ -0,0 +1,176 @@
|
||||
import asyncio
|
||||
import sys
|
||||
import weakref
|
||||
from asyncio import AbstractEventLoop, Future
|
||||
from collections.abc import Awaitable, Coroutine, Generator, Iterator
|
||||
from contextlib import contextmanager, ExitStack
|
||||
from contextvars import Context
|
||||
from typing import Any, Callable, Optional, Protocol, TypeVar
|
||||
|
||||
from torch.utils._ordered_set import OrderedSet
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
TCoro = Generator[Any, None, T]
|
||||
|
||||
if sys.version_info >= (3, 11):
|
||||
|
||||
class TaskFactory(Protocol):
|
||||
def __call__(
|
||||
self,
|
||||
__loop: AbstractEventLoop,
|
||||
__factory: Coroutine[None, None, object] | Generator[None, None, object],
|
||||
__context: Context | None = None,
|
||||
/,
|
||||
) -> asyncio.futures.Future[object]: ...
|
||||
|
||||
TaskFactoryType = TaskFactory
|
||||
else:
|
||||
TaskFactoryType = Callable[[AbstractEventLoop, Generator[TCoro, None, T]], Future] # type: ignore[valid-type]
|
||||
|
||||
|
||||
def await_sync(awaitable: Awaitable[T]) -> T:
|
||||
with get_loop() as loop:
|
||||
return loop.run_until_complete(awaitable)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def get_loop(
|
||||
always_create_new_loop: bool = False,
|
||||
) -> Iterator[AbstractEventLoop]:
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
except RuntimeError as re:
|
||||
if "There is no current event loop in thread" in str(re):
|
||||
with _new_loop() as loop:
|
||||
yield loop
|
||||
return
|
||||
else:
|
||||
raise
|
||||
|
||||
@contextmanager
|
||||
def _restore_loop(
|
||||
loop: asyncio.AbstractEventLoop,
|
||||
) -> Iterator[None]:
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
@contextmanager
|
||||
def _restore_running_loop() -> Iterator[None]:
|
||||
loop_from_events = asyncio.events._get_running_loop()
|
||||
asyncio.events._set_running_loop(None)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
asyncio.events._set_running_loop(loop_from_events)
|
||||
|
||||
with ExitStack() as stack:
|
||||
if loop.is_running():
|
||||
stack.enter_context(_restore_running_loop())
|
||||
stack.enter_context(_restore_loop(loop=loop))
|
||||
loop = stack.enter_context(_new_loop(loop.get_task_factory())) # type: ignore[arg-type]
|
||||
elif loop.is_closed():
|
||||
loop = stack.enter_context(_new_loop()) # type: ignore[arg-type]
|
||||
elif always_create_new_loop:
|
||||
stack.enter_context(_restore_loop(loop=loop))
|
||||
loop = stack.enter_context(_new_loop()) # type: ignore[arg-type]
|
||||
yield loop
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _new_loop(
|
||||
task_factory: Optional[TaskFactoryType] = None,
|
||||
) -> Iterator[asyncio.AbstractEventLoop]:
|
||||
loop = asyncio.new_event_loop()
|
||||
tasks = _patch_loop(loop)
|
||||
|
||||
if task_factory:
|
||||
# pyre-ignore[6]
|
||||
loop.set_task_factory(task_factory) # type: ignore[arg-type]
|
||||
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
yield loop
|
||||
finally:
|
||||
try:
|
||||
_cancel_all_tasks(loop, tasks)
|
||||
finally:
|
||||
asyncio.set_event_loop(None)
|
||||
loop.close()
|
||||
|
||||
|
||||
def _cancel_all_tasks(
|
||||
loop: AbstractEventLoop,
|
||||
tasks: OrderedSet[Future], # type: ignore[type-arg]
|
||||
) -> None:
|
||||
to_cancel = [task for task in tasks if not task.done()]
|
||||
|
||||
if not to_cancel:
|
||||
return
|
||||
|
||||
# pyre-fixme[1001]: Awaitable assigned to `task` is never awaited.
|
||||
for task in to_cancel:
|
||||
task.cancel()
|
||||
|
||||
loop.run_until_complete(asyncio.gather(*to_cancel, return_exceptions=True))
|
||||
|
||||
for task in to_cancel:
|
||||
if task.cancelled():
|
||||
continue
|
||||
if task.exception() is not None:
|
||||
loop.call_exception_handler(
|
||||
{
|
||||
"message": "unhandled exception during asyncio.run() shutdown",
|
||||
"exception": task.exception(),
|
||||
"task": task,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def _patch_loop(loop: AbstractEventLoop) -> OrderedSet[Future]: # type: ignore[type-arg]
|
||||
tasks: weakref.WeakSet[Future] = weakref.WeakSet() # type: ignore[type-arg]
|
||||
|
||||
task_factories: list[Optional[TaskFactoryType]] = [None]
|
||||
|
||||
def _set_task_factory(factory: Optional[TaskFactoryType]) -> None:
|
||||
task_factories[0] = factory
|
||||
|
||||
def _get_task_factory() -> Optional[TaskFactoryType]:
|
||||
return task_factories[0]
|
||||
|
||||
def _safe_task_factory(
|
||||
loop: AbstractEventLoop,
|
||||
coro: TCoro, # type: ignore[type-arg]
|
||||
*,
|
||||
context: Context | None = None,
|
||||
) -> asyncio.Future: # type: ignore[valid-type, type-arg]
|
||||
task_factory = task_factories[0]
|
||||
if task_factory is None:
|
||||
if sys.version_info >= (3, 11):
|
||||
task = asyncio.Task(coro, loop=loop, context=context)
|
||||
else:
|
||||
task = asyncio.Task(coro, loop=loop)
|
||||
# pyre-ignore[16]: `Task` has no attribute `_source_traceback`.
|
||||
if task._source_traceback: # type: ignore[attr-defined]
|
||||
del task._source_traceback[ # type: ignore[attr-defined]
|
||||
-1
|
||||
] # pragma: no cover # type: ignore[attr-defined]
|
||||
else:
|
||||
if sys.version_info >= (3, 11):
|
||||
task = task_factory(loop, coro, context=context) # type: ignore[arg-type, call-arg, assignment]
|
||||
else:
|
||||
task = task_factory(loop, coro) # type: ignore[arg-type]
|
||||
# `Union[Task[Any], Future[Any]]`.
|
||||
tasks.add(task)
|
||||
return task
|
||||
|
||||
# pyre-ignore[6]
|
||||
loop.set_task_factory(_safe_task_factory) # type: ignore[method-assign, arg-type]
|
||||
# pyre-ignore[8]
|
||||
loop.set_task_factory = _set_task_factory # type: ignore[method-assign, assignment]
|
||||
# pyre-ignore[8]
|
||||
loop.get_task_factory = _get_task_factory # type: ignore[method-assign, assignment]
|
||||
|
||||
return tasks # type: ignore[return-value]
|
@ -81,6 +81,11 @@ disable_progress = True
|
||||
# Whether to enable printing the source code for each future
|
||||
verbose_progress = False
|
||||
|
||||
# Configurable compile worker logging path for subproc_pool
|
||||
worker_log_path = (
|
||||
"/logs/dedicated_log_torch_compile_worker_rank" if is_fbcode() else None
|
||||
)
|
||||
|
||||
# precompilation timeout
|
||||
precompilation_timeout_seconds: int = 60 * 60
|
||||
|
||||
@ -91,6 +96,8 @@ fx_graph_cache: bool = Config(
|
||||
default=True,
|
||||
)
|
||||
|
||||
remote_gemm_autotune_cache: bool = False
|
||||
|
||||
# use remote fx aot graph codegen cache
|
||||
# False: Disables the cache
|
||||
# True: Enables the cache
|
||||
|
@ -15,6 +15,7 @@ from torch._inductor.autoheuristic.autoheuristic_utils import (
|
||||
mm_operations,
|
||||
)
|
||||
from torch._inductor.codegen.cpp_gemm_template import CppGemmTemplate
|
||||
from torch._inductor.remote_gemm_autotune_cache import gen_best_config
|
||||
from torch._inductor.virtualized import V
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
from torch.torch_version import TorchVersion
|
||||
@ -836,7 +837,19 @@ def tuned_mm(mat1, mat2, *, layout=None):
|
||||
lazy_register_extern_choice(k).bind(kernel_inputs.nodes(), layout)
|
||||
)
|
||||
|
||||
return autotune_select_algorithm(name, choices, kernel_inputs.nodes(), layout)
|
||||
best_config_future = None
|
||||
# Purposely not awaiting the future here - this kicks off the best config lookup at lowering time
|
||||
# The future will be awaited at scheduling time in select_algorithm.py
|
||||
if torch._inductor.config.remote_gemm_autotune_cache:
|
||||
best_config_future = gen_best_config(mat1, mat2)
|
||||
|
||||
return autotune_select_algorithm(
|
||||
name,
|
||||
choices,
|
||||
kernel_inputs.nodes(),
|
||||
layout,
|
||||
best_config_future=best_config_future,
|
||||
)
|
||||
|
||||
|
||||
@register_lowering(aten._int_mm, type_promotion_kind=None)
|
||||
|
20
torch/_inductor/remote_gemm_autotune_cache.py
Normal file
20
torch/_inductor/remote_gemm_autotune_cache.py
Normal file
@ -0,0 +1,20 @@
|
||||
import asyncio
|
||||
from typing import TypeVar
|
||||
|
||||
import torch._inductor.config as config
|
||||
from torch._inductor import ir
|
||||
|
||||
|
||||
_T = TypeVar("_T")
|
||||
|
||||
|
||||
def gen_best_config(mat1: ir.StorageBox, mat2: ir.StorageBox) -> asyncio.Task[_T]:
|
||||
"""
|
||||
Generate the best GEMM autotune config for the given matrices.
|
||||
"""
|
||||
if config.is_fbcode():
|
||||
from torch._inductor.fb.remote_gemm_autotune_cache import gen_best_config
|
||||
|
||||
return gen_best_config(mat1, mat2)
|
||||
else:
|
||||
raise NotImplementedError("Function gen_best_config is not yet implemented")
|
@ -34,6 +34,7 @@ from torch._dynamo.utils import (
|
||||
identity,
|
||||
preserve_rng_state,
|
||||
)
|
||||
from torch._inductor.await_utils import await_sync
|
||||
from torch._inductor.utils import clear_on_fresh_cache
|
||||
from torch.utils._filelock import FileLock
|
||||
from torch.utils._ordered_set import OrderedSet
|
||||
@ -2280,6 +2281,7 @@ class AlgorithmSelectorCache(PersistentCache):
|
||||
input_gen_fns: Optional[dict[int, Callable[[ir.Buffer], torch.Tensor]]] = None,
|
||||
precompilation_timeout_seconds: int = 60 * 60,
|
||||
return_multi_template=False,
|
||||
best_config_future=None,
|
||||
):
|
||||
from .codegen.cuda.cuda_kernel import CUDATemplateCaller
|
||||
|
||||
@ -2387,6 +2389,35 @@ class AlgorithmSelectorCache(PersistentCache):
|
||||
log.debug("Prescreening elapsed time: %.02fs", prescreening_elapse)
|
||||
|
||||
autotune_start_ts = time.time()
|
||||
|
||||
if best_config_future is not None:
|
||||
best_config = await_sync(best_config_future)
|
||||
|
||||
important_keys = [
|
||||
"ACC_TYPE",
|
||||
"ALLOW_TF32",
|
||||
"BLOCK_K",
|
||||
"BLOCK_M",
|
||||
"BLOCK_N",
|
||||
"EVEN_K",
|
||||
"GROUP_M",
|
||||
"USE_FAST_ACCUM",
|
||||
"num_stages",
|
||||
"num_warps",
|
||||
"num_consumer_groups",
|
||||
"num_buffers_warp_spec",
|
||||
]
|
||||
choices = [
|
||||
choice
|
||||
for choice in choices
|
||||
if all(
|
||||
f"{k}={best_config[k]}" in choice.description
|
||||
for k in important_keys
|
||||
)
|
||||
for k in important_keys
|
||||
]
|
||||
log.info("Filtered to %d choices based on best_config", len(choices))
|
||||
|
||||
timings = self.lookup(
|
||||
choices,
|
||||
name,
|
||||
|
Reference in New Issue
Block a user