Files
pytorch/torch/_inductor/async_compile.py
Pian Pawakapan 2a9745de3c [multi-kernel] shape-similarity kernel selection (#163090)
Introduces a variant of size-hint multi-kernel, where for novel runtime shapes, instead of performing full benchmarking to determine the optimal kernel, selects one of many kernels pre-generated from multi-kernel hints, based off similarity b/w hint / runtime input & output shapes (L1 distance in log2 space).

Some caveats/changes:
- Size-hint multi-kernel now only kicks in if the kernel has dynamic shapes
- Pre-generation still only does 1-d search over specified hints, e.g. `matmul([s0, s1], [s1, s2])` with size-hints `[64, 256]` only generates 2 kernels - based on tuning shapes ([64, 64], [64, 64]) and ([256, 256], [256, 256]). Extending this to reasonable n-d search (via user API?) is an extension

Benchmarking results, compared to multi-kernel w/ full benchmarking (hints 64, 4096), and compiling with the ground truth hint:
<img width="1902" height="1222" alt="550541081_1088709150049684_6528797079439730237_n" src="https://github.com/user-attachments/assets/056cca48-c16a-4451-9b4a-fa13a7a058a9" />

Full benchmarking doing worse is extremely weird, but we did see similar spikes in #156628

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163090
Approved by: https://github.com/bobrenjc93
2025-09-23 21:00:47 +00:00

693 lines
25 KiB
Python

# mypy: allow-untyped-defs
from __future__ import annotations
import atexit
import contextlib
import functools
import json
import logging
import multiprocessing
import os
import re
import sys
from concurrent.futures import Future, ThreadPoolExecutor
from concurrent.futures.process import BrokenProcessPool
from functools import partial
from time import time, time_ns
from typing import Any, Callable, Optional, TYPE_CHECKING
import torch
from torch._dynamo.device_interface import get_registered_device_interfaces
from torch._dynamo.utils import (
counters,
dynamo_timed,
get_metrics_context,
set_feature_use,
)
from torch._inductor import config
from torch._inductor.codecache import (
_load_triton_kernel_from_source,
code_hash,
CodeCacheFuture,
CppCodeCache,
CppPythonBindingsCodeCache,
CUDACodeCache,
HalideCodeCache,
LambdaFuture,
ROCmCodeCache,
StaticAutotunerFuture,
torch_key,
)
from torch._inductor.compile_worker.subproc_pool import (
AnyPool,
SubprocException,
SubprocPool,
)
from torch._inductor.compile_worker.tracked_process_pool import (
TrackedProcessPoolExecutor,
)
from torch._inductor.compile_worker.utils import _async_compile_initializer
from torch._inductor.runtime.compile_tasks import (
_set_triton_ptxas_path,
_worker_compile_triton,
)
from torch._inductor.utils import clear_on_fresh_cache
from torch._inductor.virtualized import V
from torch._utils_internal import log_triton_builds
from torch.hub import _Faketqdm, tqdm
from torch.utils._ordered_set import OrderedSet
from torch.utils._triton import has_triton_package
if TYPE_CHECKING:
from torch._inductor.runtime.hints import HalideMeta
from torch._inductor.runtime.triton_heuristics import CachingAutotuner
# timing metrics for time spent in the compilation
_cumulative_compile_time = 0.0
_t0: Optional[float] = None
kernel_code_log = torch._logging.getArtifactLogger(__name__, "kernel_code")
log = logging.getLogger(__name__)
_triton_kernel_metrics: Optional[dict[str, dict[str, Any]]] = None
size_hints_regex = re.compile(
r"size_hints=(\{.*?\})",
)
def pre_fork_setup():
"""
Setup that must be done prior to forking with a process pool.
"""
# ensure properties have been calculated before processes
# are forked
caching_device_properties()
# Computing the triton key can be slow. If we call it before fork,
# it will be cached for the forked subprocesses.
from torch._inductor.runtime.triton_compat import HAS_TRITON, triton_key
if HAS_TRITON:
triton_key()
def caching_device_properties():
for _, device_interface in get_registered_device_interfaces():
if device_interface.is_available():
device_interface.Worker.get_device_properties()
def _compile_start() -> None:
global _t0, _triton_kernel_metrics
if _t0 is None:
_t0 = time()
if _triton_kernel_metrics is None:
_triton_kernel_metrics = {}
def _compile_end() -> None:
global _cumulative_compile_time, _t0, _triton_kernel_metrics
if _t0 is not None:
t1 = time()
_cumulative_compile_time += t1 - _t0
_t0 = None
# print("CUMULATIVE COMPILE TIME", _cumulative_compile_time)
if _triton_kernel_metrics:
# Log triton kernel info
sorted_info = dict(sorted(_triton_kernel_metrics.items()))
torch._logging.trace_structured(
"artifact",
metadata_fn=lambda: {
"name": "triton_kernel_info",
"encoding": "json",
},
payload_fn=lambda: json.dumps(sorted_info),
)
_triton_kernel_metrics = None
def _add_triton_kernel_info(kernel_name: str, info: dict[str, Any]):
global _triton_kernel_metrics
# Must be called between _compile_start and _compile_end
if _triton_kernel_metrics is not None:
_triton_kernel_metrics[kernel_name] = info
_IS_WINDOWS = sys.platform == "win32"
log = logging.getLogger(__name__)
# Used to keep track of all process pools invoked so far.
_pool_set = OrderedSet[AnyPool]()
def shutdown_compile_workers() -> None:
"""Shut down all outstanding compile-worker pools."""
for pool in _pool_set:
pool.shutdown()
AsyncCompile._ready_future = None
after_fork()
def after_fork():
"""Reset pools to initial state without shutting them down"""
_pool_set.clear()
AsyncCompile.process_pool.cache_clear()
try:
os.register_at_fork(after_in_child=after_fork)
except AttributeError:
pass # register_at_fork does not exists on windows
def get_compile_threads() -> int:
"""
Temporary for internal rollout. Assign config.compile_threads lazily and return it.
TODO: remove after rollout.
"""
if config.compile_threads is None:
config.compile_threads = config.decide_compile_threads()
return config.compile_threads
@clear_on_fresh_cache
class CompiledTritonKernels:
"""
In memory cache for storing compiled triton kernels.
Each triton kernel is keyed by the hash of its source code. Each value stored
in the cache is a return value of AsyncCompile.triton().
Currently, the cache stores Future objects, but it should be generalizable for any kernels.
"""
_cache: dict[str, CodeCacheFuture] = {}
@staticmethod
def key(kernel_src: str):
"""
Generates a cache key given a triton kernel's full source code.
This source includes the inductor meta, compilation metadata, the kernel itself, etc.
`kernel_src` should be the exact string passed to async_compile.triton()'s first argument.
"""
# Hashes the kernel source with torch_key into a single hash key
return code_hash(kernel_src, extra=torch_key())
@staticmethod
def save(kernel_src: str, future: CodeCacheFuture):
"""
Saves a compiled triton kernel to the cache.
TODO: We store a LambdaFuture as that's the callable returned by async_compile.triton,
but the real type we want to return here is actually an abstract triton kernel.
TODO: Source code here is not just the kernel's source code, but also includes the inductor preamble, etc.
so it could be less strict.
"""
key = CompiledTritonKernels.key(kernel_src)
CompiledTritonKernels._cache[key] = future
@staticmethod
def get(kernel_src: str) -> Optional[CodeCacheFuture]:
key = CompiledTritonKernels.key(kernel_src)
return CompiledTritonKernels._cache.get(key, None)
@staticmethod
def cache_clear():
CompiledTritonKernels._cache = {}
@staticmethod
def remove_future(kernel_src: str) -> None:
key = CompiledTritonKernels.key(kernel_src)
# Delete the LambdaFuture if there is one
if key in CompiledTritonKernels._cache:
del CompiledTritonKernels._cache[key]
@contextlib.contextmanager
def async_compile_pool_manager():
"""
Context manager to quiesce the subproc pool at the end of compilation, i.e.,
when dynamo is done.
"""
try:
yield
finally:
AsyncCompile.quiesce()
class AsyncCompile:
"""
Utilities to compile in thread pools or subprocess pools (in the case of Triton).
"""
_ready_future: Optional[Future[Any]] = None
def __init__(self) -> None:
pass
@staticmethod
@functools.lru_cache(1)
def pool() -> ThreadPoolExecutor:
assert get_compile_threads() > 1
return ThreadPoolExecutor(get_compile_threads())
@staticmethod
def _get_ready():
"""No-op function to help mark when the subprocess pool is ready."""
return "ready"
@staticmethod
@functools.lru_cache(1)
def process_pool() -> AnyPool:
assert get_compile_threads() > 1
AsyncCompile._ready_future = None
log.info(
"Creating '%s' pool with %d workers",
config.worker_start_method,
get_compile_threads(),
)
pool: AnyPool
if config.worker_start_method == "subprocess":
# Wrapper around ProcessPoolExecutor forks in a new process we control
pool = SubprocPool(get_compile_threads())
else:
if config.worker_start_method == "spawn":
# Avoid creating pools in the spawned subprocs themselves:
os.environ["TORCH_WARM_POOL"] = "0"
pre_fork_setup()
ctx = multiprocessing.get_context(config.worker_start_method)
pool = TrackedProcessPoolExecutor(
get_compile_threads(),
mp_context=ctx,
initializer=partial(_async_compile_initializer, os.getpid()),
)
# when this pool is created in a subprocess object, the normal exit handler
# doesn't run, and we need to register our own handler.
# exitpriority has to be high, because another one of the finalizers will
# kill the worker thread that sends the shutdown message to the workers...
multiprocessing.util.Finalize(None, pool.shutdown, exitpriority=sys.maxsize)
_pool_set.add(pool)
return pool
@classmethod
def warm_pool(cls) -> None:
if get_compile_threads() <= 1:
return
_compile_start()
# Pool is created on first access. Note for a SubprocPool, the sidecar process starts,
# but its ProcessPoolExecutor does not initialize until a wakeup() call or the first
# job is submitted.
cls.process_pool()
_compile_end()
@classmethod
def wait_pool_ready(cls, timeout=120) -> None:
cls.use_process_pool()
if cls._ready_future is not None:
cls._ready_future.result(timeout=timeout)
@classmethod
def submit(cls, task: Callable[..., Any]) -> Any:
if get_compile_threads() <= 1:
return task()
return cls.pool().submit(task)
@classmethod
def use_process_pool(cls):
if get_compile_threads() <= 1:
return False
# Create a dummy job to check if the pool is ready. Submit it here instead of at
# pool creation so we don't launch the full pool of worker subprocesses until
# we're sure they're needed.
if not cls._ready_future:
cls._ready_future = cls.process_pool().submit(cls._get_ready)
return cls._ready_future.done()
@classmethod
def quiesce(cls) -> None:
"""
If using a SubprocPool, signal the sidecar process to shut down its
ProcessPoolExecutor.
"""
# Don't inadvertently create a process pool if it doesn't already exist:
if not cls.process_pool.cache_info().currsize:
return
if config.quiesce_async_compile_pool:
pool = cls.process_pool()
if isinstance(pool, SubprocPool):
pool.quiesce()
@classmethod
def wakeup(cls) -> None:
"""
If using a SubprocPool, signal the sidecar process to start up its
ProcessPoolExecutor.
"""
if not cls.use_process_pool():
return
pool = cls.process_pool()
if isinstance(pool, SubprocPool):
pool.wakeup()
def triton(self, kernel_name: str, source_code: str, device_str: str = "cuda"):
"""
Async_compile.triton is more complicated than the other backends because
we're trying to optimize compile time as much as possible for this hot callsite.
First of all, the function is cached by CompiledTritonKernels; if there's a kernel
already compiled, we grab it directly from the cache and return.
Otherwise, if we have multiple compile threads, we kick off triton compilations on each
worker process by giving it a kernel and source code to compile. The worker initializes
a CachingAutotuner, runs triton compilation, and pickles the kernel back to us.
We use TritonCompileResult to represent the objects being pickled back to us by each
worker.
Some maybe not obvious things that are pickled back to us:
- Most of the time, we can avoid sending back CachingAutotuner.fn and other metadata
and do not have to pay the cost of loading the triton kernel on the parent. But certain
cases, like coordesc tuning and dynamic_scale_rblock, require us to reload the function
in the parent lazily when we require it.
- The AutotuneCache, if enabled, is constructed on each worker per triton config
and pickled by to us via `CachingAutotuner.save_cache_hook`.
"""
load_kernel = functools.partial(
_load_triton_kernel_from_source, kernel_name, source_code
)
def reload_kernel_in_parent():
# Benchmark how often this happens
with dynamo_timed("reload_kernel_in_parent"):
return load_kernel()
counters["inductor"]["async_compile_cache_miss"] += 1
kernel_code_log.info("Triton Kernel:\n%s", source_code)
_compile_start()
if os.environ.get("TRITON_INTERPRET", "0") == "1":
return getattr(
torch._inductor.codecache.PyCodeCache.load(source_code), kernel_name
)
is_parallel = self.use_process_pool()
set_feature_use("parallel_compile_post_warmup", is_parallel)
compile_id = torch._guards.CompileContext.current_compile_id()
is_backward = getattr(V.graph, "is_backward", False)
if (future := CompiledTritonKernels.get(source_code)) is not None:
counters["inductor"]["async_compile_cache_hit"] += 1
# Set reload_kernel_from_src properly based on source_code
if isinstance(future, StaticAutotunerFuture):
# Remove the future now that we've cache hit
CompiledTritonKernels.remove_future(source_code)
future.reload_kernel_from_src = reload_kernel_in_parent
if is_parallel:
return future
else:
return future.result()
# Cache miss
if is_parallel:
# We want to support changing these env vars after (and while) the
# process pool is running, so pass them to the subprocess to reset.
env_vars = ["TORCHINDUCTOR_CACHE_DIR", "TRITON_CACHE_DIR"]
extra_env = {v: os.environ[v] for v in env_vars if v in os.environ}
extra_config = {
"use_static_cuda_launcher": torch._inductor.config.use_static_cuda_launcher
}
if len(torch._inductor.config.autotune_lookup_table) > 0:
m = size_hints_regex.search(source_code)
if m:
size_hints_str = m.group(1)
else:
size_hints_str = str(None)
triton_src = source_code.split("@triton.jit\n")[1]
from torch._inductor.runtime.triton_heuristics import (
generate_lookup_hash_from_source_code,
)
fn_hash = generate_lookup_hash_from_source_code(
size_hints_str, triton_src
)
if fn_hash in torch._inductor.config.autotune_lookup_table:
extra_config["autotune_lookup_table"] = { # type: ignore[assignment]
fn_hash: torch._inductor.config.autotune_lookup_table[fn_hash]
}
task = self.process_pool().submit(
_worker_compile_triton,
load_kernel,
extra_env,
extra_config,
)
def get_result() -> CachingAutotuner:
try:
kernel, elapsed_us = task.result()
except SubprocException as e:
raise e.with_name(kernel_name) from e
# Now that we've compiled, we should clear the future
# so it can't be used again
kernel.set_compile_info(compile_id, is_backward)
CompiledTritonKernels.remove_future(source_code)
kernel.restore_after_unpickle(old_values=None)
kernel.precompile(
warm_cache_only=False,
reload_kernel=reload_kernel_in_parent,
static_triton_bundle_key=CompiledTritonKernels.key(source_code),
)
info = kernel.autotune_cache_info or {}
info["compile_time_us"] = elapsed_us
_add_triton_kernel_info(kernel_name, info)
get_metrics_context().add_top_n(
"triton_kernel_compile_times_us", kernel_name, elapsed_us
)
return kernel
future = LambdaFuture(get_result, future=task)
CompiledTritonKernels.save(source_code, future)
return future
else:
with dynamo_timed(
"async_compile.precompile",
log_pt2_compile_event=True,
dynamo_compile_column_us="triton_compile_time_us",
log_waitcounter=True,
waitcounter_name_override="compile_triton",
):
fail = None
try:
start_ns = time_ns()
_set_triton_ptxas_path()
kernel = load_kernel()
kernel.set_compile_info(compile_id, is_backward)
kernel.precompile(
warm_cache_only=False,
static_triton_bundle_key=CompiledTritonKernels.key(source_code),
)
elapsed_us = (time_ns() - start_ns) // 1000
get_metrics_context().add_top_n(
"triton_kernel_compile_times_us", kernel_name, elapsed_us
)
info = kernel.autotune_cache_info or {}
info["compile_time_us"] = elapsed_us
_add_triton_kernel_info(kernel_name, info)
return kernel
except Exception as e:
fail = str(e)
raise
finally:
log_triton_builds(fail=fail)
def multi_kernel(self, *args, **kwargs) -> Any:
from torch._inductor.codegen.multi_kernel import MultiKernelCall
# no need to call this in parallel since the sub-kernels are already parallel tasks
return MultiKernelCall(*args, **kwargs)
def size_hint_multi_kernel(self, *args, **kwargs) -> Any:
from torch._inductor.codegen.multi_kernel import SizeHintMultiKernelCall
return SizeHintMultiKernelCall(*args, **kwargs)
def cpp(self, source_code: str):
kernel_code_log.info("CPP Kernel:\n%s", source_code)
if get_compile_threads() <= 1:
return CppCodeCache.load(source_code).kernel
else:
get_result = CppCodeCache.load_async(source_code, submit_fn=self.submit)
return LambdaFuture(lambda: get_result().kernel)
def cpp_pybinding(self, argtypes: list[str], source_code: str):
kernel_code_log.info("CPP+Bindings Kernel:\n%s", source_code)
if get_compile_threads() <= 1:
return CppPythonBindingsCodeCache.load_pybinding(argtypes, source_code)
else:
get_result = CppPythonBindingsCodeCache.load_pybinding_async(
argtypes, source_code, submit_fn=self.submit
)
return LambdaFuture(get_result)
def cuda(self, source_code, dst_file_ext, aot_compile=False):
kernel_code_log.info("CUDA Kernel:\n%s", source_code)
def task():
if aot_compile:
# We rely on JITInductor to compile the CUDA code,
# so that we can load it into AOTInductor.
output_path, *_ = CUDACodeCache.compile(source_code, "o")
CUDACodeCache.aot_kernels_o.append(output_path)
return CUDACodeCache.load(source_code, dst_file_ext)[0]
return self.submit(task)
def rocm(
self,
source_code,
dst_file_ext,
aot_compile=False,
):
kernel_code_log.info("ROCm Kernel:\n%s", source_code)
def task():
if aot_compile:
output_path, *_ = ROCmCodeCache.compile(source_code, dst_file_ext="o")
ROCmCodeCache.aot_kernels_o.append(output_path)
if config.rocm.generate_test_runner:
_ = ROCmCodeCache.compile(source_code, dst_file_ext="exe")
return ROCmCodeCache.load(source_code, dst_file_ext)[0]
return self.submit(task)
def halide(self, meta: HalideMeta, source_code: str):
kernel_code_log.info("Halide Kernel:\n%r\n%s", meta, source_code)
if get_compile_threads() <= 1:
return HalideCodeCache.generate_halide(meta, source_code)
else:
get_result = HalideCodeCache.generate_halide_async(
meta, source_code, submit_fn=self.submit
)
return LambdaFuture(get_result)
def cutedsl(self, kernel_name: str, source_code: str):
"""
Compile CuteDSL (CUTLASS Python DSL) kernels.
Args:
kernel_name: Name of the kernel to be defined
source_code: Source code of the CuteDSL kernel, as a string
Note:
CuteDSL currently requires source files to do its compilation, there we
use the PyCodeCache to write the source code to a file and load it.
"""
from torch._inductor.codegen.cutedsl.cutedsl_kernel import (
CuteDSLKernelWrapper,
MAIN_SUFFIX,
)
kernel_code_log.info("CuteDSL Kernel:\n%s", source_code)
def task():
key, path = torch._inductor.codecache.PyCodeCache.write(source_code)
mod = torch._inductor.codecache.PyCodeCache.load_by_key_path(key, path)
# Find our special entry point named function
main_func_name = f"{kernel_name}_{MAIN_SUFFIX}"
if not hasattr(mod, main_func_name):
available = [name for name in dir(mod) if callable(getattr(mod, name))]
raise RuntimeError(
f"Could not find CuteDSL main kernel function '{main_func_name}'. Available callables: {available}"
)
return CuteDSLKernelWrapper(getattr(mod, main_func_name), kernel_path=path)
if get_compile_threads() <= 1:
return task()
else:
future = self.submit(task)
return LambdaFuture(lambda: future.result())
def wait(self, scope: dict[str, Any]) -> None:
if get_compile_threads() > 1:
with dynamo_timed(
"async_compile.wait",
log_pt2_compile_event=True,
dynamo_compile_column_us="triton_compile_time_us",
log_waitcounter=True,
waitcounter_name_override="compile_triton",
):
self._wait_futures(scope)
_compile_end()
def _wait_futures(self, scope: dict[str, Any]) -> None:
kernels = {
key: value
for key, value in scope.items()
if isinstance(value, (Future, CodeCacheFuture))
}
pbar = tqdm(
total=len(kernels),
desc="Inductor Compilation",
disable=config.disable_progress,
delay=0,
)
for key, result in kernels.items():
if config.verbose_progress and not isinstance(pbar, _Faketqdm):
pbar.set_postfix_str(key)
try:
kernel = result.result()
scope[key] = kernel
except BrokenProcessPool as e:
raise RuntimeError(
"A compilation subprocess exited unexpectedly. This "
"is likely due to a crash. To facilitate debugging, "
"you can re-run with TORCHINDUCTOR_COMPILE_THREADS=1 "
"to cause compilation to occur in the main process."
) from e
pbar.update(1)
def maybe_warm_pool() -> None:
if (
os.environ.get("TORCH_TNT_IN_USE", "0") == "1"
or os.environ.get("TORCH_WARM_POOL", "1") != "1"
# The subprocess pool is only used for the Triton backend
or not has_triton_package()
# Skip for fbcode. We have internal reports of usages inside multiprocessing
# pools that lead a multiplicative number of compile subprocesses.
or config.is_fbcode()
):
return
AsyncCompile.warm_pool()
# TODO: This starts the SubprocPool's internal process pool as early as possible at
# the expense of creating a bunch of worker processes that might not be needed. We
# could start them lazily if we're willing to lose a small amount of compile time.
AsyncCompile.wakeup()
# On exit give the workers a chance to clean themselves up. Without this the
# resource_tracker can complain about leaked semaphores coming from the
# ProcessPoolExecutor:
# UserWarning: resource_tracker: There appear to be 5 leaked semaphore objects
# to clean up at shutdown
atexit.register(shutdown_compile_workers)