mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
See #145101 for details. Pull Request resolved: https://github.com/pytorch/pytorch/pull/145198 Approved by: https://github.com/bobrenjc93
2186 lines
76 KiB
Python
2186 lines
76 KiB
Python
# mypy: allow-untyped-defs
|
|
from __future__ import annotations
|
|
|
|
import builtins
|
|
import copy
|
|
import functools
|
|
import hashlib
|
|
import inspect
|
|
import logging
|
|
import math
|
|
import operator
|
|
import os
|
|
import os.path
|
|
import re
|
|
import sys
|
|
import threading
|
|
import time
|
|
from collections import namedtuple
|
|
from typing import Any, Callable, Optional, TYPE_CHECKING
|
|
|
|
import torch
|
|
from torch.utils._ordered_set import OrderedSet
|
|
|
|
from ..triton_bundler import TritonBundler
|
|
from ..utils import prefix_is_reduction
|
|
from . import triton_helpers
|
|
from .autotune_cache import AutotuneCache
|
|
from .benchmarking import benchmarker
|
|
from .coordinate_descent_tuner import CoordescTuner
|
|
from .hints import (
|
|
_NUM_THREADS_PER_WARP,
|
|
AutotuneHint,
|
|
DeviceProperties,
|
|
HeuristicType,
|
|
ReductionHint,
|
|
TileHint,
|
|
TRITON_MAX_BLOCK,
|
|
TRITON_MAX_RSPLIT,
|
|
)
|
|
from .runtime_utils import (
|
|
ceildiv,
|
|
conditional_product,
|
|
create_bandwidth_info_str,
|
|
dynamo_timed,
|
|
get_first_attr,
|
|
get_max_y_grid,
|
|
get_num_bytes,
|
|
next_power_of_2,
|
|
triton_cache_dir,
|
|
triton_config_to_hashable,
|
|
triton_hash_to_path_key,
|
|
validate_triton_config,
|
|
)
|
|
from .triton_compat import (
|
|
ASTSource,
|
|
autograd_profiler,
|
|
cc_warp_size,
|
|
CompiledKernel,
|
|
Config,
|
|
GPUTarget,
|
|
KernelInterface,
|
|
OutOfResources,
|
|
PTXASError,
|
|
triton,
|
|
)
|
|
|
|
|
|
if TYPE_CHECKING:
|
|
from collections.abc import Container, Hashable
|
|
|
|
LauncherType = Any
|
|
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
|
|
def get_total_reduction_numel(numels: dict[str, int]) -> int:
|
|
return conditional_product(
|
|
*[numel for prefix, numel in numels.items() if prefix_is_reduction(prefix)]
|
|
)
|
|
|
|
|
|
def autotune_hints_to_configs(
|
|
hints: OrderedSet[AutotuneHint],
|
|
size_hints,
|
|
block_size: int,
|
|
device_props: DeviceProperties,
|
|
) -> list[Config]:
|
|
"""
|
|
AutotuneHints can be attached to the metadata of triton kernels for providing
|
|
suggestions about what to try for autotuning. One reason to do this is if there are
|
|
some configs that are only useful in specific scenarios, in which case we can avoid
|
|
wasting compile time on autotuning unless we know we are in one of those scenarios.
|
|
|
|
Based on those hints, this function will generate a list of additional autotuning
|
|
configs to try.
|
|
"""
|
|
xyz_options: tuple[tuple[int, Optional[int], Optional[int]], ...]
|
|
configs: list[Config] = []
|
|
for hint in hints:
|
|
if hint == AutotuneHint.ONE_ELEMENT_PER_THREAD:
|
|
if len(size_hints) == 1:
|
|
xyz_options = ((block_size // 4, None, None),)
|
|
elif len(size_hints) == 2:
|
|
xyz_options = ((block_size // 4, 1, None), (1, block_size // 4, None))
|
|
elif len(size_hints) == 3:
|
|
xyz_options = (
|
|
(block_size // 4, 1, 1),
|
|
(1, block_size // 4, 1),
|
|
(1, 1, block_size // 4),
|
|
)
|
|
configs.extend(
|
|
triton_config(
|
|
size_hints,
|
|
*xyz,
|
|
num_elements_per_warp=(
|
|
device_props.warp_size if device_props.warp_size else 32
|
|
),
|
|
)
|
|
for xyz in xyz_options
|
|
)
|
|
|
|
return configs
|
|
|
|
|
|
def disable_pointwise_autotuning(inductor_meta):
|
|
# Autotuning can give different benchmarking results from run to run, and
|
|
# therefore we disable autotuning when use_deterministic flag is on.
|
|
if inductor_meta.get("are_deterministic_algorithms_enabled"):
|
|
return True
|
|
return not inductor_meta.get("autotune_pointwise", True)
|
|
|
|
|
|
def _dump_launch_params(args, kwargs, launcher, kernel_name):
|
|
call_args = []
|
|
call_kwargs = {}
|
|
for arg in args:
|
|
if isinstance(arg, (int, bool)):
|
|
call_args.append(str(arg))
|
|
else:
|
|
call_args.append("T")
|
|
for k, v in kwargs.items():
|
|
if isinstance(arg, (int, bool)):
|
|
call_kwargs[k] = v
|
|
else:
|
|
call_kwargs[k] = v
|
|
for k, v in launcher.config.kwargs.items():
|
|
call_kwargs[k] = v
|
|
call_kwargs["num_warps"] = launcher.config.num_warps
|
|
call_kwargs["num_stages"] = launcher.config.num_stages
|
|
args_str = ""
|
|
args_str += ", ".join(call_args)
|
|
for k, v in call_kwargs.items():
|
|
args_str += f", {k}={v}"
|
|
|
|
abs_path = os.path.abspath(sys.argv[0])
|
|
with open(f"{abs_path}.launch_params", "a") as f:
|
|
f.write(f"{kernel_name} | {args_str}\n")
|
|
|
|
|
|
class CachingAutotuner(KernelInterface):
|
|
"""
|
|
Simplified version of Triton autotuner that has no invalidation
|
|
key and caches the best config to disk to improve cold start times.
|
|
Unlike the main triton Autotuner, this version can precompile all
|
|
configs, and does not rely on the Triton JIT.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
fn,
|
|
triton_meta, # passed directly to triton
|
|
configs,
|
|
save_cache_hook,
|
|
mutated_arg_names: list[str], # see [Note: clone mutated buffers]
|
|
optimize_mem,
|
|
heuristic_type,
|
|
size_hints=None,
|
|
inductor_meta=None, # metadata not relevant to triton
|
|
custom_kernel=False, # whether the kernel is inductor-generated or custom
|
|
filename: Optional[str] = None,
|
|
reset_to_zero_arg_names: Optional[list[str]] = None,
|
|
):
|
|
super().__init__()
|
|
|
|
assert len(configs) > 0, "Non-empty TritonConfig list required for compiling"
|
|
# makes sure there are no pre-hooks on any of the triton configs
|
|
for cfg in configs:
|
|
validate_triton_config(cfg)
|
|
|
|
self.fn = fn
|
|
self.device_props: DeviceProperties = triton_meta["device"]
|
|
self.triton_meta = {
|
|
**triton_meta,
|
|
"device": self.device_props.index,
|
|
"device_type": self.device_props.type,
|
|
}
|
|
self.inductor_meta = {} if inductor_meta is None else inductor_meta
|
|
self.save_cache_hook = save_cache_hook
|
|
self.mutated_arg_names = mutated_arg_names
|
|
self.reset_to_zero_arg_names = (
|
|
[] if reset_to_zero_arg_names is None else reset_to_zero_arg_names
|
|
)
|
|
self.optimize_mem = optimize_mem
|
|
self.configs = configs
|
|
self.heuristic_type = heuristic_type
|
|
self.custom_kernel = custom_kernel
|
|
self.cuda_kernel_saved = False
|
|
if log.isEnabledFor(logging.DEBUG):
|
|
log.debug(
|
|
"CachingAutotuner gets %d configs for %s",
|
|
len(self.configs),
|
|
self.fn.__name__,
|
|
)
|
|
for c in self.configs:
|
|
log.debug(c)
|
|
|
|
self.compile_results: list[TritonCompileResult] = []
|
|
self.launchers: list[LauncherType] = []
|
|
self.lock = threading.Lock()
|
|
if os.getenv("TRITON_CACHE_DIR") is None:
|
|
os.environ["TRITON_CACHE_DIR"] = triton_cache_dir(
|
|
self.triton_meta.get("device", 0)
|
|
)
|
|
log.debug("Triton cache dir: %s", os.environ["TRITON_CACHE_DIR"])
|
|
|
|
self.size_hints = size_hints
|
|
self.coordesc_tuner = CoordescTuner(
|
|
is_mm=False,
|
|
name=self.fn.__name__,
|
|
size_hints=size_hints,
|
|
inductor_meta=self.inductor_meta,
|
|
)
|
|
self.filename = filename
|
|
|
|
# used for profiling
|
|
self.kernel_hash: str = ""
|
|
|
|
# Kernels are stored in the codecache with the filename as a hash of the code.
|
|
# We rely on this to obtain the kernel hash
|
|
if self.filename is not None:
|
|
base_name = os.path.basename(self.filename)
|
|
if ".py" in base_name:
|
|
self.kernel_hash = os.path.splitext(base_name)[0]
|
|
|
|
self.precompile_time_taken_ns = 0
|
|
self.autotune_time_taken_ns = 0
|
|
# Dumps the launch configs after autotuning.
|
|
self.dump_launch_params = (
|
|
os.environ.get("TORCHINDUCTOR_DUMP_LAUNCH_PARAMS", "0") == "1"
|
|
)
|
|
|
|
self.triton_interpret = os.environ.get("TRITON_INTERPRET", "0") == "1"
|
|
|
|
def precompile(
|
|
self,
|
|
warm_cache_only=False,
|
|
reload_in_parent: Optional[Callable[[], CachingAutotuner]] = None,
|
|
):
|
|
if warm_cache_only:
|
|
self._precompile_worker()
|
|
return
|
|
with self.lock:
|
|
self._precompile_worker()
|
|
self._make_launchers()
|
|
self._dynamic_scale_rblock(reload_in_parent)
|
|
|
|
def _precompile_worker(self):
|
|
if self.compile_results:
|
|
return
|
|
assert not self.launchers
|
|
if not self.configs:
|
|
raise RuntimeError("No triton configs are available")
|
|
|
|
compile_results = []
|
|
exc = None
|
|
for c in self.configs:
|
|
try:
|
|
compile_results.append(self._precompile_config(c))
|
|
except (OutOfResources, PTXASError) as e:
|
|
exc = e
|
|
if len(compile_results) == 0:
|
|
raise RuntimeError(f"No valid triton configs. {type(exc).__name__}: {exc}")
|
|
self.compile_results = compile_results
|
|
self.configs = None
|
|
|
|
def _dynamic_scale_rblock(
|
|
self, reload_in_parent: Optional[Callable[[], CachingAutotuner]] = None
|
|
):
|
|
# TODO(jansel): we should find a way to move this extra compile into the worker process
|
|
# Currently it relies on _make_launchers(), which requires a cuda context, to populate nreg.
|
|
device_prop = self.device_props
|
|
if (
|
|
self.inductor_meta.get("dynamic_scale_rblock", True)
|
|
and not self.inductor_meta.get("persistent_reduction")
|
|
and self.heuristic_type == HeuristicType.REDUCTION
|
|
and self.size_hints is not None
|
|
# Disable for Intel as Triton is not ready to return n_regs for a compiled_binary.
|
|
and device_prop.type in ["cuda", "hip"]
|
|
and device_prop.major
|
|
and (device_prop.major >= 8 or torch.version.hip)
|
|
and device_prop.regs_per_multiprocessor is not None
|
|
):
|
|
assert device_prop.regs_per_multiprocessor
|
|
assert device_prop.max_threads_per_multi_processor
|
|
assert device_prop.multi_processor_count
|
|
seen_configs: OrderedSet[Config] = OrderedSet(self.configs)
|
|
warp_size = device_prop.warp_size or 32
|
|
for result in self.compile_results:
|
|
triton_config = result.config
|
|
compiled_binary = result.kernel
|
|
assert len(self.size_hints) >= 2
|
|
xblock = triton_config.kwargs.get("XBLOCK", 1)
|
|
reduction_kwargs = [
|
|
kwarg for kwarg in triton_config.kwargs if kwarg.startswith("R")
|
|
]
|
|
rblocks = [triton_config.kwargs[kwarg] for kwarg in reduction_kwargs]
|
|
total_block = (self.size_hints["x"] + xblock - 1) // xblock
|
|
nreg = getattr(compiled_binary, "n_regs", None)
|
|
if nreg is None:
|
|
continue
|
|
|
|
# make sure rblocks are not too small
|
|
if conditional_product(*rblocks) <= 64:
|
|
continue
|
|
|
|
# each SM of A100 has 65536 32-bit registers. To maximize
|
|
# the theoretical occupancy, we need run 2048 threads on each
|
|
# SM. So each thread should use no more than 65536 / 2048
|
|
# = 32 registers. In cases where occupancy matters, and each
|
|
# thread uses too many registers, reduce R0_BLOCK to reduce
|
|
# the register usage.
|
|
# For kernel https://gist.github.com/shunting314/e4cccc031fe30d378b9b23c08c238cbd
|
|
# from PLBartForCausalLM, latency improve from
|
|
# 7.795ms to 4.883ms.
|
|
#
|
|
if (
|
|
nreg
|
|
<= device_prop.regs_per_multiprocessor
|
|
// device_prop.max_threads_per_multi_processor
|
|
):
|
|
continue
|
|
|
|
nreg_per_warp = nreg * warp_size
|
|
nreg_per_block = nreg_per_warp * triton_config.num_warps
|
|
|
|
# Previously we set max_blocks_per_sm to 'max_threads_per_multi_processo / (32 * num_warps)'
|
|
# The formula below is a tighter upper bound since we have the assumption that
|
|
# nreg > device_prop.regs_per_multiprocessor // device_prop.max_threads_per_multi_processor
|
|
# due to the if condition above and:
|
|
# regs_per_multiprocessor / nreg_per_block
|
|
# = regs_per_multiprocessor / (nreg * 32 * num_warps)
|
|
# < regs_per_multiprocessor / ((regs_per_multiprocessor / max_threads_per_multi_processor) * 32 * num_warps)
|
|
# = max_threads_per_multi_processor / (32 * num_warps)
|
|
# Using a tigher upper bound can reveal more optimization opportunities.
|
|
max_blocks_per_sm = max(
|
|
device_prop.regs_per_multiprocessor // nreg_per_block, 1
|
|
)
|
|
|
|
if total_block <= max_blocks_per_sm * device_prop.multi_processor_count:
|
|
# no need to improve occupancy
|
|
continue
|
|
new_config = copy.deepcopy(triton_config)
|
|
|
|
# Reduce the largest Rn_BLOCK by a factor of 2.
|
|
largest_rkwarg: str = max(
|
|
reduction_kwargs, key=triton_config.kwargs.__getitem__
|
|
)
|
|
new_config.kwargs[largest_rkwarg] //= 2
|
|
|
|
if new_config in seen_configs:
|
|
continue
|
|
seen_configs.add(new_config)
|
|
log.debug(
|
|
"Dynamically scale down %s from TritonConfig(%s) and get a new TritonConfig(%s)",
|
|
largest_rkwarg,
|
|
triton_config,
|
|
new_config,
|
|
)
|
|
if self.fn.fn is None:
|
|
"""
|
|
We are in the parent process, while this program was compiled in a worker
|
|
and the fn was dropped in prepare_for_pickle(). We haven't loaded the module
|
|
containing the real fn yet.
|
|
"""
|
|
assert reload_in_parent
|
|
self.fn = reload_in_parent().fn
|
|
self.compile_results.append(self._precompile_config(new_config))
|
|
|
|
self._make_launchers()
|
|
|
|
def _make_launchers(self):
|
|
if len(self.launchers) == len(self.compile_results):
|
|
return
|
|
|
|
from torch._dynamo.device_interface import DeviceGuard
|
|
|
|
device_interface = self.get_device_interface()
|
|
|
|
# load binary to the correct device
|
|
with DeviceGuard(device_interface, self.triton_meta["device"]):
|
|
# need to initialize context
|
|
device_interface.synchronize(device_interface.current_device())
|
|
launchers = []
|
|
exc = None
|
|
for result in self.compile_results:
|
|
try:
|
|
launchers.append(result.make_launcher())
|
|
except (OutOfResources, PTXASError) as e:
|
|
exc = e
|
|
if len(launchers) == 0:
|
|
raise RuntimeError(f"No valid triton configs. {type(exc).__name__}: {exc}")
|
|
self.launchers = launchers
|
|
|
|
def prepare_for_pickle(self):
|
|
"""Drop stuff from triton.JITFunction that does not pickle.
|
|
This must be called after precompile so that these things are no longer needed.
|
|
"""
|
|
self.fn.fn = None
|
|
self.fn.__globals__ = None
|
|
self.fn.used_global_vals = None
|
|
self.fn.repr = _ConstRepr(self.fn.repr(self.fn))
|
|
self.launchers = []
|
|
|
|
def __getstate__(self) -> dict[str, Any]:
|
|
assert (
|
|
not self.launchers
|
|
), "pickle should not be called with after make_launchers()"
|
|
return {
|
|
**self.__dict__,
|
|
"lock": None,
|
|
}
|
|
|
|
def __setstate__(self, state: dict[str, Any]) -> None:
|
|
self.__dict__.update(state)
|
|
self.lock = threading.Lock()
|
|
|
|
def get_device_interface(self):
|
|
# this code cannot run in compile workers, because it imports from torch
|
|
from torch._dynamo.device_interface import get_interface_for_device
|
|
|
|
return get_interface_for_device(self.device_props.type.replace("hip", "cuda"))
|
|
|
|
def _precompile_config(self, cfg: Config) -> TritonCompileResult:
|
|
"""Ahead of time compile a given autotuner config."""
|
|
compile_meta = copy.deepcopy(self.triton_meta)
|
|
cfg_kwargs = cfg.kwargs
|
|
if self.device_props.type == "hip":
|
|
cfg_kwargs = {**cfg_kwargs}
|
|
for k in ("matrix_instr_nonkdim", "waves_per_eu", "kpack"):
|
|
if k in cfg_kwargs:
|
|
compile_meta[k] = cfg_kwargs.pop(k)
|
|
compile_meta["constants"].update(cfg_kwargs)
|
|
compile_meta["num_warps"] = cfg.num_warps
|
|
compile_meta["num_stages"] = cfg.num_stages
|
|
compile_meta["debug"] = self.inductor_meta.get(
|
|
"assert_indirect_indexing", True
|
|
) and not self.inductor_meta.get("is_hip", False)
|
|
|
|
# device type will be "hip" rather than "cuda" here
|
|
compile_meta["device_type"] = self.device_props.type
|
|
compile_meta["cc"] = self.device_props.cc
|
|
|
|
if self.device_props.type == "cpu":
|
|
triton_helpers.set_driver_to_cpu()
|
|
else:
|
|
triton_helpers.set_driver_to_gpu()
|
|
|
|
if not ASTSource:
|
|
raise RuntimeError("Installed triton version too old, please upgrade")
|
|
|
|
compile_args = (
|
|
ASTSource(
|
|
self.fn,
|
|
compile_meta["signature"],
|
|
compile_meta["constants"],
|
|
compile_meta["configs"][0],
|
|
),
|
|
)
|
|
|
|
target = GPUTarget(
|
|
compile_meta["device_type"],
|
|
compile_meta["cc"],
|
|
cc_warp_size(compile_meta["cc"]),
|
|
)
|
|
|
|
options = {
|
|
"num_warps": compile_meta["num_warps"],
|
|
"num_stages": compile_meta["num_stages"],
|
|
"debug": compile_meta["debug"],
|
|
"sanitize_overflow": False, # turn off additional asserts added for overflow checks
|
|
}
|
|
if self.device_props.type == "hip":
|
|
if "waves_per_eu" in compile_meta:
|
|
options["waves_per_eu"] = compile_meta["waves_per_eu"]
|
|
if "matrix_instr_nonkdim" in compile_meta:
|
|
options["matrix_instr_nonkdim"] = compile_meta["matrix_instr_nonkdim"]
|
|
compile_kwargs = {
|
|
"target": target,
|
|
"options": options,
|
|
}
|
|
|
|
try:
|
|
binary = triton.compile(*compile_args, **compile_kwargs)
|
|
except Exception:
|
|
log.exception(
|
|
"Triton compilation failed: %s\n%s\nmetadata: %s",
|
|
self.inductor_meta.get("kernel_name", "triton_"),
|
|
self.fn.src,
|
|
compile_meta,
|
|
)
|
|
raise
|
|
|
|
TritonBundler.put(
|
|
triton_hash_to_path_key(binary.hash), self.triton_meta.get("device", 0)
|
|
)
|
|
return TritonCompileResult(binary, cfg, compile_meta, self.inductor_meta)
|
|
|
|
def bench(self, launcher, *args, grid, with_profiler=False, **kwargs):
|
|
"""Measure the performance of a given launcher"""
|
|
# we don't skip configs with spilled registers when auto-tuning custom
|
|
# (user-written) Triton kernels, as (i) we don't have any knowledge or
|
|
# control over the kernel code; (ii) there is empirical evidence that
|
|
# for some (complicated) custom Triton kernels, a register-spilling
|
|
# config may yield the best latency.
|
|
if not self.custom_kernel and launcher.n_spills > self.inductor_meta.get(
|
|
"spill_threshold", 16
|
|
):
|
|
log.debug(
|
|
"Skip config %s because of register spilling: %d",
|
|
launcher.config,
|
|
launcher.n_spills,
|
|
)
|
|
return float("inf")
|
|
|
|
device_interface = self.get_device_interface()
|
|
stream = device_interface.get_raw_stream(device_interface.current_device())
|
|
|
|
cpu_copies = self.copy_args_to_cpu_if_needed(*args, **kwargs)
|
|
|
|
def kernel_call():
|
|
cloned_args, cloned_kwargs = self.maybe_clone_args(
|
|
cpu_copies, *args, **kwargs
|
|
)
|
|
# reset to zero before evaluating any config
|
|
self.reset_to_zero_args(*args, **kwargs)
|
|
launcher(
|
|
*cloned_args,
|
|
**cloned_kwargs,
|
|
grid=grid,
|
|
stream=stream,
|
|
)
|
|
self.restore_args_from_cpu(cpu_copies)
|
|
|
|
if with_profiler:
|
|
from torch._inductor.utils import do_bench_using_profiling
|
|
|
|
return do_bench_using_profiling(kernel_call, warmup=10, rep=40)
|
|
|
|
if self.device_props.type == "cpu":
|
|
return benchmarker.benchmark_cpu(kernel_call)
|
|
|
|
return benchmarker.benchmark_gpu(kernel_call, rep=40)
|
|
|
|
def copy_args_to_cpu_if_needed(self, *args, **kwargs):
|
|
"""
|
|
To support benchmarking in the presence of mutated args, we need to avoid
|
|
autotuning contanminating them. We try to pass cloned args to the kernel.
|
|
If those clones would increase the peak memory usage, however, we instead
|
|
copy to cpu and restore them after each iteratrion. Figure out the args
|
|
to be copied and do the copying.
|
|
"""
|
|
if not self.optimize_mem:
|
|
return {}
|
|
|
|
copies = {}
|
|
budget = torch.cuda.max_memory_allocated() - torch.cuda.memory_allocated()
|
|
|
|
def maybe_copy(name, arg):
|
|
if name in self.mutated_arg_names and arg.is_cuda:
|
|
nonlocal budget
|
|
assert isinstance(arg, torch.Tensor)
|
|
size = arg.numel() * arg.element_size()
|
|
if size > budget:
|
|
cpu_arg = torch.empty_strided(
|
|
arg.size(),
|
|
arg.stride(),
|
|
dtype=arg.dtype,
|
|
device="cpu",
|
|
pin_memory=True,
|
|
)
|
|
cpu_arg.copy_(arg, non_blocking=True)
|
|
copies[name] = (arg, cpu_arg)
|
|
else:
|
|
budget -= size
|
|
|
|
for i, arg in enumerate(args):
|
|
maybe_copy(self.fn.arg_names[i], arg)
|
|
|
|
for name, arg in kwargs.items():
|
|
maybe_copy(name, arg)
|
|
|
|
return copies
|
|
|
|
def restore_args_from_cpu(self, cpu_copies):
|
|
for pair in cpu_copies.values():
|
|
arg, cpu_arg = pair
|
|
arg.copy_(cpu_arg, non_blocking=True)
|
|
|
|
def reset_to_zero_args(self, *args, **kwargs):
|
|
if not self.reset_to_zero_arg_names:
|
|
return
|
|
for i, arg in enumerate(args):
|
|
if self.fn.arg_names[i] in self.reset_to_zero_arg_names:
|
|
assert isinstance(
|
|
arg,
|
|
torch.Tensor,
|
|
), "self.reset_to_zero_arg_names should only contain valid argument names"
|
|
arg.zero_()
|
|
|
|
for name, arg in kwargs.items():
|
|
if name in self.reset_to_zero_arg_names:
|
|
assert isinstance(
|
|
arg,
|
|
torch.Tensor,
|
|
), "self.reset_to_zero_arg_names should only contain valid argument names"
|
|
arg.zero_()
|
|
|
|
def maybe_clone_args(
|
|
self, exclude: Container[str], *args, **kwargs
|
|
) -> tuple[list[Any], dict[str, Any]]:
|
|
"""
|
|
Prepare new args and kwargs by cloning any in-place buffers
|
|
(that are not in the provided exclusion list), to avoid autotune
|
|
contaminating them. Avoid cloning the other buffers because it
|
|
leads to increased memory usage.
|
|
"""
|
|
from ..compile_fx import clone_preserve_strides
|
|
|
|
def prepare_arg(name, arg):
|
|
if name in self.mutated_arg_names and name not in exclude:
|
|
assert isinstance(arg, torch.Tensor)
|
|
return clone_preserve_strides(arg)
|
|
else:
|
|
return arg
|
|
|
|
cloned_args = [
|
|
prepare_arg(self.fn.arg_names[i], arg) for i, arg in enumerate(args)
|
|
]
|
|
cloned_kwargs = {name: prepare_arg(name, arg) for name, arg in kwargs.items()}
|
|
|
|
return cloned_args, cloned_kwargs
|
|
|
|
def clone_args(self, *args, **kwargs) -> tuple[list[Any], dict[str, Any]]:
|
|
return self.maybe_clone_args(OrderedSet(), *args, **kwargs)
|
|
|
|
def benchmark_all_configs(self, *args, **kwargs):
|
|
with dynamo_timed(
|
|
"CachingAutotuner.benchmark_all_configs",
|
|
log_pt2_compile_event=True,
|
|
metadata={"kernel_name": self.inductor_meta.get("kernel_name")},
|
|
# TODO(masnesral): Enable this when we figure out how to get the CompileId:
|
|
# dynamo_compile_runtime_column_us="runtime_triton_autotune_time_us",
|
|
):
|
|
timings = {
|
|
launcher: self.bench(launcher, *args, **kwargs)
|
|
for launcher in self.launchers
|
|
}
|
|
|
|
for k, v in timings.items():
|
|
self.coordesc_tuner.cache_benchmark_result(k.config, v)
|
|
|
|
if log.isEnabledFor(logging.DEBUG):
|
|
log.debug("Benchmark all input configs for %s, get:", self.fn.__name__)
|
|
for k, v in timings.items():
|
|
log.debug(
|
|
"%s: %f, nreg %d, nspill %d, #shared-mem %s",
|
|
k.config,
|
|
v,
|
|
k.n_regs,
|
|
k.n_spills,
|
|
k.shared,
|
|
)
|
|
|
|
self.reset_to_zero_args(*args, **kwargs)
|
|
return timings
|
|
|
|
def autotune_to_one_config(self, *args, **kwargs):
|
|
"""Do the actual autotuning"""
|
|
start_time = time.time_ns()
|
|
timings = self.benchmark_all_configs(*args, **kwargs)
|
|
benchmark_time_taken_ns = time.time_ns() - start_time
|
|
self.launchers = [builtins.min(timings, key=timings.get)]
|
|
self.autotune_time_taken_ns = (
|
|
self.precompile_time_taken_ns + benchmark_time_taken_ns
|
|
)
|
|
if self.save_cache_hook:
|
|
self.save_cache_hook(self.launchers[0].config, self.autotune_time_taken_ns)
|
|
|
|
def save_gpu_kernel(self, grid, stream, launcher):
|
|
if callable(grid):
|
|
grid_x, grid_y, grid_z = grid(launcher.config.kwargs)
|
|
else:
|
|
grid_x, grid_y, grid_z = grid
|
|
|
|
key = self.inductor_meta.get("kernel_name", None) # unique kernel name
|
|
assert key is not None, "kernel_name can not be None"
|
|
params = {
|
|
"mangled_name": (
|
|
launcher.bin.metadata.name
|
|
if hasattr(launcher.bin.metadata, "name")
|
|
else launcher.bin.metadata["name"]
|
|
),
|
|
"grid_x": grid_x,
|
|
"grid_y": grid_y,
|
|
"grid_z": grid_z,
|
|
"x_block": launcher.config.kwargs.get("XBLOCK", 1),
|
|
"y_block": launcher.config.kwargs.get("YBLOCK", None),
|
|
"z_block": launcher.config.kwargs.get("ZBLOCK", None),
|
|
"r_block": launcher.config.kwargs.get("RBLOCK", None),
|
|
"num_warps": (
|
|
launcher.bin.num_warps
|
|
if hasattr(launcher.bin, "num_warps")
|
|
else launcher.bin.metadata.num_warps
|
|
),
|
|
"shared_mem": (
|
|
launcher.bin.shared
|
|
if hasattr(launcher.bin, "shared")
|
|
else launcher.bin.metadata.shared
|
|
),
|
|
"stream": stream,
|
|
# User defined triton kernels will have arbitrary kwarg names
|
|
"meta": launcher.config.kwargs,
|
|
}
|
|
from torch._inductor.codecache import CudaKernelParamCache
|
|
|
|
bin_type = {"hip": "hsaco", "xpu": "spv"}.get(self.device_props.type, "cubin")
|
|
binary = launcher.bin.asm[bin_type]
|
|
CudaKernelParamCache.set(key, params, binary, bin_type)
|
|
|
|
self.cuda_kernel_saved = True
|
|
|
|
def coordinate_descent_tuning(self, launcher, *args, **kwargs):
|
|
"""
|
|
Coordinate descent tuning can be run with or without max-autotune.
|
|
|
|
The only difference between these two is the starting config for coordinate_descent tuning.
|
|
E.g., assuming regular autotune only get one config C1; while max-autotune get 4 configs C1, C2, C3, C4
|
|
and max-autotune figure out C3 is the best.
|
|
|
|
Then if coordinate desecnt tuning is run with max-autotune disabled, it will start from C1;
|
|
while if coordinate descent tuning is run with max-autotune enabled, it will start from C3.
|
|
"""
|
|
if (
|
|
self.heuristic_type == HeuristicType.TEMPLATE
|
|
or self.heuristic_type == HeuristicType.USER_AUTOTUNE
|
|
):
|
|
# skip triton template
|
|
return launcher
|
|
|
|
config2launcher = {launcher.config: launcher}
|
|
|
|
def benchmark_one_config(config):
|
|
with self.lock:
|
|
launcher = self._precompile_config(config).make_launcher()
|
|
config2launcher[config] = launcher
|
|
|
|
out = self.bench(launcher, *args, **kwargs)
|
|
log.debug(
|
|
"COORDESC: %s: %f, nreg %d, nspill %d, #shared-mem %d",
|
|
launcher.config,
|
|
out,
|
|
launcher.n_regs,
|
|
launcher.n_spills,
|
|
launcher.shared,
|
|
)
|
|
return out
|
|
|
|
assert not (
|
|
self.heuristic_type == HeuristicType.PERSISTENT_REDUCTION
|
|
and "R0_BLOCK" in launcher.config.kwargs
|
|
), "Coordinate descent tuner relies on the assumption that persistent reduction's triton config does not have R0_BLOCK"
|
|
start_time = time.time_ns()
|
|
best_config = self.coordesc_tuner.autotune(
|
|
benchmark_one_config, launcher.config, None
|
|
)
|
|
coordesc_time_taken_ns = time.time_ns() - start_time
|
|
best_config.found_by_coordesc = True
|
|
|
|
if self.save_cache_hook:
|
|
self.save_cache_hook(
|
|
best_config,
|
|
self.autotune_time_taken_ns + coordesc_time_taken_ns,
|
|
found_by_coordesc=True,
|
|
)
|
|
return config2launcher.get(best_config)
|
|
|
|
def run(
|
|
self, *args, grid, stream, benchmark_run=False, **kwargs
|
|
): # type:ignore[override]
|
|
if self.triton_interpret:
|
|
return self.fn[grid](
|
|
*args,
|
|
**kwargs,
|
|
**self.configs[0].kwargs,
|
|
)
|
|
|
|
if len(self.launchers) != 1:
|
|
if len(self.launchers) == 0:
|
|
start_time = time.time_ns()
|
|
self.precompile()
|
|
self.precompile_time_taken_ns = time.time_ns() - start_time
|
|
if len(self.launchers) > 1:
|
|
self.autotune_to_one_config(*args, grid=grid, **kwargs)
|
|
|
|
if not getattr(
|
|
self.launchers[0].config, "found_by_coordesc", False
|
|
) and self.inductor_meta.get("coordinate_descent_tuning", False):
|
|
self.launchers = [
|
|
self.coordinate_descent_tuning(
|
|
self.launchers[0], *args, grid=grid, **kwargs
|
|
)
|
|
]
|
|
|
|
(launcher,) = self.launchers
|
|
if launcher.store_cubin and (not benchmark_run or not self.cuda_kernel_saved):
|
|
self.save_gpu_kernel(grid, stream, launcher)
|
|
|
|
if self.dump_launch_params:
|
|
_dump_launch_params(args, kwargs, launcher, self.fn.__name__)
|
|
|
|
# it is faster than entering and exiting a context manager, even if the context
|
|
# manager is a nullcontext.
|
|
if autograd_profiler._is_profiler_enabled:
|
|
# grid can be a tuple of ints or a string.
|
|
if isinstance(grid, tuple):
|
|
grid_info = str(grid)
|
|
else:
|
|
grid_info = getattr(grid, "grid_fn_str", "")
|
|
|
|
with torch._C._profiler._RecordFunctionFast(
|
|
self.inductor_meta.get("kernel_name", "triton kernel"),
|
|
args,
|
|
{
|
|
"kernel_file": (self.filename or ""),
|
|
"kernel_hash": self.kernel_hash,
|
|
"kernel_backend": "triton",
|
|
"grid": grid_info,
|
|
"stream": stream,
|
|
},
|
|
):
|
|
return launcher(
|
|
*args,
|
|
**kwargs,
|
|
grid=grid,
|
|
stream=stream,
|
|
)
|
|
else:
|
|
return launcher(
|
|
*args,
|
|
**kwargs,
|
|
grid=grid,
|
|
stream=stream,
|
|
)
|
|
|
|
|
|
class _ConstRepr:
|
|
def __init__(self, value: str):
|
|
self.value = value
|
|
|
|
def __call__(self, _=None) -> str:
|
|
return self.value
|
|
|
|
|
|
class TritonCompileResult:
|
|
"""
|
|
Upstream Triton CompileKernel can not be pickled. This is a wrapper
|
|
to support serialization and generate the launcher function.
|
|
"""
|
|
|
|
@staticmethod
|
|
@functools.lru_cache(32)
|
|
def _kernel_metadata_cls(fields: tuple[str, ...]) -> Any:
|
|
return namedtuple("KernelMetadata", sorted(fields))
|
|
|
|
def __init__(
|
|
self,
|
|
kernel: CompiledKernel,
|
|
config: Config,
|
|
compile_meta: dict[str, Any],
|
|
inductor_meta: dict[str, Any],
|
|
) -> None:
|
|
super().__init__()
|
|
self.kernel = kernel
|
|
self.config = config
|
|
self.compile_meta = compile_meta
|
|
self.inductor_meta = inductor_meta
|
|
|
|
def __getstate__(self) -> dict[str, Any]:
|
|
kernel = self.kernel
|
|
# replace the fields that don't pickle nicely
|
|
kernel_state = {
|
|
**kernel.__dict__,
|
|
"metadata": kernel.metadata._asdict(),
|
|
"module": None, # regenerated by kernel._init_handles()
|
|
"function": None, # regenerated by kernel._init_handles()
|
|
"run": None, # regenerated by kernel._init_handles()
|
|
}
|
|
return {**self.__dict__, "kernel": kernel_state} # type: ignore[dict-item]
|
|
|
|
def __setstate__(self, state: dict[str, Any]) -> None:
|
|
# src = ASTSource.__new__(ASTSource)
|
|
# src.__setstate__(state["kernel"]["src"])
|
|
# TODO(jansel): need to fixup src.fn which is now None
|
|
kernel = CompiledKernel.__new__(CompiledKernel)
|
|
metadata = state["kernel"]["metadata"]
|
|
kernel.__dict__.update(
|
|
{
|
|
**state["kernel"],
|
|
# "src": src,
|
|
"metadata": self._kernel_metadata_cls(tuple(metadata.keys()))(
|
|
**metadata
|
|
),
|
|
}
|
|
)
|
|
self.__dict__.update(state)
|
|
self.kernel = kernel
|
|
|
|
def make_launcher(self) -> LauncherType:
|
|
"""
|
|
Launching triton kernels is performance sensitive, we compile
|
|
a custom Python function get the grid() and reorder the args to
|
|
the underlying wrapper.
|
|
"""
|
|
cfg = self.config
|
|
compile_meta = self.compile_meta
|
|
binary = self.kernel
|
|
fn = binary.src.fn
|
|
binary._init_handles()
|
|
"""
|
|
https://github.com/pytorch/pytorch/issues/115344
|
|
|
|
self.fn.constexprs doesn't properly deal with None args, so when we filter out
|
|
an arg in UserDefinedTritonKernel.codegen, we need to filter it here as well.
|
|
We also don't want to modify self.fn.
|
|
|
|
We know that we removed something from the signature if:
|
|
1. It's in compile_meta["constants"]
|
|
2. It isn't a constant we already know about
|
|
Note: The value of interest has already been added to compile_meta['constants'],
|
|
so we use self.fn.constexprs instead.
|
|
3. It isn't in the compile_meta signature
|
|
"""
|
|
known_constants = OrderedSet(
|
|
arg for i, arg in enumerate(fn.arg_names) if i in fn.constexprs
|
|
)
|
|
none_args = OrderedSet(
|
|
k
|
|
for k, v in compile_meta["constants"].items()
|
|
if v is None and k not in known_constants
|
|
)
|
|
none_args = none_args.difference(OrderedSet(compile_meta["signature"].keys()))
|
|
|
|
call_args = [
|
|
arg
|
|
for i, arg in enumerate(fn.arg_names)
|
|
if i not in fn.constexprs and arg not in none_args
|
|
]
|
|
|
|
def_args = [
|
|
name
|
|
for name in fn.arg_names
|
|
if name not in cfg.kwargs and name not in none_args
|
|
]
|
|
binary_shared = (
|
|
binary.shared if hasattr(binary, "shared") else binary.metadata.shared
|
|
)
|
|
|
|
scope = {
|
|
"grid_meta": cfg.kwargs,
|
|
"bin": binary,
|
|
"launch_enter_hook": binary.__class__.launch_enter_hook,
|
|
"launch_exit_hook": binary.__class__.launch_exit_hook,
|
|
"metadata": (
|
|
binary.packed_metadata
|
|
if hasattr(binary, "packed_metadata")
|
|
else binary.metadata
|
|
),
|
|
"shared": binary_shared,
|
|
"num_warps": (
|
|
binary.num_warps
|
|
if hasattr(binary, "num_warps")
|
|
else binary.metadata.num_warps
|
|
),
|
|
"cta_args": (
|
|
(
|
|
binary.num_ctas,
|
|
*get_first_attr(binary, "cluster_dims", "clusterDims"),
|
|
)
|
|
if hasattr(binary, "num_ctas")
|
|
else (
|
|
(binary.metadata.num_ctas, *binary.metadata.cluster_dims)
|
|
if hasattr(binary, "metadata")
|
|
else ()
|
|
)
|
|
),
|
|
"function": get_first_attr(binary, "function", "cu_function"),
|
|
"runner": get_first_attr(binary, "run", "c_wrapper"),
|
|
}
|
|
|
|
if not hasattr(binary, "launch_metadata"):
|
|
# launch args before CompiledKernel.launch_metadata is added.
|
|
# TODO(jansel): delete this branch in mid-2025
|
|
runner_args = [
|
|
"grid_0",
|
|
"grid_1",
|
|
"grid_2",
|
|
"num_warps",
|
|
"*cta_args",
|
|
"shared",
|
|
"stream",
|
|
"function",
|
|
"launch_enter_hook",
|
|
"launch_exit_hook",
|
|
"metadata",
|
|
*call_args,
|
|
]
|
|
else: # args after CompiledKernel.launch_metadata: https://github.com/openai/triton/pull/3492
|
|
# Getting the kernel launch args is extremely perf-sensitive. Evaluating
|
|
# `bin.launch_metadata` is relatively expensive, and returns None unless a
|
|
# `launch_enter_hook` is installed. So if we don't have that hook installed,
|
|
# we want to burn None in to the launch args with zero overhead.
|
|
# See https://github.com/pytorch/pytorch/issues/123597
|
|
if binary.__class__.launch_enter_hook:
|
|
launch_metadata = (
|
|
f"bin.launch_metadata(grid, stream, {', '.join(call_args)})"
|
|
)
|
|
else:
|
|
launch_metadata = "None"
|
|
runner_args = [
|
|
"grid_0",
|
|
"grid_1",
|
|
"grid_2",
|
|
"stream",
|
|
"function",
|
|
"metadata",
|
|
launch_metadata,
|
|
"launch_enter_hook",
|
|
"launch_exit_hook",
|
|
*call_args,
|
|
]
|
|
|
|
exec(
|
|
f"""
|
|
def launcher({', '.join(def_args)}, grid, stream):
|
|
if callable(grid):
|
|
grid_0, grid_1, grid_2 = grid(grid_meta)
|
|
else:
|
|
grid_0, grid_1, grid_2 = grid
|
|
runner({', '.join(runner_args)})
|
|
return bin
|
|
""".lstrip(),
|
|
scope,
|
|
)
|
|
|
|
launcher = scope["launcher"]
|
|
launcher.config = cfg
|
|
launcher.n_regs = getattr(binary, "n_regs", None)
|
|
launcher.n_spills = getattr(binary, "n_spills", None)
|
|
launcher.shared = binary_shared
|
|
launcher.store_cubin = self.inductor_meta.get("store_cubin", False)
|
|
# store this global variable to avoid the high overhead of reading it when calling run
|
|
if launcher.store_cubin:
|
|
launcher.fn = fn
|
|
launcher.bin = binary
|
|
return launcher
|
|
|
|
|
|
def _find_names(obj):
|
|
import gc
|
|
import inspect
|
|
|
|
frame = inspect.currentframe()
|
|
while frame is not None:
|
|
frame.f_locals
|
|
frame = frame.f_back
|
|
obj_names = []
|
|
for referrer in gc.get_referrers(obj):
|
|
if isinstance(referrer, dict):
|
|
for k, v in referrer.items():
|
|
if v is obj:
|
|
obj_names.append(k)
|
|
return obj_names
|
|
|
|
|
|
collected_calls: list[Any] = []
|
|
|
|
|
|
def start_graph():
|
|
collected_calls.clear()
|
|
|
|
|
|
def end_graph(output_file):
|
|
if len(collected_calls) == 0:
|
|
return
|
|
overall_time = sum(call[0] for call in collected_calls)
|
|
overall_gb = sum(call[1] for call in collected_calls)
|
|
cur_file = inspect.stack()[1].filename
|
|
summary_str = (
|
|
f"SUMMARY ({cur_file})\n"
|
|
f"{overall_time:.2f}ms \t {overall_gb:.2f} GB\t {overall_gb / (overall_time / 1e3):.2f}GB/s"
|
|
)
|
|
log.info(
|
|
"%s",
|
|
summary_str,
|
|
)
|
|
if output_file is not None:
|
|
# sort perf numbers in descending order, i.e. placing the
|
|
# most runtime-heavy kernels at the top of the list
|
|
sorted_calls = sorted(collected_calls, key=lambda c: float(c[0]), reverse=True)
|
|
try:
|
|
with open(output_file, "a") as file:
|
|
log.info(
|
|
"Save profile bandwidth results to %s",
|
|
output_file,
|
|
)
|
|
file.write("====================\n")
|
|
file.write(f"TRITON KERNELS BANDWIDTH INFO ({cur_file})\n")
|
|
for ms, num_gb, gb_per_s, kernel_name in sorted_calls:
|
|
# also display the runtime percentage for each kernel
|
|
percentage = f"{ms / overall_time * 100:.2f}%"
|
|
suffix = f" \t {percentage} \t {kernel_name}"
|
|
bw_info_str = create_bandwidth_info_str(
|
|
ms,
|
|
num_gb,
|
|
gb_per_s,
|
|
suffix=suffix,
|
|
color=False,
|
|
)
|
|
file.write(bw_info_str + "\n")
|
|
file.write(f"{summary_str}\n\n")
|
|
except Exception as e:
|
|
log.warning(
|
|
"failed to write profile bandwidth result into %s: %s",
|
|
output_file,
|
|
e,
|
|
)
|
|
|
|
|
|
class DebugAutotuner(CachingAutotuner):
|
|
def __init__(
|
|
self,
|
|
*args,
|
|
regex_filter="",
|
|
with_profiler=False,
|
|
with_bandwidth_info=True,
|
|
**kwargs,
|
|
):
|
|
self.regex_filter = regex_filter
|
|
self.with_profiler = with_profiler
|
|
self.with_bandwidth_info = with_bandwidth_info
|
|
super().__init__(*args, **kwargs)
|
|
self.cached = None
|
|
|
|
def run(self, *args, grid, stream, **kwargs):
|
|
if not self.with_bandwidth_info:
|
|
super().run(*args, grid=grid, stream=stream, **kwargs, benchmark_run=True)
|
|
return
|
|
else:
|
|
possible_names = _find_names(self)
|
|
kernel_name = f"{max(possible_names, key=len)}"
|
|
if not re.match(self.regex_filter, kernel_name):
|
|
return
|
|
|
|
if len(self.launchers) != 1:
|
|
if len(self.launchers) == 0:
|
|
start_time = time.time_ns()
|
|
self.precompile()
|
|
self.precompile_time_taken_ns = time.time_ns() - start_time
|
|
if len(self.launchers) > 1:
|
|
self.autotune_to_one_config(*args, grid=grid, **kwargs)
|
|
(launcher,) = self.launchers
|
|
|
|
if launcher.store_cubin:
|
|
self.save_gpu_kernel(grid, stream, launcher)
|
|
|
|
if self.cached is None:
|
|
ms = self.bench(
|
|
launcher, *args, grid=grid, with_profiler=self.with_profiler
|
|
)
|
|
num_in_out_ptrs = len(
|
|
[
|
|
arg_name
|
|
for arg_name in self.fn.arg_names
|
|
if arg_name.startswith("in_out_ptr")
|
|
]
|
|
)
|
|
num_gb = self.inductor_meta.get("kernel_num_gb", None)
|
|
if num_gb is None:
|
|
num_gb = get_num_bytes(*args, num_in_out_args=num_in_out_ptrs) / 1e9
|
|
gb_per_s = num_gb / (ms / 1e3)
|
|
self.cached = ms, num_gb, gb_per_s, kernel_name
|
|
collected_calls.append((ms, num_gb, gb_per_s, kernel_name))
|
|
log.info(
|
|
"%s",
|
|
create_bandwidth_info_str(
|
|
ms, num_gb, gb_per_s, suffix=f" \t {kernel_name}"
|
|
),
|
|
)
|
|
else:
|
|
# in AOTI, we will call the kernel and its timing info has been cached already
|
|
collected_calls.append(self.cached)
|
|
|
|
|
|
def hash_configs(configs: list[Config]):
|
|
"""
|
|
Hash used to check for changes in configurations
|
|
"""
|
|
hasher = hashlib.sha256()
|
|
for cfg in configs:
|
|
hasher.update(
|
|
f"{sorted(cfg.kwargs.items())} {cfg.num_warps} {cfg.num_stages}\n".encode()
|
|
)
|
|
return hasher.hexdigest()
|
|
|
|
|
|
def cached_autotune(
|
|
size_hints: Optional[list[int]],
|
|
configs: list[Config],
|
|
triton_meta,
|
|
heuristic_type,
|
|
filename=None,
|
|
inductor_meta=None,
|
|
custom_kernel=False,
|
|
):
|
|
"""
|
|
A copy of triton.autotune that calls our subclass. Our subclass
|
|
has additional debugging, error handling, and on-disk caching.
|
|
"""
|
|
configs = unique_configs(configs)
|
|
assert len(configs) == 1 or filename
|
|
inductor_meta = {} if inductor_meta is None else inductor_meta
|
|
|
|
disabled = inductor_meta.get("force_disable_caches", False)
|
|
|
|
# on disk caching logic and/or remote caching
|
|
autotune_cache = None
|
|
if (
|
|
not disabled
|
|
and filename is not None
|
|
and (len(configs) > 1 or inductor_meta.get("coordinate_descent_tuning"))
|
|
and not os.environ.get("TRITON_INTERPRET", "0") == "1"
|
|
):
|
|
configs_hash = hash_configs(configs)
|
|
|
|
autotune_cache = AutotuneCache.create(inductor_meta, filename, configs_hash)
|
|
if autotune_cache:
|
|
if best_config := autotune_cache.read_best(inductor_meta, configs):
|
|
configs = [best_config]
|
|
|
|
else:
|
|
if disabled:
|
|
log.debug("autotune caching is disabled by config.force_disable_caches")
|
|
|
|
mutated_arg_names = inductor_meta.pop("mutated_arg_names", ())
|
|
optimize_mem = inductor_meta.pop("optimize_mem", True)
|
|
|
|
if "restore_value" in triton_meta:
|
|
mutated_arg_names += triton_meta.pop("restore_value")
|
|
|
|
reset_to_zero_arg_names: list[str] = []
|
|
if "reset_to_zero" in triton_meta:
|
|
reset_to_zero_arg_names.extend(triton_meta.pop("reset_to_zero"))
|
|
|
|
def decorator(fn):
|
|
# Remove XBLOCK from config if it's not a function argument.
|
|
# This way, coordinate descent tuning will not try to tune it.
|
|
#
|
|
# Context: When TritonKernel.no_x_dim is True, we hardcode XBLOCK to 1.
|
|
import inspect
|
|
|
|
if "XBLOCK" not in inspect.signature(fn.fn).parameters:
|
|
for tconfig in configs:
|
|
if "XBLOCK" in tconfig.kwargs:
|
|
assert tconfig.kwargs["XBLOCK"] == 1
|
|
tconfig.kwargs.pop("XBLOCK")
|
|
|
|
if inductor_meta.get("profile_bandwidth"):
|
|
return DebugAutotuner(
|
|
fn,
|
|
triton_meta=triton_meta,
|
|
inductor_meta=inductor_meta,
|
|
regex_filter=inductor_meta["profile_bandwidth_regex"],
|
|
with_profiler=inductor_meta[
|
|
"profile_bandwidth_with_do_bench_using_profiling"
|
|
],
|
|
configs=configs,
|
|
save_cache_hook=autotune_cache and autotune_cache.save,
|
|
mutated_arg_names=mutated_arg_names,
|
|
reset_to_zero_arg_names=reset_to_zero_arg_names,
|
|
optimize_mem=optimize_mem,
|
|
heuristic_type=heuristic_type,
|
|
size_hints=size_hints,
|
|
custom_kernel=custom_kernel,
|
|
filename=filename,
|
|
with_bandwidth_info=True,
|
|
)
|
|
return CachingAutotuner(
|
|
fn,
|
|
triton_meta=triton_meta,
|
|
inductor_meta=inductor_meta,
|
|
configs=configs,
|
|
save_cache_hook=autotune_cache and autotune_cache.save,
|
|
mutated_arg_names=mutated_arg_names,
|
|
reset_to_zero_arg_names=reset_to_zero_arg_names,
|
|
optimize_mem=optimize_mem,
|
|
heuristic_type=heuristic_type,
|
|
size_hints=size_hints,
|
|
custom_kernel=custom_kernel,
|
|
filename=filename,
|
|
)
|
|
|
|
return decorator
|
|
|
|
|
|
def unique_configs(configs: list[Config]):
|
|
"""Remove duplicate configurations"""
|
|
seen: OrderedSet[Hashable] = OrderedSet()
|
|
pruned_configs = []
|
|
|
|
for cfg in configs:
|
|
key = triton_config_to_hashable(cfg)
|
|
if key not in seen:
|
|
seen.add(key)
|
|
pruned_configs.append(cfg)
|
|
return pruned_configs
|
|
|
|
|
|
def check_config(cfg, *, xnumel=None, ynumel=None, znumel=None):
|
|
for numel, label in zip((xnumel, ynumel, znumel), "XYZ"):
|
|
if numel is None:
|
|
continue
|
|
block = cfg[f"{label}BLOCK"]
|
|
if numel == 1:
|
|
assert block == 1, (
|
|
f"TritonKernel.indexing assumes numel == 1 => BLOCK == 1"
|
|
f" but {label.lower()}numel=={numel} and {label}BLOCK={block} (cfg={cfg})."
|
|
)
|
|
max_block = TRITON_MAX_BLOCK[label]
|
|
max_block_str = f'config.triton.max_block["{label}"]'
|
|
assert max_block % block == 0, (
|
|
f"TritonKernel.indexing assumes {label}BLOCK divides {max_block_str}"
|
|
f" but {label}BLOCK={block} and {max_block_str}={max_block} (cfg={cfg})."
|
|
)
|
|
|
|
|
|
def check_max_block(cfg: dict[str, int]):
|
|
"""
|
|
Check that block sizes are within the maximum allowed.
|
|
"""
|
|
for var, val in cfg.items():
|
|
block_suffix = "BLOCK"
|
|
if block_suffix in var:
|
|
prefix = var.removesuffix(block_suffix)
|
|
max_block = TRITON_MAX_BLOCK[prefix]
|
|
assert (
|
|
val <= max_block
|
|
), f"'{var}' too large. Maximum: {max_block}. Actual: {val}."
|
|
|
|
|
|
def _num_warps(num_warps, max_num_warps=8, min_num_warps=2, register_intensive=False):
|
|
# On AMD GPU each warp has 64 lanes which is double the size on NV GPU,
|
|
# therefore using half the number of warps here correspondingly.
|
|
if torch.version.hip:
|
|
max_num_warps = (max_num_warps + 1) // 2
|
|
min_num_warps = (min_num_warps + 1) // 2
|
|
# persistent reduction is register intensive
|
|
if register_intensive:
|
|
max_num_warps = max_num_warps // 2
|
|
return next_power_of_2(min(max(num_warps, min_num_warps), max_num_warps))
|
|
|
|
|
|
def _check_max_grid_x(size_hints, x, num_warps):
|
|
# Check if maxGridSize is exceeded - if so then must scale XBLOCK further
|
|
max_grid_x = 2147483647
|
|
warp_size = (
|
|
64 if torch.version.hip else 32
|
|
) # TODO: query warp size once #129663 is merged
|
|
num_blocks = (size_hints["x"] + x - 1) // x
|
|
|
|
while (num_blocks * num_warps * warp_size) > max_grid_x and x < size_hints["x"]:
|
|
x *= 2 # Scale up XBLOCK if grid exceeds limits
|
|
num_blocks = num_blocks // 2
|
|
if (num_blocks * num_warps * warp_size) > max_grid_x:
|
|
raise AssertionError(
|
|
"Reduction config exceeds cudaDeviceProp maxGridSize. Please raise a pytorch issue"
|
|
)
|
|
return x, num_blocks
|
|
|
|
|
|
def triton_config(
|
|
size_hints,
|
|
x,
|
|
y=None,
|
|
z=None,
|
|
num_stages=1,
|
|
num_elements_per_warp=256,
|
|
min_elem_per_thread=0,
|
|
) -> Config:
|
|
"""
|
|
Construct a pointwise triton config with some adjustment heuristics
|
|
based on size_hints. Size_hints is a tuple of numels in each tile
|
|
dimension and will be rounded up to the nearest power of 2.
|
|
|
|
num_elements_per_warp is a suggestion for controlling how many warps
|
|
the triton config should contain. e.g.: if x=16, y=8, z=4 then
|
|
num_elements = 16*8*4 = 512. Then if we set num_elements_per_warp=128,
|
|
we'll launch 512 (elem) / 128 (elem/warp) = 4 warps. Note that it's
|
|
just a suggestion, and sometimes other adjustment heuristics will
|
|
override the num_elements_per_warp.
|
|
|
|
min_elem_per_thread controls the minimum number of elements
|
|
processed by each thread. It's always enforced.
|
|
"""
|
|
# Ideally we want to read this from some device config
|
|
|
|
maxGridSize = [2147483647, 65535, 65535]
|
|
|
|
target = conditional_product(x, y, z)
|
|
if conditional_product(*size_hints.values()) < target:
|
|
target //= 8
|
|
|
|
# shrink sizes to size hints
|
|
x = min(x, size_hints["x"])
|
|
if y:
|
|
y = min(y, size_hints["y"])
|
|
if z:
|
|
z = min(z, size_hints["z"])
|
|
|
|
# if we are below original block size, scale up where we can;
|
|
# or if the calculated grid size is larger than the limit, we bump up the corresponding dimension
|
|
while x < min(size_hints["x"], TRITON_MAX_BLOCK["X"]) and (
|
|
x * maxGridSize[0] < size_hints["x"] or conditional_product(x, y, z) < target
|
|
):
|
|
x *= 2
|
|
while (
|
|
y
|
|
and y < min(size_hints["y"], TRITON_MAX_BLOCK["Y"])
|
|
and (
|
|
y * maxGridSize[1] < size_hints["y"]
|
|
or conditional_product(x, y, z) < target
|
|
)
|
|
):
|
|
y *= 2
|
|
while (
|
|
z
|
|
and z < min(size_hints["z"], TRITON_MAX_BLOCK["Z"])
|
|
and (
|
|
z * maxGridSize[2] < size_hints["z"]
|
|
or conditional_product(x, y, z) < target
|
|
)
|
|
):
|
|
z *= 2
|
|
|
|
num_warps = _num_warps(
|
|
conditional_product(x, y, z) // num_elements_per_warp, min_num_warps=1
|
|
)
|
|
# we are going to arrive at 2 warps only if bs was too small due to
|
|
# numel being too small. However to workaround some ptx bugs we still
|
|
# want at least 4 warps if there's enough elements per thread
|
|
# given that this is a rare situation, don't expect this to affect perf
|
|
# in general
|
|
# see https://github.com/pytorch/pytorch/pull/97950
|
|
if conditional_product(x, y, z) >= 128 and not torch.version.hip:
|
|
num_warps = max(num_warps, 4)
|
|
xnumel = size_hints["x"]
|
|
ynumel = size_hints.get("y")
|
|
znumel = size_hints.get("z")
|
|
|
|
# Increase x to satisfy min_elem_per_thread requirements.
|
|
block_size = max(
|
|
conditional_product(x, y, z),
|
|
min_elem_per_thread * _NUM_THREADS_PER_WARP * num_warps,
|
|
)
|
|
x *= math.ceil(block_size / conditional_product(x, y, z))
|
|
|
|
x, _num_blocks = _check_max_grid_x(size_hints, x, num_warps)
|
|
x = min(x, size_hints["x"])
|
|
|
|
cfg = {"XBLOCK": x}
|
|
if y:
|
|
cfg["YBLOCK"] = y
|
|
if z:
|
|
cfg["ZBLOCK"] = z
|
|
check_max_block(cfg)
|
|
check_config(cfg, xnumel=xnumel, ynumel=ynumel, znumel=znumel)
|
|
return Config(cfg, num_warps=num_warps, num_stages=num_stages)
|
|
|
|
|
|
def _get_nd_reduction_numels(r: int, size_hints: dict[str, int]) -> dict[str, int]:
|
|
"""
|
|
Converts a linear reduction numel to ND, in row major order.
|
|
This order is often desirable as it presents opportunities to coalesce memory
|
|
accesses.
|
|
For example, if r = 64 and size_hints = [32,32], this function returns [32, 2].
|
|
This unraveling works because both r and size_hints are powers of 2.
|
|
"""
|
|
# Shrink r to size_hints.
|
|
r = min(r, get_total_reduction_numel(size_hints))
|
|
num_reduction_dims = len(
|
|
[prefix for prefix in size_hints if prefix_is_reduction(prefix)]
|
|
)
|
|
|
|
remaining = r
|
|
rnumels = {}
|
|
for idx in range(num_reduction_dims - 1, -1, -1):
|
|
prefix = f"r{idx}_"
|
|
max_size = min(size_hints[prefix], TRITON_MAX_BLOCK[prefix.upper()])
|
|
dim = min(max_size, remaining)
|
|
assert (
|
|
remaining % dim == 0
|
|
), f"Expected dimension '{dim}' to divide remaining size '{remaining}'"
|
|
rnumels[prefix] = dim
|
|
remaining //= dim
|
|
|
|
# Sanity check the results.
|
|
final_numel = conditional_product(*rnumels.values())
|
|
assert (
|
|
r == final_numel
|
|
), f"Expected ND reduction size ({rnumels}) to have {r} elements."
|
|
assert all(
|
|
rnumels[prefix] <= size_hints[prefix] for prefix in rnumels
|
|
), f"rnumels exceed size_hints. {rnumels} > {size_hints}"
|
|
|
|
return rnumels
|
|
|
|
|
|
def triton_config_reduction(
|
|
size_hints,
|
|
x: int,
|
|
r: int,
|
|
num_stages=1,
|
|
num_warps=None,
|
|
register_intensive=False,
|
|
) -> Config:
|
|
"""
|
|
Construct a reduction triton config with some adjustment heuristics
|
|
based on size_hints. Size_hints is a tuple of numels in each tile
|
|
dimension and will be rounded up to the nearest power of 2.
|
|
"""
|
|
# Convert the linear reduction numel into a multi-dimensional block.
|
|
rnumels = _get_nd_reduction_numels(r, size_hints)
|
|
|
|
# shrink sizes to size hints
|
|
x = min(x, size_hints["x"])
|
|
|
|
def total_numel() -> int:
|
|
return conditional_product(x, *rnumels.values())
|
|
|
|
target = total_numel()
|
|
if conditional_product(*size_hints.values()) < target:
|
|
target //= 8
|
|
|
|
# if we are below original block size, scale up where we can
|
|
while x < size_hints["x"] and total_numel() < target:
|
|
x *= 2
|
|
for prefix in sorted(rnumels):
|
|
while rnumels[prefix] < size_hints[prefix] and total_numel() < target:
|
|
rnumels[prefix] *= 2
|
|
|
|
if num_warps is None:
|
|
num_warps = total_numel() // 128
|
|
num_warps = _num_warps(
|
|
num_warps, max_num_warps=16, register_intensive=register_intensive
|
|
)
|
|
|
|
x, _num_blocks = _check_max_grid_x(size_hints, x, num_warps)
|
|
|
|
for prefix in sorted(rnumels):
|
|
while total_numel() > target:
|
|
if rnumels[prefix] == 1:
|
|
break
|
|
rnumels[prefix] //= 2
|
|
|
|
cfg = _get_config({"x": x, **rnumels})
|
|
check_max_block(cfg)
|
|
check_config(cfg, xnumel=size_hints["x"])
|
|
return Config(cfg, num_warps=num_warps, num_stages=num_stages)
|
|
|
|
|
|
def _get_config(numels: dict[str, int]) -> dict[str, int]:
|
|
"""
|
|
Convert numels ("x", "r0_", etc.) to block sizes ("XBLOCK", "R0_BLOCK"), etc.
|
|
"""
|
|
|
|
return {prefix.upper() + "BLOCK": numel for prefix, numel in numels.items()}
|
|
|
|
|
|
def triton_config_tiled_reduction(size_hints, x, y, r, num_stages=1):
|
|
"""
|
|
Construct a tile reduction triton config with some adjustment
|
|
heuristics based on size_hints. Size_hints is a tuple of numels in
|
|
each tile dimension and will be rounded up to the nearest power of 2.
|
|
"""
|
|
# Convert the linear reduction numel into a multi-dimensional block.
|
|
rnumels = _get_nd_reduction_numels(r, size_hints)
|
|
|
|
# shrink sizes to size hints
|
|
x = min(x, size_hints["x"])
|
|
y = min(y, size_hints["y"])
|
|
|
|
def total_numel() -> int:
|
|
return conditional_product(x, y, *rnumels.values())
|
|
|
|
target = total_numel()
|
|
if conditional_product(*size_hints.values()) < target:
|
|
target //= 8
|
|
|
|
# if we are below original block size, scale up where we can
|
|
while x < size_hints["x"] and total_numel() < target:
|
|
x *= 2
|
|
for prefix in sorted(rnumels):
|
|
while rnumels[prefix] < size_hints[prefix] and total_numel() < target:
|
|
rnumels[prefix] *= 2
|
|
while y < size_hints[1] and total_numel() < target:
|
|
y *= 2
|
|
|
|
cfg = _get_config({"x": x, "y": y, **rnumels})
|
|
num_warps = _num_warps(total_numel() // 256, min_num_warps=1)
|
|
check_config(cfg, xnumel=size_hints[0], ynumel=size_hints[1])
|
|
check_max_block(cfg)
|
|
return Config(cfg, num_warps=num_warps, num_stages=num_stages)
|
|
|
|
|
|
def pointwise(
|
|
size_hints,
|
|
triton_meta,
|
|
tile_hint=None,
|
|
filename=None,
|
|
min_elem_per_thread=0,
|
|
inductor_meta=None,
|
|
):
|
|
"""
|
|
Construct @triton.heuristics() based on size_hints.
|
|
"""
|
|
inductor_meta = {} if inductor_meta is None else inductor_meta
|
|
assert not inductor_meta.get("no_x_dim")
|
|
|
|
numel = functools.reduce(operator.mul, size_hints.values())
|
|
bs = max(256, min(numel // 128, 1024))
|
|
|
|
hinted_configs = autotune_hints_to_configs(
|
|
inductor_meta.get("autotune_hints", OrderedSet()),
|
|
size_hints,
|
|
bs,
|
|
triton_meta["device"],
|
|
)
|
|
|
|
triton_config_with_settings = functools.partial(
|
|
triton_config, min_elem_per_thread=min_elem_per_thread
|
|
)
|
|
|
|
configs = None
|
|
if len(size_hints) == 1:
|
|
if disable_pointwise_autotuning(inductor_meta) and not (
|
|
inductor_meta.get("max_autotune")
|
|
or inductor_meta.get("max_autotune_pointwise")
|
|
):
|
|
configs = [triton_config_with_settings(size_hints, bs)]
|
|
else:
|
|
configs = [
|
|
triton_config_with_settings(size_hints, bs, num_elements_per_warp=256),
|
|
triton_config_with_settings(
|
|
size_hints, bs // 2, num_elements_per_warp=64
|
|
),
|
|
*hinted_configs,
|
|
]
|
|
if len(size_hints) == 2:
|
|
if (
|
|
disable_pointwise_autotuning(inductor_meta) or tile_hint == TileHint.SQUARE
|
|
) and not (
|
|
inductor_meta.get("max_autotune")
|
|
or inductor_meta.get("max_autotune_pointwise")
|
|
):
|
|
configs = [triton_config_with_settings(size_hints, 32, 32)]
|
|
else:
|
|
configs = [
|
|
triton_config_with_settings(size_hints, 32, 32),
|
|
triton_config_with_settings(size_hints, 64, 64), # ~8% better for fp16
|
|
triton_config_with_settings(size_hints, 256, 16),
|
|
triton_config_with_settings(size_hints, 16, 256),
|
|
triton_config_with_settings(size_hints, bs, 1),
|
|
triton_config_with_settings(size_hints, 1, bs),
|
|
*hinted_configs,
|
|
]
|
|
if len(size_hints) == 3:
|
|
if disable_pointwise_autotuning(inductor_meta):
|
|
configs = [triton_config_with_settings(size_hints, 16, 16, 16)]
|
|
else:
|
|
configs = [
|
|
triton_config_with_settings(size_hints, 16, 16, 16),
|
|
triton_config_with_settings(size_hints, 64, 8, 8),
|
|
triton_config_with_settings(size_hints, 8, 64, 8),
|
|
triton_config_with_settings(size_hints, 8, 8, 64),
|
|
triton_config_with_settings(size_hints, bs, 1, 1),
|
|
triton_config_with_settings(size_hints, 1, bs, 1),
|
|
triton_config_with_settings(size_hints, 1, 1, bs),
|
|
*hinted_configs,
|
|
]
|
|
|
|
if not configs:
|
|
raise NotImplementedError(f"size_hints: {size_hints}")
|
|
return cached_autotune(
|
|
size_hints,
|
|
configs,
|
|
triton_meta=triton_meta,
|
|
inductor_meta=inductor_meta,
|
|
heuristic_type=HeuristicType.POINTWISE,
|
|
filename=filename,
|
|
)
|
|
|
|
|
|
def _reduction_configs(
|
|
*, size_hints: dict[str, int], inductor_meta: dict[str, Any]
|
|
) -> list[Config]:
|
|
reduction_hint = inductor_meta.get("reduction_hint", None)
|
|
|
|
# Convert reductions to 1D, to simplify heuristics.
|
|
rnumel = get_total_reduction_numel(size_hints)
|
|
|
|
register_intensive = False
|
|
MAX_R0_BLOCK = 2048
|
|
if (
|
|
size_hints["x"] >= 1024
|
|
and inductor_meta.get("num_load", 0) + inductor_meta.get("num_reduction", 0)
|
|
>= 10
|
|
):
|
|
# A heuristics to reduce R0_BLOCK if a kernel potentially need many registers.
|
|
# Consider load and reduction since load need move data into registers and
|
|
# reduction needs an accumulator.
|
|
#
|
|
# The magic numbers are a bit arbitrary.
|
|
#
|
|
# We cannot rely on dynamically scaling down R0_BLOCK later, since sometimes
|
|
# triton makes it to use less registers with worse perf. Check:
|
|
# https://github.com/pytorch/pytorch/issues/126463
|
|
#
|
|
# The heuristic is a very simple one since registers can be reused. But
|
|
# hopefully it can be a good enough indicator.
|
|
MAX_R0_BLOCK = 1024
|
|
register_intensive = True
|
|
|
|
contiguous_config = triton_config_reduction(
|
|
size_hints,
|
|
1,
|
|
rnumel if 256 <= rnumel < MAX_R0_BLOCK else MAX_R0_BLOCK,
|
|
register_intensive=register_intensive,
|
|
)
|
|
outer_config = triton_config_reduction(
|
|
size_hints, 64, 8, register_intensive=register_intensive
|
|
)
|
|
tiny_config = triton_config_reduction(
|
|
size_hints,
|
|
2 * (256 // rnumel) if rnumel <= 256 else 1,
|
|
min(rnumel, MAX_R0_BLOCK),
|
|
register_intensive=register_intensive,
|
|
)
|
|
if inductor_meta.get("max_autotune") or inductor_meta.get("max_autotune_pointwise"):
|
|
pass # skip all these cases
|
|
elif reduction_hint == ReductionHint.INNER:
|
|
return [contiguous_config]
|
|
elif reduction_hint == ReductionHint.OUTER:
|
|
return [outer_config]
|
|
elif reduction_hint == ReductionHint.OUTER_TINY:
|
|
return [tiny_config]
|
|
if disable_pointwise_autotuning(inductor_meta):
|
|
return [triton_config_reduction(size_hints, 32, 128)]
|
|
return [
|
|
contiguous_config,
|
|
outer_config,
|
|
tiny_config,
|
|
triton_config_reduction(size_hints, 64, 64),
|
|
triton_config_reduction(size_hints, 8, 512),
|
|
# halve the XBLOCK/Rn_BLOCK compared to outer_config
|
|
# TODO: this may only be beneficial when each iteration of the reduction
|
|
# is quite heavy. E.g. https://gist.github.com/shunting314/189a8ef69f90db9d614a823385147a72
|
|
triton_config_reduction(size_hints, 64, 4, num_warps=8),
|
|
]
|
|
|
|
|
|
def reduction(
|
|
size_hints,
|
|
reduction_hint=False,
|
|
triton_meta=None,
|
|
filename=None,
|
|
inductor_meta=None,
|
|
):
|
|
"""args to @triton.heuristics()"""
|
|
inductor_meta = {} if inductor_meta is None else inductor_meta
|
|
inductor_meta["reduction_hint"] = reduction_hint
|
|
if inductor_meta.get("no_x_dim"):
|
|
size_hints["x"] = 1
|
|
|
|
assert triton_meta is not None
|
|
|
|
configs = _reduction_configs(size_hints=size_hints, inductor_meta=inductor_meta)
|
|
return cached_autotune(
|
|
size_hints,
|
|
configs=configs,
|
|
triton_meta=triton_meta,
|
|
inductor_meta=inductor_meta,
|
|
heuristic_type=HeuristicType.REDUCTION,
|
|
filename=filename,
|
|
)
|
|
|
|
|
|
def cooperative_reduction(
|
|
size_hints,
|
|
reduction_hint,
|
|
triton_meta,
|
|
filename,
|
|
inductor_meta,
|
|
):
|
|
inductor_meta = {} if inductor_meta is None else inductor_meta
|
|
inductor_meta["reduction_hint"] = reduction_hint
|
|
if inductor_meta.get("no_x_dim"):
|
|
size_hints["x"] = 1
|
|
|
|
# Cooperative reductions currently only support a single reduction dimension.
|
|
assert (
|
|
len(size_hints) == 2
|
|
), "Cooperative reductions don't support tiling reduction dims"
|
|
xnumel, rnumel = size_hints["x"], size_hints["r0_"]
|
|
|
|
# TODO(jansel): we should base target on the SM count of the local GPU
|
|
target = 64
|
|
split = max(1, min(target // xnumel, TRITON_MAX_RSPLIT))
|
|
assert rnumel >= split
|
|
assert split <= TRITON_MAX_RSPLIT
|
|
if inductor_meta["persistent_reduction"]:
|
|
configs = _persistent_reduction_configs(
|
|
{"x": xnumel, "r0_": rnumel // split}, reduction_hint, inductor_meta
|
|
)
|
|
else:
|
|
configs = _reduction_configs(
|
|
size_hints={"x": xnumel, "r0_": rnumel // split},
|
|
inductor_meta=inductor_meta,
|
|
)
|
|
for config in configs:
|
|
config.kwargs["RSPLIT"] = split
|
|
# TODO(jansel): add more configs in max_autotune
|
|
|
|
return cached_autotune(
|
|
size_hints,
|
|
configs=configs,
|
|
triton_meta=triton_meta,
|
|
inductor_meta=inductor_meta,
|
|
heuristic_type=HeuristicType.REDUCTION,
|
|
filename=filename,
|
|
)
|
|
|
|
|
|
def _persistent_reduction_configs(
|
|
size_hints,
|
|
reduction_hint=False,
|
|
inductor_meta=None,
|
|
):
|
|
xnumel = size_hints["x"]
|
|
rnumel = get_total_reduction_numel(size_hints)
|
|
|
|
configs = [
|
|
triton_config_reduction(size_hints, xblock, rnumel, register_intensive=True)
|
|
for xblock in (1, 8, 32, 128)
|
|
if xblock == 1 or (rnumel * xblock <= 4096 and xblock <= xnumel)
|
|
]
|
|
|
|
# TODO(jansel): we should be able to improve these heuristics
|
|
if reduction_hint == ReductionHint.INNER and rnumel >= 256:
|
|
configs = configs[:1]
|
|
elif reduction_hint == ReductionHint.OUTER:
|
|
configs = configs[-1:]
|
|
elif reduction_hint == ReductionHint.OUTER_TINY:
|
|
configs = [
|
|
triton_config_reduction(
|
|
size_hints,
|
|
2 * (256 // rnumel) if rnumel <= 256 else 1,
|
|
rnumel,
|
|
)
|
|
]
|
|
for c in configs:
|
|
# we don't need Rn_BLOCK for persistent reduction
|
|
for prefix in size_hints:
|
|
if prefix_is_reduction(prefix):
|
|
c.kwargs.pop(f"{prefix.upper()}BLOCK")
|
|
|
|
if disable_pointwise_autotuning(inductor_meta):
|
|
configs = configs[:1]
|
|
|
|
return configs
|
|
|
|
|
|
def persistent_reduction(
|
|
size_hints,
|
|
reduction_hint=False,
|
|
triton_meta=None,
|
|
filename=None,
|
|
inductor_meta=None,
|
|
):
|
|
inductor_meta = {} if inductor_meta is None else inductor_meta
|
|
inductor_meta["reduction_hint"] = reduction_hint
|
|
if inductor_meta.get("no_x_dim"):
|
|
size_hints["x"] = 1
|
|
|
|
configs = _persistent_reduction_configs(size_hints, reduction_hint, inductor_meta)
|
|
|
|
return cached_autotune(
|
|
size_hints,
|
|
configs,
|
|
triton_meta=triton_meta,
|
|
inductor_meta=inductor_meta,
|
|
filename=filename,
|
|
heuristic_type=HeuristicType.PERSISTENT_REDUCTION,
|
|
)
|
|
|
|
|
|
def split_scan(
|
|
size_hints,
|
|
reduction_hint=False,
|
|
triton_meta=None,
|
|
filename=None,
|
|
inductor_meta=None,
|
|
):
|
|
"""Heuristic for TritonSplitScanKernel"""
|
|
inductor_meta = {} if inductor_meta is None else inductor_meta
|
|
inductor_meta["reduction_hint"] = reduction_hint
|
|
if inductor_meta.get("no_x_dim"):
|
|
size_hints["x"] = 1
|
|
|
|
assert triton_meta is not None
|
|
if len(size_hints) != 2:
|
|
raise NotImplementedError(f"size_hints: {size_hints}")
|
|
|
|
configs = _reduction_configs(size_hints=size_hints, inductor_meta=inductor_meta)
|
|
|
|
# Fixup configs to enforce the minimum Rn_BLOCK size
|
|
min_rblock = inductor_meta.get("min_split_scan_rblock", 256)
|
|
for cfg in configs:
|
|
for var in list(cfg.kwargs.keys()):
|
|
if var.startswith("R") and cfg.kwargs[var] < min_rblock:
|
|
cfg.kwargs[var] = min_rblock
|
|
|
|
return cached_autotune(
|
|
size_hints,
|
|
configs=configs,
|
|
triton_meta=triton_meta,
|
|
inductor_meta=inductor_meta,
|
|
heuristic_type=HeuristicType.SPLIT_SCAN,
|
|
filename=filename,
|
|
)
|
|
|
|
|
|
def template(num_stages, num_warps, triton_meta, filename=None, inductor_meta=None):
|
|
"""
|
|
Compile a triton template
|
|
"""
|
|
return cached_autotune(
|
|
None,
|
|
[triton.Config({}, num_stages=num_stages, num_warps=num_warps)],
|
|
triton_meta=triton_meta,
|
|
inductor_meta=inductor_meta,
|
|
heuristic_type=HeuristicType.TEMPLATE,
|
|
filename=filename,
|
|
)
|
|
|
|
|
|
def _pop_config_kwargs(config: dict[str, Any]) -> dict[str, Any]:
|
|
"""Extract triton.Config options that should become kwargs"""
|
|
popped = {}
|
|
for key in ("num_warps", "num_stages", "num_ctas", "maxnreg"):
|
|
val = config.pop(key, None)
|
|
if val is not None:
|
|
popped[key] = val
|
|
return popped
|
|
|
|
|
|
def fixed_config(config, filename, triton_meta, inductor_meta):
|
|
"""
|
|
Used when the configuration is already decided at compile time
|
|
"""
|
|
config = {**config}
|
|
return cached_autotune(
|
|
None,
|
|
[triton.Config(config, **_pop_config_kwargs(config))],
|
|
triton_meta=triton_meta,
|
|
inductor_meta=inductor_meta,
|
|
heuristic_type=HeuristicType.FIXED,
|
|
filename=filename,
|
|
)
|
|
|
|
|
|
def user_autotune(
|
|
configs, triton_meta, filename=None, inductor_meta=None, custom_kernel=False
|
|
):
|
|
"""
|
|
Compile a user defined triton kernel
|
|
"""
|
|
if len(configs) == 0:
|
|
configs = [triton.Config({})]
|
|
else:
|
|
configs = [
|
|
triton.Config(c.get("kwargs", {}), **_pop_config_kwargs({**c}))
|
|
for c in configs
|
|
]
|
|
return cached_autotune(
|
|
None,
|
|
configs,
|
|
triton_meta=triton_meta,
|
|
heuristic_type=HeuristicType.USER_AUTOTUNE,
|
|
filename=filename,
|
|
inductor_meta=inductor_meta,
|
|
custom_kernel=custom_kernel,
|
|
)
|
|
|
|
|
|
def foreach(triton_meta, num_warps, filename=None, inductor_meta=None):
|
|
"""
|
|
Compile a triton foreach kernel
|
|
"""
|
|
return cached_autotune(
|
|
None,
|
|
[triton.Config({}, num_stages=1, num_warps=num_warps)],
|
|
triton_meta=triton_meta,
|
|
inductor_meta=inductor_meta,
|
|
heuristic_type=HeuristicType.TEMPLATE,
|
|
filename=filename,
|
|
)
|
|
|
|
|
|
def grid(*numels):
|
|
"""Helper function to compute triton grids"""
|
|
if len(numels) == 1:
|
|
xnumel, ynumel, znumel = numels[0], None, None
|
|
elif len(numels) == 2:
|
|
xnumel, ynumel, znumel = numels[1], numels[0], None
|
|
elif len(numels) == 3:
|
|
xnumel, ynumel, znumel = numels[2], numels[1], numels[0]
|
|
else:
|
|
raise AssertionError(f"invalid size for numels {len(numels)}")
|
|
|
|
def get_grid_dim(numel, block):
|
|
if numel is None:
|
|
return 1
|
|
if block is None:
|
|
return numel
|
|
return ceildiv(numel, block)
|
|
|
|
def grid_fn(meta):
|
|
x_grid = get_grid_dim(xnumel, meta.get("XBLOCK", 1))
|
|
y_grid = get_grid_dim(ynumel, meta.get("YBLOCK", None))
|
|
|
|
max_y_grid = get_max_y_grid()
|
|
if znumel is None:
|
|
div = ceildiv(y_grid, max_y_grid)
|
|
y_grid = ceildiv(y_grid, div)
|
|
z_grid = div
|
|
else:
|
|
z_grid = get_grid_dim(znumel, meta.get("ZBLOCK", None))
|
|
torch._check(
|
|
y_grid <= max_y_grid,
|
|
lambda: f"Generated y grid beyond 2^16 ({y_grid}) not supported with z dimension present. File issue",
|
|
)
|
|
|
|
return (
|
|
x_grid,
|
|
y_grid,
|
|
z_grid,
|
|
)
|
|
|
|
setattr(grid_fn, "grid_fn_str", f"grid{numels}") # noqa: B010
|
|
|
|
return grid_fn
|
|
|
|
|
|
def cooperative_reduction_grid(xnumel):
|
|
def grid_fn(meta):
|
|
return (meta["RSPLIT"], ceildiv(xnumel, meta.get("XBLOCK", 1)), 1)
|
|
|
|
grid_fn_str = f"cooperative_reduction_grid({xnumel})"
|
|
setattr(grid_fn, "grid_fn_str", grid_fn_str) # noqa: B010
|
|
return grid_fn
|
|
|
|
|
|
def maybe_cooperative_reduction_grid(xnumel):
|
|
def grid_fn(meta):
|
|
if "RSPLIT" in meta:
|
|
return coop_grid(meta)
|
|
return normal_grid(meta)
|
|
|
|
coop_grid = cooperative_reduction_grid(xnumel)
|
|
normal_grid = grid(xnumel)
|
|
grid_fn_str = f"maybe_cooperative_reduction_grid({xnumel})"
|
|
setattr(grid_fn, "grid_fn_str", grid_fn_str) # noqa: B010
|
|
return grid_fn
|
|
|
|
|
|
def split_scan_grid(xnumel, rnumel):
|
|
def grid_fn(meta):
|
|
assert meta.get("XBLOCK", 1) == 1
|
|
return (ceildiv(rnumel, meta.get("R0_BLOCK", 1)), xnumel, 1)
|
|
|
|
grid_fn_str = f"split_scan_grid({xnumel}, {rnumel})"
|
|
setattr(grid_fn, "grid_fn_str", grid_fn_str) # noqa: B010
|
|
|
|
return grid_fn
|
|
|
|
|
|
def grid_combo_kernels(
|
|
*numels, num_kernels, min_blocks, is_sequential, default_meta=None
|
|
):
|
|
"""min_blocks is the minimal size of the grid x dimension"""
|
|
if not is_sequential:
|
|
# round robin dispatch
|
|
numels_agg = list(numels)
|
|
for i in range(len(numels_agg)):
|
|
if isinstance(numels_agg[i], (list, tuple)):
|
|
numels_agg[i] = max(max(numels_agg[i]), 0) # noqa: PLW3301
|
|
kernel_grid_fn = grid(*numels_agg)
|
|
|
|
if isinstance(numels[-1], (list, tuple)):
|
|
min_blocks_d = max(-min(numels[-1]), 0) * num_kernels
|
|
else:
|
|
min_blocks_d = None
|
|
if min_blocks is None:
|
|
assert min_blocks_d is not None
|
|
min_blocks = min_blocks_d
|
|
else:
|
|
assert (
|
|
min_blocks_d is None or min_blocks == min_blocks_d
|
|
), f"inconsistent min_blocks {min_blocks} vs x grid {numels[-1]}"
|
|
else:
|
|
# sequential dispatch
|
|
seq_numels = list(numels)
|
|
# x numels are not used here, just a place holder
|
|
seq_numels[-1] = 1024
|
|
for i in range(len(seq_numels) - 1):
|
|
if isinstance(seq_numels[i], (list, tuple)):
|
|
seq_numels[i] = max(seq_numels[i])
|
|
|
|
kernel_grid_fn = grid(*seq_numels)
|
|
|
|
def get_grid_dim(numel, block):
|
|
if numel is None:
|
|
return 1
|
|
if block is None:
|
|
return numel
|
|
return ceildiv(numel, block)
|
|
|
|
def grid_fn(meta):
|
|
assert min_blocks is not None, "min_blocks must be a number"
|
|
cuda_grid = list(kernel_grid_fn(meta))
|
|
cuda_grid[0] = max(num_kernels * cuda_grid[0], min_blocks)
|
|
return tuple(cuda_grid)
|
|
|
|
def seq_grid_fn(meta):
|
|
cuda_grid = list(kernel_grid_fn(meta))
|
|
# x <= 0 means this kernel's x grid is not tunable (x_no_dim is true)
|
|
x_grid = sum(
|
|
[
|
|
-x if x <= 0 else get_grid_dim(x, meta.get("XBLOCK", 1))
|
|
for x in numels[-1]
|
|
]
|
|
)
|
|
cuda_grid[0] = x_grid
|
|
return tuple(cuda_grid)
|
|
|
|
def grid_fn_default_meta(meta):
|
|
return grid_fn(default_meta)
|
|
|
|
def seq_grid_fn_default_meta(meta):
|
|
return seq_grid_fn(default_meta)
|
|
|
|
if default_meta is None:
|
|
return grid_fn if not is_sequential else seq_grid_fn
|
|
else:
|
|
return grid_fn_default_meta if not is_sequential else seq_grid_fn_default_meta
|