mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Introduces a variant of size-hint multi-kernel, where for novel runtime shapes, instead of performing full benchmarking to determine the optimal kernel, selects one of many kernels pre-generated from multi-kernel hints, based off similarity b/w hint / runtime input & output shapes (L1 distance in log2 space). Some caveats/changes: - Size-hint multi-kernel now only kicks in if the kernel has dynamic shapes - Pre-generation still only does 1-d search over specified hints, e.g. `matmul([s0, s1], [s1, s2])` with size-hints `[64, 256]` only generates 2 kernels - based on tuning shapes ([64, 64], [64, 64]) and ([256, 256], [256, 256]). Extending this to reasonable n-d search (via user API?) is an extension Benchmarking results, compared to multi-kernel w/ full benchmarking (hints 64, 4096), and compiling with the ground truth hint: <img width="1902" height="1222" alt="550541081_1088709150049684_6528797079439730237_n" src="https://github.com/user-attachments/assets/056cca48-c16a-4451-9b4a-fa13a7a058a9" /> Full benchmarking doing worse is extremely weird, but we did see similar spikes in #156628 Pull Request resolved: https://github.com/pytorch/pytorch/pull/163090 Approved by: https://github.com/bobrenjc93
693 lines
25 KiB
Python
693 lines
25 KiB
Python
# mypy: allow-untyped-defs
|
|
from __future__ import annotations
|
|
|
|
import atexit
|
|
import contextlib
|
|
import functools
|
|
import json
|
|
import logging
|
|
import multiprocessing
|
|
import os
|
|
import re
|
|
import sys
|
|
from concurrent.futures import Future, ThreadPoolExecutor
|
|
from concurrent.futures.process import BrokenProcessPool
|
|
from functools import partial
|
|
from time import time, time_ns
|
|
from typing import Any, Callable, Optional, TYPE_CHECKING
|
|
|
|
import torch
|
|
from torch._dynamo.device_interface import get_registered_device_interfaces
|
|
from torch._dynamo.utils import (
|
|
counters,
|
|
dynamo_timed,
|
|
get_metrics_context,
|
|
set_feature_use,
|
|
)
|
|
from torch._inductor import config
|
|
from torch._inductor.codecache import (
|
|
_load_triton_kernel_from_source,
|
|
code_hash,
|
|
CodeCacheFuture,
|
|
CppCodeCache,
|
|
CppPythonBindingsCodeCache,
|
|
CUDACodeCache,
|
|
HalideCodeCache,
|
|
LambdaFuture,
|
|
ROCmCodeCache,
|
|
StaticAutotunerFuture,
|
|
torch_key,
|
|
)
|
|
from torch._inductor.compile_worker.subproc_pool import (
|
|
AnyPool,
|
|
SubprocException,
|
|
SubprocPool,
|
|
)
|
|
from torch._inductor.compile_worker.tracked_process_pool import (
|
|
TrackedProcessPoolExecutor,
|
|
)
|
|
from torch._inductor.compile_worker.utils import _async_compile_initializer
|
|
from torch._inductor.runtime.compile_tasks import (
|
|
_set_triton_ptxas_path,
|
|
_worker_compile_triton,
|
|
)
|
|
from torch._inductor.utils import clear_on_fresh_cache
|
|
from torch._inductor.virtualized import V
|
|
from torch._utils_internal import log_triton_builds
|
|
from torch.hub import _Faketqdm, tqdm
|
|
from torch.utils._ordered_set import OrderedSet
|
|
from torch.utils._triton import has_triton_package
|
|
|
|
|
|
if TYPE_CHECKING:
|
|
from torch._inductor.runtime.hints import HalideMeta
|
|
from torch._inductor.runtime.triton_heuristics import CachingAutotuner
|
|
|
|
# timing metrics for time spent in the compilation
|
|
_cumulative_compile_time = 0.0
|
|
_t0: Optional[float] = None
|
|
|
|
kernel_code_log = torch._logging.getArtifactLogger(__name__, "kernel_code")
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
_triton_kernel_metrics: Optional[dict[str, dict[str, Any]]] = None
|
|
|
|
size_hints_regex = re.compile(
|
|
r"size_hints=(\{.*?\})",
|
|
)
|
|
|
|
|
|
def pre_fork_setup():
|
|
"""
|
|
Setup that must be done prior to forking with a process pool.
|
|
"""
|
|
# ensure properties have been calculated before processes
|
|
# are forked
|
|
caching_device_properties()
|
|
|
|
# Computing the triton key can be slow. If we call it before fork,
|
|
# it will be cached for the forked subprocesses.
|
|
from torch._inductor.runtime.triton_compat import HAS_TRITON, triton_key
|
|
|
|
if HAS_TRITON:
|
|
triton_key()
|
|
|
|
|
|
def caching_device_properties():
|
|
for _, device_interface in get_registered_device_interfaces():
|
|
if device_interface.is_available():
|
|
device_interface.Worker.get_device_properties()
|
|
|
|
|
|
def _compile_start() -> None:
|
|
global _t0, _triton_kernel_metrics
|
|
if _t0 is None:
|
|
_t0 = time()
|
|
if _triton_kernel_metrics is None:
|
|
_triton_kernel_metrics = {}
|
|
|
|
|
|
def _compile_end() -> None:
|
|
global _cumulative_compile_time, _t0, _triton_kernel_metrics
|
|
if _t0 is not None:
|
|
t1 = time()
|
|
_cumulative_compile_time += t1 - _t0
|
|
_t0 = None
|
|
# print("CUMULATIVE COMPILE TIME", _cumulative_compile_time)
|
|
if _triton_kernel_metrics:
|
|
# Log triton kernel info
|
|
sorted_info = dict(sorted(_triton_kernel_metrics.items()))
|
|
torch._logging.trace_structured(
|
|
"artifact",
|
|
metadata_fn=lambda: {
|
|
"name": "triton_kernel_info",
|
|
"encoding": "json",
|
|
},
|
|
payload_fn=lambda: json.dumps(sorted_info),
|
|
)
|
|
_triton_kernel_metrics = None
|
|
|
|
|
|
def _add_triton_kernel_info(kernel_name: str, info: dict[str, Any]):
|
|
global _triton_kernel_metrics
|
|
# Must be called between _compile_start and _compile_end
|
|
if _triton_kernel_metrics is not None:
|
|
_triton_kernel_metrics[kernel_name] = info
|
|
|
|
|
|
_IS_WINDOWS = sys.platform == "win32"
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
# Used to keep track of all process pools invoked so far.
|
|
_pool_set = OrderedSet[AnyPool]()
|
|
|
|
|
|
def shutdown_compile_workers() -> None:
|
|
"""Shut down all outstanding compile-worker pools."""
|
|
for pool in _pool_set:
|
|
pool.shutdown()
|
|
AsyncCompile._ready_future = None
|
|
after_fork()
|
|
|
|
|
|
def after_fork():
|
|
"""Reset pools to initial state without shutting them down"""
|
|
_pool_set.clear()
|
|
AsyncCompile.process_pool.cache_clear()
|
|
|
|
|
|
try:
|
|
os.register_at_fork(after_in_child=after_fork)
|
|
except AttributeError:
|
|
pass # register_at_fork does not exists on windows
|
|
|
|
|
|
def get_compile_threads() -> int:
|
|
"""
|
|
Temporary for internal rollout. Assign config.compile_threads lazily and return it.
|
|
TODO: remove after rollout.
|
|
"""
|
|
if config.compile_threads is None:
|
|
config.compile_threads = config.decide_compile_threads()
|
|
return config.compile_threads
|
|
|
|
|
|
@clear_on_fresh_cache
|
|
class CompiledTritonKernels:
|
|
"""
|
|
In memory cache for storing compiled triton kernels.
|
|
|
|
Each triton kernel is keyed by the hash of its source code. Each value stored
|
|
in the cache is a return value of AsyncCompile.triton().
|
|
|
|
Currently, the cache stores Future objects, but it should be generalizable for any kernels.
|
|
"""
|
|
|
|
_cache: dict[str, CodeCacheFuture] = {}
|
|
|
|
@staticmethod
|
|
def key(kernel_src: str):
|
|
"""
|
|
Generates a cache key given a triton kernel's full source code.
|
|
This source includes the inductor meta, compilation metadata, the kernel itself, etc.
|
|
`kernel_src` should be the exact string passed to async_compile.triton()'s first argument.
|
|
"""
|
|
# Hashes the kernel source with torch_key into a single hash key
|
|
return code_hash(kernel_src, extra=torch_key())
|
|
|
|
@staticmethod
|
|
def save(kernel_src: str, future: CodeCacheFuture):
|
|
"""
|
|
Saves a compiled triton kernel to the cache.
|
|
TODO: We store a LambdaFuture as that's the callable returned by async_compile.triton,
|
|
but the real type we want to return here is actually an abstract triton kernel.
|
|
|
|
TODO: Source code here is not just the kernel's source code, but also includes the inductor preamble, etc.
|
|
so it could be less strict.
|
|
"""
|
|
key = CompiledTritonKernels.key(kernel_src)
|
|
CompiledTritonKernels._cache[key] = future
|
|
|
|
@staticmethod
|
|
def get(kernel_src: str) -> Optional[CodeCacheFuture]:
|
|
key = CompiledTritonKernels.key(kernel_src)
|
|
return CompiledTritonKernels._cache.get(key, None)
|
|
|
|
@staticmethod
|
|
def cache_clear():
|
|
CompiledTritonKernels._cache = {}
|
|
|
|
@staticmethod
|
|
def remove_future(kernel_src: str) -> None:
|
|
key = CompiledTritonKernels.key(kernel_src)
|
|
|
|
# Delete the LambdaFuture if there is one
|
|
if key in CompiledTritonKernels._cache:
|
|
del CompiledTritonKernels._cache[key]
|
|
|
|
|
|
@contextlib.contextmanager
|
|
def async_compile_pool_manager():
|
|
"""
|
|
Context manager to quiesce the subproc pool at the end of compilation, i.e.,
|
|
when dynamo is done.
|
|
"""
|
|
try:
|
|
yield
|
|
finally:
|
|
AsyncCompile.quiesce()
|
|
|
|
|
|
class AsyncCompile:
|
|
"""
|
|
Utilities to compile in thread pools or subprocess pools (in the case of Triton).
|
|
"""
|
|
|
|
_ready_future: Optional[Future[Any]] = None
|
|
|
|
def __init__(self) -> None:
|
|
pass
|
|
|
|
@staticmethod
|
|
@functools.lru_cache(1)
|
|
def pool() -> ThreadPoolExecutor:
|
|
assert get_compile_threads() > 1
|
|
return ThreadPoolExecutor(get_compile_threads())
|
|
|
|
@staticmethod
|
|
def _get_ready():
|
|
"""No-op function to help mark when the subprocess pool is ready."""
|
|
return "ready"
|
|
|
|
@staticmethod
|
|
@functools.lru_cache(1)
|
|
def process_pool() -> AnyPool:
|
|
assert get_compile_threads() > 1
|
|
AsyncCompile._ready_future = None
|
|
log.info(
|
|
"Creating '%s' pool with %d workers",
|
|
config.worker_start_method,
|
|
get_compile_threads(),
|
|
)
|
|
|
|
pool: AnyPool
|
|
if config.worker_start_method == "subprocess":
|
|
# Wrapper around ProcessPoolExecutor forks in a new process we control
|
|
pool = SubprocPool(get_compile_threads())
|
|
else:
|
|
if config.worker_start_method == "spawn":
|
|
# Avoid creating pools in the spawned subprocs themselves:
|
|
os.environ["TORCH_WARM_POOL"] = "0"
|
|
pre_fork_setup()
|
|
ctx = multiprocessing.get_context(config.worker_start_method)
|
|
pool = TrackedProcessPoolExecutor(
|
|
get_compile_threads(),
|
|
mp_context=ctx,
|
|
initializer=partial(_async_compile_initializer, os.getpid()),
|
|
)
|
|
# when this pool is created in a subprocess object, the normal exit handler
|
|
# doesn't run, and we need to register our own handler.
|
|
# exitpriority has to be high, because another one of the finalizers will
|
|
# kill the worker thread that sends the shutdown message to the workers...
|
|
multiprocessing.util.Finalize(None, pool.shutdown, exitpriority=sys.maxsize)
|
|
|
|
_pool_set.add(pool)
|
|
return pool
|
|
|
|
@classmethod
|
|
def warm_pool(cls) -> None:
|
|
if get_compile_threads() <= 1:
|
|
return
|
|
_compile_start()
|
|
# Pool is created on first access. Note for a SubprocPool, the sidecar process starts,
|
|
# but its ProcessPoolExecutor does not initialize until a wakeup() call or the first
|
|
# job is submitted.
|
|
cls.process_pool()
|
|
_compile_end()
|
|
|
|
@classmethod
|
|
def wait_pool_ready(cls, timeout=120) -> None:
|
|
cls.use_process_pool()
|
|
if cls._ready_future is not None:
|
|
cls._ready_future.result(timeout=timeout)
|
|
|
|
@classmethod
|
|
def submit(cls, task: Callable[..., Any]) -> Any:
|
|
if get_compile_threads() <= 1:
|
|
return task()
|
|
return cls.pool().submit(task)
|
|
|
|
@classmethod
|
|
def use_process_pool(cls):
|
|
if get_compile_threads() <= 1:
|
|
return False
|
|
|
|
# Create a dummy job to check if the pool is ready. Submit it here instead of at
|
|
# pool creation so we don't launch the full pool of worker subprocesses until
|
|
# we're sure they're needed.
|
|
if not cls._ready_future:
|
|
cls._ready_future = cls.process_pool().submit(cls._get_ready)
|
|
return cls._ready_future.done()
|
|
|
|
@classmethod
|
|
def quiesce(cls) -> None:
|
|
"""
|
|
If using a SubprocPool, signal the sidecar process to shut down its
|
|
ProcessPoolExecutor.
|
|
"""
|
|
# Don't inadvertently create a process pool if it doesn't already exist:
|
|
if not cls.process_pool.cache_info().currsize:
|
|
return
|
|
if config.quiesce_async_compile_pool:
|
|
pool = cls.process_pool()
|
|
if isinstance(pool, SubprocPool):
|
|
pool.quiesce()
|
|
|
|
@classmethod
|
|
def wakeup(cls) -> None:
|
|
"""
|
|
If using a SubprocPool, signal the sidecar process to start up its
|
|
ProcessPoolExecutor.
|
|
"""
|
|
if not cls.use_process_pool():
|
|
return
|
|
pool = cls.process_pool()
|
|
if isinstance(pool, SubprocPool):
|
|
pool.wakeup()
|
|
|
|
def triton(self, kernel_name: str, source_code: str, device_str: str = "cuda"):
|
|
"""
|
|
Async_compile.triton is more complicated than the other backends because
|
|
we're trying to optimize compile time as much as possible for this hot callsite.
|
|
|
|
First of all, the function is cached by CompiledTritonKernels; if there's a kernel
|
|
already compiled, we grab it directly from the cache and return.
|
|
|
|
Otherwise, if we have multiple compile threads, we kick off triton compilations on each
|
|
worker process by giving it a kernel and source code to compile. The worker initializes
|
|
a CachingAutotuner, runs triton compilation, and pickles the kernel back to us.
|
|
We use TritonCompileResult to represent the objects being pickled back to us by each
|
|
worker.
|
|
|
|
Some maybe not obvious things that are pickled back to us:
|
|
- Most of the time, we can avoid sending back CachingAutotuner.fn and other metadata
|
|
and do not have to pay the cost of loading the triton kernel on the parent. But certain
|
|
cases, like coordesc tuning and dynamic_scale_rblock, require us to reload the function
|
|
in the parent lazily when we require it.
|
|
- The AutotuneCache, if enabled, is constructed on each worker per triton config
|
|
and pickled by to us via `CachingAutotuner.save_cache_hook`.
|
|
"""
|
|
load_kernel = functools.partial(
|
|
_load_triton_kernel_from_source, kernel_name, source_code
|
|
)
|
|
|
|
def reload_kernel_in_parent():
|
|
# Benchmark how often this happens
|
|
with dynamo_timed("reload_kernel_in_parent"):
|
|
return load_kernel()
|
|
|
|
counters["inductor"]["async_compile_cache_miss"] += 1
|
|
|
|
kernel_code_log.info("Triton Kernel:\n%s", source_code)
|
|
_compile_start()
|
|
|
|
if os.environ.get("TRITON_INTERPRET", "0") == "1":
|
|
return getattr(
|
|
torch._inductor.codecache.PyCodeCache.load(source_code), kernel_name
|
|
)
|
|
|
|
is_parallel = self.use_process_pool()
|
|
set_feature_use("parallel_compile_post_warmup", is_parallel)
|
|
|
|
compile_id = torch._guards.CompileContext.current_compile_id()
|
|
is_backward = getattr(V.graph, "is_backward", False)
|
|
|
|
if (future := CompiledTritonKernels.get(source_code)) is not None:
|
|
counters["inductor"]["async_compile_cache_hit"] += 1
|
|
# Set reload_kernel_from_src properly based on source_code
|
|
if isinstance(future, StaticAutotunerFuture):
|
|
# Remove the future now that we've cache hit
|
|
CompiledTritonKernels.remove_future(source_code)
|
|
future.reload_kernel_from_src = reload_kernel_in_parent
|
|
if is_parallel:
|
|
return future
|
|
else:
|
|
return future.result()
|
|
|
|
# Cache miss
|
|
if is_parallel:
|
|
# We want to support changing these env vars after (and while) the
|
|
# process pool is running, so pass them to the subprocess to reset.
|
|
env_vars = ["TORCHINDUCTOR_CACHE_DIR", "TRITON_CACHE_DIR"]
|
|
extra_env = {v: os.environ[v] for v in env_vars if v in os.environ}
|
|
extra_config = {
|
|
"use_static_cuda_launcher": torch._inductor.config.use_static_cuda_launcher
|
|
}
|
|
|
|
if len(torch._inductor.config.autotune_lookup_table) > 0:
|
|
m = size_hints_regex.search(source_code)
|
|
if m:
|
|
size_hints_str = m.group(1)
|
|
else:
|
|
size_hints_str = str(None)
|
|
|
|
triton_src = source_code.split("@triton.jit\n")[1]
|
|
from torch._inductor.runtime.triton_heuristics import (
|
|
generate_lookup_hash_from_source_code,
|
|
)
|
|
|
|
fn_hash = generate_lookup_hash_from_source_code(
|
|
size_hints_str, triton_src
|
|
)
|
|
|
|
if fn_hash in torch._inductor.config.autotune_lookup_table:
|
|
extra_config["autotune_lookup_table"] = { # type: ignore[assignment]
|
|
fn_hash: torch._inductor.config.autotune_lookup_table[fn_hash]
|
|
}
|
|
|
|
task = self.process_pool().submit(
|
|
_worker_compile_triton,
|
|
load_kernel,
|
|
extra_env,
|
|
extra_config,
|
|
)
|
|
|
|
def get_result() -> CachingAutotuner:
|
|
try:
|
|
kernel, elapsed_us = task.result()
|
|
except SubprocException as e:
|
|
raise e.with_name(kernel_name) from e
|
|
|
|
# Now that we've compiled, we should clear the future
|
|
# so it can't be used again
|
|
kernel.set_compile_info(compile_id, is_backward)
|
|
CompiledTritonKernels.remove_future(source_code)
|
|
|
|
kernel.restore_after_unpickle(old_values=None)
|
|
|
|
kernel.precompile(
|
|
warm_cache_only=False,
|
|
reload_kernel=reload_kernel_in_parent,
|
|
static_triton_bundle_key=CompiledTritonKernels.key(source_code),
|
|
)
|
|
info = kernel.autotune_cache_info or {}
|
|
info["compile_time_us"] = elapsed_us
|
|
_add_triton_kernel_info(kernel_name, info)
|
|
get_metrics_context().add_top_n(
|
|
"triton_kernel_compile_times_us", kernel_name, elapsed_us
|
|
)
|
|
return kernel
|
|
|
|
future = LambdaFuture(get_result, future=task)
|
|
CompiledTritonKernels.save(source_code, future)
|
|
return future
|
|
else:
|
|
with dynamo_timed(
|
|
"async_compile.precompile",
|
|
log_pt2_compile_event=True,
|
|
dynamo_compile_column_us="triton_compile_time_us",
|
|
log_waitcounter=True,
|
|
waitcounter_name_override="compile_triton",
|
|
):
|
|
fail = None
|
|
try:
|
|
start_ns = time_ns()
|
|
_set_triton_ptxas_path()
|
|
kernel = load_kernel()
|
|
kernel.set_compile_info(compile_id, is_backward)
|
|
kernel.precompile(
|
|
warm_cache_only=False,
|
|
static_triton_bundle_key=CompiledTritonKernels.key(source_code),
|
|
)
|
|
elapsed_us = (time_ns() - start_ns) // 1000
|
|
get_metrics_context().add_top_n(
|
|
"triton_kernel_compile_times_us", kernel_name, elapsed_us
|
|
)
|
|
info = kernel.autotune_cache_info or {}
|
|
info["compile_time_us"] = elapsed_us
|
|
_add_triton_kernel_info(kernel_name, info)
|
|
return kernel
|
|
except Exception as e:
|
|
fail = str(e)
|
|
raise
|
|
finally:
|
|
log_triton_builds(fail=fail)
|
|
|
|
def multi_kernel(self, *args, **kwargs) -> Any:
|
|
from torch._inductor.codegen.multi_kernel import MultiKernelCall
|
|
|
|
# no need to call this in parallel since the sub-kernels are already parallel tasks
|
|
return MultiKernelCall(*args, **kwargs)
|
|
|
|
def size_hint_multi_kernel(self, *args, **kwargs) -> Any:
|
|
from torch._inductor.codegen.multi_kernel import SizeHintMultiKernelCall
|
|
|
|
return SizeHintMultiKernelCall(*args, **kwargs)
|
|
|
|
def cpp(self, source_code: str):
|
|
kernel_code_log.info("CPP Kernel:\n%s", source_code)
|
|
if get_compile_threads() <= 1:
|
|
return CppCodeCache.load(source_code).kernel
|
|
else:
|
|
get_result = CppCodeCache.load_async(source_code, submit_fn=self.submit)
|
|
return LambdaFuture(lambda: get_result().kernel)
|
|
|
|
def cpp_pybinding(self, argtypes: list[str], source_code: str):
|
|
kernel_code_log.info("CPP+Bindings Kernel:\n%s", source_code)
|
|
if get_compile_threads() <= 1:
|
|
return CppPythonBindingsCodeCache.load_pybinding(argtypes, source_code)
|
|
else:
|
|
get_result = CppPythonBindingsCodeCache.load_pybinding_async(
|
|
argtypes, source_code, submit_fn=self.submit
|
|
)
|
|
return LambdaFuture(get_result)
|
|
|
|
def cuda(self, source_code, dst_file_ext, aot_compile=False):
|
|
kernel_code_log.info("CUDA Kernel:\n%s", source_code)
|
|
|
|
def task():
|
|
if aot_compile:
|
|
# We rely on JITInductor to compile the CUDA code,
|
|
# so that we can load it into AOTInductor.
|
|
output_path, *_ = CUDACodeCache.compile(source_code, "o")
|
|
CUDACodeCache.aot_kernels_o.append(output_path)
|
|
return CUDACodeCache.load(source_code, dst_file_ext)[0]
|
|
|
|
return self.submit(task)
|
|
|
|
def rocm(
|
|
self,
|
|
source_code,
|
|
dst_file_ext,
|
|
aot_compile=False,
|
|
):
|
|
kernel_code_log.info("ROCm Kernel:\n%s", source_code)
|
|
|
|
def task():
|
|
if aot_compile:
|
|
output_path, *_ = ROCmCodeCache.compile(source_code, dst_file_ext="o")
|
|
ROCmCodeCache.aot_kernels_o.append(output_path)
|
|
if config.rocm.generate_test_runner:
|
|
_ = ROCmCodeCache.compile(source_code, dst_file_ext="exe")
|
|
return ROCmCodeCache.load(source_code, dst_file_ext)[0]
|
|
|
|
return self.submit(task)
|
|
|
|
def halide(self, meta: HalideMeta, source_code: str):
|
|
kernel_code_log.info("Halide Kernel:\n%r\n%s", meta, source_code)
|
|
if get_compile_threads() <= 1:
|
|
return HalideCodeCache.generate_halide(meta, source_code)
|
|
else:
|
|
get_result = HalideCodeCache.generate_halide_async(
|
|
meta, source_code, submit_fn=self.submit
|
|
)
|
|
return LambdaFuture(get_result)
|
|
|
|
def cutedsl(self, kernel_name: str, source_code: str):
|
|
"""
|
|
Compile CuteDSL (CUTLASS Python DSL) kernels.
|
|
|
|
Args:
|
|
kernel_name: Name of the kernel to be defined
|
|
source_code: Source code of the CuteDSL kernel, as a string
|
|
|
|
Note:
|
|
CuteDSL currently requires source files to do its compilation, there we
|
|
use the PyCodeCache to write the source code to a file and load it.
|
|
"""
|
|
from torch._inductor.codegen.cutedsl.cutedsl_kernel import (
|
|
CuteDSLKernelWrapper,
|
|
MAIN_SUFFIX,
|
|
)
|
|
|
|
kernel_code_log.info("CuteDSL Kernel:\n%s", source_code)
|
|
|
|
def task():
|
|
key, path = torch._inductor.codecache.PyCodeCache.write(source_code)
|
|
mod = torch._inductor.codecache.PyCodeCache.load_by_key_path(key, path)
|
|
|
|
# Find our special entry point named function
|
|
main_func_name = f"{kernel_name}_{MAIN_SUFFIX}"
|
|
if not hasattr(mod, main_func_name):
|
|
available = [name for name in dir(mod) if callable(getattr(mod, name))]
|
|
raise RuntimeError(
|
|
f"Could not find CuteDSL main kernel function '{main_func_name}'. Available callables: {available}"
|
|
)
|
|
|
|
return CuteDSLKernelWrapper(getattr(mod, main_func_name), kernel_path=path)
|
|
|
|
if get_compile_threads() <= 1:
|
|
return task()
|
|
else:
|
|
future = self.submit(task)
|
|
return LambdaFuture(lambda: future.result())
|
|
|
|
def wait(self, scope: dict[str, Any]) -> None:
|
|
if get_compile_threads() > 1:
|
|
with dynamo_timed(
|
|
"async_compile.wait",
|
|
log_pt2_compile_event=True,
|
|
dynamo_compile_column_us="triton_compile_time_us",
|
|
log_waitcounter=True,
|
|
waitcounter_name_override="compile_triton",
|
|
):
|
|
self._wait_futures(scope)
|
|
|
|
_compile_end()
|
|
|
|
def _wait_futures(self, scope: dict[str, Any]) -> None:
|
|
kernels = {
|
|
key: value
|
|
for key, value in scope.items()
|
|
if isinstance(value, (Future, CodeCacheFuture))
|
|
}
|
|
pbar = tqdm(
|
|
total=len(kernels),
|
|
desc="Inductor Compilation",
|
|
disable=config.disable_progress,
|
|
delay=0,
|
|
)
|
|
for key, result in kernels.items():
|
|
if config.verbose_progress and not isinstance(pbar, _Faketqdm):
|
|
pbar.set_postfix_str(key)
|
|
try:
|
|
kernel = result.result()
|
|
scope[key] = kernel
|
|
except BrokenProcessPool as e:
|
|
raise RuntimeError(
|
|
"A compilation subprocess exited unexpectedly. This "
|
|
"is likely due to a crash. To facilitate debugging, "
|
|
"you can re-run with TORCHINDUCTOR_COMPILE_THREADS=1 "
|
|
"to cause compilation to occur in the main process."
|
|
) from e
|
|
pbar.update(1)
|
|
|
|
|
|
def maybe_warm_pool() -> None:
|
|
if (
|
|
os.environ.get("TORCH_TNT_IN_USE", "0") == "1"
|
|
or os.environ.get("TORCH_WARM_POOL", "1") != "1"
|
|
# The subprocess pool is only used for the Triton backend
|
|
or not has_triton_package()
|
|
# Skip for fbcode. We have internal reports of usages inside multiprocessing
|
|
# pools that lead a multiplicative number of compile subprocesses.
|
|
or config.is_fbcode()
|
|
):
|
|
return
|
|
|
|
AsyncCompile.warm_pool()
|
|
# TODO: This starts the SubprocPool's internal process pool as early as possible at
|
|
# the expense of creating a bunch of worker processes that might not be needed. We
|
|
# could start them lazily if we're willing to lose a small amount of compile time.
|
|
AsyncCompile.wakeup()
|
|
|
|
|
|
# On exit give the workers a chance to clean themselves up. Without this the
|
|
# resource_tracker can complain about leaked semaphores coming from the
|
|
# ProcessPoolExecutor:
|
|
# UserWarning: resource_tracker: There appear to be 5 leaked semaphore objects
|
|
# to clean up at shutdown
|
|
atexit.register(shutdown_compile_workers)
|