mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
This adds support for Triton after https://github.com/triton-lang/triton/pull/7258 landed. https://github.com/triton-lang/triton/pull/7258 adds a new argument to all the Triton kernels - a profile_scratch argument, similar to global_scratch. This PR updates the static cuda launcher and the AOTI kernel callers to pass in these arguments when calling the Triton kernel. Tests: https://github.com/pytorch/pytorch/pull/159158. I also verified these test locally with triton 3.2, 3.3, and 3.4. Fixes: * static_cuda_launcher (test/repro: `python tools/dynamo/verify_dynamo.py`) * AOTI calling logic (test/repro: `TORCHINDUCTOR_CPP_WRAPPER=1 python test/inductor/test_torchinductor_opinfo.py -k test_comprehensive_linalg_vander_cuda_float32`) Differential Revision: [D79825121](https://our.internmc.facebook.com/intern/diff/D79825121) Pull Request resolved: https://github.com/pytorch/pytorch/pull/159772 Approved by: https://github.com/NikhilAPatel, https://github.com/eellison
3181 lines
116 KiB
Python
3181 lines
116 KiB
Python
# mypy: allow-untyped-defs
|
|
from __future__ import annotations
|
|
|
|
import builtins
|
|
import copy
|
|
import dataclasses
|
|
import functools
|
|
import hashlib
|
|
import inspect
|
|
import itertools
|
|
import logging
|
|
import math
|
|
import operator
|
|
import os
|
|
import os.path
|
|
import re
|
|
import sys
|
|
import threading
|
|
import time
|
|
from collections import namedtuple
|
|
from typing import (
|
|
Any,
|
|
Callable,
|
|
Generic,
|
|
Literal,
|
|
Optional,
|
|
TYPE_CHECKING,
|
|
TypeVar,
|
|
Union,
|
|
)
|
|
|
|
import torch
|
|
from torch._dynamo.utils import set_feature_use
|
|
from torch._prims_common import compute_required_storage_length
|
|
from torch.utils._ordered_set import OrderedSet
|
|
|
|
from ..triton_bundler import TritonBundler
|
|
from ..utils import prefix_is_reduction, triton_version_uses_attrs_dict
|
|
from . import triton_helpers
|
|
from .autotune_cache import AutotuneCache
|
|
from .benchmarking import benchmarker
|
|
from .coordinate_descent_tuner import CoordescTuner
|
|
from .hints import (
|
|
_NUM_THREADS_PER_WARP,
|
|
AutotuneHint,
|
|
DeviceProperties,
|
|
HeuristicType,
|
|
ReductionHint,
|
|
TileHint,
|
|
TRITON_MAX_BLOCK,
|
|
TRITON_MAX_RSPLIT,
|
|
)
|
|
from .runtime_utils import (
|
|
ceildiv,
|
|
conditional_product,
|
|
create_bandwidth_info_str,
|
|
dynamo_timed,
|
|
get_first_attr,
|
|
get_max_y_grid,
|
|
get_num_bytes,
|
|
next_power_of_2,
|
|
triton_cache_dir,
|
|
triton_config_to_hashable,
|
|
triton_hash_to_path_key,
|
|
validate_triton_config,
|
|
)
|
|
from .static_cuda_launcher import StaticallyLaunchedCudaKernel
|
|
from .triton_compat import (
|
|
ASTSource,
|
|
autograd_profiler,
|
|
cc_warp_size,
|
|
CompiledKernel,
|
|
Config,
|
|
GPUTarget,
|
|
HAS_WARP_SPEC,
|
|
KernelInterface,
|
|
knobs,
|
|
OutOfResources,
|
|
PTXASError,
|
|
triton,
|
|
)
|
|
|
|
|
|
class NoTritonConfigsError(RuntimeError):
|
|
pass
|
|
|
|
|
|
if TYPE_CHECKING:
|
|
from collections.abc import Container, Hashable
|
|
|
|
from torch._guards import CompileId
|
|
|
|
LauncherType = Any
|
|
|
|
_KernelType = Union[CompiledKernel, StaticallyLaunchedCudaKernel]
|
|
_T = TypeVar("_T", bound=_KernelType)
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
triton_name_sub = re.compile(r"^def [^(]+\(")
|
|
|
|
|
|
def generate_lookup_hash_from_source_code(size_hints_str: str, source_code: str) -> str:
|
|
# Name agnostic + strip white space
|
|
fn_strip_name = re.sub(triton_name_sub, "(", source_code.strip(), count=1)
|
|
hash_str = size_hints_str + fn_strip_name
|
|
fn_hash = hashlib.sha256(hash_str.encode("utf-8")).hexdigest()
|
|
|
|
return fn_hash
|
|
|
|
|
|
def lookup_autotune_config(size_hints, fn) -> Optional[Config]:
|
|
lookup_table = torch._inductor.config.autotune_lookup_table
|
|
cached_config = None
|
|
if len(lookup_table) > 0 and "_fused_" in fn.src:
|
|
fn_hash = generate_lookup_hash_from_source_code(str(size_hints), fn.src)
|
|
if fn_hash in lookup_table:
|
|
config_dict = lookup_table[fn_hash]
|
|
block_configs = {k: v for k, v in config_dict.items() if "BLOCK" in k}
|
|
cached_config = Config(
|
|
block_configs,
|
|
num_warps=config_dict["num_warps"],
|
|
num_stages=config_dict["num_stages"],
|
|
)
|
|
|
|
return cached_config
|
|
|
|
|
|
def get_total_reduction_numel(numels: dict[str, int]) -> int:
|
|
return conditional_product(
|
|
*[numel for prefix, numel in numels.items() if prefix_is_reduction(prefix)]
|
|
)
|
|
|
|
|
|
def autotune_hints_to_configs(
|
|
hints: OrderedSet[AutotuneHint],
|
|
size_hints,
|
|
block_size: int,
|
|
device_props: DeviceProperties,
|
|
) -> list[Config]:
|
|
"""
|
|
AutotuneHints can be attached to the metadata of triton kernels for providing
|
|
suggestions about what to try for autotuning. One reason to do this is if there are
|
|
some configs that are only useful in specific scenarios, in which case we can avoid
|
|
wasting compile time on autotuning unless we know we are in one of those scenarios.
|
|
|
|
Based on those hints, this function will generate a list of additional autotuning
|
|
configs to try.
|
|
"""
|
|
xyz_options: tuple[tuple[int, Optional[int], Optional[int]], ...]
|
|
configs: list[Config] = []
|
|
for hint in hints:
|
|
if hint == AutotuneHint.ONE_ELEMENT_PER_THREAD:
|
|
if len(size_hints) == 1:
|
|
xyz_options = ((block_size // 4, None, None),)
|
|
elif len(size_hints) == 2:
|
|
xyz_options = ((block_size // 4, 1, None), (1, block_size // 4, None))
|
|
elif len(size_hints) == 3:
|
|
xyz_options = (
|
|
(block_size // 4, 1, 1),
|
|
(1, block_size // 4, 1),
|
|
(1, 1, block_size // 4),
|
|
)
|
|
configs.extend(
|
|
triton_config(
|
|
size_hints,
|
|
*xyz,
|
|
num_elements_per_warp=(
|
|
device_props.warp_size if device_props.warp_size else 32
|
|
),
|
|
)
|
|
for xyz in xyz_options
|
|
)
|
|
|
|
return configs
|
|
|
|
|
|
def disable_pointwise_autotuning(inductor_meta):
|
|
# Autotuning can give different benchmarking results from run to run, and
|
|
# therefore we disable autotuning when use_deterministic flag is on.
|
|
if inductor_meta.get("are_deterministic_algorithms_enabled"):
|
|
return True
|
|
return not inductor_meta.get("autotune_pointwise", True)
|
|
|
|
|
|
def _dump_launch_params(args, kwargs, launcher, kernel_name, grid):
|
|
call_args = []
|
|
call_kwargs = {}
|
|
for arg in args:
|
|
if isinstance(arg, (int, bool)):
|
|
call_args.append(str(arg))
|
|
else:
|
|
call_args.append("T")
|
|
for k, v in kwargs.items():
|
|
if isinstance(arg, (int, bool)):
|
|
call_kwargs[k] = v
|
|
else:
|
|
call_kwargs[k] = v
|
|
if not triton_version_uses_attrs_dict():
|
|
call_kwargs.update(launcher.config.kwargs)
|
|
call_kwargs["num_warps"] = launcher.config.num_warps
|
|
call_kwargs["num_stages"] = launcher.config.num_stages
|
|
if HAS_WARP_SPEC:
|
|
call_kwargs["num_consumer_groups"] = getattr(
|
|
launcher.config, "num_consumer_groups", 0
|
|
)
|
|
call_kwargs["num_buffers_warp_spec"] = getattr(
|
|
launcher.config, "num_buffers_warp_spec", 0
|
|
)
|
|
args_str = [*call_args]
|
|
args_str.extend(f"{k}={v}" for k, v in call_kwargs.items())
|
|
args_str = ", ".join(args_str)
|
|
abs_path = os.path.abspath(sys.argv[0])
|
|
with open(f"{abs_path}.launch_params", "a") as f:
|
|
f.write(f"{kernel_name} | {args_str} | {grid!r}\n")
|
|
|
|
|
|
def check_autotune_cache(
|
|
configs: list[Config], filename: Optional[str], inductor_meta: dict[str, Any]
|
|
) -> tuple[list[Config], Optional[AutotuneCache], dict[str, Any]]:
|
|
"""
|
|
Given a list of configs, checks autotune cache and return metadata
|
|
"""
|
|
autotune_cache = None
|
|
autotune_cache_info = {}
|
|
disabled = inductor_meta.get("force_disable_caches", False)
|
|
if (
|
|
not disabled
|
|
and filename is not None
|
|
and (len(configs) > 1 or inductor_meta.get("coordinate_descent_tuning"))
|
|
and not os.environ.get("TRITON_INTERPRET", "0") == "1"
|
|
):
|
|
configs_hash = hash_configs(configs)
|
|
|
|
autotune_cache = AutotuneCache.create(inductor_meta, filename, configs_hash)
|
|
if autotune_cache:
|
|
if best_config := autotune_cache.read_best(inductor_meta, configs):
|
|
configs = [best_config]
|
|
autotune_cache_info["best_config"] = triton_config_to_hashable(
|
|
best_config
|
|
)
|
|
autotune_cache_info["autotune_cache_state"] = "hit"
|
|
|
|
else:
|
|
autotune_cache_info["autotune_cache_state"] = "miss"
|
|
autotune_cache_info["num_configs"] = len(configs)
|
|
if inductor_meta.get("coordinate_descent_tuning"):
|
|
autotune_cache_info["coordesc_tuning"] = True
|
|
if len(configs) == 1:
|
|
# This is the config that coordinate descent tuning started at, which
|
|
# is not the same as the final config chosen (i.e. only_config, best_config)
|
|
autotune_cache_info["coordesc_tuning_start_config"] = (
|
|
triton_config_to_hashable(configs[0])
|
|
)
|
|
else:
|
|
if len(configs) == 1:
|
|
autotune_cache_info["autotune_cache_state"] = "only 1 config"
|
|
autotune_cache_info["only_config"] = triton_config_to_hashable(configs[0])
|
|
|
|
if disabled:
|
|
autotune_cache_info["autotune_cache_state"] = "force_disabled"
|
|
log.debug("autotune caching is disabled by config.force_disable_caches")
|
|
|
|
return configs, autotune_cache, autotune_cache_info
|
|
|
|
|
|
class CachingAutotuner(KernelInterface):
|
|
"""
|
|
Simplified version of Triton autotuner that has no invalidation
|
|
key and caches the best config to disk to improve cold start times.
|
|
Unlike the main triton Autotuner, this version can precompile all
|
|
configs, and does not rely on the Triton JIT.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
fn,
|
|
triton_meta, # passed directly to triton
|
|
configs,
|
|
save_cache_hook,
|
|
mutated_arg_names: list[str], # see [Note: clone mutated buffers]
|
|
optimize_mem,
|
|
heuristic_type,
|
|
size_hints=None,
|
|
inductor_meta=None, # metadata not relevant to triton
|
|
custom_kernel=False, # whether the kernel is inductor-generated or custom
|
|
filename: Optional[str] = None,
|
|
reset_to_zero_arg_names: Optional[list[str]] = None,
|
|
autotune_cache_info: Optional[dict[str, Any]] = None,
|
|
):
|
|
super().__init__()
|
|
|
|
assert len(configs) > 0, "Non-empty TritonConfig list required for compiling"
|
|
# makes sure there are no pre-hooks on any of the triton configs
|
|
for cfg in configs:
|
|
validate_triton_config(cfg)
|
|
|
|
self.fn = fn
|
|
self.device_props: DeviceProperties = triton_meta["device"]
|
|
self.triton_meta = {
|
|
**triton_meta,
|
|
"device": self.device_props.index,
|
|
"device_type": self.device_props.type,
|
|
}
|
|
self.inductor_meta = {} if inductor_meta is None else inductor_meta
|
|
self.save_cache_hook = save_cache_hook
|
|
self.mutated_arg_names = mutated_arg_names
|
|
self.reset_to_zero_arg_names = (
|
|
[] if reset_to_zero_arg_names is None else reset_to_zero_arg_names
|
|
)
|
|
self.optimize_mem = optimize_mem
|
|
cached_config = lookup_autotune_config(size_hints, fn)
|
|
self.configs = [cached_config] if cached_config else configs
|
|
|
|
self.heuristic_type = heuristic_type
|
|
self.custom_kernel = custom_kernel
|
|
self.cuda_kernel_saved = False
|
|
self.autotune_cache_info = autotune_cache_info
|
|
if log.isEnabledFor(logging.DEBUG):
|
|
log.debug(
|
|
"CachingAutotuner gets %d configs for %s",
|
|
len(self.configs),
|
|
self.fn.__name__,
|
|
)
|
|
for c in self.configs:
|
|
log.debug(c)
|
|
|
|
self.compile_results: list[CompileResult[_KernelType]] = []
|
|
self.launchers: list[LauncherType] = []
|
|
self.lock = threading.Lock()
|
|
if os.getenv("TRITON_CACHE_DIR") is None:
|
|
os.environ["TRITON_CACHE_DIR"] = triton_cache_dir(
|
|
self.triton_meta.get("device", 0)
|
|
)
|
|
log.debug("Triton cache dir: %s", os.environ["TRITON_CACHE_DIR"])
|
|
|
|
self.size_hints = size_hints
|
|
self.coordesc_tuner = CoordescTuner(
|
|
is_mm=False,
|
|
name=self.fn.__name__,
|
|
size_hints=size_hints,
|
|
inductor_meta=self.inductor_meta,
|
|
)
|
|
self.filename = filename
|
|
|
|
# used for profiling
|
|
self.kernel_hash: str = ""
|
|
|
|
# Kernels are stored in the codecache with the filename as a hash of the code.
|
|
# We rely on this to obtain the kernel hash
|
|
if self.filename is not None:
|
|
base_name = os.path.basename(self.filename)
|
|
if ".py" in base_name:
|
|
self.kernel_hash = os.path.splitext(base_name)[0]
|
|
|
|
self.precompile_time_taken_ns = 0
|
|
self.autotune_time_taken_ns = 0
|
|
# Dumps the launch configs after autotuning.
|
|
self.dump_launch_params = (
|
|
os.environ.get("TORCHINDUCTOR_DUMP_LAUNCH_PARAMS", "0") == "1"
|
|
)
|
|
|
|
self.triton_interpret = os.environ.get("TRITON_INTERPRET", "0") == "1"
|
|
|
|
# Compile-time info included in runtime logginging
|
|
self.compile_id: Optional[CompileId] = None
|
|
self.is_backward = False
|
|
|
|
def is_statically_launchable(self):
|
|
"""
|
|
Checks if every compiled kernel is statically launchable, which
|
|
allows us to efficiently cache it in FXGraphCache
|
|
"""
|
|
if not self.compile_results:
|
|
return False
|
|
return all(
|
|
isinstance(x, StaticTritonCompileResult) for x in self.compile_results
|
|
)
|
|
|
|
def recheck_autotune_cache(
|
|
self, reload_kernel_from_src: Callable[[], CachingAutotuner]
|
|
) -> None:
|
|
"""
|
|
On cache load on static autotuner, we need to recheck the autotune cache, since
|
|
a best config could have been found from a previous run
|
|
"""
|
|
assert self.is_statically_launchable()
|
|
|
|
configs = [result.config for result in self.compile_results]
|
|
|
|
(cached_configs, _, autotune_cache_info) = check_autotune_cache(
|
|
configs, self.filename, self.inductor_meta
|
|
)
|
|
self.autotune_cache_info = autotune_cache_info
|
|
# I.e. there was an autotune cache hit
|
|
if len(cached_configs) == 1 and len(configs) > 1:
|
|
best_config = cached_configs[0]
|
|
# Grab the best compiled config, if it's in the list of available ones
|
|
best_config_hash = triton_config_to_hashable(best_config)
|
|
|
|
for compile_result in self.compile_results:
|
|
if triton_config_to_hashable(compile_result.config) == best_config_hash:
|
|
self.compile_results = [compile_result]
|
|
return
|
|
|
|
# If the best config isn't in our list of compile results,
|
|
# it's likely because it was found by coordesc after the cache
|
|
# already saved
|
|
if best_config.found_by_coordesc:
|
|
with dynamo_timed("CachingAutotuner.slow_precompile_config"):
|
|
if self.fn.fn is None:
|
|
self.fn = reload_kernel_from_src().fn
|
|
self.compile_results = [self._precompile_config(best_config)]
|
|
|
|
def set_compile_info(
|
|
self, compile_id: Optional[CompileId], is_backward: bool
|
|
) -> None:
|
|
self.compile_id = compile_id
|
|
self.is_backward = is_backward
|
|
|
|
def precompile(
|
|
self,
|
|
warm_cache_only=False,
|
|
reload_kernel: Optional[Callable[[], CachingAutotuner]] = None,
|
|
static_triton_bundle_key: Optional[str] = None,
|
|
):
|
|
if warm_cache_only:
|
|
self._precompile_worker()
|
|
return
|
|
with self.lock:
|
|
# Helper function for reloading a kernel generated in a worker
|
|
# in the parent class. Normally we don't need to reload the kernel
|
|
# in the parent process, but in certain cases (coordesc tuning, dynamic_scale_rblock),
|
|
# we need to actually run compilation on the parent process
|
|
if reload_kernel is not None:
|
|
self._reload_kernel = reload_kernel
|
|
self._precompile_worker()
|
|
if static_triton_bundle_key is not None and self.is_statically_launchable():
|
|
TritonBundler.put_static_autotuner(static_triton_bundle_key, self)
|
|
self._make_launchers()
|
|
self._dynamic_scale_rblock()
|
|
|
|
def _precompile_worker(self):
|
|
if self.compile_results:
|
|
for result in self.compile_results:
|
|
TritonBundler.put(
|
|
triton_hash_to_path_key(result.kernel.hash), # type: ignore[attr-defined]
|
|
self.triton_meta.get("device", 0),
|
|
)
|
|
return
|
|
assert not self.launchers
|
|
if not self.configs:
|
|
raise NoTritonConfigsError("No triton configs are available")
|
|
|
|
compile_results = []
|
|
exc = None
|
|
for c in self.configs:
|
|
try:
|
|
compile_results.append(self._precompile_config(c))
|
|
except (OutOfResources, PTXASError) as e:
|
|
exc = e
|
|
if len(compile_results) == 0:
|
|
raise NoTritonConfigsError(
|
|
f"No valid triton configs. {type(exc).__name__}: {exc}"
|
|
)
|
|
self.compile_results = compile_results
|
|
self.configs = None
|
|
|
|
def _dynamic_scale_rblock(self):
|
|
# TODO(jansel): we should find a way to move this extra compile into the worker process
|
|
# Currently it relies on _make_launchers(), which requires a cuda context, to populate nreg.
|
|
device_prop = self.device_props
|
|
if (
|
|
self.inductor_meta.get("dynamic_scale_rblock", True)
|
|
and not self.inductor_meta.get("persistent_reduction")
|
|
and self.heuristic_type == HeuristicType.REDUCTION
|
|
and self.size_hints is not None
|
|
# Disable for Intel as Triton is not ready to return n_regs for a compiled_binary.
|
|
and device_prop.type in ["cuda", "hip"]
|
|
and device_prop.major
|
|
and (device_prop.major >= 8 or torch.version.hip)
|
|
and device_prop.regs_per_multiprocessor is not None
|
|
):
|
|
assert device_prop.regs_per_multiprocessor
|
|
assert device_prop.max_threads_per_multi_processor
|
|
assert device_prop.multi_processor_count
|
|
seen_config_hashes: Optional[OrderedSet[Hashable]] = None
|
|
warp_size = device_prop.warp_size or 32
|
|
for result in self.compile_results:
|
|
triton_config = result.config
|
|
compiled_binary = result.kernel
|
|
assert len(self.size_hints) >= 2
|
|
xblock = triton_config.kwargs.get("XBLOCK", 1)
|
|
reduction_kwargs = [
|
|
kwarg for kwarg in triton_config.kwargs if kwarg.startswith("R")
|
|
]
|
|
rblocks = [triton_config.kwargs[kwarg] for kwarg in reduction_kwargs]
|
|
total_block = (self.size_hints["x"] + xblock - 1) // xblock
|
|
nreg = getattr(compiled_binary, "n_regs", None)
|
|
if nreg is None:
|
|
continue
|
|
|
|
# make sure rblocks are not too small
|
|
if conditional_product(*rblocks) <= 64:
|
|
continue
|
|
|
|
# each SM of A100 has 65536 32-bit registers. To maximize
|
|
# the theoretical occupancy, we need run 2048 threads on each
|
|
# SM. So each thread should use no more than 65536 / 2048
|
|
# = 32 registers. In cases where occupancy matters, and each
|
|
# thread uses too many registers, reduce R0_BLOCK to reduce
|
|
# the register usage.
|
|
# For kernel https://gist.github.com/shunting314/e4cccc031fe30d378b9b23c08c238cbd
|
|
# from PLBartForCausalLM, latency improve from
|
|
# 7.795ms to 4.883ms.
|
|
#
|
|
if (
|
|
nreg
|
|
<= device_prop.regs_per_multiprocessor
|
|
// device_prop.max_threads_per_multi_processor
|
|
):
|
|
continue
|
|
|
|
nreg_per_warp = nreg * warp_size
|
|
nreg_per_block = nreg_per_warp * triton_config.num_warps
|
|
|
|
# Previously we set max_blocks_per_sm to 'max_threads_per_multi_processo / (32 * num_warps)'
|
|
# The formula below is a tighter upper bound since we have the assumption that
|
|
# nreg > device_prop.regs_per_multiprocessor // device_prop.max_threads_per_multi_processor
|
|
# due to the if condition above and:
|
|
# regs_per_multiprocessor / nreg_per_block
|
|
# = regs_per_multiprocessor / (nreg * 32 * num_warps)
|
|
# < regs_per_multiprocessor / ((regs_per_multiprocessor / max_threads_per_multi_processor) * 32 * num_warps)
|
|
# = max_threads_per_multi_processor / (32 * num_warps)
|
|
# Using a tigher upper bound can reveal more optimization opportunities.
|
|
max_blocks_per_sm = max(
|
|
device_prop.regs_per_multiprocessor // nreg_per_block, 1
|
|
)
|
|
|
|
if total_block <= max_blocks_per_sm * device_prop.multi_processor_count:
|
|
# no need to improve occupancy
|
|
continue
|
|
new_config = copy.deepcopy(triton_config)
|
|
|
|
# Reduce the largest Rn_BLOCK by a factor of 2.
|
|
largest_rkwarg: str = max(
|
|
reduction_kwargs, key=triton_config.kwargs.__getitem__
|
|
)
|
|
new_config.kwargs[largest_rkwarg] //= 2
|
|
|
|
if seen_config_hashes is None:
|
|
seen_config_hashes = OrderedSet(
|
|
[
|
|
triton_config_to_hashable(x.config)
|
|
for x in self.compile_results
|
|
]
|
|
)
|
|
new_config_hash = triton_config_to_hashable(new_config)
|
|
if new_config_hash in seen_config_hashes:
|
|
continue
|
|
seen_config_hashes.add(new_config_hash)
|
|
log.debug(
|
|
"Dynamically scale down %s from TritonConfig(%s) and get a new TritonConfig(%s)",
|
|
largest_rkwarg,
|
|
triton_config,
|
|
new_config,
|
|
)
|
|
if self.fn.fn is None:
|
|
"""
|
|
We are in the parent process, while this program was compiled in a worker
|
|
and the fn was dropped in prepare_for_pickle(). We haven't loaded the module
|
|
containing the real fn yet.
|
|
"""
|
|
assert hasattr(self, "_reload_kernel")
|
|
assert callable(self._reload_kernel)
|
|
self.fn = self._reload_kernel().fn
|
|
self.compile_results.append(self._precompile_config(new_config)) # noqa: B909
|
|
|
|
self._make_launchers()
|
|
|
|
def _make_launchers(self):
|
|
if len(self.launchers) == len(self.compile_results):
|
|
return
|
|
|
|
from torch._dynamo.device_interface import DeviceGuard
|
|
|
|
device_interface = self.get_device_interface()
|
|
|
|
# load binary to the correct device
|
|
with DeviceGuard(device_interface, self.triton_meta["device"]):
|
|
# need to initialize context
|
|
with dynamo_timed(
|
|
"CachingAutotuner.synchronize",
|
|
# Deliberately avoid overloading pt2_compile_events:
|
|
log_pt2_compile_event=False,
|
|
):
|
|
device_interface.synchronize(device_interface.current_device())
|
|
|
|
launchers = []
|
|
exc = None
|
|
for result in self.compile_results:
|
|
try:
|
|
launchers.append(result.make_launcher())
|
|
|
|
except (OutOfResources, PTXASError, torch.cuda.OutOfMemoryError) as e:
|
|
exc = e
|
|
if len(launchers) == 0:
|
|
raise RuntimeError(f"No valid triton configs. {type(exc).__name__}: {exc}")
|
|
self.launchers = launchers
|
|
|
|
def prepare_for_pickle(self) -> tuple[Any, Any, Any, Any, Any]:
|
|
"""Drop stuff from triton.JITFunction that does not pickle.
|
|
This must be called after precompile so that these things are no longer needed.
|
|
Returns a tuple of old values
|
|
"""
|
|
old_values = (
|
|
self.fn.fn,
|
|
self.fn.__globals__,
|
|
self.fn.used_global_vals,
|
|
self.fn.repr,
|
|
self.launchers,
|
|
)
|
|
self.fn.fn = None
|
|
self.fn.__globals__ = None
|
|
self.fn.used_global_vals = None
|
|
self.fn.repr = _ConstRepr(self.fn.repr(self.fn))
|
|
self.launchers = []
|
|
return old_values
|
|
|
|
def prepare_for_caching(self) -> None:
|
|
"""
|
|
Statically Launched CUDA Kernels have a raw cubin on them
|
|
that we don't need to store in the cache(since TritonBundler handles the collection for us)
|
|
"""
|
|
for result in self.compile_results:
|
|
if isinstance(result, StaticTritonCompileResult):
|
|
# Don't save this in the inductor cache, as it is very large
|
|
result.kernel.cubin_raw = None
|
|
|
|
def __getstate__(self) -> dict[str, Any]:
|
|
assert not self.launchers, (
|
|
"pickle should not be called with after make_launchers()"
|
|
)
|
|
return {
|
|
**self.__dict__,
|
|
"lock": None,
|
|
}
|
|
|
|
def __setstate__(self, state: dict[str, Any]) -> None:
|
|
self.__dict__.update(state)
|
|
self.lock = threading.Lock()
|
|
|
|
def get_device_interface(self):
|
|
# this code cannot run in compile workers, because it imports from torch
|
|
from torch._dynamo.device_interface import get_interface_for_device
|
|
|
|
return get_interface_for_device(self.device_props.type.replace("hip", "cuda"))
|
|
|
|
def _precompile_config(self, cfg: Config) -> CompileResult[_KernelType]:
|
|
"""Ahead of time compile a given autotuner config."""
|
|
compile_meta = copy.deepcopy(self.triton_meta)
|
|
cfg_kwargs = cfg.kwargs
|
|
if self.device_props.type == "hip":
|
|
cfg_kwargs = {**cfg_kwargs}
|
|
for k in ("matrix_instr_nonkdim", "waves_per_eu", "kpack"):
|
|
if k in cfg_kwargs:
|
|
compile_meta[k] = cfg_kwargs.pop(k)
|
|
compile_meta["constants"].update(cfg_kwargs)
|
|
for i in self.fn.constexprs:
|
|
arg_name = self.fn.arg_names[i]
|
|
if arg_name not in compile_meta["constants"] and (
|
|
arg_name == "num_warps" or arg_name == "num_stages"
|
|
):
|
|
compile_meta["constants"][arg_name] = getattr(cfg, arg_name)
|
|
compile_meta["num_warps"] = cfg.num_warps
|
|
compile_meta["num_stages"] = cfg.num_stages
|
|
if HAS_WARP_SPEC:
|
|
compile_meta["num_consumer_groups"] = getattr(cfg, "num_consumer_groups", 0)
|
|
compile_meta["num_buffers_warp_spec"] = getattr(
|
|
cfg, "num_buffers_warp_spec", 0
|
|
)
|
|
compile_meta["debug"] = self.inductor_meta.get(
|
|
"assert_indirect_indexing", True
|
|
) and not self.inductor_meta.get("is_hip", False)
|
|
|
|
# device type will be "hip" rather than "cuda" here
|
|
compile_meta["device_type"] = self.device_props.type
|
|
compile_meta["cc"] = self.device_props.cc
|
|
|
|
if self.device_props.type == "cpu":
|
|
triton_helpers.set_driver_to_cpu()
|
|
else:
|
|
triton_helpers.set_driver_to_gpu()
|
|
|
|
if not ASTSource:
|
|
raise RuntimeError("Installed triton version too old, please upgrade")
|
|
|
|
compile_args = (
|
|
ASTSource(
|
|
self.fn,
|
|
compile_meta["signature"],
|
|
compile_meta["constants"],
|
|
compile_meta["configs"][0],
|
|
),
|
|
)
|
|
|
|
if self.device_props.type == "mtia":
|
|
from mtia.host_runtime.torch_mtia.acc_flags import ( # type: ignore[import-not-found]
|
|
build_codename,
|
|
)
|
|
|
|
arch = build_codename()
|
|
else:
|
|
arch = compile_meta["cc"]
|
|
|
|
target = GPUTarget(
|
|
compile_meta["device_type"],
|
|
arch,
|
|
cc_warp_size(compile_meta["cc"]),
|
|
)
|
|
|
|
options = {
|
|
"num_warps": compile_meta["num_warps"],
|
|
"num_stages": compile_meta["num_stages"],
|
|
"debug": compile_meta["debug"],
|
|
"sanitize_overflow": False, # turn off additional asserts added for overflow checks
|
|
}
|
|
if HAS_WARP_SPEC:
|
|
options.update(
|
|
{
|
|
"num_consumer_groups": compile_meta.get("num_consumer_groups", 0),
|
|
"num_buffers_warp_spec": compile_meta.get(
|
|
"num_buffers_warp_spec", 0
|
|
),
|
|
}
|
|
)
|
|
if self.device_props.type == "hip":
|
|
if "waves_per_eu" in compile_meta:
|
|
options["waves_per_eu"] = compile_meta["waves_per_eu"]
|
|
if "matrix_instr_nonkdim" in compile_meta:
|
|
options["matrix_instr_nonkdim"] = compile_meta["matrix_instr_nonkdim"]
|
|
compile_kwargs = {
|
|
"target": target,
|
|
"options": options,
|
|
}
|
|
|
|
try:
|
|
binary = triton.compile(*compile_args, **compile_kwargs)
|
|
except Exception:
|
|
log.exception(
|
|
"Triton compilation failed: %s\n%s\nmetadata: %s",
|
|
self.inductor_meta.get("kernel_name", "triton_"),
|
|
self.fn.src,
|
|
compile_meta,
|
|
)
|
|
raise
|
|
TritonBundler.put(
|
|
triton_hash_to_path_key(binary.hash), self.triton_meta.get("device", 0)
|
|
)
|
|
# If the binary has a cubin file to directly launch, save it on the binary
|
|
static_launcher = StaticTritonCompileResult.can_statically_launch(
|
|
binary, self.inductor_meta, self.triton_meta, self.heuristic_type
|
|
)
|
|
|
|
if static_launcher is not None:
|
|
result = StaticTritonCompileResult(
|
|
static_launcher, cfg, compile_meta, self.inductor_meta
|
|
)
|
|
return result
|
|
|
|
return TritonCompileResult(binary, cfg, compile_meta, self.inductor_meta)
|
|
|
|
def _get_args_with_constexprs(self, args, launcher):
|
|
"""
|
|
`args` is passed in with only the non-constexpr args (because the constexpr arg values
|
|
depend on the config). However, in later triton versions, the constexpr args need to be
|
|
added into the args list.
|
|
"""
|
|
if triton_version_uses_attrs_dict():
|
|
# first: aggregate the constexpr args in (index, val) pairs
|
|
# so we can sort them by index.
|
|
constexpr_args: list[tuple[int, Any]] = []
|
|
for arg_name, arg_val in launcher.config.kwargs.items():
|
|
if arg_name in self.fn.arg_names:
|
|
constexpr_args.append((self.fn.arg_names.index(arg_name), arg_val))
|
|
|
|
constexpr_args.sort()
|
|
new_args = [*args]
|
|
for arg_idx, arg_val in constexpr_args:
|
|
new_args.insert(arg_idx, arg_val)
|
|
|
|
return new_args
|
|
return args
|
|
|
|
def bench(self, launcher, *args, with_profiler=False, **kwargs):
|
|
"""Measure the performance of a given launcher"""
|
|
# we don't skip configs with spilled registers when auto-tuning custom
|
|
# (user-written) Triton kernels, as (i) we don't have any knowledge or
|
|
# control over the kernel code; (ii) there is empirical evidence that
|
|
# for some (complicated) custom Triton kernels, a register-spilling
|
|
# config may yield the best latency.
|
|
if not self.custom_kernel and launcher.n_spills > self.inductor_meta.get(
|
|
"spill_threshold", 16
|
|
):
|
|
log.debug(
|
|
"Skip config %s because of register spilling: %d",
|
|
launcher.config,
|
|
launcher.n_spills,
|
|
)
|
|
return float("inf")
|
|
|
|
device_interface = self.get_device_interface()
|
|
stream = device_interface.get_raw_stream(device_interface.current_device())
|
|
|
|
cpu_copies = self.copy_args_to_cpu_if_needed(*args, **kwargs)
|
|
|
|
def kernel_call():
|
|
cloned_args, cloned_kwargs = self.maybe_clone_args(
|
|
cpu_copies, *args, **kwargs
|
|
)
|
|
# reset to zero before evaluating any config
|
|
self.reset_to_zero_args(*args, **kwargs)
|
|
args_with_constexprs = self._get_args_with_constexprs(cloned_args, launcher)
|
|
if autograd_profiler._is_profiler_enabled:
|
|
profiler_kwargs = self.get_profiler_kwargs(stream, launcher)
|
|
with torch._C._profiler._RecordFunctionFast(
|
|
self.inductor_meta.get("kernel_name", "triton kernel"),
|
|
args_with_constexprs,
|
|
profiler_kwargs,
|
|
):
|
|
launcher(
|
|
*args_with_constexprs,
|
|
**cloned_kwargs,
|
|
stream=stream,
|
|
)
|
|
|
|
else:
|
|
launcher(
|
|
*args_with_constexprs,
|
|
**cloned_kwargs,
|
|
stream=stream,
|
|
)
|
|
self.restore_args_from_cpu(cpu_copies)
|
|
|
|
# only use profiler when not already in a profiler instance
|
|
if with_profiler and not autograd_profiler._is_profiler_enabled:
|
|
from torch._inductor.utils import do_bench_using_profiling
|
|
|
|
return do_bench_using_profiling(kernel_call, warmup=10, rep=40)
|
|
|
|
if self.device_props.type == "cpu":
|
|
return benchmarker.benchmark_cpu(kernel_call)
|
|
|
|
return benchmarker.benchmark_gpu(kernel_call, rep=40)
|
|
|
|
def copy_args_to_cpu_if_needed(self, *args, **kwargs):
|
|
"""
|
|
To support benchmarking in the presence of mutated args, we need to avoid
|
|
autotuning contanminating them. We try to pass cloned args to the kernel.
|
|
If those clones would increase the peak memory usage, however, we instead
|
|
copy to cpu and restore them after each iteration. Figure out the args
|
|
to be copied and do the copying.
|
|
"""
|
|
if not self.optimize_mem:
|
|
return {}
|
|
|
|
copies = {}
|
|
budget = torch.cuda.max_memory_allocated() - torch.cuda.memory_allocated()
|
|
|
|
def maybe_copy(name, arg):
|
|
if name in self.mutated_arg_names and arg.is_cuda:
|
|
nonlocal budget
|
|
assert isinstance(arg, torch.Tensor)
|
|
required_storage_length = compute_required_storage_length(
|
|
arg.size(),
|
|
arg.stride(),
|
|
0,
|
|
)
|
|
size = required_storage_length * arg.element_size()
|
|
if size > budget:
|
|
cpu_arg = torch.empty_strided(
|
|
(required_storage_length,),
|
|
(1,),
|
|
dtype=arg.dtype,
|
|
device="cpu",
|
|
pin_memory=True,
|
|
)
|
|
cpu_arg.copy_(
|
|
arg.as_strided((required_storage_length,), (1,)),
|
|
non_blocking=True,
|
|
)
|
|
copies[name] = (arg, cpu_arg)
|
|
else:
|
|
budget -= size
|
|
|
|
for name, arg in zip(self.fn.arg_names, args):
|
|
maybe_copy(name, arg)
|
|
|
|
for name, arg in kwargs.items():
|
|
maybe_copy(name, arg)
|
|
|
|
return copies
|
|
|
|
def restore_args_from_cpu(self, cpu_copies):
|
|
for pair in cpu_copies.values():
|
|
arg, cpu_arg = pair
|
|
required_storage_length = compute_required_storage_length(
|
|
arg.size(),
|
|
arg.stride(),
|
|
0,
|
|
)
|
|
arg.as_strided((required_storage_length,), (1,)).copy_(
|
|
cpu_arg, non_blocking=True
|
|
)
|
|
|
|
def reset_to_zero_args(self, *args, **kwargs):
|
|
if not self.reset_to_zero_arg_names:
|
|
return
|
|
for i, arg in enumerate(args):
|
|
if self.fn.arg_names[i] in self.reset_to_zero_arg_names:
|
|
assert isinstance(
|
|
arg,
|
|
torch.Tensor,
|
|
), (
|
|
"self.reset_to_zero_arg_names should only contain valid argument names"
|
|
)
|
|
arg.zero_()
|
|
|
|
for name, arg in kwargs.items():
|
|
if name in self.reset_to_zero_arg_names:
|
|
assert isinstance(
|
|
arg,
|
|
torch.Tensor,
|
|
), (
|
|
"self.reset_to_zero_arg_names should only contain valid argument names"
|
|
)
|
|
arg.zero_()
|
|
|
|
def maybe_clone_args(
|
|
self, exclude: Container[str], *args, **kwargs
|
|
) -> tuple[list[Any], dict[str, Any]]:
|
|
"""
|
|
Prepare new args and kwargs by cloning any in-place buffers
|
|
(that are not in the provided exclusion list), to avoid autotune
|
|
contaminating them. Avoid cloning the other buffers because it
|
|
leads to increased memory usage.
|
|
"""
|
|
from ..compile_fx import clone_preserve_strides
|
|
|
|
def prepare_arg(name, arg):
|
|
if name in self.mutated_arg_names and name not in exclude:
|
|
assert isinstance(arg, torch.Tensor)
|
|
return clone_preserve_strides(arg)
|
|
else:
|
|
return arg
|
|
|
|
cloned_args = [
|
|
prepare_arg(name, arg)
|
|
for name, arg in itertools.zip_longest(self.fn.arg_names[: len(args)], args)
|
|
]
|
|
cloned_kwargs = {name: prepare_arg(name, arg) for name, arg in kwargs.items()}
|
|
return cloned_args, cloned_kwargs
|
|
|
|
def clone_args(self, *args, **kwargs) -> tuple[list[Any], dict[str, Any]]:
|
|
return self.maybe_clone_args(OrderedSet(), *args, **kwargs)
|
|
|
|
def benchmark_all_configs(self, *args, **kwargs):
|
|
with (
|
|
dynamo_timed(
|
|
"CachingAutotuner.benchmark_all_configs",
|
|
log_pt2_compile_event=True,
|
|
metadata={"kernel_name": self.inductor_meta.get("kernel_name")},
|
|
dynamo_compile_column_us="runtime_triton_autotune_time_us",
|
|
compile_id=self.compile_id,
|
|
is_backward=self.is_backward,
|
|
log_waitcounter=True,
|
|
waitcounter_name_override="triton_autotuner",
|
|
),
|
|
# Temporarily disable due to spam
|
|
# compilation_callback.callback_handler.install_callbacks(
|
|
# compilation_callback.CallbackTrigger.TRITON_AUTOTUNING,
|
|
# str(self.compile_id),
|
|
# ),
|
|
):
|
|
timings = {
|
|
launcher: self.bench(launcher, *args, **kwargs)
|
|
for launcher in self.launchers
|
|
}
|
|
|
|
for k, v in timings.items():
|
|
self.coordesc_tuner.cache_benchmark_result(k.config, v)
|
|
|
|
if log.isEnabledFor(logging.DEBUG):
|
|
log.debug("Benchmark all input configs for %s, get:", self.fn.__name__)
|
|
for k, v in timings.items():
|
|
log.debug(
|
|
"%s: %f, nreg %d, nspill %d, #shared-mem %s",
|
|
k.config,
|
|
v,
|
|
k.n_regs,
|
|
k.n_spills,
|
|
k.shared,
|
|
)
|
|
|
|
self.reset_to_zero_args(*args, **kwargs)
|
|
return timings
|
|
|
|
def autotune_to_one_config(self, *args, **kwargs):
|
|
"""Do the actual autotuning"""
|
|
start_time = time.time_ns()
|
|
timings = self.benchmark_all_configs(*args, **kwargs)
|
|
benchmark_time_taken_ns = time.time_ns() - start_time
|
|
self.launchers = [builtins.min(timings, key=timings.get)]
|
|
self.autotune_time_taken_ns = (
|
|
self.precompile_time_taken_ns + benchmark_time_taken_ns
|
|
)
|
|
|
|
# log the best config
|
|
launcher = self.launchers[0]
|
|
log.debug(
|
|
"Best config for %s: %s: %f, nreg %d, nspill %d, #shared-mem %s",
|
|
self.fn.__name__,
|
|
launcher.config,
|
|
timings[launcher],
|
|
launcher.n_regs,
|
|
launcher.n_spills,
|
|
launcher.shared,
|
|
)
|
|
|
|
if self.save_cache_hook:
|
|
self.save_cache_hook(
|
|
launcher.config,
|
|
self.autotune_time_taken_ns,
|
|
triton_cache_hash=launcher.cache_hash,
|
|
)
|
|
|
|
def save_gpu_kernel(self, stream, launcher):
|
|
key = self.inductor_meta.get("kernel_name", None) # unique kernel name
|
|
assert key is not None, "kernel_name can not be None"
|
|
params = {
|
|
"mangled_name": (
|
|
launcher.bin.metadata.name
|
|
if hasattr(launcher.bin.metadata, "name")
|
|
else launcher.bin.metadata["name"]
|
|
),
|
|
"num_warps": (
|
|
launcher.bin.num_warps
|
|
if hasattr(launcher.bin, "num_warps")
|
|
else launcher.bin.metadata.num_warps
|
|
),
|
|
"shared_mem": (
|
|
launcher.bin.shared
|
|
if hasattr(launcher.bin, "shared")
|
|
else launcher.bin.metadata.shared
|
|
),
|
|
"stream": stream,
|
|
# User defined triton kernels will have arbitrary kwarg names
|
|
"config": config_to_dict(launcher.config),
|
|
"inductor_meta": self.inductor_meta,
|
|
"triton_meta": self.triton_meta,
|
|
"def_args": launcher.def_args,
|
|
"call_args": launcher.call_args,
|
|
"global_scratch": launcher.global_scratch,
|
|
"profile_scratch": launcher.profile_scratch,
|
|
}
|
|
from torch._inductor.codecache import CudaKernelParamCache
|
|
|
|
bin_type = {"hip": "hsaco", "xpu": "spv"}.get(self.device_props.type, "cubin")
|
|
binary = launcher.bin.asm[bin_type]
|
|
# Also store asm code which can be used for debugging and generating cpp package
|
|
asm_type = {"hip": "amdgcn", "cuda": "ptx", "xpu": "spv"}.get(
|
|
self.device_props.type, None
|
|
)
|
|
asm = launcher.bin.asm.get(asm_type, None)
|
|
|
|
CudaKernelParamCache.set(key, params, binary, bin_type, asm, asm_type)
|
|
self.cuda_kernel_saved = True
|
|
|
|
def coordinate_descent_tuning(self, launcher, *args, **kwargs):
|
|
"""
|
|
Coordinate descent tuning can be run with or without max-autotune.
|
|
|
|
The only difference between these two is the starting config for coordinate_descent tuning.
|
|
E.g., assuming regular autotune only get one config C1; while max-autotune get 4 configs C1, C2, C3, C4
|
|
and max-autotune figure out C3 is the best.
|
|
|
|
Then if coordinate desecnt tuning is run with max-autotune disabled, it will start from C1;
|
|
while if coordinate descent tuning is run with max-autotune enabled, it will start from C3.
|
|
"""
|
|
if (
|
|
self.heuristic_type == HeuristicType.TEMPLATE
|
|
or self.heuristic_type == HeuristicType.USER_AUTOTUNE
|
|
):
|
|
# skip triton template
|
|
return launcher
|
|
|
|
with dynamo_timed(
|
|
"CachingAutotuner.coordinate_descent_tuning",
|
|
# These generate too many pt2_compile_event logs:
|
|
log_pt2_compile_event=False,
|
|
metadata={"kernel_name": self.inductor_meta.get("kernel_name")},
|
|
dynamo_compile_column_us="runtime_triton_autotune_time_us",
|
|
compile_id=self.compile_id,
|
|
is_backward=self.is_backward,
|
|
log_waitcounter=True,
|
|
waitcounter_name_override="triton_autotuner",
|
|
):
|
|
return self._coordinate_descent_tuning(launcher, *args, **kwargs)
|
|
|
|
def _coordinate_descent_tuning(self, launcher, *args, **kwargs):
|
|
config2launcher = {launcher.config: launcher}
|
|
|
|
# TODO: should we just load the kernels ahead of time if we know we're going to call this?
|
|
if self.fn.fn is None:
|
|
"""
|
|
We are in the parent process, while this program was compiled in a worker
|
|
and the fn was dropped in prepare_for_pickle(). We haven't loaded the module
|
|
containing the real fn yet.
|
|
"""
|
|
assert hasattr(self, "_reload_kernel")
|
|
assert callable(self._reload_kernel)
|
|
self.fn = self._reload_kernel().fn
|
|
|
|
def benchmark_one_config(config):
|
|
with self.lock:
|
|
launcher = self._precompile_config(config).make_launcher()
|
|
config2launcher[config] = launcher
|
|
|
|
out = self.bench(launcher, *args, **kwargs)
|
|
log.debug(
|
|
"COORDESC: %s: %f, nreg %d, nspill %d, #shared-mem %d",
|
|
launcher.config,
|
|
out,
|
|
launcher.n_regs,
|
|
launcher.n_spills,
|
|
launcher.shared,
|
|
)
|
|
return out
|
|
|
|
assert not (
|
|
self.heuristic_type == HeuristicType.PERSISTENT_REDUCTION
|
|
and "R0_BLOCK" in launcher.config.kwargs
|
|
), (
|
|
"Coordinate descent tuner relies on the assumption that persistent reduction's triton config does not have R0_BLOCK"
|
|
)
|
|
start_time = time.time_ns()
|
|
best_config = self.coordesc_tuner.autotune(
|
|
benchmark_one_config, launcher.config, None
|
|
)
|
|
coordesc_time_taken_ns = time.time_ns() - start_time
|
|
best_config.found_by_coordesc = True
|
|
|
|
if self.save_cache_hook:
|
|
self.save_cache_hook(
|
|
best_config,
|
|
self.autotune_time_taken_ns + coordesc_time_taken_ns,
|
|
found_by_coordesc=True,
|
|
)
|
|
|
|
if best_config not in config2launcher:
|
|
# On a Coordesc cache hit, we might not have loaded the launcher
|
|
# This can happen because PyCodeCache saves CachingAutotuners in memory,
|
|
# even for separate compile IDs (which can have different inputs without changing output code)
|
|
config2launcher[best_config] = self._precompile_config(
|
|
best_config
|
|
).make_launcher()
|
|
|
|
fn_hash = generate_lookup_hash_from_source_code(
|
|
str(self.size_hints), self.fn.src
|
|
)
|
|
log.debug("Function hash %s has best config %s", fn_hash, best_config)
|
|
return config2launcher[best_config]
|
|
|
|
def get_profiler_kwargs(self, stream, launcher):
|
|
kernel_kwargs_str = ",".join(
|
|
f"{k}={v}" for (k, v) in launcher.config.kwargs.items()
|
|
)
|
|
|
|
ret = {
|
|
"kernel_file": (self.filename or ""),
|
|
"kernel_hash": self.kernel_hash,
|
|
"kernel_backend": "triton",
|
|
"stream": stream,
|
|
"num_warps": launcher.config.num_warps,
|
|
"num_stages": launcher.config.num_stages,
|
|
"kernel_kwargs": kernel_kwargs_str,
|
|
}
|
|
if "kernel_name" in self.inductor_meta:
|
|
ret["kernel_name"] = self.inductor_meta["kernel_name"]
|
|
if "kernel_flop" in self.inductor_meta:
|
|
ret["kernel_flop"] = self.inductor_meta["kernel_flop"]
|
|
if "kernel_num_gb" in self.inductor_meta:
|
|
ret["kernel_num_gb"] = self.inductor_meta["kernel_num_gb"]
|
|
return ret
|
|
|
|
def run(
|
|
self,
|
|
*args,
|
|
stream,
|
|
benchmark_run=False,
|
|
**kwargs,
|
|
): # type:ignore[override]
|
|
if hasattr(triton, "set_allocator"):
|
|
|
|
def alloc_fn(size: int, align: int, stream: Optional[int]):
|
|
return torch.empty(
|
|
size, dtype=torch.int8, device=self.device_props.type
|
|
)
|
|
|
|
triton.set_allocator(alloc_fn)
|
|
|
|
if self.triton_interpret:
|
|
args, grid = self._interpret_args_grid(args, self.configs[0])
|
|
return self.fn[grid](
|
|
*args,
|
|
**kwargs,
|
|
**self.configs[0].kwargs,
|
|
)
|
|
|
|
if len(self.launchers) != 1:
|
|
if len(self.launchers) == 0:
|
|
start_time = time.time_ns()
|
|
self.precompile()
|
|
self.precompile_time_taken_ns = time.time_ns() - start_time
|
|
if len(self.launchers) > 1:
|
|
self.autotune_to_one_config(*args, **kwargs)
|
|
|
|
if not getattr(
|
|
self.launchers[0].config, "found_by_coordesc", False
|
|
) and self.inductor_meta.get("coordinate_descent_tuning", False):
|
|
self.launchers = [
|
|
self.coordinate_descent_tuning(self.launchers[0], *args, **kwargs)
|
|
]
|
|
|
|
(launcher,) = self.launchers
|
|
if launcher.store_cubin and (not benchmark_run or not self.cuda_kernel_saved):
|
|
self.save_gpu_kernel(stream, launcher)
|
|
|
|
# PyTorch execution trace replay calls CachingAutotuner::run() instead of calls launcher
|
|
# so _RecordFunctionFast need to capture the args into CachingAutotuner::run()
|
|
# make a copy here to avoid mutating the original args
|
|
args_without_constexprs = tuple(args)
|
|
args = self._get_args_with_constexprs(args, launcher)
|
|
|
|
if self.dump_launch_params:
|
|
new_args, grid = self._interpret_args_grid(args, launcher.config)
|
|
_dump_launch_params(new_args, kwargs, launcher, self.fn.__name__, grid)
|
|
|
|
# it is faster than entering and exiting a context manager, even if the context
|
|
# manager is a nullcontext.
|
|
if autograd_profiler._is_profiler_enabled:
|
|
profiler_kwargs = self.get_profiler_kwargs(stream, launcher)
|
|
|
|
with torch._C._profiler._RecordFunctionFast(
|
|
self.inductor_meta.get("kernel_name", "triton kernel"),
|
|
args_without_constexprs,
|
|
profiler_kwargs,
|
|
):
|
|
return launcher(
|
|
*args,
|
|
**kwargs,
|
|
stream=stream,
|
|
)
|
|
else:
|
|
return launcher(
|
|
*args,
|
|
**kwargs,
|
|
stream=stream,
|
|
)
|
|
|
|
def _interpret_args_grid(
|
|
self, args: tuple[Any, ...], cfg: Config
|
|
) -> tuple[tuple[Any, ...], tuple[int, int, int]]:
|
|
grid = GridExpr.from_meta(self.inductor_meta, cfg).eval_slow(
|
|
dict(
|
|
zip(
|
|
[
|
|
*self.triton_meta["signature"].keys(),
|
|
*self.inductor_meta.get("extra_launcher_args", ()),
|
|
],
|
|
args,
|
|
)
|
|
)
|
|
)
|
|
if self.inductor_meta.get("extra_launcher_args"):
|
|
args = args[: -len(self.inductor_meta["extra_launcher_args"])]
|
|
return args, grid
|
|
|
|
|
|
class _ConstRepr:
|
|
def __init__(self, value: str):
|
|
self.value = value
|
|
|
|
def __call__(self, _=None) -> str:
|
|
return self.value
|
|
|
|
|
|
class CompileResult(Generic[_T]):
|
|
def __init__(
|
|
self,
|
|
kernel: _T,
|
|
config: Config,
|
|
compile_meta: dict[str, Any],
|
|
inductor_meta: dict[str, Any],
|
|
):
|
|
self.kernel = kernel
|
|
self.config = config
|
|
self.compile_meta = compile_meta
|
|
self.inductor_meta = inductor_meta
|
|
|
|
def make_launcher(self) -> LauncherType: ...
|
|
|
|
def _gen_launcher_code(self, scope, def_args, runner_args) -> LauncherType:
|
|
grid = GridExpr.from_meta(self.inductor_meta, self.config)
|
|
# grid.prefix is usually empty, grid.x_grid is something like `-(xnumel//-1024)`
|
|
lines = [
|
|
f"def launcher({', '.join(def_args)}, stream):",
|
|
*[f" {line}" for line in grid.prefix],
|
|
f" grid_0 = {grid.x_grid}",
|
|
f" grid_1 = {grid.y_grid}",
|
|
f" grid_2 = {grid.z_grid}",
|
|
f" runner({', '.join(runner_args)})",
|
|
]
|
|
launcher_code = "\n".join(lines)
|
|
exec(launcher_code, scope)
|
|
return scope["launcher"]
|
|
|
|
def _get_arg_lists(
|
|
self, arg_names, constexprs
|
|
) -> tuple[list[str], list[str], OrderedSet[str]]:
|
|
"""
|
|
Return a bunch of intermediate lists of args needed for generating
|
|
launcher code.
|
|
"""
|
|
compile_meta = self.compile_meta
|
|
cfg = self.config
|
|
known_constants = OrderedSet(
|
|
arg for i, arg in enumerate(arg_names) if i in constexprs
|
|
)
|
|
|
|
"""
|
|
https://github.com/pytorch/pytorch/issues/115344
|
|
|
|
self.fn.constexprs doesn't properly deal with None args, so when we filter out
|
|
an arg in UserDefinedTritonKernel.codegen, we need to filter it here as well.
|
|
We also don't want to modify self.fn.
|
|
|
|
We know that we removed something from the signature if:
|
|
1. It's in compile_meta["constants"]
|
|
2. It isn't a constant we already know about
|
|
Note: The value of interest has already been added to compile_meta['constants'],
|
|
so we use self.fn.constexprs instead.
|
|
3. It isn't in the compile_meta signature
|
|
"""
|
|
none_args = OrderedSet(
|
|
k
|
|
for k, v in compile_meta["constants"].items()
|
|
if v is None and k not in known_constants
|
|
)
|
|
none_args = none_args.difference(OrderedSet(compile_meta["signature"].keys()))
|
|
|
|
if triton_version_uses_attrs_dict():
|
|
call_args = arg_names
|
|
def_args = arg_names
|
|
if (
|
|
"num_warps" in compile_meta["constants"]
|
|
or "num_stages" in compile_meta["constants"]
|
|
):
|
|
# num_warps/num_stages are special implicit args that are not in the signature
|
|
# see test_triton_kernel_special_params
|
|
def_args = [
|
|
arg for arg in def_args if arg not in ("num_warps", "num_stages")
|
|
]
|
|
repl = {
|
|
k: str(compile_meta["constants"].get(k))
|
|
for k in ("num_warps", "num_stages")
|
|
}
|
|
call_args = [repl.get(arg, arg) for arg in call_args]
|
|
else:
|
|
call_args = [
|
|
arg
|
|
for i, arg in enumerate(arg_names)
|
|
if i not in constexprs and arg not in none_args
|
|
]
|
|
cfg_dict = config_to_dict(cfg)
|
|
def_args = [
|
|
name
|
|
for name in arg_names
|
|
if name not in cfg_dict and name not in none_args
|
|
]
|
|
|
|
if "extra_launcher_args" in self.inductor_meta:
|
|
def_args = [*def_args, *self.inductor_meta["extra_launcher_args"]]
|
|
|
|
return call_args, def_args, none_args
|
|
|
|
|
|
class CannotStaticallyLaunchKernel(Exception):
|
|
pass
|
|
|
|
|
|
class StaticTritonCompileResult(CompileResult[StaticallyLaunchedCudaKernel]):
|
|
"""
|
|
TritonCompileResult that uses StaticCudaLauncher,
|
|
which vastly simplifies the setup and metadata needed to be kept.
|
|
"""
|
|
|
|
@staticmethod
|
|
def can_statically_launch(
|
|
kernel: CompiledKernel,
|
|
inductor_meta: dict[str, Any],
|
|
triton_meta: dict[str, Any],
|
|
heuristic_type: HeuristicType,
|
|
) -> Optional[StaticallyLaunchedCudaKernel]:
|
|
if not torch._inductor.config.use_static_cuda_launcher:
|
|
return None
|
|
|
|
def check_can_launch() -> StaticallyLaunchedCudaKernel:
|
|
if triton_meta.get("device_type", None) != "cuda":
|
|
# Only cuda kernels
|
|
raise CannotStaticallyLaunchKernel("Non-cuda device")
|
|
|
|
if torch._inductor.config.cpp_wrapper:
|
|
# If we're running with cpp wrapper, it doesn't
|
|
# make sense to statically compile since everything
|
|
# is codegenned anyway
|
|
raise CannotStaticallyLaunchKernel("Cpp wrapper enabled")
|
|
|
|
if (
|
|
heuristic_type == HeuristicType.USER_AUTOTUNE
|
|
and not torch._inductor.config.static_launch_user_defined_triton_kernels
|
|
):
|
|
# Don't support user defined triton kernels yet
|
|
raise CannotStaticallyLaunchKernel("User defined triton kernel")
|
|
|
|
if inductor_meta.get("store_cubin", None):
|
|
# Requires storing the entire binary
|
|
raise CannotStaticallyLaunchKernel("store_cubin is enabled")
|
|
|
|
cubin_location = os.path.join(
|
|
triton_cache_dir(triton_meta.get("device", 0)),
|
|
triton_hash_to_path_key(kernel.hash),
|
|
f"{kernel.src.fn.__name__}.cubin",
|
|
)
|
|
|
|
if not os.path.exists(cubin_location):
|
|
raise CannotStaticallyLaunchKernel(
|
|
f"Cubin path not found: {cubin_location}"
|
|
)
|
|
|
|
else:
|
|
kernel._cubin_path = cubin_location
|
|
|
|
try:
|
|
static_kernel = StaticallyLaunchedCudaKernel(kernel)
|
|
except NotImplementedError as e:
|
|
raise CannotStaticallyLaunchKernel(f"NotImplemented: {str(e)}") from e
|
|
|
|
return static_kernel
|
|
|
|
try:
|
|
result = check_can_launch()
|
|
return result
|
|
except CannotStaticallyLaunchKernel as e:
|
|
log.info("Bypassing StaticallyLaunchedCudaKernel due to %s", str(e))
|
|
if torch._inductor.config.strict_static_cuda_launcher:
|
|
raise e
|
|
return None
|
|
|
|
def reload_cubin_path(self):
|
|
"""
|
|
When loading from cache on disk, we want to reload cubin
|
|
files from their appropriate location on disc.
|
|
"""
|
|
cubin_location = os.path.join(
|
|
triton_cache_dir(self.compile_meta.get("device", 0)),
|
|
triton_hash_to_path_key(self.kernel.hash),
|
|
f"{self.kernel.name}.cubin",
|
|
)
|
|
if not os.path.exists(cubin_location):
|
|
if self.kernel.cubin_raw is not None:
|
|
# We saved the raw cubin, so write it to he appropriate location
|
|
self.kernel.reload_cubin_from_raw(cubin_location)
|
|
else:
|
|
raise RuntimeError(
|
|
"Cubin file saved by TritonBundler not found at %s", cubin_location
|
|
)
|
|
self.kernel.cubin_path = cubin_location
|
|
|
|
def make_launcher(self) -> LauncherType:
|
|
# If at least one static make_launcher call occurs,
|
|
# we're sure static cuda launcher was used for this compile
|
|
set_feature_use("static_cuda_launcher", True)
|
|
# Load the binary on the parent
|
|
if not self.kernel.cubin_path:
|
|
self.reload_cubin_path()
|
|
device = self.compile_meta.get("device", 0)
|
|
if device is None:
|
|
device = 0
|
|
self.kernel.load_kernel(device)
|
|
scope = {
|
|
"runner": self.kernel.run,
|
|
}
|
|
|
|
# NOTE: Constexpr handling for triton and static cuda launcher
|
|
|
|
# Triton kernels have two types of constexprs: *declared* ones, which are ones the user
|
|
# has explicitly declared as tl.constexpr, and *implied* ones, which are expressions triton
|
|
# deems constant while compiling/analyzing the code (i.e. unused parameters, for example)
|
|
|
|
# Triton kernels handle constexprs slightly differently depending on which version of triton
|
|
# we care about (we support 3.2.0 and 3.3.0).
|
|
|
|
# In 3.2.0, triton kernels do not require passing any declared constexprs into the kernel
|
|
# In 3.3.0, triton kernels require all declared constexprs be passed into the kernel, where
|
|
# they are subsequently ignored.
|
|
# When statically launching, since we're launching from the triton generated cubin, we actually want to
|
|
# always get rid of all const exprs, declared or implied, since the underlying cubin file has all
|
|
# of the constants stripped away anyway.
|
|
|
|
# But CachingAutotuner.run will pass us a different number of arguments depending on
|
|
# whether or not we're in triton 3.2.0 or later, so we grab def_args with the same logic
|
|
# as the (non static) TritonCompileResult. We then generate call_args ourselves, since we
|
|
# want only a subset of the arguments passed to triton.
|
|
# Here, arg_names is exactly fn.src.arg_names and declared_constexprs is exactly fn.src.constexprs,
|
|
# which matches behavior with regular TritonCompileResult
|
|
_, def_args, none_args = self._get_arg_lists(
|
|
self.kernel.arg_names, self.kernel.declared_constexprs
|
|
)
|
|
|
|
call_args = [
|
|
arg
|
|
for i, arg in enumerate(self.kernel.arg_names)
|
|
if i not in self.kernel.full_constexprs and arg not in none_args
|
|
]
|
|
|
|
# StaticallyLaunchedCudaKernel.run takes in order grid_0, grid_1, grid_2, stream, and call_args
|
|
runner_args = ["grid_0", "grid_1", "grid_2", "stream", *call_args]
|
|
launcher = self._gen_launcher_code(scope, def_args, runner_args)
|
|
launcher.config = self.config # type: ignore[attr-defined]
|
|
launcher.n_regs = self.kernel.n_regs # type: ignore[attr-defined]
|
|
launcher.n_spills = self.kernel.n_spills # type: ignore[attr-defined]
|
|
launcher.shared = self.kernel.shared # type: ignore[attr-defined]
|
|
launcher.cache_hash = triton_hash_to_path_key(self.kernel.hash) # type: ignore[attr-defined]
|
|
launcher.store_cubin = False # type: ignore[attr-defined]
|
|
launcher._is_static = True # type: ignore[attr-defined]
|
|
return launcher
|
|
|
|
|
|
class TritonCompileResult(CompileResult[CompiledKernel]):
|
|
"""
|
|
Upstream Triton CompileKernel can not be pickled. This is a wrapper
|
|
to support serialization and generate the launcher function.
|
|
"""
|
|
|
|
@staticmethod
|
|
@functools.lru_cache(32)
|
|
def _kernel_metadata_cls(fields: tuple[str, ...]) -> Any:
|
|
return namedtuple("KernelMetadata", sorted(fields))
|
|
|
|
@staticmethod
|
|
def _serialize_metadata(metadata):
|
|
"""
|
|
Triton uses a nested class called KernelMetadata to store metadata information.
|
|
Pickle does not work well with nested namedtuples, as the namedtuple doesn't appear
|
|
in the toplevel namespace of the module. So these serialization/deser functions
|
|
are used to convert the namedtuples to a dict and back.
|
|
|
|
As for packed_metadata, depending on the triton backend, KernelMetadata can be
|
|
a namedtuple, or a regular tuple! So the serialization function branches on whether
|
|
the metadata to be serialized is a namedtuple or regular, serializable one.
|
|
"""
|
|
|
|
def is_namedtuple(obj) -> bool:
|
|
return (
|
|
isinstance(obj, tuple)
|
|
and hasattr(obj, "_asdict")
|
|
and hasattr(obj, "_fields")
|
|
)
|
|
|
|
if is_namedtuple(metadata):
|
|
return metadata._asdict()
|
|
else:
|
|
return metadata
|
|
|
|
@staticmethod
|
|
def _deserialize_metadata(metadata):
|
|
if isinstance(metadata, dict):
|
|
return TritonCompileResult._kernel_metadata_cls(tuple(metadata.keys()))(
|
|
**metadata
|
|
)
|
|
else:
|
|
return metadata
|
|
|
|
def __getstate__(self) -> dict[str, Any]:
|
|
kernel = self.kernel
|
|
# replace the fields that don't pickle nicely
|
|
kernel_state = {
|
|
**kernel.__dict__,
|
|
# See doc about serializing metadata above
|
|
"metadata": self._serialize_metadata(kernel.metadata),
|
|
"packed_metadata": self._serialize_metadata(
|
|
getattr(kernel, "packed_metadata", None)
|
|
),
|
|
"module": None, # regenerated by kernel._init_handles()
|
|
"function": None, # regenerated by kernel._init_handles()
|
|
"run": None, # regenerated by kernel._init_handles()
|
|
}
|
|
return {**self.__dict__, "kernel": kernel_state} # type: ignore[dict-item]
|
|
|
|
def __setstate__(self, state: dict[str, Any]) -> None:
|
|
# src = ASTSource.__new__(ASTSource)
|
|
# src.__setstate__(state["kernel"]["src"])
|
|
# TODO(jansel): need to fixup src.fn which is now None
|
|
kernel = CompiledKernel.__new__(CompiledKernel)
|
|
metadata = state["kernel"]["metadata"]
|
|
packed_metadata = state["kernel"]["packed_metadata"]
|
|
kernel.__dict__.update(
|
|
{
|
|
**state["kernel"],
|
|
# "src": src,
|
|
"metadata": self._deserialize_metadata(metadata),
|
|
"packed_metadata": self._deserialize_metadata(packed_metadata),
|
|
}
|
|
)
|
|
self.__dict__.update(state)
|
|
self.kernel = kernel
|
|
|
|
def make_launcher(self) -> LauncherType:
|
|
"""
|
|
Launching triton kernels is performance sensitive, we compile
|
|
a custom Python function get the grid() and reorder the args to
|
|
the underlying wrapper.
|
|
"""
|
|
cfg = self.config
|
|
compile_meta = self.compile_meta
|
|
binary = self.kernel
|
|
fn = binary.src.fn
|
|
binary._init_handles()
|
|
(call_args, def_args, none_args) = self._get_arg_lists(
|
|
fn.arg_names, fn.constexprs
|
|
)
|
|
binary_shared = (
|
|
binary.shared if hasattr(binary, "shared") else binary.metadata.shared
|
|
)
|
|
|
|
if knobs is None:
|
|
launch_enter = binary.__class__.launch_enter_hook
|
|
launch_exit = binary.__class__.launch_exit_hook
|
|
else:
|
|
launch_enter = knobs.runtime.launch_enter_hook
|
|
launch_exit = knobs.runtime.launch_exit_hook
|
|
|
|
import math as math_lib
|
|
|
|
import torch as torch_lib
|
|
|
|
scope = {
|
|
"grid_meta": cfg.kwargs,
|
|
"bin": binary,
|
|
"launch_enter_hook": launch_enter,
|
|
"launch_exit_hook": launch_exit,
|
|
"metadata": (
|
|
binary.packed_metadata
|
|
if hasattr(binary, "packed_metadata")
|
|
else binary.metadata
|
|
),
|
|
"shared": binary_shared,
|
|
"num_warps": (
|
|
binary.num_warps
|
|
if hasattr(binary, "num_warps")
|
|
else binary.metadata.num_warps
|
|
),
|
|
"cta_args": (
|
|
(
|
|
binary.num_ctas,
|
|
*get_first_attr(binary, "cluster_dims", "clusterDims"),
|
|
)
|
|
if hasattr(binary, "num_ctas")
|
|
else (
|
|
(binary.metadata.num_ctas, *binary.metadata.cluster_dims)
|
|
if hasattr(binary, "metadata")
|
|
else ()
|
|
)
|
|
),
|
|
"function": get_first_attr(binary, "function", "cu_function"),
|
|
"runner": get_first_attr(binary, "run", "c_wrapper"),
|
|
"math": math_lib,
|
|
"torch": torch_lib,
|
|
}
|
|
|
|
if not hasattr(binary, "launch_metadata"):
|
|
# launch args before CompiledKernel.launch_metadata is added.
|
|
# TODO(jansel): delete this branch in mid-2025
|
|
runner_args = [
|
|
"grid_0",
|
|
"grid_1",
|
|
"grid_2",
|
|
"num_warps",
|
|
"*cta_args",
|
|
"shared",
|
|
"stream",
|
|
"function",
|
|
"launch_enter_hook",
|
|
"launch_exit_hook",
|
|
"metadata",
|
|
*call_args,
|
|
]
|
|
else: # args after CompiledKernel.launch_metadata: https://github.com/triton-lang/triton/pull/3492
|
|
# Getting the kernel launch args is extremely perf-sensitive. Evaluating
|
|
# `bin.launch_metadata` is relatively expensive, and returns None unless a
|
|
# `launch_enter_hook` is installed. So if we don't have that hook installed,
|
|
# we want to burn None in to the launch args with zero overhead.
|
|
# See https://github.com/pytorch/pytorch/issues/123597
|
|
if launch_enter:
|
|
launch_metadata = f"bin.launch_metadata((grid_0, grid_1, grid_2), stream, {', '.join(call_args)})"
|
|
else:
|
|
launch_metadata = "None"
|
|
runner_args = [
|
|
"grid_0",
|
|
"grid_1",
|
|
"grid_2",
|
|
"stream",
|
|
"function",
|
|
"metadata",
|
|
launch_metadata,
|
|
"launch_enter_hook",
|
|
"launch_exit_hook",
|
|
*call_args,
|
|
]
|
|
|
|
launcher = self._gen_launcher_code(scope, def_args, runner_args)
|
|
|
|
launcher = scope["launcher"]
|
|
launcher.config = cfg
|
|
launcher.n_regs = getattr(binary, "n_regs", None)
|
|
launcher.n_spills = getattr(binary, "n_spills", None)
|
|
launcher.shared = binary_shared
|
|
launcher.cache_hash = triton_hash_to_path_key(binary.hash)
|
|
launcher.store_cubin = self.inductor_meta.get("store_cubin", False)
|
|
# store this global variable to avoid the high overhead of reading it when calling run
|
|
if launcher.store_cubin:
|
|
launcher.fn = fn
|
|
launcher.bin = binary
|
|
if triton_version_uses_attrs_dict():
|
|
# arg filtering wasn't done above
|
|
cfg_dict = config_to_dict(cfg)
|
|
def_args = [x for x in def_args if x not in cfg_dict]
|
|
call_args = [
|
|
x
|
|
for x in call_args
|
|
if compile_meta["signature"].get(x, "constexpr") != "constexpr"
|
|
and x not in none_args
|
|
]
|
|
launcher.def_args = def_args
|
|
launcher.call_args = call_args
|
|
kernel_metadata = getattr(self.kernel, "metadata", None)
|
|
|
|
# for the scratch arguments: None indicates that the kernel doesn't
|
|
# take any scratch argument; otherwise a number indicates the number
|
|
# of bytes of scratch that need to be provided.
|
|
|
|
# in AMD's Triton backend, the global scratch size is never provided
|
|
# (but for AMD it's safe to pass an extra null arg, so always include it)
|
|
global_scratch: Optional[int] = getattr(
|
|
kernel_metadata,
|
|
"global_scratch_size",
|
|
(0 if torch.version.hip else None),
|
|
)
|
|
profile_scratch: Optional[int] = getattr(
|
|
kernel_metadata, "profile_scratch_size", None
|
|
)
|
|
launcher.global_scratch = global_scratch
|
|
launcher.profile_scratch = profile_scratch
|
|
return launcher
|
|
|
|
|
|
def _find_names(obj):
|
|
import gc
|
|
import inspect
|
|
|
|
frame = inspect.currentframe()
|
|
while frame is not None:
|
|
frame.f_locals
|
|
frame = frame.f_back
|
|
obj_names = []
|
|
for referrer in gc.get_referrers(obj):
|
|
if isinstance(referrer, dict):
|
|
for k, v in referrer.items():
|
|
if v is obj:
|
|
obj_names.append(k)
|
|
return obj_names
|
|
|
|
|
|
collected_calls: list[Any] = []
|
|
|
|
|
|
def start_graph():
|
|
collected_calls.clear()
|
|
|
|
|
|
def end_graph(output_file):
|
|
if len(collected_calls) == 0:
|
|
return
|
|
overall_time = sum(call[0] for call in collected_calls)
|
|
overall_gb = sum(call[1] for call in collected_calls)
|
|
cur_file = inspect.stack()[1].filename
|
|
summary_str = (
|
|
f"SUMMARY ({cur_file})\n"
|
|
f"{overall_time:.2f}ms \t {overall_gb:.2f} GB\t {overall_gb / (overall_time / 1e3):.2f}GB/s"
|
|
)
|
|
log.info(
|
|
"%s",
|
|
summary_str,
|
|
)
|
|
if output_file is not None:
|
|
# sort perf numbers in descending order, i.e. placing the
|
|
# most runtime-heavy kernels at the top of the list
|
|
sorted_calls = sorted(collected_calls, key=lambda c: float(c[0]), reverse=True)
|
|
try:
|
|
with open(output_file, "a") as file:
|
|
log.info(
|
|
"Save profile bandwidth results to %s",
|
|
output_file,
|
|
)
|
|
file.write("====================\n")
|
|
file.write(f"TRITON KERNELS BANDWIDTH INFO ({cur_file})\n")
|
|
for ms, num_gb, gb_per_s, kernel_name in sorted_calls:
|
|
# also display the runtime percentage for each kernel
|
|
percentage = f"{ms / overall_time * 100:.2f}%"
|
|
suffix = f" \t {percentage} \t {kernel_name}"
|
|
bw_info_str = create_bandwidth_info_str(
|
|
ms,
|
|
num_gb,
|
|
gb_per_s,
|
|
suffix=suffix,
|
|
color=False,
|
|
)
|
|
file.write(bw_info_str + "\n")
|
|
file.write(f"{summary_str}\n\n")
|
|
except Exception as e:
|
|
log.warning(
|
|
"failed to write profile bandwidth result into %s: %s",
|
|
output_file,
|
|
e,
|
|
)
|
|
|
|
|
|
class DebugAutotuner(CachingAutotuner):
|
|
def __init__(
|
|
self,
|
|
*args,
|
|
regex_filter="",
|
|
with_profiler=False,
|
|
with_bandwidth_info=True,
|
|
**kwargs,
|
|
):
|
|
self.regex_filter = regex_filter
|
|
self.with_profiler = with_profiler
|
|
self.with_bandwidth_info = with_bandwidth_info
|
|
super().__init__(*args, **kwargs)
|
|
self.cached = None
|
|
|
|
def run(self, *args, stream, **kwargs):
|
|
if not self.with_bandwidth_info:
|
|
super().run(*args, stream=stream, **kwargs, benchmark_run=True)
|
|
return
|
|
else:
|
|
possible_names = _find_names(self)
|
|
kernel_name = f"{max(possible_names, key=len)}"
|
|
if not re.match(self.regex_filter, kernel_name):
|
|
return
|
|
|
|
if len(self.launchers) != 1:
|
|
if len(self.launchers) == 0:
|
|
start_time = time.time_ns()
|
|
self.precompile()
|
|
self.precompile_time_taken_ns = time.time_ns() - start_time
|
|
if len(self.launchers) > 1:
|
|
self.autotune_to_one_config(*args, **kwargs)
|
|
(launcher,) = self.launchers
|
|
|
|
if launcher.store_cubin:
|
|
self.save_gpu_kernel(stream, launcher)
|
|
|
|
if self.cached is None:
|
|
ms = self.bench(launcher, *args, with_profiler=self.with_profiler)
|
|
num_in_out_ptrs = len(
|
|
[
|
|
arg_name
|
|
for arg_name in self.fn.arg_names
|
|
if arg_name.startswith("in_out_ptr")
|
|
]
|
|
)
|
|
num_gb = self.inductor_meta.get("kernel_num_gb", None)
|
|
if num_gb is None:
|
|
num_gb = get_num_bytes(*args, num_in_out_args=num_in_out_ptrs) / 1e9
|
|
gb_per_s = num_gb / (ms / 1e3)
|
|
self.cached = ms, num_gb, gb_per_s, kernel_name
|
|
collected_calls.append((ms, num_gb, gb_per_s, kernel_name))
|
|
log.info(
|
|
"%s",
|
|
create_bandwidth_info_str(
|
|
ms, num_gb, gb_per_s, suffix=f" \t {kernel_name}"
|
|
),
|
|
)
|
|
else:
|
|
# in AOTI, we will call the kernel and its timing info has been cached already
|
|
collected_calls.append(self.cached)
|
|
|
|
|
|
def hash_configs(configs: list[Config]):
|
|
"""
|
|
Hash used to check for changes in configurations
|
|
"""
|
|
hasher = hashlib.sha256()
|
|
for cfg in configs:
|
|
hasher.update(
|
|
f"{sorted(cfg.kwargs.items())} {cfg.num_warps} {cfg.num_stages}\n".encode()
|
|
)
|
|
return hasher.hexdigest()
|
|
|
|
|
|
def cached_autotune(
|
|
size_hints: Optional[list[int]],
|
|
configs: list[Config],
|
|
triton_meta,
|
|
heuristic_type,
|
|
filename=None,
|
|
inductor_meta=None,
|
|
custom_kernel=False,
|
|
):
|
|
"""
|
|
A copy of triton.autotune that calls our subclass. Our subclass
|
|
has additional debugging, error handling, and on-disk caching.
|
|
"""
|
|
configs = unique_configs(configs)
|
|
assert len(configs) == 1 or filename
|
|
inductor_meta = {} if inductor_meta is None else inductor_meta
|
|
|
|
configs, autotune_cache, autotune_cache_info = check_autotune_cache(
|
|
configs, filename, inductor_meta
|
|
)
|
|
mutated_arg_names = inductor_meta.pop("mutated_arg_names", ())
|
|
optimize_mem = inductor_meta.pop("optimize_mem", True)
|
|
|
|
if "restore_value" in triton_meta:
|
|
mutated_arg_names += triton_meta.pop("restore_value")
|
|
|
|
reset_to_zero_arg_names: list[str] = []
|
|
if "reset_to_zero" in triton_meta:
|
|
reset_to_zero_arg_names.extend(triton_meta.pop("reset_to_zero"))
|
|
|
|
def decorator(fn):
|
|
# Remove XBLOCK from config if it's not a function argument.
|
|
# This way, coordinate descent tuning will not try to tune it.
|
|
#
|
|
# Context: When TritonKernel.no_x_dim is True, we hardcode XBLOCK to 1.
|
|
import inspect
|
|
|
|
if "XBLOCK" not in inspect.signature(fn.fn).parameters:
|
|
for tconfig in configs:
|
|
if "XBLOCK" in tconfig.kwargs:
|
|
assert tconfig.kwargs["XBLOCK"] == 1
|
|
tconfig.kwargs.pop("XBLOCK")
|
|
|
|
if inductor_meta.get("profile_bandwidth"):
|
|
return DebugAutotuner(
|
|
fn,
|
|
triton_meta=triton_meta,
|
|
inductor_meta=inductor_meta,
|
|
regex_filter=inductor_meta["profile_bandwidth_regex"],
|
|
with_profiler=inductor_meta[
|
|
"profile_bandwidth_with_do_bench_using_profiling"
|
|
],
|
|
configs=configs,
|
|
save_cache_hook=autotune_cache and autotune_cache.save,
|
|
mutated_arg_names=mutated_arg_names,
|
|
reset_to_zero_arg_names=reset_to_zero_arg_names,
|
|
optimize_mem=optimize_mem,
|
|
heuristic_type=heuristic_type,
|
|
size_hints=size_hints,
|
|
custom_kernel=custom_kernel,
|
|
filename=filename,
|
|
with_bandwidth_info=True,
|
|
)
|
|
return CachingAutotuner(
|
|
fn,
|
|
triton_meta=triton_meta,
|
|
inductor_meta=inductor_meta,
|
|
configs=configs,
|
|
save_cache_hook=autotune_cache and autotune_cache.save,
|
|
mutated_arg_names=mutated_arg_names,
|
|
reset_to_zero_arg_names=reset_to_zero_arg_names,
|
|
optimize_mem=optimize_mem,
|
|
heuristic_type=heuristic_type,
|
|
size_hints=size_hints,
|
|
custom_kernel=custom_kernel,
|
|
filename=filename,
|
|
autotune_cache_info=autotune_cache_info,
|
|
)
|
|
|
|
return decorator
|
|
|
|
|
|
def unique_configs(configs: list[Config]):
|
|
"""Remove duplicate configurations"""
|
|
seen: OrderedSet[Hashable] = OrderedSet()
|
|
pruned_configs = []
|
|
|
|
for cfg in configs:
|
|
key = triton_config_to_hashable(cfg)
|
|
if key not in seen:
|
|
seen.add(key)
|
|
pruned_configs.append(cfg)
|
|
return pruned_configs
|
|
|
|
|
|
def check_config(cfg, *, xnumel=None, ynumel=None, znumel=None):
|
|
for numel, label in zip((xnumel, ynumel, znumel), "XYZ"):
|
|
if numel is None:
|
|
continue
|
|
block = cfg[f"{label}BLOCK"]
|
|
if numel == 1:
|
|
assert block == 1, (
|
|
f"TritonKernel.indexing assumes numel == 1 => BLOCK == 1"
|
|
f" but {label.lower()}numel=={numel} and {label}BLOCK={block} (cfg={cfg})."
|
|
)
|
|
max_block = TRITON_MAX_BLOCK[label]
|
|
max_block_str = f'config.triton.max_block["{label}"]'
|
|
assert max_block % block == 0, (
|
|
f"TritonKernel.indexing assumes {label}BLOCK divides {max_block_str}"
|
|
f" but {label}BLOCK={block} and {max_block_str}={max_block} (cfg={cfg})."
|
|
)
|
|
|
|
|
|
def check_max_block(cfg: dict[str, int]):
|
|
"""
|
|
Check that block sizes are within the maximum allowed.
|
|
"""
|
|
for var, val in cfg.items():
|
|
block_suffix = "BLOCK"
|
|
if block_suffix in var:
|
|
prefix = var.removesuffix(block_suffix)
|
|
max_block = TRITON_MAX_BLOCK[prefix]
|
|
assert val <= max_block, (
|
|
f"'{var}' too large. Maximum: {max_block}. Actual: {val}."
|
|
)
|
|
|
|
|
|
def _num_warps(num_warps, max_num_warps=8, min_num_warps=2, register_intensive=False):
|
|
# On AMD GPU each warp has 64 lanes which is double the size on NV GPU,
|
|
# therefore using half the number of warps here correspondingly.
|
|
if torch.version.hip:
|
|
max_num_warps = (max_num_warps + 1) // 2
|
|
min_num_warps = (min_num_warps + 1) // 2
|
|
# persistent reduction is register intensive
|
|
if register_intensive:
|
|
max_num_warps = max_num_warps // 2
|
|
return next_power_of_2(min(max(num_warps, min_num_warps), max_num_warps))
|
|
|
|
|
|
def _check_max_grid_x(size_hints, x, num_warps):
|
|
# Check if maxGridSize is exceeded - if so then must scale XBLOCK further
|
|
max_grid_x = 2147483647
|
|
warp_size = (
|
|
64 if torch.version.hip else 32
|
|
) # TODO: query warp size once #129663 is merged
|
|
num_blocks = (size_hints["x"] + x - 1) // x
|
|
|
|
while (num_blocks * num_warps * warp_size) > max_grid_x and x < size_hints["x"]:
|
|
x *= 2 # Scale up XBLOCK if grid exceeds limits
|
|
num_blocks = num_blocks // 2
|
|
if (num_blocks * num_warps * warp_size) > max_grid_x:
|
|
raise AssertionError(
|
|
"Reduction config exceeds cudaDeviceProp maxGridSize. Please raise a pytorch issue"
|
|
)
|
|
return x, num_blocks
|
|
|
|
|
|
def triton_config(
|
|
size_hints,
|
|
x,
|
|
y=None,
|
|
z=None,
|
|
num_stages=1,
|
|
num_elements_per_warp=256,
|
|
min_elem_per_thread=0,
|
|
) -> Config:
|
|
"""
|
|
Construct a pointwise triton config with some adjustment heuristics
|
|
based on size_hints. Size_hints is a tuple of numels in each tile
|
|
dimension and will be rounded up to the nearest power of 2.
|
|
|
|
num_elements_per_warp is a suggestion for controlling how many warps
|
|
the triton config should contain. e.g.: if x=16, y=8, z=4 then
|
|
num_elements = 16*8*4 = 512. Then if we set num_elements_per_warp=128,
|
|
we'll launch 512 (elem) / 128 (elem/warp) = 4 warps. Note that it's
|
|
just a suggestion, and sometimes other adjustment heuristics will
|
|
override the num_elements_per_warp.
|
|
|
|
min_elem_per_thread controls the minimum number of elements
|
|
processed by each thread. It's always enforced.
|
|
"""
|
|
# Ideally we want to read this from some device config
|
|
|
|
maxGridSize = [2147483647, 65535, 65535]
|
|
|
|
target = conditional_product(x, y, z)
|
|
if conditional_product(*size_hints.values()) < target:
|
|
target //= 8
|
|
|
|
# shrink sizes to size hints
|
|
x = min(x, size_hints["x"])
|
|
if y:
|
|
y = min(y, size_hints["y"])
|
|
if z:
|
|
z = min(z, size_hints["z"])
|
|
|
|
# if we are below original block size, scale up where we can;
|
|
# or if the calculated grid size is larger than the limit, we bump up the corresponding dimension
|
|
while x < min(size_hints["x"], TRITON_MAX_BLOCK["X"]) and (
|
|
x * maxGridSize[0] < size_hints["x"] or conditional_product(x, y, z) < target
|
|
):
|
|
x *= 2
|
|
while (
|
|
y
|
|
and y < min(size_hints["y"], TRITON_MAX_BLOCK["Y"])
|
|
and (
|
|
y * maxGridSize[1] < size_hints["y"]
|
|
or conditional_product(x, y, z) < target
|
|
)
|
|
):
|
|
y *= 2
|
|
while (
|
|
z
|
|
and z < min(size_hints["z"], TRITON_MAX_BLOCK["Z"])
|
|
and (
|
|
z * maxGridSize[2] < size_hints["z"]
|
|
or conditional_product(x, y, z) < target
|
|
)
|
|
):
|
|
z *= 2
|
|
|
|
num_warps = _num_warps(
|
|
conditional_product(x, y, z) // num_elements_per_warp, min_num_warps=1
|
|
)
|
|
# we are going to arrive at 2 warps only if bs was too small due to
|
|
# numel being too small. However to workaround some ptx bugs we still
|
|
# want at least 4 warps if there's enough elements per thread
|
|
# given that this is a rare situation, don't expect this to affect perf
|
|
# in general
|
|
# see https://github.com/pytorch/pytorch/pull/97950
|
|
if conditional_product(x, y, z) >= 128 and not torch.version.hip:
|
|
num_warps = max(num_warps, 4)
|
|
xnumel = size_hints["x"]
|
|
ynumel = size_hints.get("y")
|
|
znumel = size_hints.get("z")
|
|
|
|
# Increase x to satisfy min_elem_per_thread requirements.
|
|
block_size = max(
|
|
conditional_product(x, y, z),
|
|
min_elem_per_thread * _NUM_THREADS_PER_WARP * num_warps,
|
|
)
|
|
x *= math.ceil(block_size / conditional_product(x, y, z))
|
|
|
|
x, _num_blocks = _check_max_grid_x(size_hints, x, num_warps)
|
|
x = min(x, size_hints["x"])
|
|
|
|
cfg = {"XBLOCK": x}
|
|
if y:
|
|
cfg["YBLOCK"] = y
|
|
if z:
|
|
cfg["ZBLOCK"] = z
|
|
check_max_block(cfg)
|
|
check_config(cfg, xnumel=xnumel, ynumel=ynumel, znumel=znumel)
|
|
return Config(cfg, num_warps=num_warps, num_stages=num_stages)
|
|
|
|
|
|
def _get_nd_reduction_numels(r: int, size_hints: dict[str, int]) -> dict[str, int]:
|
|
"""
|
|
Converts a linear reduction numel to ND, in row major order.
|
|
This order is often desirable as it presents opportunities to coalesce memory
|
|
accesses.
|
|
For example, if r = 64 and size_hints = [32,32], this function returns [32, 2].
|
|
This unraveling works because both r and size_hints are powers of 2.
|
|
"""
|
|
# Shrink r to size_hints.
|
|
r = min(r, get_total_reduction_numel(size_hints))
|
|
num_reduction_dims = len(
|
|
[prefix for prefix in size_hints if prefix_is_reduction(prefix)]
|
|
)
|
|
|
|
remaining = r
|
|
rnumels = {}
|
|
for idx in range(num_reduction_dims - 1, -1, -1):
|
|
prefix = f"r{idx}_"
|
|
max_size = min(size_hints[prefix], TRITON_MAX_BLOCK[prefix.upper()])
|
|
dim = min(max_size, remaining)
|
|
assert remaining % dim == 0, (
|
|
f"Expected dimension '{dim}' to divide remaining size '{remaining}'"
|
|
)
|
|
rnumels[prefix] = dim
|
|
remaining //= dim
|
|
|
|
# Sanity check the results.
|
|
final_numel = conditional_product(*rnumels.values())
|
|
assert r == final_numel, (
|
|
f"Expected ND reduction size ({rnumels}) to have {r} elements."
|
|
)
|
|
assert all(rnumels[prefix] <= size_hints[prefix] for prefix in rnumels), (
|
|
f"rnumels exceed size_hints. {rnumels} > {size_hints}"
|
|
)
|
|
|
|
return rnumels
|
|
|
|
|
|
def triton_config_reduction(
|
|
size_hints,
|
|
x: int,
|
|
r: int,
|
|
num_stages=1,
|
|
num_warps=None,
|
|
register_intensive=False,
|
|
) -> Config:
|
|
"""
|
|
Construct a reduction triton config with some adjustment heuristics
|
|
based on size_hints. Size_hints is a tuple of numels in each tile
|
|
dimension and will be rounded up to the nearest power of 2.
|
|
"""
|
|
# Convert the linear reduction numel into a multi-dimensional block.
|
|
rnumels = _get_nd_reduction_numels(r, size_hints)
|
|
|
|
# shrink sizes to size hints
|
|
x = min(x, size_hints["x"])
|
|
|
|
def total_numel() -> int:
|
|
return conditional_product(x, *rnumels.values())
|
|
|
|
target = total_numel()
|
|
if conditional_product(*size_hints.values()) < target:
|
|
target //= 8
|
|
|
|
# if we are below original block size, scale up where we can
|
|
while x < size_hints["x"] and total_numel() < target:
|
|
x *= 2
|
|
for prefix in sorted(rnumels):
|
|
while rnumels[prefix] < size_hints[prefix] and total_numel() < target:
|
|
rnumels[prefix] *= 2
|
|
|
|
if num_warps is None:
|
|
num_warps = total_numel() // 128
|
|
num_warps = _num_warps(
|
|
num_warps, max_num_warps=16, register_intensive=register_intensive
|
|
)
|
|
|
|
x, _num_blocks = _check_max_grid_x(size_hints, x, num_warps)
|
|
|
|
for prefix in sorted(rnumels):
|
|
while total_numel() > target:
|
|
if rnumels[prefix] == 1:
|
|
break
|
|
rnumels[prefix] //= 2
|
|
|
|
cfg = _get_config({"x": x, **rnumels})
|
|
check_max_block(cfg)
|
|
check_config(cfg, xnumel=size_hints["x"])
|
|
return Config(cfg, num_warps=num_warps, num_stages=num_stages)
|
|
|
|
|
|
def _get_config(numels: dict[str, int]) -> dict[str, int]:
|
|
"""
|
|
Convert numels ("x", "r0_", etc.) to block sizes ("XBLOCK", "R0_BLOCK"), etc.
|
|
"""
|
|
|
|
return {prefix.upper() + "BLOCK": numel for prefix, numel in numels.items()}
|
|
|
|
|
|
def triton_config_tiled_reduction(
|
|
size_hints, x, y, r, num_stages=1, register_intensive=False
|
|
):
|
|
"""
|
|
Construct a tile reduction triton config with some adjustment
|
|
heuristics based on size_hints. Size_hints is a tuple of numels in
|
|
each tile dimension and will be rounded up to the nearest power of 2.
|
|
"""
|
|
# Convert the linear reduction numel into a multi-dimensional block.
|
|
rnumels = _get_nd_reduction_numels(r, size_hints)
|
|
|
|
# shrink sizes to size hints
|
|
x = min(x, size_hints["x"])
|
|
y = min(y, size_hints["y"])
|
|
|
|
def total_numel() -> int:
|
|
return conditional_product(x, y, *rnumels.values())
|
|
|
|
target = total_numel()
|
|
if conditional_product(*size_hints.values()) < target:
|
|
target //= 8
|
|
|
|
# if we are below original block size, scale up where we can
|
|
while x < size_hints["x"] and total_numel() < target:
|
|
x *= 2
|
|
for prefix in sorted(rnumels):
|
|
while rnumels[prefix] < size_hints[prefix] and total_numel() < target:
|
|
rnumels[prefix] *= 2
|
|
while y < size_hints["y"] and total_numel() < target:
|
|
y *= 2
|
|
|
|
cfg = _get_config({"x": x, "y": y, **rnumels})
|
|
num_warps = _num_warps(total_numel() // 256, min_num_warps=1)
|
|
num_warps = _num_warps(
|
|
num_warps, max_num_warps=16, register_intensive=register_intensive
|
|
)
|
|
check_config(cfg, xnumel=size_hints["x"], ynumel=size_hints["y"])
|
|
check_max_block(cfg)
|
|
return Config(cfg, num_warps=num_warps, num_stages=num_stages)
|
|
|
|
|
|
def _maybe_filter_configs_for_tma_restrictions(inductor_meta, configs: list[Config]):
|
|
tma_min_block_sizes: dict[str, int]
|
|
if (tma_min_block_sizes := inductor_meta.get("tma_min_block_sizes")) and configs:
|
|
# Rn blocks are not provided to the kernel for persistent reductions
|
|
if inductor_meta.get("persistent_reduction"):
|
|
tma_min_block_sizes = {
|
|
block_type: block_size
|
|
for block_type, block_size in tma_min_block_sizes
|
|
if not prefix_is_reduction(block_type.lower())
|
|
}
|
|
|
|
assert all(
|
|
block_type in configs[0].kwargs for block_type in tma_min_block_sizes.keys()
|
|
)
|
|
|
|
# Add a config that is guaranteed to compile
|
|
example_config = configs[0]
|
|
config_block_sizes = {**example_config.kwargs}
|
|
config_block_sizes.update(tma_min_block_sizes)
|
|
new_configs = [
|
|
Config(
|
|
config_block_sizes,
|
|
num_warps=example_config.num_warps,
|
|
num_stages=example_config.num_stages,
|
|
maxnreg=example_config.maxnreg,
|
|
pre_hook=example_config.pre_hook,
|
|
)
|
|
]
|
|
# Remove configs that will not compile
|
|
for c in configs:
|
|
if all(
|
|
c.kwargs.get(block_type) >= min_block_value
|
|
for block_type, min_block_value in tma_min_block_sizes.items()
|
|
):
|
|
new_configs.append(c)
|
|
|
|
log.debug(
|
|
"Filtering configs for TMA API restrictions. Input configs size: %d. Output configs size: %d",
|
|
len(configs),
|
|
len(new_configs),
|
|
)
|
|
return new_configs
|
|
return configs
|
|
|
|
|
|
def pointwise(
|
|
size_hints,
|
|
triton_meta,
|
|
tile_hint=None,
|
|
filename=None,
|
|
min_elem_per_thread=0,
|
|
inductor_meta=None,
|
|
):
|
|
"""
|
|
Construct @triton.heuristics() based on size_hints.
|
|
"""
|
|
inductor_meta = {} if inductor_meta is None else inductor_meta
|
|
assert not inductor_meta.get("no_x_dim")
|
|
|
|
numel = functools.reduce(operator.mul, size_hints.values())
|
|
bs = max(256, min(numel // 128, 1024))
|
|
|
|
hinted_configs = autotune_hints_to_configs(
|
|
inductor_meta.get("autotune_hints", OrderedSet()),
|
|
size_hints,
|
|
bs,
|
|
triton_meta["device"],
|
|
)
|
|
|
|
triton_config_with_settings = functools.partial(
|
|
triton_config, min_elem_per_thread=min_elem_per_thread
|
|
)
|
|
|
|
configs = None
|
|
if len(size_hints) == 1:
|
|
if disable_pointwise_autotuning(inductor_meta) and not (
|
|
inductor_meta.get("max_autotune")
|
|
or inductor_meta.get("max_autotune_pointwise")
|
|
):
|
|
configs = [triton_config_with_settings(size_hints, bs)]
|
|
else:
|
|
configs = [
|
|
triton_config_with_settings(size_hints, bs, num_elements_per_warp=256),
|
|
triton_config_with_settings(
|
|
size_hints, bs // 2, num_elements_per_warp=64
|
|
),
|
|
*hinted_configs,
|
|
]
|
|
if len(size_hints) == 2:
|
|
if (
|
|
disable_pointwise_autotuning(inductor_meta) or tile_hint == TileHint.SQUARE
|
|
) and not (
|
|
inductor_meta.get("max_autotune")
|
|
or inductor_meta.get("max_autotune_pointwise")
|
|
):
|
|
configs = [triton_config_with_settings(size_hints, 32, 32)]
|
|
else:
|
|
configs = [
|
|
triton_config_with_settings(size_hints, 32, 32),
|
|
triton_config_with_settings(size_hints, 64, 64), # ~8% better for fp16
|
|
triton_config_with_settings(size_hints, 256, 16),
|
|
triton_config_with_settings(size_hints, 16, 256),
|
|
triton_config_with_settings(size_hints, bs, 1),
|
|
triton_config_with_settings(size_hints, 1, bs),
|
|
*hinted_configs,
|
|
]
|
|
if len(size_hints) == 3:
|
|
if disable_pointwise_autotuning(inductor_meta):
|
|
configs = [triton_config_with_settings(size_hints, 16, 16, 16)]
|
|
else:
|
|
configs = [
|
|
triton_config_with_settings(size_hints, 16, 16, 16),
|
|
triton_config_with_settings(size_hints, 64, 8, 8),
|
|
triton_config_with_settings(size_hints, 8, 64, 8),
|
|
triton_config_with_settings(size_hints, 8, 8, 64),
|
|
triton_config_with_settings(size_hints, bs, 1, 1),
|
|
triton_config_with_settings(size_hints, 1, bs, 1),
|
|
triton_config_with_settings(size_hints, 1, 1, bs),
|
|
*hinted_configs,
|
|
]
|
|
|
|
if not configs:
|
|
raise NotImplementedError(f"size_hints: {size_hints}")
|
|
|
|
configs = _maybe_filter_configs_for_tma_restrictions(inductor_meta, configs)
|
|
|
|
return cached_autotune(
|
|
size_hints,
|
|
configs,
|
|
triton_meta=triton_meta,
|
|
inductor_meta=inductor_meta,
|
|
heuristic_type=HeuristicType.POINTWISE,
|
|
filename=filename,
|
|
)
|
|
|
|
|
|
def _reduction_configs(
|
|
*, size_hints: dict[str, int], inductor_meta: dict[str, Any]
|
|
) -> list[Config]:
|
|
reduction_hint = inductor_meta.get("reduction_hint", None)
|
|
|
|
# Convert reductions to 1D, to simplify heuristics.
|
|
rnumel = get_total_reduction_numel(size_hints)
|
|
|
|
register_intensive = False
|
|
MAX_R0_BLOCK = 2048
|
|
if (
|
|
size_hints["x"] >= 1024
|
|
and inductor_meta.get("num_load", 0) + inductor_meta.get("num_reduction", 0)
|
|
>= 10
|
|
):
|
|
# A heuristics to reduce R0_BLOCK if a kernel potentially need many registers.
|
|
# Consider load and reduction since load need move data into registers and
|
|
# reduction needs an accumulator.
|
|
#
|
|
# The magic numbers are a bit arbitrary.
|
|
#
|
|
# We cannot rely on dynamically scaling down R0_BLOCK later, since sometimes
|
|
# triton makes it to use less registers with worse perf. Check:
|
|
# https://github.com/pytorch/pytorch/issues/126463
|
|
#
|
|
# The heuristic is a very simple one since registers can be reused. But
|
|
# hopefully it can be a good enough indicator.
|
|
MAX_R0_BLOCK = 1024
|
|
register_intensive = True
|
|
|
|
def make_config(x, r, num_warps=None, num_stages=1, register_intensive=False):
|
|
# For 3D case with tiling scores, create an adapted version
|
|
if "y" in size_hints:
|
|
assert "tiling_scores" in inductor_meta
|
|
return adapt_config_for_tiling(
|
|
size_hints,
|
|
inductor_meta["tiling_scores"],
|
|
x,
|
|
r,
|
|
num_warps=num_warps,
|
|
num_stages=num_stages,
|
|
register_intensive=register_intensive,
|
|
)
|
|
else:
|
|
# For other cases, use the original function
|
|
return triton_config_reduction(
|
|
size_hints,
|
|
x,
|
|
r,
|
|
num_warps=num_warps,
|
|
num_stages=num_stages,
|
|
register_intensive=register_intensive,
|
|
)
|
|
|
|
contiguous_config = make_config(
|
|
1,
|
|
min(rnumel, MAX_R0_BLOCK),
|
|
register_intensive=register_intensive,
|
|
)
|
|
outer_config = make_config(64, 8, register_intensive=register_intensive)
|
|
tiny_config = make_config(
|
|
2 * (256 // rnumel) if rnumel <= 256 else 1,
|
|
min(rnumel, MAX_R0_BLOCK),
|
|
register_intensive=register_intensive,
|
|
)
|
|
# For 3d tiling, default to more autotuning initially
|
|
if "y" in size_hints:
|
|
pass
|
|
elif inductor_meta.get("max_autotune") or inductor_meta.get(
|
|
"max_autotune_pointwise"
|
|
):
|
|
pass # skip all these cases
|
|
elif reduction_hint == ReductionHint.INNER:
|
|
return [contiguous_config]
|
|
elif reduction_hint == ReductionHint.OUTER:
|
|
return [outer_config]
|
|
elif reduction_hint == ReductionHint.OUTER_TINY:
|
|
return [tiny_config]
|
|
if disable_pointwise_autotuning(inductor_meta):
|
|
return [make_config(32, 128)]
|
|
return [
|
|
contiguous_config,
|
|
outer_config,
|
|
tiny_config,
|
|
make_config(64, 64),
|
|
make_config(8, 512),
|
|
# halve the XBLOCK/Rn_BLOCK compared to outer_config
|
|
# TODO: this may only be beneficial when each iteration of the reduction
|
|
# is quite heavy. E.g. https://gist.github.com/shunting314/189a8ef69f90db9d614a823385147a72
|
|
make_config(64, 4, num_warps=8),
|
|
]
|
|
|
|
|
|
def match_target_block_product(
|
|
size_hints, tiling_scores, target_block_product, min_block_size=1
|
|
):
|
|
"""
|
|
Distribute block sizes across dimensions according to tiling scores,
|
|
aiming to match a target product of block sizes.
|
|
"""
|
|
total_score = sum(tiling_scores.values())
|
|
if total_score == 0:
|
|
# just assume even score with no minimum block size
|
|
min_block_size = 1
|
|
tiling_scores = dict.fromkeys(tiling_scores.keys(), target_block_product)
|
|
|
|
# First, give each coalescing dimension at least min_block_size
|
|
block_sizes = {}
|
|
relative_scores = {}
|
|
curr_block_product = 1
|
|
|
|
for dim, score in tiling_scores.items():
|
|
if score == 0:
|
|
block_sizes[dim] = 1
|
|
continue
|
|
|
|
block_sizes[dim] = min_block_size
|
|
curr_block_product *= min_block_size
|
|
relative_scores[dim] = score / total_score
|
|
|
|
# Scale up dimensions by their relative scores until we reach the target
|
|
while curr_block_product < target_block_product and len(relative_scores):
|
|
dim, score = max(relative_scores.items(), key=lambda item: item[1])
|
|
|
|
# Check if we've hit the max for this dimension
|
|
if (
|
|
block_sizes[dim] >= TRITON_MAX_BLOCK[dim.capitalize()]
|
|
or block_sizes[dim] >= size_hints[dim]
|
|
):
|
|
del relative_scores[dim]
|
|
continue
|
|
|
|
block_sizes[dim] *= 2
|
|
relative_scores[dim] /= 2
|
|
curr_block_product *= 2
|
|
|
|
return block_sizes
|
|
|
|
|
|
def adapt_config_for_tiling(
|
|
size_hints,
|
|
tiling_scores,
|
|
original_x,
|
|
original_r,
|
|
num_warps=None,
|
|
num_stages=1,
|
|
register_intensive=False,
|
|
persistent_reduction=False,
|
|
) -> Config:
|
|
"""
|
|
Create an adapted configuration based on tiling scores,
|
|
redistributing the same total block size (x * r) according to tiling scores.
|
|
"""
|
|
assert all(s in tiling_scores for s in size_hints)
|
|
target_block_product = original_x * original_r
|
|
block_sizes = match_target_block_product(
|
|
size_hints, tiling_scores, target_block_product
|
|
)
|
|
|
|
return triton_config_tiled_reduction(
|
|
size_hints,
|
|
block_sizes["x"],
|
|
block_sizes["y"],
|
|
block_sizes["r0_"],
|
|
num_stages=num_stages,
|
|
register_intensive=register_intensive,
|
|
)
|
|
|
|
|
|
def reduction(
|
|
size_hints,
|
|
reduction_hint=False,
|
|
triton_meta=None,
|
|
filename=None,
|
|
inductor_meta=None,
|
|
):
|
|
"""args to @triton.heuristics()"""
|
|
inductor_meta = {} if inductor_meta is None else inductor_meta
|
|
inductor_meta["reduction_hint"] = reduction_hint
|
|
if inductor_meta.get("no_x_dim"):
|
|
size_hints["x"] = 1
|
|
|
|
assert triton_meta is not None
|
|
|
|
configs = _reduction_configs(size_hints=size_hints, inductor_meta=inductor_meta)
|
|
configs = _maybe_filter_configs_for_tma_restrictions(inductor_meta, configs)
|
|
return cached_autotune(
|
|
size_hints,
|
|
configs=configs,
|
|
triton_meta=triton_meta,
|
|
inductor_meta=inductor_meta,
|
|
heuristic_type=HeuristicType.REDUCTION,
|
|
filename=filename,
|
|
)
|
|
|
|
|
|
def cooperative_reduction(
|
|
size_hints,
|
|
reduction_hint,
|
|
triton_meta,
|
|
filename,
|
|
inductor_meta,
|
|
):
|
|
inductor_meta = {} if inductor_meta is None else inductor_meta
|
|
inductor_meta["reduction_hint"] = reduction_hint
|
|
if inductor_meta.get("no_x_dim"):
|
|
size_hints["x"] = 1
|
|
|
|
# Cooperative reductions currently only support a single reduction dimension.
|
|
assert len(size_hints) == 2, (
|
|
"Cooperative reductions don't support tiling reduction dims"
|
|
)
|
|
xnumel, rnumel = size_hints["x"], size_hints["r0_"]
|
|
|
|
# TODO(jansel): we should base target on the SM count of the local GPU
|
|
target = 64
|
|
split = max(1, min(target // xnumel, TRITON_MAX_RSPLIT))
|
|
assert rnumel >= split
|
|
assert split <= TRITON_MAX_RSPLIT
|
|
if inductor_meta["persistent_reduction"]:
|
|
configs = _persistent_reduction_configs(
|
|
{"x": xnumel, "r0_": rnumel // split}, reduction_hint, inductor_meta
|
|
)
|
|
else:
|
|
configs = _reduction_configs(
|
|
size_hints={"x": xnumel, "r0_": rnumel // split},
|
|
inductor_meta=inductor_meta,
|
|
)
|
|
for config in configs:
|
|
config.kwargs["RSPLIT"] = split
|
|
# TODO(jansel): add more configs in max_autotune
|
|
|
|
configs = _maybe_filter_configs_for_tma_restrictions(inductor_meta, configs)
|
|
return cached_autotune(
|
|
size_hints,
|
|
configs=configs,
|
|
triton_meta=triton_meta,
|
|
inductor_meta=inductor_meta,
|
|
heuristic_type=HeuristicType.REDUCTION,
|
|
filename=filename,
|
|
)
|
|
|
|
|
|
def _persistent_reduction_configs(
|
|
size_hints,
|
|
reduction_hint=False,
|
|
inductor_meta=None,
|
|
):
|
|
xnumel = size_hints["x"]
|
|
rnumel = get_total_reduction_numel(size_hints)
|
|
|
|
MAX_PERSISTENT_BLOCK_NUMEL = 4096
|
|
|
|
if "y" not in size_hints:
|
|
configs = [
|
|
triton_config_reduction(size_hints, xblock, rnumel, register_intensive=True)
|
|
for xblock in (1, 8, 32, 128)
|
|
if xblock == 1
|
|
or (rnumel * xblock <= MAX_PERSISTENT_BLOCK_NUMEL and xblock <= xnumel)
|
|
]
|
|
else:
|
|
configs = []
|
|
assert "tiling_scores" in inductor_meta
|
|
x_y_scores = {dim: inductor_meta["tiling_scores"][dim] for dim in ("x", "y")}
|
|
for target_block_size in (1, 8, 32, 64, 128):
|
|
if target_block_size * rnumel > MAX_PERSISTENT_BLOCK_NUMEL:
|
|
continue
|
|
|
|
block_sizes = match_target_block_product(
|
|
size_hints, x_y_scores, target_block_size
|
|
)
|
|
configs.append(
|
|
triton_config_tiled_reduction(
|
|
size_hints, block_sizes["x"], block_sizes["y"], rnumel
|
|
)
|
|
)
|
|
|
|
# defer to more autotuning, initially
|
|
if "y" in size_hints:
|
|
pass
|
|
# TODO(jansel): we should be able to improve these heuristics
|
|
elif reduction_hint == ReductionHint.INNER and rnumel >= 256:
|
|
configs = configs[:1]
|
|
elif reduction_hint == ReductionHint.OUTER:
|
|
configs = configs[-1:]
|
|
elif reduction_hint == ReductionHint.OUTER_TINY:
|
|
configs = [
|
|
triton_config_reduction(
|
|
size_hints,
|
|
2 * (256 // rnumel) if rnumel <= 256 else 1,
|
|
rnumel,
|
|
)
|
|
]
|
|
for c in configs:
|
|
# we don't need Rn_BLOCK for persistent reduction
|
|
for prefix in size_hints:
|
|
if prefix_is_reduction(prefix):
|
|
c.kwargs.pop(f"{prefix.upper()}BLOCK")
|
|
|
|
if disable_pointwise_autotuning(inductor_meta):
|
|
configs = configs[:1]
|
|
|
|
return configs
|
|
|
|
|
|
def persistent_reduction(
|
|
size_hints,
|
|
reduction_hint=False,
|
|
triton_meta=None,
|
|
filename=None,
|
|
inductor_meta=None,
|
|
):
|
|
inductor_meta = {} if inductor_meta is None else inductor_meta
|
|
inductor_meta["reduction_hint"] = reduction_hint
|
|
if inductor_meta.get("no_x_dim"):
|
|
size_hints["x"] = 1
|
|
|
|
configs = _persistent_reduction_configs(size_hints, reduction_hint, inductor_meta)
|
|
|
|
# This key is not added to the inductor meta as its clear from the heuristic
|
|
# choice that it is persistent. Add it and remove it below so that persistent
|
|
# configs can be filtered appropriately by _maybe_filter_configs_for_tma_restrictions
|
|
persistent_reduction_key = "persistent_reduction"
|
|
inductor_meta[persistent_reduction_key] = True
|
|
configs = _maybe_filter_configs_for_tma_restrictions(inductor_meta, configs)
|
|
inductor_meta.pop(persistent_reduction_key)
|
|
|
|
return cached_autotune(
|
|
size_hints,
|
|
configs,
|
|
triton_meta=triton_meta,
|
|
inductor_meta=inductor_meta,
|
|
filename=filename,
|
|
heuristic_type=HeuristicType.PERSISTENT_REDUCTION,
|
|
)
|
|
|
|
|
|
def split_scan(
|
|
size_hints,
|
|
reduction_hint=False,
|
|
triton_meta=None,
|
|
filename=None,
|
|
inductor_meta=None,
|
|
):
|
|
"""Heuristic for TritonSplitScanKernel"""
|
|
inductor_meta = {} if inductor_meta is None else inductor_meta
|
|
inductor_meta["reduction_hint"] = reduction_hint
|
|
if inductor_meta.get("no_x_dim"):
|
|
size_hints["x"] = 1
|
|
|
|
assert triton_meta is not None
|
|
if len(size_hints) != 2:
|
|
raise NotImplementedError(f"size_hints: {size_hints}")
|
|
|
|
configs = _reduction_configs(size_hints=size_hints, inductor_meta=inductor_meta)
|
|
|
|
# Fixup configs to enforce the minimum Rn_BLOCK size
|
|
min_rblock = inductor_meta.get("min_split_scan_rblock", 256)
|
|
for cfg in configs:
|
|
for var in list(cfg.kwargs.keys()):
|
|
if var.startswith("R") and cfg.kwargs[var] < min_rblock:
|
|
cfg.kwargs[var] = min_rblock
|
|
|
|
configs = _maybe_filter_configs_for_tma_restrictions(inductor_meta, configs)
|
|
return cached_autotune(
|
|
size_hints,
|
|
configs=configs,
|
|
triton_meta=triton_meta,
|
|
inductor_meta=inductor_meta,
|
|
heuristic_type=HeuristicType.SPLIT_SCAN,
|
|
filename=filename,
|
|
)
|
|
|
|
|
|
def template(
|
|
num_stages,
|
|
num_warps,
|
|
triton_meta,
|
|
num_consumer_groups=0,
|
|
num_buffers_warp_spec=0,
|
|
filename=None,
|
|
inductor_meta=None,
|
|
):
|
|
"""
|
|
Compile a triton template
|
|
"""
|
|
# Prepare the base configuration
|
|
config_args = {
|
|
"num_stages": num_stages,
|
|
"num_warps": num_warps,
|
|
}
|
|
|
|
# Conditionally add arguments based on HAS_WARP_SPEC
|
|
if HAS_WARP_SPEC:
|
|
config_args.update(
|
|
{
|
|
"num_consumer_groups": num_consumer_groups,
|
|
"num_buffers_warp_spec": num_buffers_warp_spec,
|
|
}
|
|
)
|
|
return cached_autotune(
|
|
None,
|
|
[triton.Config({}, **config_args)],
|
|
triton_meta=triton_meta,
|
|
inductor_meta=inductor_meta,
|
|
heuristic_type=HeuristicType.TEMPLATE,
|
|
filename=filename,
|
|
)
|
|
|
|
|
|
def _pop_config_kwargs(config: dict[str, Any]) -> dict[str, Any]:
|
|
"""Extract triton.Config options that should become kwargs"""
|
|
popped = {}
|
|
for key in (
|
|
"num_warps",
|
|
"num_stages",
|
|
"num_ctas",
|
|
"maxnreg",
|
|
"num_consumer_groups",
|
|
"num_buffers_warp_spec",
|
|
):
|
|
val = config.pop(key, None)
|
|
if val is not None:
|
|
popped[key] = val
|
|
return popped
|
|
|
|
|
|
def config_to_dict(config: Config) -> dict[str, Any]:
|
|
config_dict = {
|
|
**config.kwargs,
|
|
"num_warps": config.num_warps,
|
|
"num_stages": config.num_stages,
|
|
}
|
|
if HAS_WARP_SPEC:
|
|
config_dict.update(
|
|
{
|
|
"num_consumer_groups": getattr(config, "num_consumer_groups", 0),
|
|
"num_buffers_warp_spec": getattr(config, "num_buffers_warp_spec", 0),
|
|
}
|
|
)
|
|
return config_dict
|
|
|
|
|
|
def config_from_dict(config: dict[str, Any]) -> Config:
|
|
config = {**config}
|
|
return Config(config, **_pop_config_kwargs(config))
|
|
|
|
|
|
def fixed_config(config, filename, triton_meta, inductor_meta):
|
|
"""
|
|
Used when the configuration is already decided at compile time
|
|
"""
|
|
config = {**config}
|
|
return cached_autotune(
|
|
None,
|
|
[triton.Config(config, **_pop_config_kwargs(config))],
|
|
triton_meta=triton_meta,
|
|
inductor_meta=inductor_meta,
|
|
heuristic_type=HeuristicType.FIXED,
|
|
filename=filename,
|
|
)
|
|
|
|
|
|
def user_autotune(
|
|
configs, triton_meta, filename=None, inductor_meta=None, custom_kernel=False
|
|
):
|
|
"""
|
|
Compile a user defined triton kernel
|
|
"""
|
|
if len(configs) == 0:
|
|
configs = [triton.Config({})]
|
|
else:
|
|
configs = [*map(config_from_dict, configs)]
|
|
return cached_autotune(
|
|
None,
|
|
configs,
|
|
triton_meta=triton_meta,
|
|
heuristic_type=HeuristicType.USER_AUTOTUNE,
|
|
filename=filename,
|
|
inductor_meta=inductor_meta,
|
|
custom_kernel=custom_kernel,
|
|
)
|
|
|
|
|
|
def foreach(triton_meta, num_warps, filename=None, inductor_meta=None):
|
|
"""
|
|
Compile a triton foreach kernel
|
|
"""
|
|
return cached_autotune(
|
|
None,
|
|
[triton.Config({}, num_stages=1, num_warps=num_warps)],
|
|
triton_meta=triton_meta,
|
|
inductor_meta=inductor_meta,
|
|
heuristic_type=HeuristicType.TEMPLATE,
|
|
filename=filename,
|
|
)
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class GridExpr:
|
|
"""Generate code for grid size expressions in launcher"""
|
|
|
|
inductor_meta: dict[str, Any]
|
|
mode: Literal["python", "cpp"] = "python"
|
|
prefix: list[str] = dataclasses.field(default_factory=list)
|
|
x_grid: Union[str, int] = 1
|
|
y_grid: Union[str, int] = 1
|
|
z_grid: Union[str, int] = 1
|
|
|
|
def __post_init__(self) -> None:
|
|
assert self.mode in ("python", "cpp")
|
|
|
|
def generate(self, meta: dict[str, int]) -> None:
|
|
raise NotImplementedError
|
|
|
|
def ceildiv(
|
|
self, numel: Union[str, int], block: Union[None, int, str]
|
|
) -> Union[str, int]:
|
|
if block is None or block == 1:
|
|
return numel
|
|
if isinstance(numel, int) and isinstance(block, int):
|
|
return ceildiv(numel, block) # constant fold
|
|
if self.mode == "python":
|
|
return f"-(({numel}) // -({block}))"
|
|
# trick above doesn't work in C++ due to rounding differences
|
|
return f"(({numel} + ({block} - 1)) / ({block}))"
|
|
|
|
def maximum(self, seq: list[Union[int, str]]) -> Union[int, str]:
|
|
"""Codegen for max function with constant folding, constants are represented as int"""
|
|
items = self._constant_fold(max, seq)
|
|
if len(items) <= 1:
|
|
return items[0]
|
|
if self.mode == "python":
|
|
return f"max({', '.join(map(str, items))})"
|
|
return functools.reduce(lambda x, y: f"std::max({x}, {y})", items)
|
|
|
|
def summation(self, seq: list[Union[int, str]]) -> Union[int, str]:
|
|
"""Codegen for sum function with constant folding, constants are represented as int"""
|
|
items = self._constant_fold(sum, seq)
|
|
if len(items) <= 1:
|
|
return items[0]
|
|
return " + ".join(map(str, items))
|
|
|
|
def _constant_fold(
|
|
self, fn: Callable[[list[int]], int], seq: list[Union[int, str]]
|
|
) -> list[Union[int, str]]:
|
|
"""Constant fold through a commutative fn where ints are constants"""
|
|
items: list[Union[int, str]] = [x for x in seq if not isinstance(x, int)]
|
|
const_items = [x for x in seq if isinstance(x, int)]
|
|
if const_items:
|
|
items.append(fn(const_items))
|
|
return items
|
|
|
|
def assign_tmp(self, name: str, expr: Union[str, int]) -> str:
|
|
# Grid functions are one per kernel, so name collisions are fine
|
|
if self.mode == "python":
|
|
return f"{name} = {expr}"
|
|
if self.mode == "cpp":
|
|
return f"uint32_t {name} = {expr};"
|
|
raise AssertionError(f"invalid mode {self.mode}")
|
|
|
|
@staticmethod
|
|
def from_meta(
|
|
inductor_meta: dict[str, Any],
|
|
cfg: Union[Config, dict[str, int]],
|
|
mode: Literal["python", "cpp"] = "python",
|
|
) -> GridExpr:
|
|
grid_cls = globals()[inductor_meta["grid_type"]]
|
|
assert issubclass(grid_cls, GridExpr)
|
|
grid = grid_cls(inductor_meta=inductor_meta, mode=mode)
|
|
if isinstance(cfg, Config):
|
|
cfg = config_to_dict(cfg)
|
|
grid.generate(cfg)
|
|
return grid
|
|
|
|
def eval_slow(self, meta: dict[str, int]) -> tuple[int, int, int]:
|
|
scope = {**meta}
|
|
for line in self.prefix:
|
|
exec(line, scope)
|
|
exec(f"grid_0 = {self.x_grid}", scope)
|
|
exec(f"grid_1 = {self.y_grid}", scope)
|
|
exec(f"grid_2 = {self.z_grid}", scope)
|
|
return scope["grid_0"], scope["grid_1"], scope["grid_2"]
|
|
|
|
|
|
class Grid1D(GridExpr):
|
|
def generate(self, meta: dict[str, int]) -> None:
|
|
self.x_grid = self.ceildiv("xnumel", meta.get("XBLOCK"))
|
|
|
|
|
|
class Grid2D(GridExpr):
|
|
def generate(self, meta: dict[str, int]) -> None:
|
|
self.x_grid = self.ceildiv("xnumel", meta.get("XBLOCK"))
|
|
self.y_grid = self.ceildiv("ynumel", meta.get("YBLOCK"))
|
|
|
|
|
|
class Grid3D(GridExpr):
|
|
def generate(self, meta: dict[str, int]) -> None:
|
|
self.x_grid = self.ceildiv("xnumel", meta.get("XBLOCK"))
|
|
self.y_grid = self.ceildiv("ynumel", meta.get("YBLOCK"))
|
|
self.z_grid = self.ceildiv("znumel", meta.get("ZBLOCK"))
|
|
|
|
|
|
class Grid2DWithYZOverflow(GridExpr):
|
|
def generate(self, meta: dict[str, int]) -> None:
|
|
self.x_grid = self.ceildiv("xnumel", meta.get("XBLOCK"))
|
|
self.prefix.extend(
|
|
[
|
|
self.assign_tmp(
|
|
"y_grid_raw_", self.ceildiv("ynumel", meta.get("YBLOCK"))
|
|
),
|
|
self.assign_tmp(
|
|
"y_grid_div_", self.ceildiv("y_grid_raw_", get_max_y_grid())
|
|
),
|
|
]
|
|
)
|
|
self.y_grid = self.ceildiv("y_grid_raw_", "y_grid_div_")
|
|
self.z_grid = "y_grid_div_"
|
|
|
|
|
|
class CooperativeReductionGrid(GridExpr):
|
|
def generate(self, meta: dict[str, int]) -> None:
|
|
self.x_grid = str(meta["RSPLIT"])
|
|
self.y_grid = self.ceildiv("xnumel", meta.get("XBLOCK"))
|
|
|
|
|
|
class SplitScanGrid(GridExpr):
|
|
def generate(self, meta: dict[str, int]) -> None:
|
|
assert meta.get("XBLOCK", 1) == 1
|
|
self.x_grid = self.ceildiv("r0_numel", meta.get("R0_BLOCK"))
|
|
self.y_grid = "xnumel"
|
|
|
|
|
|
class FixedGrid(GridExpr):
|
|
@staticmethod
|
|
def setup_grid_as_args() -> dict[str, Any]:
|
|
"""Inductor meta so the launcher takes three extra grid arguments"""
|
|
return {
|
|
"grid_type": FixedGrid.__name__,
|
|
"fixed_grid": ["_grid_0", "_grid_1", "_grid_2"],
|
|
"extra_launcher_args": ["_grid_0", "_grid_1", "_grid_2"],
|
|
}
|
|
|
|
def generate(self, meta: dict[str, int]) -> None:
|
|
self.x_grid, self.y_grid, self.z_grid = self.inductor_meta["fixed_grid"]
|
|
|
|
|
|
class PrecomputedGrid(GridExpr):
|
|
def generate(self, meta: dict[str, int]) -> None:
|
|
for candidate in self.inductor_meta["precomputed_grids"]:
|
|
if all(meta.get(k) == v for k, v in candidate["config"].items()):
|
|
self.x_grid, self.y_grid, self.z_grid = candidate[self.mode]
|
|
return
|
|
raise AssertionError(
|
|
f"Precomputed grid not found for {meta} in {self.inductor_meta['precomputed_grids']}"
|
|
)
|
|
|
|
|
|
class ComboKernelGrid(GridExpr):
|
|
def generate(self, meta: dict[str, int]):
|
|
combo_meta = self.inductor_meta["combo_grid_meta"]
|
|
if combo_meta["default_config"]:
|
|
meta = {**combo_meta["default_config"], **meta}
|
|
no_x_dims = []
|
|
xnumels = []
|
|
ynumels = []
|
|
znumels = []
|
|
for num in range(combo_meta["num_kernels"]):
|
|
assert (
|
|
combo_meta[f"xnumel_{num}"] is None or combo_meta[f"xnumel_{num}"] > 0
|
|
)
|
|
no_x_dims.append(combo_meta[f"no_x_dim_{num}"])
|
|
xnumels.append(combo_meta[f"xnumel_{num}"] or f"xnumel_{num}")
|
|
if f"ynumel_{num}" in combo_meta:
|
|
ynumels.append(combo_meta[f"ynumel_{num}"] or f"ynumel_{num}")
|
|
if f"znumel_{num}" in combo_meta:
|
|
znumels.append(combo_meta[f"znumel_{num}"] or f"znumel_{num}")
|
|
|
|
self.x_grid = self.combo_x_grid(xnumels, no_x_dims, meta)
|
|
if combo_meta["min_blocks"]:
|
|
self.x_grid = self.maximum([self.x_grid, combo_meta["min_blocks"]])
|
|
if ynumels:
|
|
self.y_grid = self.ceildiv(self.maximum(ynumels), meta.get("YBLOCK"))
|
|
if znumels:
|
|
self.z_grid = self.ceildiv(self.maximum(znumels), meta.get("ZBLOCK"))
|
|
|
|
def combo_x_grid(
|
|
self,
|
|
xnumels: list[Union[int, str]],
|
|
no_x_dims: list[bool],
|
|
meta: dict[str, int],
|
|
) -> Union[str, int]:
|
|
raise NotImplementedError
|
|
|
|
|
|
class SequentialComboKernelGrid(ComboKernelGrid):
|
|
def combo_x_grid(
|
|
self,
|
|
xnumels: list[Union[int, str]],
|
|
no_x_dims: list[bool],
|
|
meta: dict[str, int],
|
|
) -> Union[str, int]:
|
|
assert len(xnumels) == len(no_x_dims)
|
|
return self.summation(
|
|
[
|
|
self.ceildiv(x, 1 if no_x_dim else meta.get("XBLOCK"))
|
|
for x, no_x_dim in zip(xnumels, no_x_dims)
|
|
]
|
|
)
|
|
|
|
|
|
class RoundRobinComboKernelGrid(ComboKernelGrid):
|
|
def combo_x_grid(
|
|
self,
|
|
xnumels: list[Union[int, str]],
|
|
no_x_dims: list[bool],
|
|
meta: dict[str, int],
|
|
) -> str:
|
|
assert len(xnumels) == len(no_x_dims)
|
|
num_kernels = self.inductor_meta["combo_grid_meta"]["num_kernels"]
|
|
exprs = [x for x, no_x_dim in zip(xnumels, no_x_dims) if no_x_dim]
|
|
xnumels_x_dim = [x for x, no_x_dim in zip(xnumels, no_x_dims) if not no_x_dim]
|
|
if xnumels_x_dim:
|
|
exprs.append(self.ceildiv(self.maximum(xnumels_x_dim), meta.get("XBLOCK")))
|
|
return f"({self.maximum(exprs)}) * {num_kernels}"
|