mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Pull Request resolved: https://github.com/pytorch/pytorch/pull/159467 Approved by: https://github.com/eellison
891 lines
29 KiB
Python
891 lines
29 KiB
Python
# mypy: allow-untyped-defs
|
|
from __future__ import annotations
|
|
|
|
import atexit
|
|
import ctypes
|
|
import dataclasses
|
|
import functools
|
|
import logging
|
|
import os
|
|
import pickle
|
|
import queue
|
|
import selectors
|
|
import subprocess
|
|
import sys
|
|
import time
|
|
import warnings
|
|
from collections.abc import Iterable, Sequence
|
|
from concurrent.futures import ThreadPoolExecutor
|
|
from ctypes import byref, c_size_t, c_void_p, CDLL
|
|
from typing import Any, Callable, IO, Optional, TYPE_CHECKING, Union
|
|
|
|
import torch
|
|
import torch._inductor.async_compile # noqa: F401 required to warm up AsyncCompile pools
|
|
from torch._dynamo.device_interface import get_interface_for_device
|
|
from torch._dynamo.testing import rand_strided
|
|
from torch._inductor import ir
|
|
from torch._inductor.codecache import (
|
|
CppCodeCache,
|
|
CUDACodeCache,
|
|
DLLWrapper,
|
|
get_hash,
|
|
PyCodeCache,
|
|
)
|
|
from torch._inductor.utils import get_gpu_type, get_ld_library_path, is_gpu
|
|
from torch._logging import getArtifactLogger
|
|
from torch.utils._ordered_set import OrderedSet
|
|
|
|
|
|
if TYPE_CHECKING:
|
|
from types import ModuleType
|
|
|
|
from torch._inductor.select_algorithm import TritonTemplateCaller
|
|
|
|
from . import config
|
|
from .runtime.benchmarking import benchmarker
|
|
from .virtualized import V
|
|
|
|
|
|
CUDA_VISIBLE_DEVICES = "CUDA_VISIBLE_DEVICES"
|
|
|
|
autotuning_log = getArtifactLogger(__name__, "autotuning")
|
|
|
|
|
|
class NonzeroWorkspaceNotSupportedError(Exception):
|
|
pass
|
|
|
|
|
|
class TuningProcess:
|
|
"""
|
|
Class to launch and interact with a benchmarking subprocess.
|
|
"""
|
|
|
|
@staticmethod
|
|
def process_main(read_pipe: IO[bytes], write_pipe: IO[bytes]) -> None:
|
|
"""
|
|
Entry point for the child process.
|
|
"""
|
|
autotuning_log.debug(
|
|
"Started autotune subprocess %s. Visible devices: %s",
|
|
os.getpid(),
|
|
os.environ.get(CUDA_VISIBLE_DEVICES),
|
|
)
|
|
|
|
def workloop():
|
|
while True:
|
|
job = TuningProcess.recv(read_pipe)
|
|
if job is None:
|
|
# None is a sentinel for the child to shut down
|
|
break
|
|
try:
|
|
result = job()
|
|
except Exception as e:
|
|
result = e
|
|
TuningProcess.send(result, write_pipe)
|
|
|
|
try:
|
|
workloop()
|
|
except EOFError:
|
|
# The parent closed the pipe
|
|
pass
|
|
|
|
@staticmethod
|
|
def send(obj: Any, write_pipe: IO[bytes]) -> None:
|
|
pickle.dump(obj, write_pipe)
|
|
write_pipe.flush()
|
|
|
|
@staticmethod
|
|
def recv(read_pipe: IO[bytes]) -> Any:
|
|
return pickle.load(read_pipe)
|
|
|
|
def __init__(self, device: Optional[int]):
|
|
self.device = device
|
|
self.start()
|
|
|
|
def start(self):
|
|
"""
|
|
Start the benchmarking subprocess.
|
|
"""
|
|
entry = os.path.join(os.path.dirname(__file__), "__autotune_main__.py")
|
|
|
|
subproc_read_fd, write_fd = os.pipe()
|
|
read_fd, subproc_write_fd = os.pipe()
|
|
self.write_pipe = os.fdopen(write_fd, "wb")
|
|
self.read_pipe = os.fdopen(read_fd, "rb")
|
|
|
|
self.selector = selectors.DefaultSelector()
|
|
self.selector.register(self.read_pipe, selectors.EVENT_READ)
|
|
|
|
cmd = [
|
|
sys.executable,
|
|
entry,
|
|
f"--parent={os.getpid()}",
|
|
f"--read-fd={str(subproc_read_fd)}",
|
|
f"--write-fd={str(subproc_write_fd)}",
|
|
]
|
|
extra_env = {
|
|
# We need to set the PYTHONPATH so the subprocess can find torch.
|
|
"PYTHONPATH": os.environ.get(
|
|
"TORCH_CUSTOM_PYTHONPATH", os.pathsep.join(sys.path)
|
|
),
|
|
# We shouldn't be using the Triton async compile subprocess pool,
|
|
# but as a precaution set the env var that disables its creation.
|
|
"TORCH_WARM_POOL": "0",
|
|
# Some internal usages need a modified LD_LIBRARY_PATH.
|
|
"LD_LIBRARY_PATH": get_ld_library_path(),
|
|
# This will cause the subprocs to profile using the profiler.
|
|
"TORCHINDUCTOR_PROFILE_WITH_DO_BENCH_USING_PROFILING": "1"
|
|
if config.profile_bandwidth_with_do_bench_using_profiling
|
|
else "0",
|
|
}
|
|
if self.device is not None:
|
|
extra_env[CUDA_VISIBLE_DEVICES] = str(self.device)
|
|
self.process = subprocess.Popen(
|
|
cmd,
|
|
env={**os.environ, **extra_env},
|
|
pass_fds=(subproc_read_fd, subproc_write_fd),
|
|
)
|
|
os.close(subproc_read_fd)
|
|
os.close(subproc_write_fd)
|
|
|
|
self.running = True
|
|
|
|
def alive(self) -> bool:
|
|
"""
|
|
True if the subprocess is still running.
|
|
"""
|
|
return self.running and self.process.poll() is None
|
|
|
|
def put(self, req: Any) -> None:
|
|
"""
|
|
Push a work item to the child process.
|
|
"""
|
|
if not self.alive():
|
|
self.start()
|
|
TuningProcess.send(req, self.write_pipe)
|
|
|
|
def get(self, timeout: float = 120.0) -> Any:
|
|
"""
|
|
Get a response from the child process. Raises TimeoutError on timeout;
|
|
raises EOFError if the subprocess crashes.
|
|
"""
|
|
try:
|
|
if not self.selector.select(timeout):
|
|
raise TimeoutError(f"Timeout in autotune subprocess {self.process.pid}")
|
|
result = TuningProcess.recv(self.read_pipe)
|
|
except TimeoutError:
|
|
self.kill()
|
|
raise
|
|
except EOFError:
|
|
# The subprocess crashed
|
|
self.close()
|
|
raise
|
|
except Exception:
|
|
autotuning_log.exception(
|
|
"Unexpected exception in autotune subprocess %s", self.process.pid
|
|
)
|
|
self.kill()
|
|
raise
|
|
|
|
if isinstance(result, Exception):
|
|
raise result
|
|
return result
|
|
|
|
def shutdown(self, wait: bool = True) -> None:
|
|
"""
|
|
Signal the child process to shut down gracefully.
|
|
"""
|
|
if self.alive():
|
|
TuningProcess.send(None, self.write_pipe)
|
|
if wait:
|
|
self.wait()
|
|
|
|
def wait(self) -> None:
|
|
"""
|
|
Wait for the child process to exit.
|
|
"""
|
|
if self.alive():
|
|
self.process.wait()
|
|
self.close()
|
|
|
|
def close(self) -> None:
|
|
"""
|
|
Close resources.
|
|
"""
|
|
self.selector.close()
|
|
self.read_pipe.close()
|
|
self.write_pipe.close()
|
|
self.running = False
|
|
|
|
def kill(self) -> None:
|
|
"""
|
|
Send a SIGKILL to the child process.
|
|
"""
|
|
if self.alive():
|
|
autotuning_log.error(
|
|
"Sending SIGKILL to autotune subprocess %d",
|
|
self.process.pid,
|
|
)
|
|
self.process.kill()
|
|
self.close()
|
|
|
|
|
|
class TuningProcessPool:
|
|
"""
|
|
Maintains a pool of TuningProcesses to benchmark kernels in parallel
|
|
across devices. By default, we create one TuningProcess per device and
|
|
set the sub-process environment to make only that device visible.
|
|
"""
|
|
|
|
def __init__(self) -> None:
|
|
"""
|
|
Start the child processes.
|
|
"""
|
|
devices = self.get_device_list()
|
|
autotuning_log.debug("Sub-process autotune device list: %s", devices)
|
|
|
|
# Launch the child processes.
|
|
self.processes = [TuningProcess(device=device) for device in devices]
|
|
|
|
self.process_queue: queue.Queue[TuningProcess] = queue.Queue()
|
|
for p in self.processes:
|
|
self.process_queue.put(p)
|
|
|
|
# Use a thread pool to manage distributing work to the subprocesses.
|
|
# Threads block on an available process, so it makes sense to match
|
|
# the number of threads with the number of devices.
|
|
self.executor = ThreadPoolExecutor(max_workers=len(devices))
|
|
|
|
@staticmethod
|
|
def get_device_list() -> Sequence[Optional[int]]:
|
|
"""
|
|
Gather the list of devices to be used in the pool.
|
|
"""
|
|
if not config.autotune_multi_device:
|
|
# Don't use multiple devices
|
|
return [None]
|
|
|
|
gpu_type = get_gpu_type()
|
|
device_interface = get_interface_for_device(gpu_type)
|
|
count = device_interface.device_count()
|
|
|
|
# If the user specified the visible devices in the env, use those.
|
|
if CUDA_VISIBLE_DEVICES in os.environ:
|
|
devices = [int(d) for d in os.environ[CUDA_VISIBLE_DEVICES].split(",")]
|
|
assert len(devices) <= count
|
|
return devices
|
|
|
|
return list(range(count))
|
|
|
|
def shutdown(self) -> None:
|
|
"""
|
|
Signal all child processes to exit.
|
|
"""
|
|
self.executor.shutdown()
|
|
|
|
for p in self.processes:
|
|
p.shutdown(wait=False)
|
|
for p in self.processes:
|
|
p.wait()
|
|
|
|
def target(self, choice: TritonTemplateCaller) -> float:
|
|
"""
|
|
Entry point for the thread-pool helper threads: Wait for an open TuningProcess,
|
|
remove it from the queue, execute the benchmark in that subprocess, and return
|
|
the TuningProcess to the queue.
|
|
"""
|
|
assert choice.bmreq is not None
|
|
|
|
process = self.process_queue.get()
|
|
process.put(choice.bmreq.benchmark)
|
|
try:
|
|
return process.get(
|
|
config.max_autotune_subproc_result_timeout_seconds,
|
|
)
|
|
except TimeoutError:
|
|
warnings.warn(
|
|
f"Timed out benchmarking choice '{choice}'. It will be ignored. "
|
|
"Please debug the root cause in case the choice can bring perf gains."
|
|
)
|
|
# Set to INF so this choice will be ignored
|
|
return float("inf")
|
|
except Exception:
|
|
warnings.warn(
|
|
f"Failed to benchmark choice '{choice}'. It will be ignored. "
|
|
"Please debug the root cause in case the choice can bring perf gains."
|
|
)
|
|
# Set to INF so this choice will be ignored
|
|
return float("inf")
|
|
finally:
|
|
self.process_queue.put(process)
|
|
|
|
def benchmark(
|
|
self,
|
|
choices: list[TritonTemplateCaller],
|
|
) -> dict[TritonTemplateCaller, float]:
|
|
"""
|
|
Benchmark each choice in a separate process.
|
|
"""
|
|
|
|
# Use a ThreadExecutorPool to spread the work across the subprocesses and
|
|
# to grab subprocesses as soon as they're free.
|
|
results = dict(zip(choices, self.executor.map(self.target, choices)))
|
|
|
|
return results
|
|
|
|
|
|
LayoutOrBuffer = Union[ir.Layout, ir.Buffer]
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class TensorMeta:
|
|
device: torch.device
|
|
dtype: torch.dtype
|
|
sizes: torch._prims_common.ShapeType
|
|
strides: torch._prims_common.StrideType
|
|
offset: int
|
|
name: Optional[str] = None
|
|
|
|
@classmethod
|
|
def from_irnodes(
|
|
cls, irnodes: Union[LayoutOrBuffer, Sequence[LayoutOrBuffer]]
|
|
) -> Union[TensorMeta, list[TensorMeta]]:
|
|
if isinstance(irnodes, Sequence):
|
|
result: list[Any] = [cls.from_irnodes(x) for x in irnodes]
|
|
assert all(isinstance(x, TensorMeta) for x in result)
|
|
return result
|
|
|
|
node = irnodes
|
|
if isinstance(node, ir.Layout):
|
|
node = ir.Buffer(name="fake", layout=node)
|
|
|
|
dtype = node.get_dtype()
|
|
assert dtype is not None
|
|
device = node.get_device()
|
|
assert device is not None
|
|
|
|
return TensorMeta(
|
|
device=device,
|
|
dtype=dtype,
|
|
sizes=V.graph.sizevars.size_hints(
|
|
node.get_size(),
|
|
fallback=config.unbacked_symint_fallback,
|
|
),
|
|
strides=V.graph.sizevars.size_hints(
|
|
node.get_stride(),
|
|
fallback=config.unbacked_symint_fallback,
|
|
),
|
|
offset=V.graph.sizevars.size_hint(
|
|
node.get_layout().offset,
|
|
fallback=config.unbacked_symint_fallback,
|
|
),
|
|
name=node.get_name(),
|
|
)
|
|
|
|
def to_tensor(self) -> torch.Tensor:
|
|
return rand_strided(
|
|
self.sizes,
|
|
self.strides,
|
|
device=self.device,
|
|
dtype=self.dtype,
|
|
extra_size=self.offset,
|
|
)
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class BenchmarkRequest:
|
|
"""
|
|
Only handle triton template benchmark for now. The extern kernel benchmark
|
|
can be done inside the same process since they usually don't cause crash.
|
|
|
|
Important: Instances of this class and subclasses have to be serializable
|
|
across process boundaries. Do not put CUDA Tensors in here!
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
kernel_name: str,
|
|
input_tensor_meta: Union[TensorMeta, list[TensorMeta]],
|
|
output_tensor_meta: Union[TensorMeta, list[TensorMeta]],
|
|
extra_args: Iterable[Any],
|
|
) -> None:
|
|
# the kernel name defined in the module
|
|
self.kernel_name = kernel_name
|
|
|
|
if isinstance(input_tensor_meta, TensorMeta):
|
|
input_tensor_meta = [input_tensor_meta]
|
|
self.input_tensor_meta = input_tensor_meta
|
|
|
|
if isinstance(output_tensor_meta, (tuple, list)):
|
|
if len(output_tensor_meta) > 1:
|
|
# Each output with same meta for Grouped GEMM
|
|
assert all(
|
|
getattr(output_tensor_meta[0], attr) == getattr(x, attr)
|
|
for x in output_tensor_meta
|
|
for attr in ["device", "dtype", "sizes", "strides", "offset"]
|
|
)
|
|
output_tensor_meta = output_tensor_meta[0]
|
|
self.output_tensor_meta = output_tensor_meta
|
|
|
|
self.extra_args = extra_args
|
|
|
|
def make_run_fn(
|
|
self, *input_tensors: torch.Tensor, out: torch.Tensor
|
|
) -> Callable[[], None]:
|
|
raise NotImplementedError
|
|
|
|
def cleanup_run_fn(self) -> None:
|
|
pass
|
|
|
|
def do_bench(
|
|
self,
|
|
fn,
|
|
*input_tensors: torch.Tensor,
|
|
out: Optional[torch.Tensor] = None,
|
|
) -> float:
|
|
raise NotImplementedError
|
|
|
|
def benchmark(
|
|
self,
|
|
*input_tensors: torch.Tensor,
|
|
out: Optional[torch.Tensor] = None,
|
|
) -> float:
|
|
debug = autotuning_log.isEnabledFor(logging.DEBUG)
|
|
if debug:
|
|
start_ts = time.time()
|
|
|
|
# create args and out tensor
|
|
if out is None:
|
|
assert len(input_tensors) == 0
|
|
input_tensors = tuple(x.to_tensor() for x in self.input_tensor_meta)
|
|
out = self.output_tensor_meta.to_tensor()
|
|
|
|
if debug:
|
|
create_tensor_elapse = time.time() - start_ts # type: ignore[possibly-undefined]
|
|
start_ts = time.time()
|
|
try:
|
|
fn = self.make_run_fn(*input_tensors, out=out)
|
|
except NonzeroWorkspaceNotSupportedError:
|
|
# Skipping all ops with nonzero workspace requirements
|
|
autotuning_log.info("Skipping op due to nonzero workspace requirement")
|
|
return float("inf")
|
|
|
|
if debug:
|
|
load_elapse = time.time() - start_ts # type: ignore[possibly-undefined]
|
|
start_ts = time.time()
|
|
|
|
res = self.do_bench(fn, *input_tensors, out)
|
|
|
|
if debug:
|
|
bench_elapse = time.time() - start_ts # type: ignore[possibly-undefined]
|
|
autotuning_log.debug(
|
|
"InChildProcess %s: load %f, create tensor %f, bench %f",
|
|
str(self),
|
|
load_elapse, # type: ignore[possibly-undefined]
|
|
create_tensor_elapse, # type: ignore[possibly-undefined]
|
|
bench_elapse,
|
|
)
|
|
self.cleanup_run_fn()
|
|
return res
|
|
|
|
|
|
class _TestBenchmarkRequest(BenchmarkRequest):
|
|
"""
|
|
Supports unit testing. Defined in this file instead of the test file so the
|
|
TuningProcess sub-process can unpickle these objects.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
result: float = 0.0,
|
|
device: Optional[int] = None,
|
|
sleep: Optional[float] = None,
|
|
exc: Optional[Exception] = None,
|
|
crash: bool = False,
|
|
):
|
|
self.result = result
|
|
self.device = device
|
|
self.sleep = sleep
|
|
self.exc = exc
|
|
self.crash = crash
|
|
|
|
def benchmark(
|
|
self, *input_tensors: torch.Tensor, out: Optional[torch.Tensor] = None
|
|
) -> float:
|
|
if self.device is not None:
|
|
assert os.environ.get(CUDA_VISIBLE_DEVICES, None) == str(self.device)
|
|
if self.sleep:
|
|
time.sleep(self.sleep)
|
|
if self.exc:
|
|
raise self.exc
|
|
if self.crash:
|
|
sys.exit(1)
|
|
return self.result
|
|
|
|
|
|
class GPUDeviceBenchmarkMixin:
|
|
def do_bench(
|
|
self,
|
|
fn,
|
|
*input_tensors: torch.Tensor,
|
|
out: Optional[torch.Tensor] = None,
|
|
) -> float:
|
|
device_idx_set = OrderedSet(
|
|
tensor.device.index
|
|
for tensor in [*input_tensors, out]
|
|
if isinstance(tensor, torch.Tensor)
|
|
and is_gpu(tensor.device.type)
|
|
and tensor.device.index is not None
|
|
)
|
|
assert len(device_idx_set) <= 1, f"Can not mix devices {device_idx_set}"
|
|
device_type = next(
|
|
(
|
|
tensor.device.type
|
|
for tensor in input_tensors
|
|
if is_gpu(tensor.device.type)
|
|
),
|
|
"cuda",
|
|
)
|
|
device_interface = get_interface_for_device(device_type)
|
|
if len(device_idx_set) == 1:
|
|
device_idx = next(iter(device_idx_set))
|
|
else:
|
|
device_idx = device_interface.current_device()
|
|
with device_interface.device(device_idx): # type: ignore[attr-defined]
|
|
res = benchmarker.benchmark_gpu(fn)
|
|
device_interface.synchronize() # shake out any CUDA errors
|
|
|
|
return res
|
|
|
|
|
|
class CPUDeviceBenchmarkMixin:
|
|
def do_bench(
|
|
self,
|
|
fn,
|
|
*input_tensors: torch.Tensor,
|
|
out: Optional[torch.Tensor] = None,
|
|
) -> float:
|
|
return benchmarker.benchmark_cpu(fn)
|
|
|
|
|
|
class TritonBenchmarkRequest(BenchmarkRequest):
|
|
# Important: Instances of this class have to be serializable
|
|
# across process boundaries. Do not put CUDA Tensors in here!
|
|
def __init__(
|
|
self,
|
|
kernel_name: str,
|
|
input_tensor_meta: Union[TensorMeta, list[TensorMeta]],
|
|
output_tensor_meta: Union[TensorMeta, list[TensorMeta]],
|
|
extra_args: Iterable[Any],
|
|
module_path: str, # the path of the module defining the triton kernel
|
|
module_cache_key: str,
|
|
num_stages: int,
|
|
num_warps: int,
|
|
num_consumer_groups: int = 0,
|
|
num_buffers_warp_spec: int = 0,
|
|
matrix_instr_nonkdim: int = 0, # only used for hip to choose the shape of mfma instruction.
|
|
waves_per_eu: int = 0, # only used for hip to schedule waves per execution unit
|
|
kpack: int = 0, # ROCm specific gemm parameter
|
|
) -> None:
|
|
super().__init__(kernel_name, input_tensor_meta, output_tensor_meta, extra_args)
|
|
self.module_path = module_path
|
|
self.module_cache_key = module_cache_key
|
|
self.num_stages = num_stages
|
|
self.num_warps = num_warps
|
|
self.num_consumer_groups = num_consumer_groups
|
|
self.num_buffers_warp_spec = num_buffers_warp_spec
|
|
self.matrix_instr_nonkdim = matrix_instr_nonkdim
|
|
self.waves_per_eu = waves_per_eu
|
|
self.kpack = kpack
|
|
|
|
def make_run_fn(
|
|
self, *input_tensors: torch.Tensor, out: torch.Tensor
|
|
) -> Callable[[], None]:
|
|
mod = PyCodeCache.load_by_key_path(self.module_cache_key, self.module_path)
|
|
autotuning_log.debug(
|
|
"benchmark module key: %s, path: %s",
|
|
self.module_cache_key,
|
|
self.module_path,
|
|
)
|
|
|
|
run_method = getattr(mod, self.kernel_name).run
|
|
extra_args = list(self.extra_args)
|
|
run_method.__self__.with_bandwidth_info = False
|
|
|
|
# Newer version of triton add warmup argument to JITFunction.run.
|
|
# This code handles backward-compatibility.
|
|
warmup_arg = {}
|
|
import inspect
|
|
|
|
if "warmup" in inspect.signature(run_method).parameters:
|
|
warmup_arg["warmup"] = False
|
|
|
|
if out.device.type == "cpu":
|
|
stream = 0
|
|
else:
|
|
device_type = out.device.type
|
|
device_interface = get_interface_for_device(device_type)
|
|
stream = device_interface.get_raw_stream(
|
|
self.output_tensor_meta.device.index
|
|
)
|
|
|
|
if isinstance(
|
|
getattr(mod, self.kernel_name),
|
|
torch._inductor.runtime.triton_heuristics.DebugAutotuner,
|
|
):
|
|
return functools.partial(
|
|
run_method,
|
|
*input_tensors,
|
|
out,
|
|
*extra_args,
|
|
**warmup_arg,
|
|
stream=stream,
|
|
)
|
|
else:
|
|
return functools.partial(
|
|
run_method,
|
|
*input_tensors,
|
|
out,
|
|
*extra_args,
|
|
**warmup_arg,
|
|
stream=stream,
|
|
benchmark_run=True,
|
|
)
|
|
|
|
def precompile(self):
|
|
mod = PyCodeCache.load_by_key_path(self.module_cache_key, self.module_path)
|
|
getattr(mod, self.kernel_name).precompile()
|
|
|
|
def __str__(self) -> str:
|
|
return f"{self.kernel_name=}, {self.module_path=}, {self.module_cache_key=}"
|
|
|
|
|
|
class TritonGPUBenchmarkRequest(GPUDeviceBenchmarkMixin, TritonBenchmarkRequest):
|
|
pass
|
|
|
|
|
|
class TritonCPUBenchmarkRequest(CPUDeviceBenchmarkMixin, TritonBenchmarkRequest):
|
|
pass
|
|
|
|
|
|
class CUDABenchmarkRequest(GPUDeviceBenchmarkMixin, BenchmarkRequest):
|
|
"""
|
|
A class to handle CUDA (CUTLASS) benchmark requests. This class is for
|
|
managing the lifecycle of a CUDA kernel benchmark, including compiling
|
|
the source code, managing workspace memory, and executing the kernel.
|
|
|
|
Important: Instances of this class have to be serializable across
|
|
process boundaries. Do not put CUDA Tensors in here!
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
kernel_name: str,
|
|
input_tensor_meta: Union[TensorMeta, list[TensorMeta]],
|
|
output_tensor_meta: Union[TensorMeta, list[TensorMeta]],
|
|
extra_args: Iterable[Any],
|
|
source_code: str,
|
|
) -> None:
|
|
super().__init__(kernel_name, input_tensor_meta, output_tensor_meta, extra_args)
|
|
self.source_code = source_code
|
|
self.workspace_size: int = 0
|
|
self.workspace: Optional[torch.Tensor] = None
|
|
self.DLL: Optional[DLLWrapper] = None
|
|
self._workspace_size_updated = False
|
|
self.hash_key: str = ""
|
|
self.source_file: str = ""
|
|
self.hash_key, self.source_file = CUDACodeCache.write(self.source_code, "so")
|
|
|
|
def precompile(self):
|
|
"""
|
|
Precompile the CUDA source code to populate the CUDACodeCache.
|
|
This may happen in a separate thread pool.
|
|
"""
|
|
autotuning_log.debug("Precompiling %s", self)
|
|
CUDACodeCache.compile(self.source_code, "so")
|
|
autotuning_log.debug("Done precompiling %s", self)
|
|
|
|
def make_run_fn(
|
|
self, *input_tensors: torch.Tensor, out: torch.Tensor
|
|
) -> Callable[[], None]:
|
|
"""
|
|
Create a function to run the CUDA kernel with the given input and output tensors.
|
|
"""
|
|
|
|
self.ensure_dll_loaded()
|
|
self.update_workspace_size()
|
|
args = [c_void_p(tensor.data_ptr()) for tensor in list(input_tensors) + [out]]
|
|
autotuning_log.debug(
|
|
"make_run_fn: self.kernel_name=%s, self.source_file=%s, self.hash_key=%s, self.DLL=%s, args=%s, self.extra_args=%s",
|
|
self.kernel_name,
|
|
self.source_file,
|
|
self.hash_key,
|
|
self.DLL,
|
|
args,
|
|
self.extra_args,
|
|
)
|
|
stream_ptr = c_void_p(torch.cuda.current_stream().cuda_stream)
|
|
run_method = getattr(self.DLL, self.kernel_name)
|
|
workspace_ptr = c_void_p(0)
|
|
if self.workspace_size > 0:
|
|
self.workspace = torch.zeros(
|
|
(self.workspace_size + 7) // 8,
|
|
dtype=torch.float64,
|
|
device=out.device,
|
|
)
|
|
workspace_ptr = c_void_p(self.workspace.data_ptr())
|
|
|
|
# Generate partial function.
|
|
ret = functools.partial(
|
|
run_method,
|
|
*args,
|
|
*self.extra_args,
|
|
None, # null workspace size ptr
|
|
workspace_ptr, # set workspace ptr,
|
|
stream_ptr,
|
|
)
|
|
|
|
# sanity check to make sure we cleanup run fn properly
|
|
try:
|
|
ret()
|
|
except RuntimeError as e:
|
|
err_msg = str(e)
|
|
|
|
def raise_runtime_error():
|
|
raise RuntimeError(err_msg)
|
|
|
|
self.cleanup_run_fn()
|
|
return raise_runtime_error
|
|
|
|
return ret
|
|
|
|
def update_workspace_size(self) -> None:
|
|
if self._workspace_size_updated:
|
|
return
|
|
self.ensure_dll_loaded()
|
|
unique_input_count = len(
|
|
dict.fromkeys(meta.name for meta in self.input_tensor_meta)
|
|
)
|
|
args = [c_void_p(None) for _ in range(unique_input_count + 1)]
|
|
stream_ptr = c_void_p(torch.cuda.current_stream().cuda_stream)
|
|
|
|
run_method = getattr(self.DLL, self.kernel_name)
|
|
# Retrieve workspace_size and initialize workspace.
|
|
c_workspace_size = c_size_t()
|
|
run_method(
|
|
*args, # input ptrs and output ptrs
|
|
*self.extra_args,
|
|
byref(
|
|
c_workspace_size
|
|
), # set workspace size ptr to retrieve workspace size
|
|
None, # null workspace ptr
|
|
stream_ptr,
|
|
)
|
|
torch.cuda.synchronize() # shake out any CUDA errors
|
|
self.workspace_size = c_workspace_size.value
|
|
autotuning_log.debug(
|
|
"update_workspace_size called: new workspace size=%d, self.kernel_name=%s, self.source_file=%s, self.hash_key=%s, self.DLL=%s, args=%s, self.extra_args=%s", # noqa: B950
|
|
self.workspace_size,
|
|
self.kernel_name,
|
|
self.source_file,
|
|
self.hash_key,
|
|
self.DLL,
|
|
args,
|
|
self.extra_args,
|
|
)
|
|
self._workspace_size_updated = True
|
|
|
|
def ensure_dll_loaded(self):
|
|
if self.DLL is None:
|
|
self.DLL, self.hash_key, self.source_file = CUDACodeCache.load(
|
|
self.source_code, "so"
|
|
)
|
|
|
|
def cleanup_run_fn(self) -> None:
|
|
if self.DLL is not None:
|
|
self.DLL.close()
|
|
self.DLL = None
|
|
self.workspace = None
|
|
|
|
def __str__(self) -> str:
|
|
return f"{self.kernel_name=}, {self.source_file=}, {self.hash_key=}"
|
|
|
|
|
|
class CppBenchmarkRequest(CPUDeviceBenchmarkMixin, BenchmarkRequest):
|
|
# Important: Instances of this class have to be serializable
|
|
# across process boundaries. Do not put Tensors in here!
|
|
|
|
def __init__(
|
|
self,
|
|
kernel_name: str,
|
|
input_tensor_meta: Union[TensorMeta, list[TensorMeta]],
|
|
output_tensor_meta: Union[TensorMeta, list[TensorMeta]],
|
|
extra_args: Iterable[Any],
|
|
source_code: str,
|
|
) -> None:
|
|
super().__init__(kernel_name, input_tensor_meta, output_tensor_meta, extra_args)
|
|
self.source_code = source_code
|
|
self.hash_key = get_hash(source_code)
|
|
self.DLL: Optional[Union[CDLL, ModuleType]] = None
|
|
|
|
def precompile(self):
|
|
# Prepopulate CppCodeCache
|
|
# may happen in separate Threadpool
|
|
autotuning_log.debug("Precompiling %s", self)
|
|
CppCodeCache.load(self.source_code, device_type="cpu")
|
|
autotuning_log.debug("Done precompiling %s", self)
|
|
|
|
def make_run_fn(
|
|
self, *input_tensors: torch.Tensor, out: torch.Tensor
|
|
) -> Callable[[], None]:
|
|
# TODO(jgong5): use CppPythonBindingsCodeCache for better binding perf
|
|
self.DLL = CppCodeCache.load(self.source_code, device_type="cpu")
|
|
args = [tensor.data_ptr() for tensor in list(input_tensors) + [out]]
|
|
autotuning_log.debug(
|
|
"make_run_fn: self.kernel_name=%s, self.DLL=%s, args=%s, self.extra_args=%s",
|
|
self.kernel_name,
|
|
self.DLL,
|
|
args,
|
|
self.extra_args,
|
|
)
|
|
run_method = getattr(self.DLL, self.kernel_name)
|
|
# Assume only size with type ctypes.c_ulonglong in extra_args
|
|
assert all(isinstance(arg, ctypes.c_ulonglong) for arg in self.extra_args)
|
|
run_method.argtypes = [ctypes.c_ulonglong] * (
|
|
len(args) + len(list(self.extra_args))
|
|
)
|
|
|
|
# Generate partial function.
|
|
return functools.partial(
|
|
run_method,
|
|
*args,
|
|
*self.extra_args,
|
|
)
|
|
|
|
def cleanup_run_fn(self) -> None:
|
|
if self.DLL is not None:
|
|
"""
|
|
Check close attr due to it crash on Windows.
|
|
"""
|
|
if hasattr(self.DLL, "close"):
|
|
self.DLL.close()
|
|
|
|
def __str__(self) -> str:
|
|
return f"{self.kernel_name=}"
|
|
|
|
|
|
@functools.cache
|
|
def get_tuning_process_pool() -> TuningProcessPool:
|
|
pool = TuningProcessPool()
|
|
atexit.register(pool.shutdown)
|
|
return pool
|
|
|
|
|
|
def benchmark_in_sub_process(
|
|
choices: list[TritonTemplateCaller],
|
|
) -> dict[TritonTemplateCaller, float]:
|
|
"""
|
|
Do benchmarking in a subprocess and return the perf number (latency).
|
|
"""
|
|
return get_tuning_process_pool().benchmark(choices)
|