mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Prereqs: - https://github.com/pytorch/pytorch/pull/152708 Features: 1. Adds inductor's estimate of flops and bandwidth to the json trace events that perfetto uses. 1. Only use the tflops estimation from triton if we don't have the info from the datasheet because Triton's estimates are inaccurate. I have a backlog item to fix triton flops estimation upstream. New `DeviceInfo` class, and new function `get_device_tflops`. 1. New helpers `countable_fx` and `count_flops_fx` helps get the flops of an `fx.Node`. 1. Extends Triton `torch.profiler` logging to `DebugAutotuner`. 1. New script `profile_analysis.py`: `--augment_trace` adds perf estimates to any perfetto json trace, `--analyze` creates a summary table of these perf estimates, and `--diff` will compare two traces side by side: ```python Device(NVIDIA H100, 0): Kernel Name | resnet Kernel Count | resnet FLOPS | resnet bw gbps | resnet Dur (ms) | resnet Achieved FLOPS % | resnet Achieved Bandwidth % | newresnet Kernel Count | newresnet FLOPS | newresnet bw gbps | newresnet Dur (ms) | newresnet Achieved FLOPS % | newresnet Achieved Bandwidth % --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- triton_poi_fused__native_batch_norm_legi | 24 | 0 | 0.11395268248131513 | 2.5919166666666666 | 0 | 0.003401572611382541 | 24 | 0 | 0.11395268248131513 | 2.5919166666666666 | 0 | 0.003401572611382541 sm90_xmma_fprop_implicit_gemm_f32f32_tf3 | 142 | 16932673552.422373 | 0.2585007824198784 | 12.441619718309857 | 0.08683422334575583 | 0.007716441266265022 | 142 | 16932673552.422373 | 0.2585007824198784 | 12.441619718309857 | 0.08683422334575583 | 0.007716441266265022 triton_red_fused__native_batch_norm_legi | 39 | 0 | 0.13990024992108846 | 5.752589743589743 | 0 | 0.004176126863316074 | 39 | 0 | 0.13990024992108846 | 5.752589743589743 | 0 | 0.004176126863316074 triton_poi_fused__native_batch_norm_legi | 25 | 0 | 0.31824055917536503 | 2.5291999999999994 | 0 | 0.009499718184339253 | 25 | 0 | 0.31824055917536503 | 2.5291999999999994 | 0 | 0.009499718184339253 void cutlass::Kernel2<cutlass_80_tensoro | 98 | 16211056473.596165 | 0.42972434051025826 | 7.130408163265306 | 0.08313362294151874 | 0.012827592254037562 | 98 | 16211056473.596165 | 0.42972434051025826 | 7.130408163265306 | 0.08313362294151874 | 0.012827592254037562 triton_red_fused__native_batch_norm_legi | 73 | 0 | 0.3225381327611705 | 9.987068493150682 | 0 | 0.009628003963020014 | 73 | 0 | 0.3225381327611705 | 9.987068493150682 | 0 | 0.009628003963020014 triton_poi_fused__native_batch_norm_legi | 15 | 0 | 1.4491211346487216 | 4.439333333333333 | 0 | 0.043257347302946926 | 15 | 0 | 1.4491211346487216 | 4.439333333333333 | 0 | 0.043257347302946926 void cutlass::Kernel2<cutlass_80_tensoro | 186 | 14501701145.337954 | 0.2667131401910989 | 7.873865591397849 | 0.07436769818122027 | 0.007961586274361157 | 186 | 14501701145.337954 | 0.2667131401910989 | 7.873865591397849 | 0.07436769818122027 | 0.007961586274361157 triton_poi_fused__native_batch_norm_legi | 33 | 0 | 1.4924556538193923 | 4.3101515151515155 | 0 | 0.044550915039384846 | 33 | 0 | 1.4924556538193923 | 4.3101515151515155 | 0 | 0.044550915039384846 triton_red_fused__native_batch_norm_legi | 29 | 0 | 0.25562590522631107 | 6.296275862068965 | 0 | 0.007630624036606301 | 29 | 0 | 0.25562590522631107 | 6.296275862068965 | 0 | 0.007630624036606301 triton_poi_fused__native_batch_norm_legi | 13 | 0 | 0.5870562174192726 | 2.7397692307692307 | 0 | 0.01752406619162008 | 13 | 0 | 0.5870562174192726 | 2.7397692307692307 | 0 | 0.01752406619162008 triton_poi_fused__native_batch_norm_legi | 34 | 0 | 0.41409928846284 | 2.853588235294117 | 0 | 0.012361172789935523 | 34 | 0 | 0.41409928846284 | 2.853588235294117 | 0 | 0.012361172789935523 triton_per_fused__native_batch_norm_legi | 34 | 0 | 0.11705315007018151 | 3.460647058823529 | 0 | 0.0034941238826919864 | 34 | 0 | 0.11705315007018151 | 3.460647058823529 | 0 | 0.0034941238826919864 triton_poi_fused__native_batch_norm_legi | 16 | 0 | 0.17207853197124584 | 2.3459375000000002 | 0 | 0.005136672596156592 | 16 | 0 | 0.17207853197124584 | 2.3459375000000002 | 0 | 0.005136672596156592 triton_per_fused__native_batch_norm_legi | 30 | 0 | 0.2639714322022256 | 6.131199999999999 | 0 | 0.007879744244842555 | 30 | 0 | 0.2639714322022256 | 6.131199999999999 | 0 | 0.007879744244842555 sm90_xmma_fprop_implicit_gemm_f32f32_tf3 | 100 | 11875430356.891787 | 0.19494470869421385 | 16.36534 | 0.06089964285585531 | 0.005819245035648175 | 100 | 11875430356.891787 | 0.19494470869421385 | 16.36534 | 0.06089964285585531 | 0.005819245035648175 triton_poi_fused__native_batch_norm_legi | 8 | 0 | 0.9854096626224687 | 3.2757500000000004 | 0 | 0.029415213809625928 | 8 | 0 | 0.9854096626224687 | 3.2757500000000004 | 0 | 0.029415213809625928 void cublasLt::splitKreduce_kernel<32, 1 | 56 | 34377923395.147064 | 0.8310300045762317 | 3.4199999999999986 | 0.17629704305203628 | 0.024806865808245714 | 56 | 34377923395.147064 | 0.8310300045762317 | 3.4199999999999986 | 0.17629704305203628 | 0.024806865808245714 triton_poi_fused__native_batch_norm_legi | 23 | 0 | 0.9944002965861103 | 3.2431304347826084 | 0 | 0.02968359094286896 | 23 | 0 | 0.9944002965861103 | 3.2431304347826084 | 0 | 0.02968359094286896 triton_per_fused__native_batch_norm_legi | 10 | 0 | 0.1826801058931057 | 4.428800000000001 | 0 | 0.00545313748934644 | 10 | 0 | 0.1826801058931057 | 4.428800000000001 | 0 | 0.00545313748934644 triton_poi_fused__native_batch_norm_legi | 10 | 0 | 0.3168973585366449 | 2.5471999999999997 | 0 | 0.009459622642884923 | 10 | 0 | 0.3168973585366449 | 2.5471999999999997 | 0 | 0.009459622642884923 triton_poi_fused__native_batch_norm_legi | 34 | 0 | 1.1463614897015777 | 4.124323529411764 | 0 | 0.03421974596124114 | 34 | 0 | 1.1463614897015777 | 4.124323529411764 | 0 | 0.03421974596124114 void cask_plugin_cudnn::xmma_cudnn::init | 44 | 44045510816.64277 | 2.0661232850348643 | 3.6887499999999993 | 0.22587441444432194 | 0.06167532194133924 | 44 | 44045510816.64277 | 2.0661232850348643 | 3.6887499999999993 | 0.22587441444432194 | 0.06167532194133924 sm90_xmma_fprop_implicit_gemm_f32f32_tf3 | 95 | 7876855400.165316 | 0.4694941555946739 | 18.224315789473682 | 0.04039413025725802 | 0.014014750913273854 | 95 | 7876855400.165316 | 0.4694941555946739 | 18.224315789473682 | 0.04039413025725802 | 0.014014750913273854 triton_per_fused__native_batch_norm_legi | 41 | 0 | 0.06825669875995298 | 3.0384146341463416 | 0 | 0.002037513395819492 | 41 | 0 | 0.06825669875995298 | 3.0384146341463416 | 0 | 0.002037513395819492 triton_poi_fused__native_batch_norm_legi | 23 | 0 | 0.08808154712430301 | 2.3275652173913044 | 0 | 0.0026292999141582997 | 23 | 0 | 0.08808154712430301 | 2.3275652173913044 | 0 | 0.0026292999141582997 triton_per_fused__native_batch_norm_legi | 40 | 0 | 0.18179321034952417 | 4.556825 | 0 | 0.005426662995508183 | 40 | 0 | 0.18179321034952417 | 4.556825 | 0 | 0.005426662995508183 triton_poi_fused__native_batch_norm_legi | 15 | 0 | 0.5887415155454232 | 2.783866666666667 | 0 | 0.017574373598370836 | 15 | 0 | 0.5887415155454232 | 2.783866666666667 | 0 | 0.017574373598370836 void cutlass::Kernel2<cutlass_80_tensoro | 38 | 14242013806.264643 | 0.256592404353939 | 7.217631578947369 | 0.0730359682372546 | 0.007659474756834 | 38 | 14242013806.264643 | 0.256592404353939 | 7.217631578947369 | 0.0730359682372546 | 0.007659474756834 triton_poi_fused__native_batch_norm_legi | 21 | 0 | 0.5842860973430516 | 2.7779047619047623 | 0 | 0.017441376040091088 | 21 | 0 | 0.5842860973430516 | 2.7779047619047623 | 0 | 0.017441376040091088 triton_per_fused__native_batch_norm_legi | 16 | 0 | 0.11509365173486417 | 3.5959375000000002 | 0 | 0.0034356313950705724 | 16 | 0 | 0.11509365173486417 | 3.5959375000000002 | 0 | 0.0034356313950705724 triton_poi_fused__native_batch_norm_legi | 14 | 0 | 0.1704672000243914 | 2.4044285714285714 | 0 | 0.00508857313505646 | 14 | 0 | 0.1704672000243914 | 2.4044285714285714 | 0 | 0.00508857313505646 triton_poi_fused__native_batch_norm_legi | 58 | 0 | 2.307520779930795 | 8.190706896551722 | 0 | 0.06888121731136704 | 58 | 0 | 2.307520779930795 | 8.190706896551722 | 0 | 0.06888121731136704 triton_per_fused__native_batch_norm_legi | 29 | 0 | 0.037243248971881276 | 3.0277586206896556 | 0 | 0.001111738775280038 | 29 | 0 | 0.037243248971881276 | 3.0277586206896556 | 0 | 0.001111738775280038 triton_poi_fused__native_batch_norm_legi | 20 | 0 | 0.04741699795428918 | 2.2911500000000005 | 0 | 0.0014154327747549007 | 20 | 0 | 0.04741699795428918 | 2.2911500000000005 | 0 | 0.0014154327747549007 triton_per_fused__native_batch_norm_legi | 25 | 0 | 0.13357016893727824 | 3.37536 | 0 | 0.003987169222008305 | 25 | 0 | 0.13357016893727824 | 3.37536 | 0 | 0.003987169222008305 triton_poi_fused__native_batch_norm_legi | 13 | 0 | 0.3089862268300253 | 2.8111538461538457 | 0 | 0.009223469457612694 | 13 | 0 | 0.3089862268300253 | 2.8111538461538457 | 0 | 0.009223469457612694 triton_poi_fused__native_batch_norm_legi | 17 | 0 | 0.3129385387909844 | 2.673 | 0 | 0.009341448919133863 | 17 | 0 | 0.3129385387909844 | 2.673 | 0 | 0.009341448919133863 triton_per_fused__native_batch_norm_legi | 19 | 0 | 0.2215568162533158 | 3.8837368421052636 | 0 | 0.0066136363060691275 | 19 | 0 | 0.2215568162533158 | 3.8837368421052636 | 0 | 0.0066136363060691275 std::enable_if<!(false), void>::type int | 23 | 504916805.19297093 | 1.0118296096314707 | 8.113913043478261 | 0.0025893169497075447 | 0.030203868944223014 | 23 | 504916805.19297093 | 1.0118296096314707 | 8.113913043478261 | 0.0025893169497075447 | 0.030203868944223014 triton_poi_fused_add_copy__38 | 56 | 0 | 0 | 2.132482142857143 | 0 | 0 | 56 | 0 | 0 | 2.132482142857143 | 0 | 0 triton_poi_fused_convolution_0 | 18 | 0 | 0.43458610794936897 | 2.773333333333334 | 0 | 0.012972719640279667 | 18 | 0 | 0.43458610794936897 | 2.773333333333334 | 0 | 0.012972719640279667 triton_poi_fused_convolution_1 | 17 | 0 | 0.028816312469162712 | 2.6145882352941174 | 0 | 0.0008601884319153051 | 17 | 0 | 0.028816312469162712 | 2.6145882352941174 | 0 | 0.0008601884319153051 void convolve_common_engine_float_NHWC<f | 44 | 8641868995.31118 | 0.024730540008465626 | 25.87327272727273 | 0.04431727689903169 | 0.0007382250748795709 | 44 | 8641868995.31118 | 0.024730540008465626 | 25.87327272727273 | 0.04431727689903169 | 0.0007382250748795709 triton_per_fused__native_batch_norm_legi | 12 | 0 | 0.6809930918986744 | 4.82675 | 0 | 0.020328151996975356 | 12 | 0 | 0.6809930918986744 | 4.82675 | 0 | 0.020328151996975356 triton_per_fused__native_batch_norm_legi | 14 | 0 | 0.02883030597936608 | 2.6651428571428575 | 0 | 0.0008606061486377935 | 14 | 0 | 0.02883030597936608 | 2.6651428571428575 | 0 | 0.0008606061486377935 triton_per_fused__native_batch_norm_legi | 16 | 0 | 0.0014658988233201874 | 2.098 | 0 | 4.375817383045335e-05 | 16 | 0 | 0.0014658988233201874 | 2.098 | 0 | 4.375817383045335e-05 triton_poi_fused__native_batch_norm_legi | 13 | 0 | 0.9926297180284697 | 3.2367692307692306 | 0 | 0.02963073785159611 | 13 | 0 | 0.9926297180284697 | 3.2367692307692306 | 0 | 0.02963073785159611 triton_poi_fused__native_batch_norm_legi | 9 | 0 | 1.3008817095666507 | 3.0863333333333336 | 0 | 0.03883228983781048 | 9 | 0 | 1.3008817095666507 | 3.0863333333333336 | 0 | 0.03883228983781048 void at::native::(anonymous namespace):: | 98 | 0 | 0.09174335613709389 | 4.408520408163265 | 0 | 0.0027386076458833994 | 98 | 0 | 0.09174335613709389 | 4.408520408163265 | 0 | 0.0027386076458833994 void at::native::vectorized_elementwise_ | 7 | 0 | 0 | 1.7278571428571428 | 0 | 0 | 7 | 0 | 0 | 1.7278571428571428 | 0 | 0 ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/149697 Approved by: https://github.com/eellison, https://github.com/shunting314
3048 lines
110 KiB
Python
3048 lines
110 KiB
Python
# mypy: allow-untyped-defs
|
|
from __future__ import annotations
|
|
|
|
import builtins
|
|
import copy
|
|
import dataclasses
|
|
import functools
|
|
import hashlib
|
|
import inspect
|
|
import itertools
|
|
import logging
|
|
import math
|
|
import operator
|
|
import os
|
|
import os.path
|
|
import re
|
|
import sys
|
|
import threading
|
|
import time
|
|
from collections import namedtuple
|
|
from typing import (
|
|
Any,
|
|
Callable,
|
|
Generic,
|
|
Literal,
|
|
Optional,
|
|
TYPE_CHECKING,
|
|
TypeVar,
|
|
Union,
|
|
)
|
|
|
|
import torch
|
|
from torch._dynamo.utils import set_feature_use
|
|
from torch._prims_common import compute_required_storage_length
|
|
from torch.utils._ordered_set import OrderedSet
|
|
|
|
from ..triton_bundler import TritonBundler
|
|
from ..utils import prefix_is_reduction, triton_version_uses_attrs_dict
|
|
from . import triton_helpers
|
|
from .autotune_cache import AutotuneCache
|
|
from .benchmarking import benchmarker
|
|
from .coordinate_descent_tuner import CoordescTuner
|
|
from .hints import (
|
|
_NUM_THREADS_PER_WARP,
|
|
AutotuneHint,
|
|
DeviceProperties,
|
|
HeuristicType,
|
|
ReductionHint,
|
|
TileHint,
|
|
TRITON_MAX_BLOCK,
|
|
TRITON_MAX_RSPLIT,
|
|
)
|
|
from .runtime_utils import (
|
|
ceildiv,
|
|
conditional_product,
|
|
create_bandwidth_info_str,
|
|
dynamo_timed,
|
|
get_first_attr,
|
|
get_max_y_grid,
|
|
get_num_bytes,
|
|
next_power_of_2,
|
|
triton_cache_dir,
|
|
triton_config_to_hashable,
|
|
triton_hash_to_path_key,
|
|
validate_triton_config,
|
|
)
|
|
from .static_cuda_launcher import StaticallyLaunchedCudaKernel
|
|
from .triton_compat import (
|
|
ASTSource,
|
|
autograd_profiler,
|
|
cc_warp_size,
|
|
CompiledKernel,
|
|
Config,
|
|
GPUTarget,
|
|
HAS_WARP_SPEC,
|
|
KernelInterface,
|
|
knobs,
|
|
OutOfResources,
|
|
PTXASError,
|
|
triton,
|
|
)
|
|
|
|
|
|
class NoTritonConfigsError(RuntimeError):
|
|
pass
|
|
|
|
|
|
if TYPE_CHECKING:
|
|
from collections.abc import Container, Hashable
|
|
|
|
from torch._guards import CompileId
|
|
|
|
LauncherType = Any
|
|
|
|
_KernelType = Union[CompiledKernel, StaticallyLaunchedCudaKernel]
|
|
_T = TypeVar("_T", bound=_KernelType)
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
|
|
def get_total_reduction_numel(numels: dict[str, int]) -> int:
|
|
return conditional_product(
|
|
*[numel for prefix, numel in numels.items() if prefix_is_reduction(prefix)]
|
|
)
|
|
|
|
|
|
def autotune_hints_to_configs(
|
|
hints: OrderedSet[AutotuneHint],
|
|
size_hints,
|
|
block_size: int,
|
|
device_props: DeviceProperties,
|
|
) -> list[Config]:
|
|
"""
|
|
AutotuneHints can be attached to the metadata of triton kernels for providing
|
|
suggestions about what to try for autotuning. One reason to do this is if there are
|
|
some configs that are only useful in specific scenarios, in which case we can avoid
|
|
wasting compile time on autotuning unless we know we are in one of those scenarios.
|
|
|
|
Based on those hints, this function will generate a list of additional autotuning
|
|
configs to try.
|
|
"""
|
|
xyz_options: tuple[tuple[int, Optional[int], Optional[int]], ...]
|
|
configs: list[Config] = []
|
|
for hint in hints:
|
|
if hint == AutotuneHint.ONE_ELEMENT_PER_THREAD:
|
|
if len(size_hints) == 1:
|
|
xyz_options = ((block_size // 4, None, None),)
|
|
elif len(size_hints) == 2:
|
|
xyz_options = ((block_size // 4, 1, None), (1, block_size // 4, None))
|
|
elif len(size_hints) == 3:
|
|
xyz_options = (
|
|
(block_size // 4, 1, 1),
|
|
(1, block_size // 4, 1),
|
|
(1, 1, block_size // 4),
|
|
)
|
|
configs.extend(
|
|
triton_config(
|
|
size_hints,
|
|
*xyz,
|
|
num_elements_per_warp=(
|
|
device_props.warp_size if device_props.warp_size else 32
|
|
),
|
|
)
|
|
for xyz in xyz_options
|
|
)
|
|
|
|
return configs
|
|
|
|
|
|
def disable_pointwise_autotuning(inductor_meta):
|
|
# Autotuning can give different benchmarking results from run to run, and
|
|
# therefore we disable autotuning when use_deterministic flag is on.
|
|
if inductor_meta.get("are_deterministic_algorithms_enabled"):
|
|
return True
|
|
return not inductor_meta.get("autotune_pointwise", True)
|
|
|
|
|
|
def _dump_launch_params(args, kwargs, launcher, kernel_name, grid):
|
|
call_args = []
|
|
call_kwargs = {}
|
|
for arg in args:
|
|
if isinstance(arg, (int, bool)):
|
|
call_args.append(str(arg))
|
|
else:
|
|
call_args.append("T")
|
|
for k, v in kwargs.items():
|
|
if isinstance(arg, (int, bool)):
|
|
call_kwargs[k] = v
|
|
else:
|
|
call_kwargs[k] = v
|
|
if not triton_version_uses_attrs_dict():
|
|
call_kwargs.update(launcher.config.kwargs)
|
|
call_kwargs["num_warps"] = launcher.config.num_warps
|
|
call_kwargs["num_stages"] = launcher.config.num_stages
|
|
if HAS_WARP_SPEC:
|
|
call_kwargs["num_consumer_groups"] = getattr(
|
|
launcher.config, "num_consumer_groups", 0
|
|
)
|
|
call_kwargs["num_buffers_warp_spec"] = getattr(
|
|
launcher.config, "num_buffers_warp_spec", 0
|
|
)
|
|
args_str = [*call_args]
|
|
args_str.extend(f"{k}={v}" for k, v in call_kwargs.items())
|
|
args_str = ", ".join(args_str)
|
|
abs_path = os.path.abspath(sys.argv[0])
|
|
with open(f"{abs_path}.launch_params", "a") as f:
|
|
f.write(f"{kernel_name} | {args_str} | {grid!r}\n")
|
|
|
|
|
|
def check_autotune_cache(
|
|
configs: list[Config], filename: Optional[str], inductor_meta: dict[str, Any]
|
|
) -> tuple[list[Config], Optional[AutotuneCache], dict[str, Any]]:
|
|
"""
|
|
Given a list of configs, checks autotune cache and return metadata
|
|
"""
|
|
autotune_cache = None
|
|
autotune_cache_info = {}
|
|
disabled = inductor_meta.get("force_disable_caches", False)
|
|
if (
|
|
not disabled
|
|
and filename is not None
|
|
and (len(configs) > 1 or inductor_meta.get("coordinate_descent_tuning"))
|
|
and not os.environ.get("TRITON_INTERPRET", "0") == "1"
|
|
):
|
|
configs_hash = hash_configs(configs)
|
|
|
|
autotune_cache = AutotuneCache.create(inductor_meta, filename, configs_hash)
|
|
if autotune_cache:
|
|
if best_config := autotune_cache.read_best(inductor_meta, configs):
|
|
configs = [best_config]
|
|
autotune_cache_info["best_config"] = triton_config_to_hashable(
|
|
best_config
|
|
)
|
|
autotune_cache_info["autotune_cache_state"] = "hit"
|
|
|
|
else:
|
|
autotune_cache_info["autotune_cache_state"] = "miss"
|
|
autotune_cache_info["num_configs"] = len(configs)
|
|
if inductor_meta.get("coordinate_descent_tuning"):
|
|
autotune_cache_info["coordesc_tuning"] = True
|
|
if len(configs) == 1:
|
|
# This is the config that coordinate descent tuning started at, which
|
|
# is not the same as the final config chosen (i.e. only_config, best_config)
|
|
autotune_cache_info["coordesc_tuning_start_config"] = (
|
|
triton_config_to_hashable(configs[0])
|
|
)
|
|
else:
|
|
if len(configs) == 1:
|
|
autotune_cache_info["autotune_cache_state"] = "only 1 config"
|
|
autotune_cache_info["only_config"] = triton_config_to_hashable(configs[0])
|
|
|
|
if disabled:
|
|
autotune_cache_info["autotune_cache_state"] = "force_disabled"
|
|
log.debug("autotune caching is disabled by config.force_disable_caches")
|
|
|
|
return configs, autotune_cache, autotune_cache_info
|
|
|
|
|
|
class CachingAutotuner(KernelInterface):
|
|
"""
|
|
Simplified version of Triton autotuner that has no invalidation
|
|
key and caches the best config to disk to improve cold start times.
|
|
Unlike the main triton Autotuner, this version can precompile all
|
|
configs, and does not rely on the Triton JIT.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
fn,
|
|
triton_meta, # passed directly to triton
|
|
configs,
|
|
save_cache_hook,
|
|
mutated_arg_names: list[str], # see [Note: clone mutated buffers]
|
|
optimize_mem,
|
|
heuristic_type,
|
|
size_hints=None,
|
|
inductor_meta=None, # metadata not relevant to triton
|
|
custom_kernel=False, # whether the kernel is inductor-generated or custom
|
|
filename: Optional[str] = None,
|
|
reset_to_zero_arg_names: Optional[list[str]] = None,
|
|
autotune_cache_info: Optional[dict[str, Any]] = None,
|
|
):
|
|
super().__init__()
|
|
|
|
assert len(configs) > 0, "Non-empty TritonConfig list required for compiling"
|
|
# makes sure there are no pre-hooks on any of the triton configs
|
|
for cfg in configs:
|
|
validate_triton_config(cfg)
|
|
|
|
self.fn = fn
|
|
self.device_props: DeviceProperties = triton_meta["device"]
|
|
self.triton_meta = {
|
|
**triton_meta,
|
|
"device": self.device_props.index,
|
|
"device_type": self.device_props.type,
|
|
}
|
|
self.inductor_meta = {} if inductor_meta is None else inductor_meta
|
|
self.save_cache_hook = save_cache_hook
|
|
self.mutated_arg_names = mutated_arg_names
|
|
self.reset_to_zero_arg_names = (
|
|
[] if reset_to_zero_arg_names is None else reset_to_zero_arg_names
|
|
)
|
|
self.optimize_mem = optimize_mem
|
|
self.configs = configs
|
|
self.heuristic_type = heuristic_type
|
|
self.custom_kernel = custom_kernel
|
|
self.cuda_kernel_saved = False
|
|
self.autotune_cache_info = autotune_cache_info
|
|
if log.isEnabledFor(logging.DEBUG):
|
|
log.debug(
|
|
"CachingAutotuner gets %d configs for %s",
|
|
len(self.configs),
|
|
self.fn.__name__,
|
|
)
|
|
for c in self.configs:
|
|
log.debug(c)
|
|
|
|
self.compile_results: list[CompileResult[_KernelType]] = []
|
|
self.launchers: list[LauncherType] = []
|
|
self.lock = threading.Lock()
|
|
if os.getenv("TRITON_CACHE_DIR") is None:
|
|
os.environ["TRITON_CACHE_DIR"] = triton_cache_dir(
|
|
self.triton_meta.get("device", 0)
|
|
)
|
|
log.debug("Triton cache dir: %s", os.environ["TRITON_CACHE_DIR"])
|
|
|
|
self.size_hints = size_hints
|
|
self.coordesc_tuner = CoordescTuner(
|
|
is_mm=False,
|
|
name=self.fn.__name__,
|
|
size_hints=size_hints,
|
|
inductor_meta=self.inductor_meta,
|
|
)
|
|
self.filename = filename
|
|
|
|
# used for profiling
|
|
self.kernel_hash: str = ""
|
|
|
|
# Kernels are stored in the codecache with the filename as a hash of the code.
|
|
# We rely on this to obtain the kernel hash
|
|
if self.filename is not None:
|
|
base_name = os.path.basename(self.filename)
|
|
if ".py" in base_name:
|
|
self.kernel_hash = os.path.splitext(base_name)[0]
|
|
|
|
self.precompile_time_taken_ns = 0
|
|
self.autotune_time_taken_ns = 0
|
|
# Dumps the launch configs after autotuning.
|
|
self.dump_launch_params = (
|
|
os.environ.get("TORCHINDUCTOR_DUMP_LAUNCH_PARAMS", "0") == "1"
|
|
)
|
|
|
|
self.triton_interpret = os.environ.get("TRITON_INTERPRET", "0") == "1"
|
|
|
|
# Compile-time info included in runtime logginging
|
|
self.compile_id: Optional[CompileId] = None
|
|
self.is_backward = False
|
|
|
|
def is_statically_launchable(self):
|
|
"""
|
|
Checks if every compiled kernel is statically launchable, which
|
|
allows us to efficiently cache it in FXGraphCache
|
|
"""
|
|
if not self.compile_results:
|
|
return False
|
|
return all(
|
|
isinstance(x, StaticTritonCompileResult) for x in self.compile_results
|
|
)
|
|
|
|
def recheck_autotune_cache(
|
|
self, reload_kernel_from_src: Callable[[], CachingAutotuner]
|
|
) -> None:
|
|
"""
|
|
On cache load on static autotuner, we need to recheck the autotune cache, since
|
|
a best config could have been found from a previous run
|
|
"""
|
|
assert self.is_statically_launchable()
|
|
|
|
configs = [result.config for result in self.compile_results]
|
|
|
|
(cached_configs, _, autotune_cache_info) = check_autotune_cache(
|
|
configs, self.filename, self.inductor_meta
|
|
)
|
|
self.autotune_cache_info = autotune_cache_info
|
|
# I.e. there was an autotune cache hit
|
|
if len(cached_configs) == 1 and len(configs) > 1:
|
|
best_config = cached_configs[0]
|
|
# Grab the best compiled config, if it's in the list of available ones
|
|
best_config_hash = triton_config_to_hashable(best_config)
|
|
|
|
for compile_result in self.compile_results:
|
|
if triton_config_to_hashable(compile_result.config) == best_config_hash:
|
|
self.compile_results = [compile_result]
|
|
return
|
|
|
|
# If the best config isn't in our list of compile results,
|
|
# it's likely because it was found by coordesc after the cache
|
|
# already saved
|
|
if best_config.found_by_coordesc:
|
|
with dynamo_timed("CachingAutotuner.slow_precompile_config"):
|
|
if self.fn.fn is None:
|
|
self.fn = reload_kernel_from_src().fn
|
|
self.compile_results = [self._precompile_config(best_config)]
|
|
|
|
def set_compile_info(
|
|
self, compile_id: Optional[CompileId], is_backward: bool
|
|
) -> None:
|
|
self.compile_id = compile_id
|
|
self.is_backward = is_backward
|
|
|
|
def precompile(
|
|
self,
|
|
warm_cache_only=False,
|
|
reload_kernel: Optional[Callable[[], CachingAutotuner]] = None,
|
|
static_triton_bundle_key: Optional[str] = None,
|
|
):
|
|
if warm_cache_only:
|
|
self._precompile_worker()
|
|
return
|
|
with self.lock:
|
|
# Helper function for reloading a kernel generated in a worker
|
|
# in the parent class. Normally we don't need to reload the kernel
|
|
# in the parent process, but in certain cases (coordesc tuning, dynamic_scale_rblock),
|
|
# we need to actually run compilation on the parent process
|
|
if reload_kernel is not None:
|
|
self._reload_kernel = reload_kernel
|
|
self._precompile_worker()
|
|
if static_triton_bundle_key is not None and self.is_statically_launchable():
|
|
TritonBundler.put_static_autotuner(static_triton_bundle_key, self)
|
|
self._make_launchers()
|
|
self._dynamic_scale_rblock()
|
|
|
|
def _precompile_worker(self):
|
|
if self.compile_results:
|
|
for result in self.compile_results:
|
|
TritonBundler.put(
|
|
triton_hash_to_path_key(result.kernel.hash), # type: ignore[attr-defined]
|
|
self.triton_meta.get("device", 0),
|
|
)
|
|
return
|
|
assert not self.launchers
|
|
if not self.configs:
|
|
raise NoTritonConfigsError("No triton configs are available")
|
|
|
|
compile_results = []
|
|
exc = None
|
|
for c in self.configs:
|
|
try:
|
|
compile_results.append(self._precompile_config(c))
|
|
except (OutOfResources, PTXASError) as e:
|
|
exc = e
|
|
if len(compile_results) == 0:
|
|
raise NoTritonConfigsError(
|
|
f"No valid triton configs. {type(exc).__name__}: {exc}"
|
|
)
|
|
self.compile_results = compile_results
|
|
self.configs = None
|
|
|
|
def _dynamic_scale_rblock(self):
|
|
# TODO(jansel): we should find a way to move this extra compile into the worker process
|
|
# Currently it relies on _make_launchers(), which requires a cuda context, to populate nreg.
|
|
device_prop = self.device_props
|
|
if (
|
|
self.inductor_meta.get("dynamic_scale_rblock", True)
|
|
and not self.inductor_meta.get("persistent_reduction")
|
|
and self.heuristic_type == HeuristicType.REDUCTION
|
|
and self.size_hints is not None
|
|
# Disable for Intel as Triton is not ready to return n_regs for a compiled_binary.
|
|
and device_prop.type in ["cuda", "hip"]
|
|
and device_prop.major
|
|
and (device_prop.major >= 8 or torch.version.hip)
|
|
and device_prop.regs_per_multiprocessor is not None
|
|
):
|
|
assert device_prop.regs_per_multiprocessor
|
|
assert device_prop.max_threads_per_multi_processor
|
|
assert device_prop.multi_processor_count
|
|
seen_config_hashes: Optional[OrderedSet[Hashable]] = None
|
|
warp_size = device_prop.warp_size or 32
|
|
for result in self.compile_results:
|
|
triton_config = result.config
|
|
compiled_binary = result.kernel
|
|
assert len(self.size_hints) >= 2
|
|
xblock = triton_config.kwargs.get("XBLOCK", 1)
|
|
reduction_kwargs = [
|
|
kwarg for kwarg in triton_config.kwargs if kwarg.startswith("R")
|
|
]
|
|
rblocks = [triton_config.kwargs[kwarg] for kwarg in reduction_kwargs]
|
|
total_block = (self.size_hints["x"] + xblock - 1) // xblock
|
|
nreg = getattr(compiled_binary, "n_regs", None)
|
|
if nreg is None:
|
|
continue
|
|
|
|
# make sure rblocks are not too small
|
|
if conditional_product(*rblocks) <= 64:
|
|
continue
|
|
|
|
# each SM of A100 has 65536 32-bit registers. To maximize
|
|
# the theoretical occupancy, we need run 2048 threads on each
|
|
# SM. So each thread should use no more than 65536 / 2048
|
|
# = 32 registers. In cases where occupancy matters, and each
|
|
# thread uses too many registers, reduce R0_BLOCK to reduce
|
|
# the register usage.
|
|
# For kernel https://gist.github.com/shunting314/e4cccc031fe30d378b9b23c08c238cbd
|
|
# from PLBartForCausalLM, latency improve from
|
|
# 7.795ms to 4.883ms.
|
|
#
|
|
if (
|
|
nreg
|
|
<= device_prop.regs_per_multiprocessor
|
|
// device_prop.max_threads_per_multi_processor
|
|
):
|
|
continue
|
|
|
|
nreg_per_warp = nreg * warp_size
|
|
nreg_per_block = nreg_per_warp * triton_config.num_warps
|
|
|
|
# Previously we set max_blocks_per_sm to 'max_threads_per_multi_processo / (32 * num_warps)'
|
|
# The formula below is a tighter upper bound since we have the assumption that
|
|
# nreg > device_prop.regs_per_multiprocessor // device_prop.max_threads_per_multi_processor
|
|
# due to the if condition above and:
|
|
# regs_per_multiprocessor / nreg_per_block
|
|
# = regs_per_multiprocessor / (nreg * 32 * num_warps)
|
|
# < regs_per_multiprocessor / ((regs_per_multiprocessor / max_threads_per_multi_processor) * 32 * num_warps)
|
|
# = max_threads_per_multi_processor / (32 * num_warps)
|
|
# Using a tigher upper bound can reveal more optimization opportunities.
|
|
max_blocks_per_sm = max(
|
|
device_prop.regs_per_multiprocessor // nreg_per_block, 1
|
|
)
|
|
|
|
if total_block <= max_blocks_per_sm * device_prop.multi_processor_count:
|
|
# no need to improve occupancy
|
|
continue
|
|
new_config = copy.deepcopy(triton_config)
|
|
|
|
# Reduce the largest Rn_BLOCK by a factor of 2.
|
|
largest_rkwarg: str = max(
|
|
reduction_kwargs, key=triton_config.kwargs.__getitem__
|
|
)
|
|
new_config.kwargs[largest_rkwarg] //= 2
|
|
|
|
if seen_config_hashes is None:
|
|
seen_config_hashes = OrderedSet(
|
|
[
|
|
triton_config_to_hashable(x.config)
|
|
for x in self.compile_results
|
|
]
|
|
)
|
|
new_config_hash = triton_config_to_hashable(new_config)
|
|
if new_config_hash in seen_config_hashes:
|
|
continue
|
|
seen_config_hashes.add(new_config_hash)
|
|
log.debug(
|
|
"Dynamically scale down %s from TritonConfig(%s) and get a new TritonConfig(%s)",
|
|
largest_rkwarg,
|
|
triton_config,
|
|
new_config,
|
|
)
|
|
if self.fn.fn is None:
|
|
"""
|
|
We are in the parent process, while this program was compiled in a worker
|
|
and the fn was dropped in prepare_for_pickle(). We haven't loaded the module
|
|
containing the real fn yet.
|
|
"""
|
|
assert hasattr(self, "_reload_kernel")
|
|
assert callable(self._reload_kernel)
|
|
self.fn = self._reload_kernel().fn
|
|
self.compile_results.append(self._precompile_config(new_config))
|
|
|
|
self._make_launchers()
|
|
|
|
def _make_launchers(self):
|
|
if len(self.launchers) == len(self.compile_results):
|
|
return
|
|
|
|
from torch._dynamo.device_interface import DeviceGuard
|
|
|
|
device_interface = self.get_device_interface()
|
|
|
|
# load binary to the correct device
|
|
with DeviceGuard(device_interface, self.triton_meta["device"]):
|
|
# need to initialize context
|
|
device_interface.synchronize(device_interface.current_device())
|
|
launchers = []
|
|
exc = None
|
|
for result in self.compile_results:
|
|
try:
|
|
launchers.append(result.make_launcher())
|
|
|
|
except (OutOfResources, PTXASError, torch.cuda.OutOfMemoryError) as e:
|
|
exc = e
|
|
if len(launchers) == 0:
|
|
raise RuntimeError(f"No valid triton configs. {type(exc).__name__}: {exc}")
|
|
self.launchers = launchers
|
|
|
|
def prepare_for_pickle(self) -> tuple[Any, Any, Any, Any, Any]:
|
|
"""Drop stuff from triton.JITFunction that does not pickle.
|
|
This must be called after precompile so that these things are no longer needed.
|
|
Returns a tuple of old values
|
|
"""
|
|
old_values = (
|
|
self.fn.fn,
|
|
self.fn.__globals__,
|
|
self.fn.used_global_vals,
|
|
self.fn.repr,
|
|
self.launchers,
|
|
)
|
|
self.fn.fn = None
|
|
self.fn.__globals__ = None
|
|
self.fn.used_global_vals = None
|
|
self.fn.repr = _ConstRepr(self.fn.repr(self.fn))
|
|
self.launchers = []
|
|
return old_values
|
|
|
|
def prepare_for_caching(self) -> None:
|
|
"""
|
|
Statically Launched CUDA Kernels have a raw cubin on them
|
|
that we don't need to store in the cache(since TritonBundler handles the collection for us)
|
|
"""
|
|
for result in self.compile_results:
|
|
if isinstance(result, StaticTritonCompileResult):
|
|
# Don't save this in the inductor cache, as it is very large
|
|
result.kernel.cubin_raw = None
|
|
|
|
def __getstate__(self) -> dict[str, Any]:
|
|
assert not self.launchers, (
|
|
"pickle should not be called with after make_launchers()"
|
|
)
|
|
return {
|
|
**self.__dict__,
|
|
"lock": None,
|
|
}
|
|
|
|
def __setstate__(self, state: dict[str, Any]) -> None:
|
|
self.__dict__.update(state)
|
|
self.lock = threading.Lock()
|
|
|
|
def get_device_interface(self):
|
|
# this code cannot run in compile workers, because it imports from torch
|
|
from torch._dynamo.device_interface import get_interface_for_device
|
|
|
|
return get_interface_for_device(self.device_props.type.replace("hip", "cuda"))
|
|
|
|
def _precompile_config(self, cfg: Config) -> CompileResult[_KernelType]:
|
|
"""Ahead of time compile a given autotuner config."""
|
|
compile_meta = copy.deepcopy(self.triton_meta)
|
|
cfg_kwargs = cfg.kwargs
|
|
if self.device_props.type == "hip":
|
|
cfg_kwargs = {**cfg_kwargs}
|
|
for k in ("matrix_instr_nonkdim", "waves_per_eu", "kpack"):
|
|
if k in cfg_kwargs:
|
|
compile_meta[k] = cfg_kwargs.pop(k)
|
|
compile_meta["constants"].update(cfg_kwargs)
|
|
for i in self.fn.constexprs:
|
|
arg_name = self.fn.arg_names[i]
|
|
if arg_name not in compile_meta["constants"] and (
|
|
arg_name == "num_warps" or arg_name == "num_stages"
|
|
):
|
|
compile_meta["constants"][arg_name] = getattr(cfg, arg_name)
|
|
compile_meta["num_warps"] = cfg.num_warps
|
|
compile_meta["num_stages"] = cfg.num_stages
|
|
if HAS_WARP_SPEC:
|
|
compile_meta["num_consumer_groups"] = getattr(cfg, "num_consumer_groups", 0)
|
|
compile_meta["num_buffers_warp_spec"] = getattr(
|
|
cfg, "num_buffers_warp_spec", 0
|
|
)
|
|
compile_meta["debug"] = self.inductor_meta.get(
|
|
"assert_indirect_indexing", True
|
|
) and not self.inductor_meta.get("is_hip", False)
|
|
|
|
# device type will be "hip" rather than "cuda" here
|
|
compile_meta["device_type"] = self.device_props.type
|
|
compile_meta["cc"] = self.device_props.cc
|
|
|
|
if self.device_props.type == "cpu":
|
|
triton_helpers.set_driver_to_cpu()
|
|
else:
|
|
triton_helpers.set_driver_to_gpu()
|
|
|
|
if not ASTSource:
|
|
raise RuntimeError("Installed triton version too old, please upgrade")
|
|
|
|
compile_args = (
|
|
ASTSource(
|
|
self.fn,
|
|
compile_meta["signature"],
|
|
compile_meta["constants"],
|
|
compile_meta["configs"][0],
|
|
),
|
|
)
|
|
|
|
if self.device_props.type == "mtia":
|
|
from mtia.host_runtime.torch_mtia.acc_flags import ( # type: ignore[import-not-found]
|
|
build_codename,
|
|
)
|
|
|
|
arch = build_codename()
|
|
else:
|
|
arch = compile_meta["cc"]
|
|
|
|
target = GPUTarget(
|
|
compile_meta["device_type"],
|
|
arch,
|
|
cc_warp_size(compile_meta["cc"]),
|
|
)
|
|
|
|
options = {
|
|
"num_warps": compile_meta["num_warps"],
|
|
"num_stages": compile_meta["num_stages"],
|
|
"debug": compile_meta["debug"],
|
|
"sanitize_overflow": False, # turn off additional asserts added for overflow checks
|
|
}
|
|
if HAS_WARP_SPEC:
|
|
options.update(
|
|
{
|
|
"num_consumer_groups": compile_meta.get("num_consumer_groups", 0),
|
|
"num_buffers_warp_spec": compile_meta.get(
|
|
"num_buffers_warp_spec", 0
|
|
),
|
|
}
|
|
)
|
|
if self.device_props.type == "hip":
|
|
if "waves_per_eu" in compile_meta:
|
|
options["waves_per_eu"] = compile_meta["waves_per_eu"]
|
|
if "matrix_instr_nonkdim" in compile_meta:
|
|
options["matrix_instr_nonkdim"] = compile_meta["matrix_instr_nonkdim"]
|
|
compile_kwargs = {
|
|
"target": target,
|
|
"options": options,
|
|
}
|
|
|
|
try:
|
|
binary = triton.compile(*compile_args, **compile_kwargs)
|
|
except Exception:
|
|
log.exception(
|
|
"Triton compilation failed: %s\n%s\nmetadata: %s",
|
|
self.inductor_meta.get("kernel_name", "triton_"),
|
|
self.fn.src,
|
|
compile_meta,
|
|
)
|
|
raise
|
|
TritonBundler.put(
|
|
triton_hash_to_path_key(binary.hash), self.triton_meta.get("device", 0)
|
|
)
|
|
# If the binary has a cubin file to directly launch, save it on the binary
|
|
static_launcher = StaticTritonCompileResult.can_statically_launch(
|
|
binary, self.inductor_meta, self.triton_meta, self.heuristic_type
|
|
)
|
|
|
|
if static_launcher is not None:
|
|
result = StaticTritonCompileResult(
|
|
static_launcher, cfg, compile_meta, self.inductor_meta
|
|
)
|
|
return result
|
|
|
|
return TritonCompileResult(binary, cfg, compile_meta, self.inductor_meta)
|
|
|
|
def _get_args_with_constexprs(self, args, launcher):
|
|
"""
|
|
`args` is passed in with only the non-constexpr args (because the constexpr arg values
|
|
depend on the config). However, in later triton versions, the constexpr args need to be
|
|
added into the args list.
|
|
"""
|
|
if triton_version_uses_attrs_dict():
|
|
# first: aggregate the constexpr args in (index, val) pairs
|
|
# so we can sort them by index.
|
|
constexpr_args: list[tuple[int, Any]] = []
|
|
for arg_name, arg_val in launcher.config.kwargs.items():
|
|
if arg_name in self.fn.arg_names:
|
|
constexpr_args.append((self.fn.arg_names.index(arg_name), arg_val))
|
|
|
|
constexpr_args.sort()
|
|
new_args = [*args]
|
|
for arg_idx, arg_val in constexpr_args:
|
|
new_args.insert(arg_idx, arg_val)
|
|
|
|
return new_args
|
|
return args
|
|
|
|
def bench(self, launcher, *args, with_profiler=False, **kwargs):
|
|
"""Measure the performance of a given launcher"""
|
|
# we don't skip configs with spilled registers when auto-tuning custom
|
|
# (user-written) Triton kernels, as (i) we don't have any knowledge or
|
|
# control over the kernel code; (ii) there is empirical evidence that
|
|
# for some (complicated) custom Triton kernels, a register-spilling
|
|
# config may yield the best latency.
|
|
if not self.custom_kernel and launcher.n_spills > self.inductor_meta.get(
|
|
"spill_threshold", 16
|
|
):
|
|
log.debug(
|
|
"Skip config %s because of register spilling: %d",
|
|
launcher.config,
|
|
launcher.n_spills,
|
|
)
|
|
return float("inf")
|
|
|
|
device_interface = self.get_device_interface()
|
|
stream = device_interface.get_raw_stream(device_interface.current_device())
|
|
|
|
cpu_copies = self.copy_args_to_cpu_if_needed(*args, **kwargs)
|
|
|
|
def kernel_call():
|
|
cloned_args, cloned_kwargs = self.maybe_clone_args(
|
|
cpu_copies, *args, **kwargs
|
|
)
|
|
# reset to zero before evaluating any config
|
|
self.reset_to_zero_args(*args, **kwargs)
|
|
args_with_constexprs = self._get_args_with_constexprs(cloned_args, launcher)
|
|
if autograd_profiler._is_profiler_enabled:
|
|
profiler_kwargs = self.get_profiler_kwargs(stream, launcher)
|
|
with torch._C._profiler._RecordFunctionFast(
|
|
self.inductor_meta.get("kernel_name", "triton kernel"),
|
|
args_with_constexprs,
|
|
profiler_kwargs,
|
|
):
|
|
launcher(
|
|
*args_with_constexprs,
|
|
**cloned_kwargs,
|
|
stream=stream,
|
|
)
|
|
|
|
else:
|
|
launcher(
|
|
*args_with_constexprs,
|
|
**cloned_kwargs,
|
|
stream=stream,
|
|
)
|
|
self.restore_args_from_cpu(cpu_copies)
|
|
|
|
# only use profiler when not already in a profiler instance
|
|
if with_profiler and not autograd_profiler._is_profiler_enabled:
|
|
from torch._inductor.utils import do_bench_using_profiling
|
|
|
|
return do_bench_using_profiling(kernel_call, warmup=10, rep=40)
|
|
|
|
if self.device_props.type == "cpu":
|
|
return benchmarker.benchmark_cpu(kernel_call)
|
|
|
|
return benchmarker.benchmark_gpu(kernel_call, rep=40)
|
|
|
|
def copy_args_to_cpu_if_needed(self, *args, **kwargs):
|
|
"""
|
|
To support benchmarking in the presence of mutated args, we need to avoid
|
|
autotuning contanminating them. We try to pass cloned args to the kernel.
|
|
If those clones would increase the peak memory usage, however, we instead
|
|
copy to cpu and restore them after each iteration. Figure out the args
|
|
to be copied and do the copying.
|
|
"""
|
|
if not self.optimize_mem:
|
|
return {}
|
|
|
|
copies = {}
|
|
budget = torch.cuda.max_memory_allocated() - torch.cuda.memory_allocated()
|
|
|
|
def maybe_copy(name, arg):
|
|
if name in self.mutated_arg_names and arg.is_cuda:
|
|
nonlocal budget
|
|
assert isinstance(arg, torch.Tensor)
|
|
required_storage_length = compute_required_storage_length(
|
|
arg.size(),
|
|
arg.stride(),
|
|
0,
|
|
)
|
|
size = required_storage_length * arg.element_size()
|
|
if size > budget:
|
|
cpu_arg = torch.empty_strided(
|
|
(required_storage_length,),
|
|
(1,),
|
|
dtype=arg.dtype,
|
|
device="cpu",
|
|
pin_memory=True,
|
|
)
|
|
cpu_arg.copy_(
|
|
arg.as_strided((required_storage_length,), (1,)),
|
|
non_blocking=True,
|
|
)
|
|
copies[name] = (arg, cpu_arg)
|
|
else:
|
|
budget -= size
|
|
|
|
for name, arg in zip(self.fn.arg_names, args):
|
|
maybe_copy(name, arg)
|
|
|
|
for name, arg in kwargs.items():
|
|
maybe_copy(name, arg)
|
|
|
|
return copies
|
|
|
|
def restore_args_from_cpu(self, cpu_copies):
|
|
for pair in cpu_copies.values():
|
|
arg, cpu_arg = pair
|
|
required_storage_length = compute_required_storage_length(
|
|
arg.size(),
|
|
arg.stride(),
|
|
0,
|
|
)
|
|
arg.as_strided((required_storage_length,), (1,)).copy_(
|
|
cpu_arg, non_blocking=True
|
|
)
|
|
|
|
def reset_to_zero_args(self, *args, **kwargs):
|
|
if not self.reset_to_zero_arg_names:
|
|
return
|
|
for i, arg in enumerate(args):
|
|
if self.fn.arg_names[i] in self.reset_to_zero_arg_names:
|
|
assert isinstance(
|
|
arg,
|
|
torch.Tensor,
|
|
), (
|
|
"self.reset_to_zero_arg_names should only contain valid argument names"
|
|
)
|
|
arg.zero_()
|
|
|
|
for name, arg in kwargs.items():
|
|
if name in self.reset_to_zero_arg_names:
|
|
assert isinstance(
|
|
arg,
|
|
torch.Tensor,
|
|
), (
|
|
"self.reset_to_zero_arg_names should only contain valid argument names"
|
|
)
|
|
arg.zero_()
|
|
|
|
def maybe_clone_args(
|
|
self, exclude: Container[str], *args, **kwargs
|
|
) -> tuple[list[Any], dict[str, Any]]:
|
|
"""
|
|
Prepare new args and kwargs by cloning any in-place buffers
|
|
(that are not in the provided exclusion list), to avoid autotune
|
|
contaminating them. Avoid cloning the other buffers because it
|
|
leads to increased memory usage.
|
|
"""
|
|
from ..compile_fx import clone_preserve_strides
|
|
|
|
def prepare_arg(name, arg):
|
|
if name in self.mutated_arg_names and name not in exclude:
|
|
assert isinstance(arg, torch.Tensor)
|
|
return clone_preserve_strides(arg)
|
|
else:
|
|
return arg
|
|
|
|
cloned_args = [
|
|
prepare_arg(name, arg)
|
|
for name, arg in itertools.zip_longest(self.fn.arg_names[: len(args)], args)
|
|
]
|
|
cloned_kwargs = {name: prepare_arg(name, arg) for name, arg in kwargs.items()}
|
|
return cloned_args, cloned_kwargs
|
|
|
|
def clone_args(self, *args, **kwargs) -> tuple[list[Any], dict[str, Any]]:
|
|
return self.maybe_clone_args(OrderedSet(), *args, **kwargs)
|
|
|
|
def benchmark_all_configs(self, *args, **kwargs):
|
|
with (
|
|
dynamo_timed(
|
|
"CachingAutotuner.benchmark_all_configs",
|
|
log_pt2_compile_event=True,
|
|
metadata={"kernel_name": self.inductor_meta.get("kernel_name")},
|
|
dynamo_compile_column_us="runtime_triton_autotune_time_us",
|
|
compile_id=self.compile_id,
|
|
is_backward=self.is_backward,
|
|
log_waitcounter=True,
|
|
waitcounter_name_override="triton_autotuner",
|
|
),
|
|
# Temporarily disable due to spam
|
|
# compilation_callback.callback_handler.install_callbacks(
|
|
# compilation_callback.CallbackTrigger.TRITON_AUTOTUNING,
|
|
# str(self.compile_id),
|
|
# ),
|
|
):
|
|
timings = {
|
|
launcher: self.bench(launcher, *args, **kwargs)
|
|
for launcher in self.launchers
|
|
}
|
|
|
|
for k, v in timings.items():
|
|
self.coordesc_tuner.cache_benchmark_result(k.config, v)
|
|
|
|
if log.isEnabledFor(logging.DEBUG):
|
|
log.debug("Benchmark all input configs for %s, get:", self.fn.__name__)
|
|
for k, v in timings.items():
|
|
log.debug(
|
|
"%s: %f, nreg %d, nspill %d, #shared-mem %s",
|
|
k.config,
|
|
v,
|
|
k.n_regs,
|
|
k.n_spills,
|
|
k.shared,
|
|
)
|
|
|
|
self.reset_to_zero_args(*args, **kwargs)
|
|
return timings
|
|
|
|
def autotune_to_one_config(self, *args, **kwargs):
|
|
"""Do the actual autotuning"""
|
|
start_time = time.time_ns()
|
|
timings = self.benchmark_all_configs(*args, **kwargs)
|
|
benchmark_time_taken_ns = time.time_ns() - start_time
|
|
self.launchers = [builtins.min(timings, key=timings.get)]
|
|
self.autotune_time_taken_ns = (
|
|
self.precompile_time_taken_ns + benchmark_time_taken_ns
|
|
)
|
|
|
|
# log the best config
|
|
launcher = self.launchers[0]
|
|
log.debug(
|
|
"Best config for %s: %s: %f, nreg %d, nspill %d, #shared-mem %s",
|
|
self.fn.__name__,
|
|
launcher.config,
|
|
timings[launcher],
|
|
launcher.n_regs,
|
|
launcher.n_spills,
|
|
launcher.shared,
|
|
)
|
|
|
|
if self.save_cache_hook:
|
|
self.save_cache_hook(
|
|
launcher.config,
|
|
self.autotune_time_taken_ns,
|
|
triton_cache_hash=launcher.cache_hash,
|
|
)
|
|
|
|
def save_gpu_kernel(self, stream, launcher):
|
|
key = self.inductor_meta.get("kernel_name", None) # unique kernel name
|
|
assert key is not None, "kernel_name can not be None"
|
|
params = {
|
|
"mangled_name": (
|
|
launcher.bin.metadata.name
|
|
if hasattr(launcher.bin.metadata, "name")
|
|
else launcher.bin.metadata["name"]
|
|
),
|
|
"num_warps": (
|
|
launcher.bin.num_warps
|
|
if hasattr(launcher.bin, "num_warps")
|
|
else launcher.bin.metadata.num_warps
|
|
),
|
|
"shared_mem": (
|
|
launcher.bin.shared
|
|
if hasattr(launcher.bin, "shared")
|
|
else launcher.bin.metadata.shared
|
|
),
|
|
"stream": stream,
|
|
# User defined triton kernels will have arbitrary kwarg names
|
|
"config": config_to_dict(launcher.config),
|
|
"inductor_meta": self.inductor_meta,
|
|
"triton_meta": self.triton_meta,
|
|
"def_args": launcher.def_args,
|
|
"call_args": launcher.call_args,
|
|
"global_scratch": launcher.global_scratch,
|
|
}
|
|
from torch._inductor.codecache import CudaKernelParamCache
|
|
|
|
bin_type = {"hip": "hsaco", "xpu": "spv"}.get(self.device_props.type, "cubin")
|
|
binary = launcher.bin.asm[bin_type]
|
|
# Also store asm code which can be used for debugging and generating cpp package
|
|
asm_type = {"hip": "amdgcn", "cuda": "ptx", "xpu": "spv"}.get(
|
|
self.device_props.type, None
|
|
)
|
|
asm = launcher.bin.asm.get(asm_type, None)
|
|
|
|
CudaKernelParamCache.set(key, params, binary, bin_type, asm, asm_type)
|
|
self.cuda_kernel_saved = True
|
|
|
|
def coordinate_descent_tuning(self, launcher, *args, **kwargs):
|
|
"""
|
|
Coordinate descent tuning can be run with or without max-autotune.
|
|
|
|
The only difference between these two is the starting config for coordinate_descent tuning.
|
|
E.g., assuming regular autotune only get one config C1; while max-autotune get 4 configs C1, C2, C3, C4
|
|
and max-autotune figure out C3 is the best.
|
|
|
|
Then if coordinate desecnt tuning is run with max-autotune disabled, it will start from C1;
|
|
while if coordinate descent tuning is run with max-autotune enabled, it will start from C3.
|
|
"""
|
|
if (
|
|
self.heuristic_type == HeuristicType.TEMPLATE
|
|
or self.heuristic_type == HeuristicType.USER_AUTOTUNE
|
|
):
|
|
# skip triton template
|
|
return launcher
|
|
|
|
config2launcher = {launcher.config: launcher}
|
|
|
|
# TODO: should we just load the kernels ahead of time if we know we're going to call this?
|
|
if self.fn.fn is None:
|
|
"""
|
|
We are in the parent process, while this program was compiled in a worker
|
|
and the fn was dropped in prepare_for_pickle(). We haven't loaded the module
|
|
containing the real fn yet.
|
|
"""
|
|
assert hasattr(self, "_reload_kernel")
|
|
assert callable(self._reload_kernel)
|
|
self.fn = self._reload_kernel().fn
|
|
|
|
def benchmark_one_config(config):
|
|
with self.lock:
|
|
launcher = self._precompile_config(config).make_launcher()
|
|
config2launcher[config] = launcher
|
|
|
|
out = self.bench(launcher, *args, **kwargs)
|
|
log.debug(
|
|
"COORDESC: %s: %f, nreg %d, nspill %d, #shared-mem %d",
|
|
launcher.config,
|
|
out,
|
|
launcher.n_regs,
|
|
launcher.n_spills,
|
|
launcher.shared,
|
|
)
|
|
return out
|
|
|
|
assert not (
|
|
self.heuristic_type == HeuristicType.PERSISTENT_REDUCTION
|
|
and "R0_BLOCK" in launcher.config.kwargs
|
|
), (
|
|
"Coordinate descent tuner relies on the assumption that persistent reduction's triton config does not have R0_BLOCK"
|
|
)
|
|
start_time = time.time_ns()
|
|
best_config = self.coordesc_tuner.autotune(
|
|
benchmark_one_config, launcher.config, None
|
|
)
|
|
coordesc_time_taken_ns = time.time_ns() - start_time
|
|
best_config.found_by_coordesc = True
|
|
|
|
if self.save_cache_hook:
|
|
self.save_cache_hook(
|
|
best_config,
|
|
self.autotune_time_taken_ns + coordesc_time_taken_ns,
|
|
found_by_coordesc=True,
|
|
)
|
|
|
|
if best_config not in config2launcher:
|
|
# On a Coordesc cache hit, we might not have loaded the launcher
|
|
# This can happen because PyCodeCache saves CachingAutotuners in memory,
|
|
# even for separate compile IDs (which can have different inputs without changing output code)
|
|
config2launcher[best_config] = self._precompile_config(
|
|
best_config
|
|
).make_launcher()
|
|
return config2launcher[best_config]
|
|
|
|
def get_profiler_kwargs(self, stream, launcher):
|
|
kernel_kwargs_str = ",".join(
|
|
f"{k}={v}" for (k, v) in launcher.config.kwargs.items()
|
|
)
|
|
|
|
ret = {
|
|
"kernel_file": (self.filename or ""),
|
|
"kernel_hash": self.kernel_hash,
|
|
"kernel_backend": "triton",
|
|
"stream": stream,
|
|
"num_warps": launcher.config.num_warps,
|
|
"num_stages": launcher.config.num_stages,
|
|
"kernel_kwargs": kernel_kwargs_str,
|
|
}
|
|
if "kernel_name" in self.inductor_meta:
|
|
ret["kernel_name"] = self.inductor_meta["kernel_name"]
|
|
if "kernel_flop" in self.inductor_meta:
|
|
ret["kernel_flop"] = self.inductor_meta["kernel_flop"]
|
|
if "kernel_num_gb" in self.inductor_meta:
|
|
ret["kernel_num_gb"] = self.inductor_meta["kernel_num_gb"]
|
|
return ret
|
|
|
|
def run(
|
|
self,
|
|
*args,
|
|
stream,
|
|
benchmark_run=False,
|
|
**kwargs,
|
|
): # type:ignore[override]
|
|
if hasattr(triton, "set_allocator"):
|
|
|
|
def alloc_fn(size: int, align: int, stream: Optional[int]):
|
|
return torch.empty(
|
|
size, dtype=torch.int8, device=self.device_props.type
|
|
)
|
|
|
|
triton.set_allocator(alloc_fn)
|
|
|
|
if self.triton_interpret:
|
|
args, grid = self._interpret_args_grid(args, self.configs[0])
|
|
return self.fn[grid](
|
|
*args,
|
|
**kwargs,
|
|
**self.configs[0].kwargs,
|
|
)
|
|
|
|
if len(self.launchers) != 1:
|
|
if len(self.launchers) == 0:
|
|
start_time = time.time_ns()
|
|
self.precompile()
|
|
self.precompile_time_taken_ns = time.time_ns() - start_time
|
|
if len(self.launchers) > 1:
|
|
self.autotune_to_one_config(*args, **kwargs)
|
|
|
|
if not getattr(
|
|
self.launchers[0].config, "found_by_coordesc", False
|
|
) and self.inductor_meta.get("coordinate_descent_tuning", False):
|
|
self.launchers = [
|
|
self.coordinate_descent_tuning(self.launchers[0], *args, **kwargs)
|
|
]
|
|
|
|
(launcher,) = self.launchers
|
|
if launcher.store_cubin and (not benchmark_run or not self.cuda_kernel_saved):
|
|
self.save_gpu_kernel(stream, launcher)
|
|
|
|
args = self._get_args_with_constexprs(args, launcher)
|
|
|
|
if self.dump_launch_params:
|
|
new_args, grid = self._interpret_args_grid(args, launcher.config)
|
|
_dump_launch_params(new_args, kwargs, launcher, self.fn.__name__, grid)
|
|
|
|
# it is faster than entering and exiting a context manager, even if the context
|
|
# manager is a nullcontext.
|
|
if autograd_profiler._is_profiler_enabled:
|
|
profiler_kwargs = self.get_profiler_kwargs(stream, launcher)
|
|
|
|
with torch._C._profiler._RecordFunctionFast(
|
|
self.inductor_meta.get("kernel_name", "triton kernel"),
|
|
args,
|
|
profiler_kwargs,
|
|
):
|
|
return launcher(
|
|
*args,
|
|
**kwargs,
|
|
stream=stream,
|
|
)
|
|
else:
|
|
return launcher(
|
|
*args,
|
|
**kwargs,
|
|
stream=stream,
|
|
)
|
|
|
|
def _interpret_args_grid(
|
|
self, args: tuple[Any, ...], cfg: Config
|
|
) -> tuple[tuple[Any, ...], tuple[int, int, int]]:
|
|
grid = GridExpr.from_meta(self.inductor_meta, cfg).eval_slow(
|
|
dict(
|
|
zip(
|
|
[
|
|
*self.triton_meta["signature"].keys(),
|
|
*self.inductor_meta.get("extra_launcher_args", ()),
|
|
],
|
|
args,
|
|
)
|
|
)
|
|
)
|
|
if self.inductor_meta.get("extra_launcher_args"):
|
|
args = args[: -len(self.inductor_meta["extra_launcher_args"])]
|
|
return args, grid
|
|
|
|
|
|
class _ConstRepr:
|
|
def __init__(self, value: str):
|
|
self.value = value
|
|
|
|
def __call__(self, _=None) -> str:
|
|
return self.value
|
|
|
|
|
|
class CompileResult(Generic[_T]):
|
|
def __init__(
|
|
self,
|
|
kernel: _T,
|
|
config: Config,
|
|
compile_meta: dict[str, Any],
|
|
inductor_meta: dict[str, Any],
|
|
):
|
|
self.kernel = kernel
|
|
self.config = config
|
|
self.compile_meta = compile_meta
|
|
self.inductor_meta = inductor_meta
|
|
|
|
def make_launcher(self) -> LauncherType: ...
|
|
|
|
def _gen_launcher_code(self, scope, def_args, runner_args) -> LauncherType:
|
|
grid = GridExpr.from_meta(self.inductor_meta, self.config)
|
|
# grid.prefix is usually empty, grid.x_grid is something like `-(xnumel//-1024)`
|
|
lines = [
|
|
f"def launcher({', '.join(def_args)}, stream):",
|
|
*[f" {line}" for line in grid.prefix],
|
|
f" grid_0 = {grid.x_grid}",
|
|
f" grid_1 = {grid.y_grid}",
|
|
f" grid_2 = {grid.z_grid}",
|
|
f" runner({', '.join(runner_args)})",
|
|
]
|
|
launcher_code = "\n".join(lines)
|
|
exec(launcher_code, scope)
|
|
return scope["launcher"]
|
|
|
|
def _get_arg_lists(
|
|
self, arg_names, constexprs
|
|
) -> tuple[list[str], list[str], OrderedSet[str]]:
|
|
"""
|
|
Return a bunch of intermediate lists of args needed for generating
|
|
launcher code.
|
|
"""
|
|
compile_meta = self.compile_meta
|
|
cfg = self.config
|
|
known_constants = OrderedSet(
|
|
arg for i, arg in enumerate(arg_names) if i in constexprs
|
|
)
|
|
|
|
"""
|
|
https://github.com/pytorch/pytorch/issues/115344
|
|
|
|
self.fn.constexprs doesn't properly deal with None args, so when we filter out
|
|
an arg in UserDefinedTritonKernel.codegen, we need to filter it here as well.
|
|
We also don't want to modify self.fn.
|
|
|
|
We know that we removed something from the signature if:
|
|
1. It's in compile_meta["constants"]
|
|
2. It isn't a constant we already know about
|
|
Note: The value of interest has already been added to compile_meta['constants'],
|
|
so we use self.fn.constexprs instead.
|
|
3. It isn't in the compile_meta signature
|
|
"""
|
|
none_args = OrderedSet(
|
|
k
|
|
for k, v in compile_meta["constants"].items()
|
|
if v is None and k not in known_constants
|
|
)
|
|
none_args = none_args.difference(OrderedSet(compile_meta["signature"].keys()))
|
|
|
|
if triton_version_uses_attrs_dict():
|
|
call_args = arg_names
|
|
def_args = arg_names
|
|
if (
|
|
"num_warps" in compile_meta["constants"]
|
|
or "num_stages" in compile_meta["constants"]
|
|
):
|
|
# num_warps/num_stages are special implicit args that are not in the signature
|
|
# see test_triton_kernel_special_params
|
|
def_args = [
|
|
arg for arg in def_args if arg not in ("num_warps", "num_stages")
|
|
]
|
|
repl = {
|
|
k: str(compile_meta["constants"].get(k))
|
|
for k in ("num_warps", "num_stages")
|
|
}
|
|
call_args = [repl.get(arg, arg) for arg in call_args]
|
|
else:
|
|
call_args = [
|
|
arg
|
|
for i, arg in enumerate(arg_names)
|
|
if i not in constexprs and arg not in none_args
|
|
]
|
|
cfg_dict = config_to_dict(cfg)
|
|
def_args = [
|
|
name
|
|
for name in arg_names
|
|
if name not in cfg_dict and name not in none_args
|
|
]
|
|
|
|
if "extra_launcher_args" in self.inductor_meta:
|
|
def_args = [*def_args, *self.inductor_meta["extra_launcher_args"]]
|
|
|
|
return call_args, def_args, none_args
|
|
|
|
|
|
class CannotStaticallyLaunchKernel(Exception):
|
|
pass
|
|
|
|
|
|
class StaticTritonCompileResult(CompileResult[StaticallyLaunchedCudaKernel]):
|
|
"""
|
|
TritonCompileResult that uses StaticCudaLauncher,
|
|
which vastly simplifies the setup and metadata needed to be kept.
|
|
"""
|
|
|
|
@staticmethod
|
|
def can_statically_launch(
|
|
kernel: CompiledKernel,
|
|
inductor_meta: dict[str, Any],
|
|
triton_meta: dict[str, Any],
|
|
heuristic_type: HeuristicType,
|
|
) -> Optional[StaticallyLaunchedCudaKernel]:
|
|
if not torch._inductor.config.use_static_cuda_launcher:
|
|
return None
|
|
|
|
def check_can_launch() -> StaticallyLaunchedCudaKernel:
|
|
if triton_meta.get("device_type", None) != "cuda":
|
|
# Only cuda kernels
|
|
raise CannotStaticallyLaunchKernel("Non-cuda device")
|
|
|
|
if torch._inductor.config.cpp_wrapper:
|
|
# If we're running with cpp wrapper, it doesn't
|
|
# make sense to statically compile since everything
|
|
# is codegenned anyway
|
|
raise CannotStaticallyLaunchKernel("Cpp wrapper enabled")
|
|
|
|
if (
|
|
heuristic_type == HeuristicType.USER_AUTOTUNE
|
|
and not torch._inductor.config.static_launch_user_defined_triton_kernels
|
|
):
|
|
# Don't support user defined triton kernels yet
|
|
raise CannotStaticallyLaunchKernel("User defined triton kernel")
|
|
|
|
if inductor_meta.get("store_cubin", None):
|
|
# Requires storing the entire binary
|
|
raise CannotStaticallyLaunchKernel("store_cubin is enabled")
|
|
|
|
cubin_location = os.path.join(
|
|
triton_cache_dir(triton_meta.get("device", 0)),
|
|
triton_hash_to_path_key(kernel.hash),
|
|
f"{kernel.src.fn.__name__}.cubin",
|
|
)
|
|
|
|
if not os.path.exists(cubin_location):
|
|
raise CannotStaticallyLaunchKernel(
|
|
f"Cubin path not found: {cubin_location}"
|
|
)
|
|
|
|
else:
|
|
kernel._cubin_path = cubin_location
|
|
|
|
try:
|
|
static_kernel = StaticallyLaunchedCudaKernel(kernel)
|
|
except NotImplementedError as e:
|
|
raise CannotStaticallyLaunchKernel(f"NotImplemented: {str(e)}") from e
|
|
|
|
return static_kernel
|
|
|
|
try:
|
|
result = check_can_launch()
|
|
return result
|
|
except CannotStaticallyLaunchKernel as e:
|
|
log.info("Bypassing StaticallyLaunchedCudaKernel due to %s", str(e))
|
|
if torch._inductor.config.strict_static_cuda_launcher:
|
|
raise e
|
|
return None
|
|
|
|
def reload_cubin_path(self):
|
|
"""
|
|
When loading from cache on disk, we want to reload cubin
|
|
files from their appropriate location on disc.
|
|
"""
|
|
cubin_location = os.path.join(
|
|
triton_cache_dir(self.compile_meta.get("device", 0)),
|
|
triton_hash_to_path_key(self.kernel.hash),
|
|
f"{self.kernel.name}.cubin",
|
|
)
|
|
if not os.path.exists(cubin_location):
|
|
if self.kernel.cubin_raw is not None:
|
|
# We saved the raw cubin, so write it to he appropriate location
|
|
self.kernel.reload_cubin_from_raw(cubin_location)
|
|
else:
|
|
raise RuntimeError(
|
|
"Cubin file saved by TritonBundler not found at %s", cubin_location
|
|
)
|
|
self.kernel.cubin_path = cubin_location
|
|
|
|
def make_launcher(self) -> LauncherType:
|
|
# If at least one static make_launcher call occurs,
|
|
# we're sure static cuda launcher was used for this compile
|
|
set_feature_use("static_cuda_launcher", True)
|
|
# Load the binary on the parent
|
|
if not self.kernel.cubin_path:
|
|
self.reload_cubin_path()
|
|
device = self.compile_meta.get("device", 0)
|
|
if device is None:
|
|
device = 0
|
|
self.kernel.load_kernel(device)
|
|
scope = {
|
|
"runner": self.kernel.run,
|
|
}
|
|
|
|
# NOTE: Constexpr handling for triton and static cuda launcher
|
|
|
|
# Triton kernels have two types of constexprs: *declared* ones, which are ones the user
|
|
# has explicitly declared as tl.constexpr, and *implied* ones, which are expressions triton
|
|
# deems constant while compiling/analyzing the code (i.e. unused parameters, for example)
|
|
|
|
# Triton kernels handle constexprs slightly differently depending on which version of triton
|
|
# we care about (we support 3.2.0 and 3.3.0).
|
|
|
|
# In 3.2.0, triton kernels do not require passing any declared constexprs into the kernel
|
|
# In 3.3.0, triton kernels require all declared constexprs be passed into the kernel, where
|
|
# they are subsequently ignored.
|
|
# When statically launching, since we're launching from the triton generated cubin, we actually want to
|
|
# always get rid of all const exprs, declared or implied, since the underlying cubin file has all
|
|
# of the constants stripped away anyway.
|
|
|
|
# But CachingAutotuner.run will pass us a different number of arguments depending on
|
|
# whether or not we're in triton 3.2.0 or later, so we grab def_args with the same logic
|
|
# as the (non static) TritonCompileResult. We then generate call_args ourselves, since we
|
|
# want only a subset of the arguments passed to triton.
|
|
# Here, arg_names is exactly fn.src.arg_names and declared_constexprs is exactly fn.src.constexprs,
|
|
# which matches behavior with regular TritonCompileResult
|
|
_, def_args, none_args = self._get_arg_lists(
|
|
self.kernel.arg_names, self.kernel.declared_constexprs
|
|
)
|
|
|
|
call_args = [
|
|
arg
|
|
for i, arg in enumerate(self.kernel.arg_names)
|
|
if i not in self.kernel.full_constexprs and arg not in none_args
|
|
]
|
|
|
|
# StaticallyLaunchedCudaKernel.run takes in order grid_0, grid_1, grid_2, stream, and call_args
|
|
runner_args = ["grid_0", "grid_1", "grid_2", "stream", *call_args]
|
|
launcher = self._gen_launcher_code(scope, def_args, runner_args)
|
|
launcher.config = self.config # type: ignore[attr-defined]
|
|
launcher.n_regs = self.kernel.n_regs # type: ignore[attr-defined]
|
|
launcher.n_spills = self.kernel.n_spills # type: ignore[attr-defined]
|
|
launcher.shared = self.kernel.shared # type: ignore[attr-defined]
|
|
launcher.cache_hash = triton_hash_to_path_key(self.kernel.hash) # type: ignore[attr-defined]
|
|
launcher.store_cubin = False # type: ignore[attr-defined]
|
|
launcher._is_static = True # type: ignore[attr-defined]
|
|
return launcher
|
|
|
|
|
|
class TritonCompileResult(CompileResult[CompiledKernel]):
|
|
"""
|
|
Upstream Triton CompileKernel can not be pickled. This is a wrapper
|
|
to support serialization and generate the launcher function.
|
|
"""
|
|
|
|
@staticmethod
|
|
@functools.lru_cache(32)
|
|
def _kernel_metadata_cls(fields: tuple[str, ...]) -> Any:
|
|
return namedtuple("KernelMetadata", sorted(fields))
|
|
|
|
@staticmethod
|
|
def _serialize_metadata(metadata):
|
|
"""
|
|
Triton uses a nested class called KernelMetadata to store metadata information.
|
|
Pickle does not work well with nested namedtuples, as the namedtuple doesn't appear
|
|
in the toplevel namespace of the module. So these serialization/deser functions
|
|
are used to convert the namedtuples to a dict and back.
|
|
|
|
As for packed_metadata, depending on the triton backend, KernelMetadata can be
|
|
a namedtuple, or a regular tuple! So the serialization function branches on whether
|
|
the metadata to be serialized is a namedtuple or regular, serializable one.
|
|
"""
|
|
|
|
def is_namedtuple(obj) -> bool:
|
|
return (
|
|
isinstance(obj, tuple)
|
|
and hasattr(obj, "_asdict")
|
|
and hasattr(obj, "_fields")
|
|
)
|
|
|
|
if is_namedtuple(metadata):
|
|
return metadata._asdict()
|
|
else:
|
|
return metadata
|
|
|
|
@staticmethod
|
|
def _deserialize_metadata(metadata):
|
|
if isinstance(metadata, dict):
|
|
return TritonCompileResult._kernel_metadata_cls(tuple(metadata.keys()))(
|
|
**metadata
|
|
)
|
|
else:
|
|
return metadata
|
|
|
|
def __getstate__(self) -> dict[str, Any]:
|
|
kernel = self.kernel
|
|
# replace the fields that don't pickle nicely
|
|
kernel_state = {
|
|
**kernel.__dict__,
|
|
# See doc about serializing metadata above
|
|
"metadata": self._serialize_metadata(kernel.metadata),
|
|
"packed_metadata": self._serialize_metadata(
|
|
getattr(kernel, "packed_metadata", None)
|
|
),
|
|
"module": None, # regenerated by kernel._init_handles()
|
|
"function": None, # regenerated by kernel._init_handles()
|
|
"run": None, # regenerated by kernel._init_handles()
|
|
}
|
|
return {**self.__dict__, "kernel": kernel_state} # type: ignore[dict-item]
|
|
|
|
def __setstate__(self, state: dict[str, Any]) -> None:
|
|
# src = ASTSource.__new__(ASTSource)
|
|
# src.__setstate__(state["kernel"]["src"])
|
|
# TODO(jansel): need to fixup src.fn which is now None
|
|
kernel = CompiledKernel.__new__(CompiledKernel)
|
|
metadata = state["kernel"]["metadata"]
|
|
packed_metadata = state["kernel"]["packed_metadata"]
|
|
kernel.__dict__.update(
|
|
{
|
|
**state["kernel"],
|
|
# "src": src,
|
|
"metadata": self._deserialize_metadata(metadata),
|
|
"packed_metadata": self._deserialize_metadata(packed_metadata),
|
|
}
|
|
)
|
|
self.__dict__.update(state)
|
|
self.kernel = kernel
|
|
|
|
def make_launcher(self) -> LauncherType:
|
|
"""
|
|
Launching triton kernels is performance sensitive, we compile
|
|
a custom Python function get the grid() and reorder the args to
|
|
the underlying wrapper.
|
|
"""
|
|
cfg = self.config
|
|
compile_meta = self.compile_meta
|
|
binary = self.kernel
|
|
fn = binary.src.fn
|
|
binary._init_handles()
|
|
(call_args, def_args, none_args) = self._get_arg_lists(
|
|
fn.arg_names, fn.constexprs
|
|
)
|
|
binary_shared = (
|
|
binary.shared if hasattr(binary, "shared") else binary.metadata.shared
|
|
)
|
|
|
|
if knobs is None:
|
|
launch_enter = binary.__class__.launch_enter_hook
|
|
launch_exit = binary.__class__.launch_exit_hook
|
|
else:
|
|
launch_enter = knobs.runtime.launch_enter_hook
|
|
launch_exit = knobs.runtime.launch_exit_hook
|
|
|
|
import math as math_lib
|
|
|
|
import torch as torch_lib
|
|
|
|
scope = {
|
|
"grid_meta": cfg.kwargs,
|
|
"bin": binary,
|
|
"launch_enter_hook": launch_enter,
|
|
"launch_exit_hook": launch_exit,
|
|
"metadata": (
|
|
binary.packed_metadata
|
|
if hasattr(binary, "packed_metadata")
|
|
else binary.metadata
|
|
),
|
|
"shared": binary_shared,
|
|
"num_warps": (
|
|
binary.num_warps
|
|
if hasattr(binary, "num_warps")
|
|
else binary.metadata.num_warps
|
|
),
|
|
"cta_args": (
|
|
(
|
|
binary.num_ctas,
|
|
*get_first_attr(binary, "cluster_dims", "clusterDims"),
|
|
)
|
|
if hasattr(binary, "num_ctas")
|
|
else (
|
|
(binary.metadata.num_ctas, *binary.metadata.cluster_dims)
|
|
if hasattr(binary, "metadata")
|
|
else ()
|
|
)
|
|
),
|
|
"function": get_first_attr(binary, "function", "cu_function"),
|
|
"runner": get_first_attr(binary, "run", "c_wrapper"),
|
|
"math": math_lib,
|
|
"torch": torch_lib,
|
|
}
|
|
|
|
if not hasattr(binary, "launch_metadata"):
|
|
# launch args before CompiledKernel.launch_metadata is added.
|
|
# TODO(jansel): delete this branch in mid-2025
|
|
runner_args = [
|
|
"grid_0",
|
|
"grid_1",
|
|
"grid_2",
|
|
"num_warps",
|
|
"*cta_args",
|
|
"shared",
|
|
"stream",
|
|
"function",
|
|
"launch_enter_hook",
|
|
"launch_exit_hook",
|
|
"metadata",
|
|
*call_args,
|
|
]
|
|
else: # args after CompiledKernel.launch_metadata: https://github.com/triton-lang/triton/pull/3492
|
|
# Getting the kernel launch args is extremely perf-sensitive. Evaluating
|
|
# `bin.launch_metadata` is relatively expensive, and returns None unless a
|
|
# `launch_enter_hook` is installed. So if we don't have that hook installed,
|
|
# we want to burn None in to the launch args with zero overhead.
|
|
# See https://github.com/pytorch/pytorch/issues/123597
|
|
if launch_enter:
|
|
launch_metadata = f"bin.launch_metadata((grid_0, grid_1, grid_2), stream, {', '.join(call_args)})"
|
|
else:
|
|
launch_metadata = "None"
|
|
runner_args = [
|
|
"grid_0",
|
|
"grid_1",
|
|
"grid_2",
|
|
"stream",
|
|
"function",
|
|
"metadata",
|
|
launch_metadata,
|
|
"launch_enter_hook",
|
|
"launch_exit_hook",
|
|
*call_args,
|
|
]
|
|
|
|
launcher = self._gen_launcher_code(scope, def_args, runner_args)
|
|
|
|
launcher = scope["launcher"]
|
|
launcher.config = cfg
|
|
launcher.n_regs = getattr(binary, "n_regs", None)
|
|
launcher.n_spills = getattr(binary, "n_spills", None)
|
|
launcher.shared = binary_shared
|
|
launcher.cache_hash = triton_hash_to_path_key(binary.hash)
|
|
launcher.store_cubin = self.inductor_meta.get("store_cubin", False)
|
|
# store this global variable to avoid the high overhead of reading it when calling run
|
|
if launcher.store_cubin:
|
|
launcher.fn = fn
|
|
launcher.bin = binary
|
|
if triton_version_uses_attrs_dict():
|
|
# arg filtering wasn't done above
|
|
cfg_dict = config_to_dict(cfg)
|
|
def_args = [x for x in def_args if x not in cfg_dict]
|
|
call_args = [
|
|
x
|
|
for x in call_args
|
|
if compile_meta["signature"].get(x, "constexpr") != "constexpr"
|
|
and x not in none_args
|
|
]
|
|
launcher.def_args = def_args
|
|
launcher.call_args = call_args
|
|
kernel_metadata = getattr(self.kernel, "metadata", None)
|
|
launcher.global_scratch = getattr(
|
|
kernel_metadata, "global_scratch_size", None
|
|
)
|
|
return launcher
|
|
|
|
|
|
def _find_names(obj):
|
|
import gc
|
|
import inspect
|
|
|
|
frame = inspect.currentframe()
|
|
while frame is not None:
|
|
frame.f_locals
|
|
frame = frame.f_back
|
|
obj_names = []
|
|
for referrer in gc.get_referrers(obj):
|
|
if isinstance(referrer, dict):
|
|
for k, v in referrer.items():
|
|
if v is obj:
|
|
obj_names.append(k)
|
|
return obj_names
|
|
|
|
|
|
collected_calls: list[Any] = []
|
|
|
|
|
|
def start_graph():
|
|
collected_calls.clear()
|
|
|
|
|
|
def end_graph(output_file):
|
|
if len(collected_calls) == 0:
|
|
return
|
|
overall_time = sum(call[0] for call in collected_calls)
|
|
overall_gb = sum(call[1] for call in collected_calls)
|
|
cur_file = inspect.stack()[1].filename
|
|
summary_str = (
|
|
f"SUMMARY ({cur_file})\n"
|
|
f"{overall_time:.2f}ms \t {overall_gb:.2f} GB\t {overall_gb / (overall_time / 1e3):.2f}GB/s"
|
|
)
|
|
log.info(
|
|
"%s",
|
|
summary_str,
|
|
)
|
|
if output_file is not None:
|
|
# sort perf numbers in descending order, i.e. placing the
|
|
# most runtime-heavy kernels at the top of the list
|
|
sorted_calls = sorted(collected_calls, key=lambda c: float(c[0]), reverse=True)
|
|
try:
|
|
with open(output_file, "a") as file:
|
|
log.info(
|
|
"Save profile bandwidth results to %s",
|
|
output_file,
|
|
)
|
|
file.write("====================\n")
|
|
file.write(f"TRITON KERNELS BANDWIDTH INFO ({cur_file})\n")
|
|
for ms, num_gb, gb_per_s, kernel_name in sorted_calls:
|
|
# also display the runtime percentage for each kernel
|
|
percentage = f"{ms / overall_time * 100:.2f}%"
|
|
suffix = f" \t {percentage} \t {kernel_name}"
|
|
bw_info_str = create_bandwidth_info_str(
|
|
ms,
|
|
num_gb,
|
|
gb_per_s,
|
|
suffix=suffix,
|
|
color=False,
|
|
)
|
|
file.write(bw_info_str + "\n")
|
|
file.write(f"{summary_str}\n\n")
|
|
except Exception as e:
|
|
log.warning(
|
|
"failed to write profile bandwidth result into %s: %s",
|
|
output_file,
|
|
e,
|
|
)
|
|
|
|
|
|
class DebugAutotuner(CachingAutotuner):
|
|
def __init__(
|
|
self,
|
|
*args,
|
|
regex_filter="",
|
|
with_profiler=False,
|
|
with_bandwidth_info=True,
|
|
**kwargs,
|
|
):
|
|
self.regex_filter = regex_filter
|
|
self.with_profiler = with_profiler
|
|
self.with_bandwidth_info = with_bandwidth_info
|
|
super().__init__(*args, **kwargs)
|
|
self.cached = None
|
|
|
|
def run(self, *args, stream, **kwargs):
|
|
if not self.with_bandwidth_info:
|
|
super().run(*args, stream=stream, **kwargs, benchmark_run=True)
|
|
return
|
|
else:
|
|
possible_names = _find_names(self)
|
|
kernel_name = f"{max(possible_names, key=len)}"
|
|
if not re.match(self.regex_filter, kernel_name):
|
|
return
|
|
|
|
if len(self.launchers) != 1:
|
|
if len(self.launchers) == 0:
|
|
start_time = time.time_ns()
|
|
self.precompile()
|
|
self.precompile_time_taken_ns = time.time_ns() - start_time
|
|
if len(self.launchers) > 1:
|
|
self.autotune_to_one_config(*args, **kwargs)
|
|
(launcher,) = self.launchers
|
|
|
|
if launcher.store_cubin:
|
|
self.save_gpu_kernel(stream, launcher)
|
|
|
|
if self.cached is None:
|
|
ms = self.bench(launcher, *args, with_profiler=self.with_profiler)
|
|
num_in_out_ptrs = len(
|
|
[
|
|
arg_name
|
|
for arg_name in self.fn.arg_names
|
|
if arg_name.startswith("in_out_ptr")
|
|
]
|
|
)
|
|
num_gb = self.inductor_meta.get("kernel_num_gb", None)
|
|
if num_gb is None:
|
|
num_gb = get_num_bytes(*args, num_in_out_args=num_in_out_ptrs) / 1e9
|
|
gb_per_s = num_gb / (ms / 1e3)
|
|
self.cached = ms, num_gb, gb_per_s, kernel_name
|
|
collected_calls.append((ms, num_gb, gb_per_s, kernel_name))
|
|
log.info(
|
|
"%s",
|
|
create_bandwidth_info_str(
|
|
ms, num_gb, gb_per_s, suffix=f" \t {kernel_name}"
|
|
),
|
|
)
|
|
else:
|
|
# in AOTI, we will call the kernel and its timing info has been cached already
|
|
collected_calls.append(self.cached)
|
|
|
|
|
|
def hash_configs(configs: list[Config]):
|
|
"""
|
|
Hash used to check for changes in configurations
|
|
"""
|
|
hasher = hashlib.sha256()
|
|
for cfg in configs:
|
|
hasher.update(
|
|
f"{sorted(cfg.kwargs.items())} {cfg.num_warps} {cfg.num_stages}\n".encode()
|
|
)
|
|
return hasher.hexdigest()
|
|
|
|
|
|
def cached_autotune(
|
|
size_hints: Optional[list[int]],
|
|
configs: list[Config],
|
|
triton_meta,
|
|
heuristic_type,
|
|
filename=None,
|
|
inductor_meta=None,
|
|
custom_kernel=False,
|
|
):
|
|
"""
|
|
A copy of triton.autotune that calls our subclass. Our subclass
|
|
has additional debugging, error handling, and on-disk caching.
|
|
"""
|
|
configs = unique_configs(configs)
|
|
assert len(configs) == 1 or filename
|
|
inductor_meta = {} if inductor_meta is None else inductor_meta
|
|
|
|
configs, autotune_cache, autotune_cache_info = check_autotune_cache(
|
|
configs, filename, inductor_meta
|
|
)
|
|
mutated_arg_names = inductor_meta.pop("mutated_arg_names", ())
|
|
optimize_mem = inductor_meta.pop("optimize_mem", True)
|
|
|
|
if "restore_value" in triton_meta:
|
|
mutated_arg_names += triton_meta.pop("restore_value")
|
|
|
|
reset_to_zero_arg_names: list[str] = []
|
|
if "reset_to_zero" in triton_meta:
|
|
reset_to_zero_arg_names.extend(triton_meta.pop("reset_to_zero"))
|
|
|
|
def decorator(fn):
|
|
# Remove XBLOCK from config if it's not a function argument.
|
|
# This way, coordinate descent tuning will not try to tune it.
|
|
#
|
|
# Context: When TritonKernel.no_x_dim is True, we hardcode XBLOCK to 1.
|
|
import inspect
|
|
|
|
if "XBLOCK" not in inspect.signature(fn.fn).parameters:
|
|
for tconfig in configs:
|
|
if "XBLOCK" in tconfig.kwargs:
|
|
assert tconfig.kwargs["XBLOCK"] == 1
|
|
tconfig.kwargs.pop("XBLOCK")
|
|
|
|
if inductor_meta.get("profile_bandwidth"):
|
|
return DebugAutotuner(
|
|
fn,
|
|
triton_meta=triton_meta,
|
|
inductor_meta=inductor_meta,
|
|
regex_filter=inductor_meta["profile_bandwidth_regex"],
|
|
with_profiler=inductor_meta[
|
|
"profile_bandwidth_with_do_bench_using_profiling"
|
|
],
|
|
configs=configs,
|
|
save_cache_hook=autotune_cache and autotune_cache.save,
|
|
mutated_arg_names=mutated_arg_names,
|
|
reset_to_zero_arg_names=reset_to_zero_arg_names,
|
|
optimize_mem=optimize_mem,
|
|
heuristic_type=heuristic_type,
|
|
size_hints=size_hints,
|
|
custom_kernel=custom_kernel,
|
|
filename=filename,
|
|
with_bandwidth_info=True,
|
|
)
|
|
return CachingAutotuner(
|
|
fn,
|
|
triton_meta=triton_meta,
|
|
inductor_meta=inductor_meta,
|
|
configs=configs,
|
|
save_cache_hook=autotune_cache and autotune_cache.save,
|
|
mutated_arg_names=mutated_arg_names,
|
|
reset_to_zero_arg_names=reset_to_zero_arg_names,
|
|
optimize_mem=optimize_mem,
|
|
heuristic_type=heuristic_type,
|
|
size_hints=size_hints,
|
|
custom_kernel=custom_kernel,
|
|
filename=filename,
|
|
autotune_cache_info=autotune_cache_info,
|
|
)
|
|
|
|
return decorator
|
|
|
|
|
|
def unique_configs(configs: list[Config]):
|
|
"""Remove duplicate configurations"""
|
|
seen: OrderedSet[Hashable] = OrderedSet()
|
|
pruned_configs = []
|
|
|
|
for cfg in configs:
|
|
key = triton_config_to_hashable(cfg)
|
|
if key not in seen:
|
|
seen.add(key)
|
|
pruned_configs.append(cfg)
|
|
return pruned_configs
|
|
|
|
|
|
def check_config(cfg, *, xnumel=None, ynumel=None, znumel=None):
|
|
for numel, label in zip((xnumel, ynumel, znumel), "XYZ"):
|
|
if numel is None:
|
|
continue
|
|
block = cfg[f"{label}BLOCK"]
|
|
if numel == 1:
|
|
assert block == 1, (
|
|
f"TritonKernel.indexing assumes numel == 1 => BLOCK == 1"
|
|
f" but {label.lower()}numel=={numel} and {label}BLOCK={block} (cfg={cfg})."
|
|
)
|
|
max_block = TRITON_MAX_BLOCK[label]
|
|
max_block_str = f'config.triton.max_block["{label}"]'
|
|
assert max_block % block == 0, (
|
|
f"TritonKernel.indexing assumes {label}BLOCK divides {max_block_str}"
|
|
f" but {label}BLOCK={block} and {max_block_str}={max_block} (cfg={cfg})."
|
|
)
|
|
|
|
|
|
def check_max_block(cfg: dict[str, int]):
|
|
"""
|
|
Check that block sizes are within the maximum allowed.
|
|
"""
|
|
for var, val in cfg.items():
|
|
block_suffix = "BLOCK"
|
|
if block_suffix in var:
|
|
prefix = var.removesuffix(block_suffix)
|
|
max_block = TRITON_MAX_BLOCK[prefix]
|
|
assert val <= max_block, (
|
|
f"'{var}' too large. Maximum: {max_block}. Actual: {val}."
|
|
)
|
|
|
|
|
|
def _num_warps(num_warps, max_num_warps=8, min_num_warps=2, register_intensive=False):
|
|
# On AMD GPU each warp has 64 lanes which is double the size on NV GPU,
|
|
# therefore using half the number of warps here correspondingly.
|
|
if torch.version.hip:
|
|
max_num_warps = (max_num_warps + 1) // 2
|
|
min_num_warps = (min_num_warps + 1) // 2
|
|
# persistent reduction is register intensive
|
|
if register_intensive:
|
|
max_num_warps = max_num_warps // 2
|
|
return next_power_of_2(min(max(num_warps, min_num_warps), max_num_warps))
|
|
|
|
|
|
def _check_max_grid_x(size_hints, x, num_warps):
|
|
# Check if maxGridSize is exceeded - if so then must scale XBLOCK further
|
|
max_grid_x = 2147483647
|
|
warp_size = (
|
|
64 if torch.version.hip else 32
|
|
) # TODO: query warp size once #129663 is merged
|
|
num_blocks = (size_hints["x"] + x - 1) // x
|
|
|
|
while (num_blocks * num_warps * warp_size) > max_grid_x and x < size_hints["x"]:
|
|
x *= 2 # Scale up XBLOCK if grid exceeds limits
|
|
num_blocks = num_blocks // 2
|
|
if (num_blocks * num_warps * warp_size) > max_grid_x:
|
|
raise AssertionError(
|
|
"Reduction config exceeds cudaDeviceProp maxGridSize. Please raise a pytorch issue"
|
|
)
|
|
return x, num_blocks
|
|
|
|
|
|
def triton_config(
|
|
size_hints,
|
|
x,
|
|
y=None,
|
|
z=None,
|
|
num_stages=1,
|
|
num_elements_per_warp=256,
|
|
min_elem_per_thread=0,
|
|
) -> Config:
|
|
"""
|
|
Construct a pointwise triton config with some adjustment heuristics
|
|
based on size_hints. Size_hints is a tuple of numels in each tile
|
|
dimension and will be rounded up to the nearest power of 2.
|
|
|
|
num_elements_per_warp is a suggestion for controlling how many warps
|
|
the triton config should contain. e.g.: if x=16, y=8, z=4 then
|
|
num_elements = 16*8*4 = 512. Then if we set num_elements_per_warp=128,
|
|
we'll launch 512 (elem) / 128 (elem/warp) = 4 warps. Note that it's
|
|
just a suggestion, and sometimes other adjustment heuristics will
|
|
override the num_elements_per_warp.
|
|
|
|
min_elem_per_thread controls the minimum number of elements
|
|
processed by each thread. It's always enforced.
|
|
"""
|
|
# Ideally we want to read this from some device config
|
|
|
|
maxGridSize = [2147483647, 65535, 65535]
|
|
|
|
target = conditional_product(x, y, z)
|
|
if conditional_product(*size_hints.values()) < target:
|
|
target //= 8
|
|
|
|
# shrink sizes to size hints
|
|
x = min(x, size_hints["x"])
|
|
if y:
|
|
y = min(y, size_hints["y"])
|
|
if z:
|
|
z = min(z, size_hints["z"])
|
|
|
|
# if we are below original block size, scale up where we can;
|
|
# or if the calculated grid size is larger than the limit, we bump up the corresponding dimension
|
|
while x < min(size_hints["x"], TRITON_MAX_BLOCK["X"]) and (
|
|
x * maxGridSize[0] < size_hints["x"] or conditional_product(x, y, z) < target
|
|
):
|
|
x *= 2
|
|
while (
|
|
y
|
|
and y < min(size_hints["y"], TRITON_MAX_BLOCK["Y"])
|
|
and (
|
|
y * maxGridSize[1] < size_hints["y"]
|
|
or conditional_product(x, y, z) < target
|
|
)
|
|
):
|
|
y *= 2
|
|
while (
|
|
z
|
|
and z < min(size_hints["z"], TRITON_MAX_BLOCK["Z"])
|
|
and (
|
|
z * maxGridSize[2] < size_hints["z"]
|
|
or conditional_product(x, y, z) < target
|
|
)
|
|
):
|
|
z *= 2
|
|
|
|
num_warps = _num_warps(
|
|
conditional_product(x, y, z) // num_elements_per_warp, min_num_warps=1
|
|
)
|
|
# we are going to arrive at 2 warps only if bs was too small due to
|
|
# numel being too small. However to workaround some ptx bugs we still
|
|
# want at least 4 warps if there's enough elements per thread
|
|
# given that this is a rare situation, don't expect this to affect perf
|
|
# in general
|
|
# see https://github.com/pytorch/pytorch/pull/97950
|
|
if conditional_product(x, y, z) >= 128 and not torch.version.hip:
|
|
num_warps = max(num_warps, 4)
|
|
xnumel = size_hints["x"]
|
|
ynumel = size_hints.get("y")
|
|
znumel = size_hints.get("z")
|
|
|
|
# Increase x to satisfy min_elem_per_thread requirements.
|
|
block_size = max(
|
|
conditional_product(x, y, z),
|
|
min_elem_per_thread * _NUM_THREADS_PER_WARP * num_warps,
|
|
)
|
|
x *= math.ceil(block_size / conditional_product(x, y, z))
|
|
|
|
x, _num_blocks = _check_max_grid_x(size_hints, x, num_warps)
|
|
x = min(x, size_hints["x"])
|
|
|
|
cfg = {"XBLOCK": x}
|
|
if y:
|
|
cfg["YBLOCK"] = y
|
|
if z:
|
|
cfg["ZBLOCK"] = z
|
|
check_max_block(cfg)
|
|
check_config(cfg, xnumel=xnumel, ynumel=ynumel, znumel=znumel)
|
|
return Config(cfg, num_warps=num_warps, num_stages=num_stages)
|
|
|
|
|
|
def _get_nd_reduction_numels(r: int, size_hints: dict[str, int]) -> dict[str, int]:
|
|
"""
|
|
Converts a linear reduction numel to ND, in row major order.
|
|
This order is often desirable as it presents opportunities to coalesce memory
|
|
accesses.
|
|
For example, if r = 64 and size_hints = [32,32], this function returns [32, 2].
|
|
This unraveling works because both r and size_hints are powers of 2.
|
|
"""
|
|
# Shrink r to size_hints.
|
|
r = min(r, get_total_reduction_numel(size_hints))
|
|
num_reduction_dims = len(
|
|
[prefix for prefix in size_hints if prefix_is_reduction(prefix)]
|
|
)
|
|
|
|
remaining = r
|
|
rnumels = {}
|
|
for idx in range(num_reduction_dims - 1, -1, -1):
|
|
prefix = f"r{idx}_"
|
|
max_size = min(size_hints[prefix], TRITON_MAX_BLOCK[prefix.upper()])
|
|
dim = min(max_size, remaining)
|
|
assert remaining % dim == 0, (
|
|
f"Expected dimension '{dim}' to divide remaining size '{remaining}'"
|
|
)
|
|
rnumels[prefix] = dim
|
|
remaining //= dim
|
|
|
|
# Sanity check the results.
|
|
final_numel = conditional_product(*rnumels.values())
|
|
assert r == final_numel, (
|
|
f"Expected ND reduction size ({rnumels}) to have {r} elements."
|
|
)
|
|
assert all(rnumels[prefix] <= size_hints[prefix] for prefix in rnumels), (
|
|
f"rnumels exceed size_hints. {rnumels} > {size_hints}"
|
|
)
|
|
|
|
return rnumels
|
|
|
|
|
|
def triton_config_reduction(
|
|
size_hints,
|
|
x: int,
|
|
r: int,
|
|
num_stages=1,
|
|
num_warps=None,
|
|
register_intensive=False,
|
|
) -> Config:
|
|
"""
|
|
Construct a reduction triton config with some adjustment heuristics
|
|
based on size_hints. Size_hints is a tuple of numels in each tile
|
|
dimension and will be rounded up to the nearest power of 2.
|
|
"""
|
|
# Convert the linear reduction numel into a multi-dimensional block.
|
|
rnumels = _get_nd_reduction_numels(r, size_hints)
|
|
|
|
# shrink sizes to size hints
|
|
x = min(x, size_hints["x"])
|
|
|
|
def total_numel() -> int:
|
|
return conditional_product(x, *rnumels.values())
|
|
|
|
target = total_numel()
|
|
if conditional_product(*size_hints.values()) < target:
|
|
target //= 8
|
|
|
|
# if we are below original block size, scale up where we can
|
|
while x < size_hints["x"] and total_numel() < target:
|
|
x *= 2
|
|
for prefix in sorted(rnumels):
|
|
while rnumels[prefix] < size_hints[prefix] and total_numel() < target:
|
|
rnumels[prefix] *= 2
|
|
|
|
if num_warps is None:
|
|
num_warps = total_numel() // 128
|
|
num_warps = _num_warps(
|
|
num_warps, max_num_warps=16, register_intensive=register_intensive
|
|
)
|
|
|
|
x, _num_blocks = _check_max_grid_x(size_hints, x, num_warps)
|
|
|
|
for prefix in sorted(rnumels):
|
|
while total_numel() > target:
|
|
if rnumels[prefix] == 1:
|
|
break
|
|
rnumels[prefix] //= 2
|
|
|
|
cfg = _get_config({"x": x, **rnumels})
|
|
check_max_block(cfg)
|
|
check_config(cfg, xnumel=size_hints["x"])
|
|
return Config(cfg, num_warps=num_warps, num_stages=num_stages)
|
|
|
|
|
|
def _get_config(numels: dict[str, int]) -> dict[str, int]:
|
|
"""
|
|
Convert numels ("x", "r0_", etc.) to block sizes ("XBLOCK", "R0_BLOCK"), etc.
|
|
"""
|
|
|
|
return {prefix.upper() + "BLOCK": numel for prefix, numel in numels.items()}
|
|
|
|
|
|
def triton_config_tiled_reduction(
|
|
size_hints, x, y, r, num_stages=1, register_intensive=False
|
|
):
|
|
"""
|
|
Construct a tile reduction triton config with some adjustment
|
|
heuristics based on size_hints. Size_hints is a tuple of numels in
|
|
each tile dimension and will be rounded up to the nearest power of 2.
|
|
"""
|
|
# Convert the linear reduction numel into a multi-dimensional block.
|
|
rnumels = _get_nd_reduction_numels(r, size_hints)
|
|
|
|
# shrink sizes to size hints
|
|
x = min(x, size_hints["x"])
|
|
y = min(y, size_hints["y"])
|
|
|
|
def total_numel() -> int:
|
|
return conditional_product(x, y, *rnumels.values())
|
|
|
|
target = total_numel()
|
|
if conditional_product(*size_hints.values()) < target:
|
|
target //= 8
|
|
|
|
# if we are below original block size, scale up where we can
|
|
while x < size_hints["x"] and total_numel() < target:
|
|
x *= 2
|
|
for prefix in sorted(rnumels):
|
|
while rnumels[prefix] < size_hints[prefix] and total_numel() < target:
|
|
rnumels[prefix] *= 2
|
|
while y < size_hints["y"] and total_numel() < target:
|
|
y *= 2
|
|
|
|
cfg = _get_config({"x": x, "y": y, **rnumels})
|
|
num_warps = _num_warps(total_numel() // 256, min_num_warps=1)
|
|
num_warps = _num_warps(
|
|
num_warps, max_num_warps=16, register_intensive=register_intensive
|
|
)
|
|
check_config(cfg, xnumel=size_hints["x"], ynumel=size_hints["y"])
|
|
check_max_block(cfg)
|
|
return Config(cfg, num_warps=num_warps, num_stages=num_stages)
|
|
|
|
|
|
def pointwise(
|
|
size_hints,
|
|
triton_meta,
|
|
tile_hint=None,
|
|
filename=None,
|
|
min_elem_per_thread=0,
|
|
inductor_meta=None,
|
|
):
|
|
"""
|
|
Construct @triton.heuristics() based on size_hints.
|
|
"""
|
|
inductor_meta = {} if inductor_meta is None else inductor_meta
|
|
assert not inductor_meta.get("no_x_dim")
|
|
|
|
numel = functools.reduce(operator.mul, size_hints.values())
|
|
bs = max(256, min(numel // 128, 1024))
|
|
|
|
hinted_configs = autotune_hints_to_configs(
|
|
inductor_meta.get("autotune_hints", OrderedSet()),
|
|
size_hints,
|
|
bs,
|
|
triton_meta["device"],
|
|
)
|
|
|
|
triton_config_with_settings = functools.partial(
|
|
triton_config, min_elem_per_thread=min_elem_per_thread
|
|
)
|
|
|
|
configs = None
|
|
if len(size_hints) == 1:
|
|
if disable_pointwise_autotuning(inductor_meta) and not (
|
|
inductor_meta.get("max_autotune")
|
|
or inductor_meta.get("max_autotune_pointwise")
|
|
):
|
|
configs = [triton_config_with_settings(size_hints, bs)]
|
|
else:
|
|
configs = [
|
|
triton_config_with_settings(size_hints, bs, num_elements_per_warp=256),
|
|
triton_config_with_settings(
|
|
size_hints, bs // 2, num_elements_per_warp=64
|
|
),
|
|
*hinted_configs,
|
|
]
|
|
if len(size_hints) == 2:
|
|
if (
|
|
disable_pointwise_autotuning(inductor_meta) or tile_hint == TileHint.SQUARE
|
|
) and not (
|
|
inductor_meta.get("max_autotune")
|
|
or inductor_meta.get("max_autotune_pointwise")
|
|
):
|
|
configs = [triton_config_with_settings(size_hints, 32, 32)]
|
|
else:
|
|
configs = [
|
|
triton_config_with_settings(size_hints, 32, 32),
|
|
triton_config_with_settings(size_hints, 64, 64), # ~8% better for fp16
|
|
triton_config_with_settings(size_hints, 256, 16),
|
|
triton_config_with_settings(size_hints, 16, 256),
|
|
triton_config_with_settings(size_hints, bs, 1),
|
|
triton_config_with_settings(size_hints, 1, bs),
|
|
*hinted_configs,
|
|
]
|
|
if len(size_hints) == 3:
|
|
if disable_pointwise_autotuning(inductor_meta):
|
|
configs = [triton_config_with_settings(size_hints, 16, 16, 16)]
|
|
else:
|
|
configs = [
|
|
triton_config_with_settings(size_hints, 16, 16, 16),
|
|
triton_config_with_settings(size_hints, 64, 8, 8),
|
|
triton_config_with_settings(size_hints, 8, 64, 8),
|
|
triton_config_with_settings(size_hints, 8, 8, 64),
|
|
triton_config_with_settings(size_hints, bs, 1, 1),
|
|
triton_config_with_settings(size_hints, 1, bs, 1),
|
|
triton_config_with_settings(size_hints, 1, 1, bs),
|
|
*hinted_configs,
|
|
]
|
|
|
|
if not configs:
|
|
raise NotImplementedError(f"size_hints: {size_hints}")
|
|
return cached_autotune(
|
|
size_hints,
|
|
configs,
|
|
triton_meta=triton_meta,
|
|
inductor_meta=inductor_meta,
|
|
heuristic_type=HeuristicType.POINTWISE,
|
|
filename=filename,
|
|
)
|
|
|
|
|
|
def _reduction_configs(
|
|
*, size_hints: dict[str, int], inductor_meta: dict[str, Any]
|
|
) -> list[Config]:
|
|
reduction_hint = inductor_meta.get("reduction_hint", None)
|
|
|
|
# Convert reductions to 1D, to simplify heuristics.
|
|
rnumel = get_total_reduction_numel(size_hints)
|
|
|
|
register_intensive = False
|
|
MAX_R0_BLOCK = 2048
|
|
if (
|
|
size_hints["x"] >= 1024
|
|
and inductor_meta.get("num_load", 0) + inductor_meta.get("num_reduction", 0)
|
|
>= 10
|
|
):
|
|
# A heuristics to reduce R0_BLOCK if a kernel potentially need many registers.
|
|
# Consider load and reduction since load need move data into registers and
|
|
# reduction needs an accumulator.
|
|
#
|
|
# The magic numbers are a bit arbitrary.
|
|
#
|
|
# We cannot rely on dynamically scaling down R0_BLOCK later, since sometimes
|
|
# triton makes it to use less registers with worse perf. Check:
|
|
# https://github.com/pytorch/pytorch/issues/126463
|
|
#
|
|
# The heuristic is a very simple one since registers can be reused. But
|
|
# hopefully it can be a good enough indicator.
|
|
MAX_R0_BLOCK = 1024
|
|
register_intensive = True
|
|
|
|
def make_config(x, r, num_warps=None, num_stages=1, register_intensive=False):
|
|
# For 3D case with tiling scores, create an adapted version
|
|
if "y" in size_hints:
|
|
assert "tiling_scores" in inductor_meta
|
|
return adapt_config_for_tiling(
|
|
size_hints,
|
|
inductor_meta["tiling_scores"],
|
|
x,
|
|
r,
|
|
num_warps=num_warps,
|
|
num_stages=num_stages,
|
|
register_intensive=register_intensive,
|
|
)
|
|
else:
|
|
# For other cases, use the original function
|
|
return triton_config_reduction(
|
|
size_hints,
|
|
x,
|
|
r,
|
|
num_warps=num_warps,
|
|
num_stages=num_stages,
|
|
register_intensive=register_intensive,
|
|
)
|
|
|
|
contiguous_config = make_config(
|
|
1,
|
|
min(rnumel, MAX_R0_BLOCK),
|
|
register_intensive=register_intensive,
|
|
)
|
|
outer_config = make_config(64, 8, register_intensive=register_intensive)
|
|
tiny_config = make_config(
|
|
2 * (256 // rnumel) if rnumel <= 256 else 1,
|
|
min(rnumel, MAX_R0_BLOCK),
|
|
register_intensive=register_intensive,
|
|
)
|
|
# For 3d tiling, default to more autotuning initially
|
|
if "y" in size_hints:
|
|
pass
|
|
elif inductor_meta.get("max_autotune") or inductor_meta.get(
|
|
"max_autotune_pointwise"
|
|
):
|
|
pass # skip all these cases
|
|
elif reduction_hint == ReductionHint.INNER:
|
|
return [contiguous_config]
|
|
elif reduction_hint == ReductionHint.OUTER:
|
|
return [outer_config]
|
|
elif reduction_hint == ReductionHint.OUTER_TINY:
|
|
return [tiny_config]
|
|
if disable_pointwise_autotuning(inductor_meta):
|
|
return [make_config(32, 128)]
|
|
return [
|
|
contiguous_config,
|
|
outer_config,
|
|
tiny_config,
|
|
make_config(64, 64),
|
|
make_config(8, 512),
|
|
# halve the XBLOCK/Rn_BLOCK compared to outer_config
|
|
# TODO: this may only be beneficial when each iteration of the reduction
|
|
# is quite heavy. E.g. https://gist.github.com/shunting314/189a8ef69f90db9d614a823385147a72
|
|
make_config(64, 4, num_warps=8),
|
|
]
|
|
|
|
|
|
def match_target_block_product(
|
|
size_hints, tiling_scores, target_block_product, min_block_size=1
|
|
):
|
|
"""
|
|
Distribute block sizes across dimensions according to tiling scores,
|
|
aiming to match a target product of block sizes.
|
|
"""
|
|
total_score = sum(tiling_scores.values())
|
|
if total_score == 0:
|
|
# just assume even score with no minimum block size
|
|
min_block_size = 1
|
|
tiling_scores = dict.fromkeys(tiling_scores.keys(), target_block_product)
|
|
|
|
# First, give each coalescing dimension at least min_block_size
|
|
block_sizes = {}
|
|
relative_scores = {}
|
|
curr_block_product = 1
|
|
|
|
for dim, score in tiling_scores.items():
|
|
if score == 0:
|
|
block_sizes[dim] = 1
|
|
continue
|
|
|
|
block_sizes[dim] = min_block_size
|
|
curr_block_product *= min_block_size
|
|
relative_scores[dim] = score / total_score
|
|
|
|
# Scale up dimensions by their relative scores until we reach the target
|
|
while curr_block_product < target_block_product and len(relative_scores):
|
|
dim, score = max(relative_scores.items(), key=lambda item: item[1])
|
|
|
|
# Check if we've hit the max for this dimension
|
|
if (
|
|
block_sizes[dim] >= TRITON_MAX_BLOCK[dim.capitalize()]
|
|
or block_sizes[dim] >= size_hints[dim]
|
|
):
|
|
del relative_scores[dim]
|
|
continue
|
|
|
|
block_sizes[dim] *= 2
|
|
relative_scores[dim] /= 2
|
|
curr_block_product *= 2
|
|
|
|
return block_sizes
|
|
|
|
|
|
def adapt_config_for_tiling(
|
|
size_hints,
|
|
tiling_scores,
|
|
original_x,
|
|
original_r,
|
|
num_warps=None,
|
|
num_stages=1,
|
|
register_intensive=False,
|
|
persistent_reduction=False,
|
|
) -> Config:
|
|
"""
|
|
Create an adapted configuration based on tiling scores,
|
|
redistributing the same total block size (x * r) according to tiling scores.
|
|
"""
|
|
assert all(s in tiling_scores for s in size_hints)
|
|
target_block_product = original_x * original_r
|
|
block_sizes = match_target_block_product(
|
|
size_hints, tiling_scores, target_block_product
|
|
)
|
|
|
|
return triton_config_tiled_reduction(
|
|
size_hints,
|
|
block_sizes["x"],
|
|
block_sizes["y"],
|
|
block_sizes["r0_"],
|
|
num_stages=num_stages,
|
|
register_intensive=register_intensive,
|
|
)
|
|
|
|
|
|
def reduction(
|
|
size_hints,
|
|
reduction_hint=False,
|
|
triton_meta=None,
|
|
filename=None,
|
|
inductor_meta=None,
|
|
):
|
|
"""args to @triton.heuristics()"""
|
|
inductor_meta = {} if inductor_meta is None else inductor_meta
|
|
inductor_meta["reduction_hint"] = reduction_hint
|
|
if inductor_meta.get("no_x_dim"):
|
|
size_hints["x"] = 1
|
|
|
|
assert triton_meta is not None
|
|
|
|
configs = _reduction_configs(size_hints=size_hints, inductor_meta=inductor_meta)
|
|
return cached_autotune(
|
|
size_hints,
|
|
configs=configs,
|
|
triton_meta=triton_meta,
|
|
inductor_meta=inductor_meta,
|
|
heuristic_type=HeuristicType.REDUCTION,
|
|
filename=filename,
|
|
)
|
|
|
|
|
|
def cooperative_reduction(
|
|
size_hints,
|
|
reduction_hint,
|
|
triton_meta,
|
|
filename,
|
|
inductor_meta,
|
|
):
|
|
inductor_meta = {} if inductor_meta is None else inductor_meta
|
|
inductor_meta["reduction_hint"] = reduction_hint
|
|
if inductor_meta.get("no_x_dim"):
|
|
size_hints["x"] = 1
|
|
|
|
# Cooperative reductions currently only support a single reduction dimension.
|
|
assert len(size_hints) == 2, (
|
|
"Cooperative reductions don't support tiling reduction dims"
|
|
)
|
|
xnumel, rnumel = size_hints["x"], size_hints["r0_"]
|
|
|
|
# TODO(jansel): we should base target on the SM count of the local GPU
|
|
target = 64
|
|
split = max(1, min(target // xnumel, TRITON_MAX_RSPLIT))
|
|
assert rnumel >= split
|
|
assert split <= TRITON_MAX_RSPLIT
|
|
if inductor_meta["persistent_reduction"]:
|
|
configs = _persistent_reduction_configs(
|
|
{"x": xnumel, "r0_": rnumel // split}, reduction_hint, inductor_meta
|
|
)
|
|
else:
|
|
configs = _reduction_configs(
|
|
size_hints={"x": xnumel, "r0_": rnumel // split},
|
|
inductor_meta=inductor_meta,
|
|
)
|
|
for config in configs:
|
|
config.kwargs["RSPLIT"] = split
|
|
# TODO(jansel): add more configs in max_autotune
|
|
|
|
return cached_autotune(
|
|
size_hints,
|
|
configs=configs,
|
|
triton_meta=triton_meta,
|
|
inductor_meta=inductor_meta,
|
|
heuristic_type=HeuristicType.REDUCTION,
|
|
filename=filename,
|
|
)
|
|
|
|
|
|
def _persistent_reduction_configs(
|
|
size_hints,
|
|
reduction_hint=False,
|
|
inductor_meta=None,
|
|
):
|
|
xnumel = size_hints["x"]
|
|
rnumel = get_total_reduction_numel(size_hints)
|
|
|
|
MAX_PERSISTENT_BLOCK_NUMEL = 4096
|
|
|
|
if "y" not in size_hints:
|
|
configs = [
|
|
triton_config_reduction(size_hints, xblock, rnumel, register_intensive=True)
|
|
for xblock in (1, 8, 32, 128)
|
|
if xblock == 1
|
|
or (rnumel * xblock <= MAX_PERSISTENT_BLOCK_NUMEL and xblock <= xnumel)
|
|
]
|
|
else:
|
|
configs = []
|
|
assert "tiling_scores" in inductor_meta
|
|
x_y_scores = {dim: inductor_meta["tiling_scores"][dim] for dim in ("x", "y")}
|
|
for target_block_size in (1, 8, 32, 64, 128):
|
|
if target_block_size * rnumel > MAX_PERSISTENT_BLOCK_NUMEL:
|
|
continue
|
|
|
|
block_sizes = match_target_block_product(
|
|
size_hints, x_y_scores, target_block_size
|
|
)
|
|
configs.append(
|
|
triton_config_tiled_reduction(
|
|
size_hints, block_sizes["x"], block_sizes["y"], rnumel
|
|
)
|
|
)
|
|
|
|
# defer to more autotuning, initially
|
|
if "y" in size_hints:
|
|
pass
|
|
# TODO(jansel): we should be able to improve these heuristics
|
|
elif reduction_hint == ReductionHint.INNER and rnumel >= 256:
|
|
configs = configs[:1]
|
|
elif reduction_hint == ReductionHint.OUTER:
|
|
configs = configs[-1:]
|
|
elif reduction_hint == ReductionHint.OUTER_TINY:
|
|
configs = [
|
|
triton_config_reduction(
|
|
size_hints,
|
|
2 * (256 // rnumel) if rnumel <= 256 else 1,
|
|
rnumel,
|
|
)
|
|
]
|
|
for c in configs:
|
|
# we don't need Rn_BLOCK for persistent reduction
|
|
for prefix in size_hints:
|
|
if prefix_is_reduction(prefix):
|
|
c.kwargs.pop(f"{prefix.upper()}BLOCK")
|
|
|
|
if disable_pointwise_autotuning(inductor_meta):
|
|
configs = configs[:1]
|
|
|
|
return configs
|
|
|
|
|
|
def persistent_reduction(
|
|
size_hints,
|
|
reduction_hint=False,
|
|
triton_meta=None,
|
|
filename=None,
|
|
inductor_meta=None,
|
|
):
|
|
inductor_meta = {} if inductor_meta is None else inductor_meta
|
|
inductor_meta["reduction_hint"] = reduction_hint
|
|
if inductor_meta.get("no_x_dim"):
|
|
size_hints["x"] = 1
|
|
|
|
configs = _persistent_reduction_configs(size_hints, reduction_hint, inductor_meta)
|
|
|
|
return cached_autotune(
|
|
size_hints,
|
|
configs,
|
|
triton_meta=triton_meta,
|
|
inductor_meta=inductor_meta,
|
|
filename=filename,
|
|
heuristic_type=HeuristicType.PERSISTENT_REDUCTION,
|
|
)
|
|
|
|
|
|
def split_scan(
|
|
size_hints,
|
|
reduction_hint=False,
|
|
triton_meta=None,
|
|
filename=None,
|
|
inductor_meta=None,
|
|
):
|
|
"""Heuristic for TritonSplitScanKernel"""
|
|
inductor_meta = {} if inductor_meta is None else inductor_meta
|
|
inductor_meta["reduction_hint"] = reduction_hint
|
|
if inductor_meta.get("no_x_dim"):
|
|
size_hints["x"] = 1
|
|
|
|
assert triton_meta is not None
|
|
if len(size_hints) != 2:
|
|
raise NotImplementedError(f"size_hints: {size_hints}")
|
|
|
|
configs = _reduction_configs(size_hints=size_hints, inductor_meta=inductor_meta)
|
|
|
|
# Fixup configs to enforce the minimum Rn_BLOCK size
|
|
min_rblock = inductor_meta.get("min_split_scan_rblock", 256)
|
|
for cfg in configs:
|
|
for var in list(cfg.kwargs.keys()):
|
|
if var.startswith("R") and cfg.kwargs[var] < min_rblock:
|
|
cfg.kwargs[var] = min_rblock
|
|
|
|
return cached_autotune(
|
|
size_hints,
|
|
configs=configs,
|
|
triton_meta=triton_meta,
|
|
inductor_meta=inductor_meta,
|
|
heuristic_type=HeuristicType.SPLIT_SCAN,
|
|
filename=filename,
|
|
)
|
|
|
|
|
|
def template(
|
|
num_stages,
|
|
num_warps,
|
|
triton_meta,
|
|
num_consumer_groups=0,
|
|
num_buffers_warp_spec=0,
|
|
filename=None,
|
|
inductor_meta=None,
|
|
):
|
|
"""
|
|
Compile a triton template
|
|
"""
|
|
# Prepare the base configuration
|
|
config_args = {
|
|
"num_stages": num_stages,
|
|
"num_warps": num_warps,
|
|
}
|
|
|
|
# Conditionally add arguments based on HAS_WARP_SPEC
|
|
if HAS_WARP_SPEC:
|
|
config_args.update(
|
|
{
|
|
"num_consumer_groups": num_consumer_groups,
|
|
"num_buffers_warp_spec": num_buffers_warp_spec,
|
|
}
|
|
)
|
|
return cached_autotune(
|
|
None,
|
|
[triton.Config({}, **config_args)],
|
|
triton_meta=triton_meta,
|
|
inductor_meta=inductor_meta,
|
|
heuristic_type=HeuristicType.TEMPLATE,
|
|
filename=filename,
|
|
)
|
|
|
|
|
|
def _pop_config_kwargs(config: dict[str, Any]) -> dict[str, Any]:
|
|
"""Extract triton.Config options that should become kwargs"""
|
|
popped = {}
|
|
for key in (
|
|
"num_warps",
|
|
"num_stages",
|
|
"num_ctas",
|
|
"maxnreg",
|
|
"num_consumer_groups",
|
|
"num_buffers_warp_spec",
|
|
):
|
|
val = config.pop(key, None)
|
|
if val is not None:
|
|
popped[key] = val
|
|
return popped
|
|
|
|
|
|
def config_to_dict(config: Config) -> dict[str, Any]:
|
|
config_dict = {
|
|
**config.kwargs,
|
|
"num_warps": config.num_warps,
|
|
"num_stages": config.num_stages,
|
|
}
|
|
if HAS_WARP_SPEC:
|
|
config_dict.update(
|
|
{
|
|
"num_consumer_groups": getattr(config, "num_consumer_groups", 0),
|
|
"num_buffers_warp_spec": getattr(config, "num_buffers_warp_spec", 0),
|
|
}
|
|
)
|
|
return config_dict
|
|
|
|
|
|
def config_from_dict(config: dict[str, Any]) -> Config:
|
|
config = {**config}
|
|
return Config(config, **_pop_config_kwargs(config))
|
|
|
|
|
|
def fixed_config(config, filename, triton_meta, inductor_meta):
|
|
"""
|
|
Used when the configuration is already decided at compile time
|
|
"""
|
|
config = {**config}
|
|
return cached_autotune(
|
|
None,
|
|
[triton.Config(config, **_pop_config_kwargs(config))],
|
|
triton_meta=triton_meta,
|
|
inductor_meta=inductor_meta,
|
|
heuristic_type=HeuristicType.FIXED,
|
|
filename=filename,
|
|
)
|
|
|
|
|
|
def user_autotune(
|
|
configs, triton_meta, filename=None, inductor_meta=None, custom_kernel=False
|
|
):
|
|
"""
|
|
Compile a user defined triton kernel
|
|
"""
|
|
if len(configs) == 0:
|
|
configs = [triton.Config({})]
|
|
else:
|
|
configs = [*map(config_from_dict, configs)]
|
|
return cached_autotune(
|
|
None,
|
|
configs,
|
|
triton_meta=triton_meta,
|
|
heuristic_type=HeuristicType.USER_AUTOTUNE,
|
|
filename=filename,
|
|
inductor_meta=inductor_meta,
|
|
custom_kernel=custom_kernel,
|
|
)
|
|
|
|
|
|
def foreach(triton_meta, num_warps, filename=None, inductor_meta=None):
|
|
"""
|
|
Compile a triton foreach kernel
|
|
"""
|
|
return cached_autotune(
|
|
None,
|
|
[triton.Config({}, num_stages=1, num_warps=num_warps)],
|
|
triton_meta=triton_meta,
|
|
inductor_meta=inductor_meta,
|
|
heuristic_type=HeuristicType.TEMPLATE,
|
|
filename=filename,
|
|
)
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class GridExpr:
|
|
"""Generate code for grid size expressions in launcher"""
|
|
|
|
inductor_meta: dict[str, Any]
|
|
mode: Literal["python", "cpp"] = "python"
|
|
prefix: list[str] = dataclasses.field(default_factory=list)
|
|
x_grid: Union[str, int] = 1
|
|
y_grid: Union[str, int] = 1
|
|
z_grid: Union[str, int] = 1
|
|
|
|
def __post_init__(self) -> None:
|
|
assert self.mode in ("python", "cpp")
|
|
|
|
def generate(self, meta: dict[str, int]) -> None:
|
|
raise NotImplementedError
|
|
|
|
def ceildiv(
|
|
self, numel: Union[str, int], block: Union[None, int, str]
|
|
) -> Union[str, int]:
|
|
if block is None or block == 1:
|
|
return numel
|
|
if isinstance(numel, int) and isinstance(block, int):
|
|
return ceildiv(numel, block) # constant fold
|
|
if self.mode == "python":
|
|
return f"-(({numel}) // -({block}))"
|
|
# trick above doesn't work in C++ due to rounding differences
|
|
return f"(({numel} + ({block} - 1)) / ({block}))"
|
|
|
|
def maximum(self, seq: list[Union[int, str]]) -> Union[int, str]:
|
|
"""Codegen for max function with constant folding, constants are represented as int"""
|
|
items = self._constant_fold(max, seq)
|
|
if len(items) <= 1:
|
|
return items[0]
|
|
if self.mode == "python":
|
|
return f"max({', '.join(map(str, items))})"
|
|
return functools.reduce(lambda x, y: f"std::max({x}, {y})", items)
|
|
|
|
def summation(self, seq: list[Union[int, str]]) -> Union[int, str]:
|
|
"""Codegen for sum function with constant folding, constants are represented as int"""
|
|
items = self._constant_fold(sum, seq)
|
|
if len(items) <= 1:
|
|
return items[0]
|
|
return " + ".join(map(str, items))
|
|
|
|
def _constant_fold(
|
|
self, fn: Callable[[list[int]], int], seq: list[Union[int, str]]
|
|
) -> list[Union[int, str]]:
|
|
"""Constant fold through a commutative fn where ints are constants"""
|
|
items: list[Union[int, str]] = [x for x in seq if not isinstance(x, int)]
|
|
const_items = [x for x in seq if isinstance(x, int)]
|
|
if const_items:
|
|
items.append(fn(const_items))
|
|
return items
|
|
|
|
def assign_tmp(self, name: str, expr: Union[str, int]) -> str:
|
|
# Grid functions are one per kernel, so name collisions are fine
|
|
if self.mode == "python":
|
|
return f"{name} = {expr}"
|
|
if self.mode == "cpp":
|
|
return f"uint32_t {name} = {expr};"
|
|
raise AssertionError(f"invalid mode {self.mode}")
|
|
|
|
@staticmethod
|
|
def from_meta(
|
|
inductor_meta: dict[str, Any],
|
|
cfg: Union[Config, dict[str, int]],
|
|
mode: Literal["python", "cpp"] = "python",
|
|
) -> GridExpr:
|
|
grid_cls = globals()[inductor_meta["grid_type"]]
|
|
assert issubclass(grid_cls, GridExpr)
|
|
grid = grid_cls(inductor_meta=inductor_meta, mode=mode)
|
|
if isinstance(cfg, Config):
|
|
cfg = config_to_dict(cfg)
|
|
grid.generate(cfg)
|
|
return grid
|
|
|
|
def eval_slow(self, meta: dict[str, int]) -> tuple[int, int, int]:
|
|
scope = {**meta}
|
|
for line in self.prefix:
|
|
exec(line, scope)
|
|
exec(f"grid_0 = {self.x_grid}", scope)
|
|
exec(f"grid_1 = {self.y_grid}", scope)
|
|
exec(f"grid_2 = {self.z_grid}", scope)
|
|
return scope["grid_0"], scope["grid_1"], scope["grid_2"]
|
|
|
|
|
|
class Grid1D(GridExpr):
|
|
def generate(self, meta: dict[str, int]) -> None:
|
|
self.x_grid = self.ceildiv("xnumel", meta.get("XBLOCK"))
|
|
|
|
|
|
class Grid2D(GridExpr):
|
|
def generate(self, meta: dict[str, int]) -> None:
|
|
self.x_grid = self.ceildiv("xnumel", meta.get("XBLOCK"))
|
|
self.y_grid = self.ceildiv("ynumel", meta.get("YBLOCK"))
|
|
|
|
|
|
class Grid3D(GridExpr):
|
|
def generate(self, meta: dict[str, int]) -> None:
|
|
self.x_grid = self.ceildiv("xnumel", meta.get("XBLOCK"))
|
|
self.y_grid = self.ceildiv("ynumel", meta.get("YBLOCK"))
|
|
self.z_grid = self.ceildiv("znumel", meta.get("ZBLOCK"))
|
|
|
|
|
|
class Grid2DWithYZOverflow(GridExpr):
|
|
def generate(self, meta: dict[str, int]) -> None:
|
|
self.x_grid = self.ceildiv("xnumel", meta.get("XBLOCK"))
|
|
self.prefix.extend(
|
|
[
|
|
self.assign_tmp(
|
|
"y_grid_raw_", self.ceildiv("ynumel", meta.get("YBLOCK"))
|
|
),
|
|
self.assign_tmp(
|
|
"y_grid_div_", self.ceildiv("y_grid_raw_", get_max_y_grid())
|
|
),
|
|
]
|
|
)
|
|
self.y_grid = self.ceildiv("y_grid_raw_", "y_grid_div_")
|
|
self.z_grid = "y_grid_div_"
|
|
|
|
|
|
class CooperativeReductionGrid(GridExpr):
|
|
def generate(self, meta: dict[str, int]) -> None:
|
|
self.x_grid = str(meta["RSPLIT"])
|
|
self.y_grid = self.ceildiv("xnumel", meta.get("XBLOCK"))
|
|
|
|
|
|
class SplitScanGrid(GridExpr):
|
|
def generate(self, meta: dict[str, int]) -> None:
|
|
assert meta.get("XBLOCK", 1) == 1
|
|
self.x_grid = self.ceildiv("r0_numel", meta.get("R0_BLOCK"))
|
|
self.y_grid = "xnumel"
|
|
|
|
|
|
class FixedGrid(GridExpr):
|
|
@staticmethod
|
|
def setup_grid_as_args() -> dict[str, Any]:
|
|
"""Inductor meta so the launcher takes three extra grid arguments"""
|
|
return {
|
|
"grid_type": FixedGrid.__name__,
|
|
"fixed_grid": ["_grid_0", "_grid_1", "_grid_2"],
|
|
"extra_launcher_args": ["_grid_0", "_grid_1", "_grid_2"],
|
|
}
|
|
|
|
def generate(self, meta: dict[str, int]) -> None:
|
|
self.x_grid, self.y_grid, self.z_grid = self.inductor_meta["fixed_grid"]
|
|
|
|
|
|
class PrecomputedGrid(GridExpr):
|
|
def generate(self, meta: dict[str, int]) -> None:
|
|
for candidate in self.inductor_meta["precomputed_grids"]:
|
|
if all(meta.get(k) == v for k, v in candidate["config"].items()):
|
|
self.x_grid, self.y_grid, self.z_grid = candidate[self.mode]
|
|
return
|
|
raise AssertionError(
|
|
f"Precomputed grid not found for {meta} in {self.inductor_meta['precomputed_grids']}"
|
|
)
|
|
|
|
|
|
class ComboKernelGrid(GridExpr):
|
|
def generate(self, meta: dict[str, int]):
|
|
combo_meta = self.inductor_meta["combo_grid_meta"]
|
|
if combo_meta["default_config"]:
|
|
meta = {**combo_meta["default_config"], **meta}
|
|
no_x_dims = []
|
|
xnumels = []
|
|
ynumels = []
|
|
znumels = []
|
|
for num in range(combo_meta["num_kernels"]):
|
|
assert (
|
|
combo_meta[f"xnumel_{num}"] is None or combo_meta[f"xnumel_{num}"] > 0
|
|
)
|
|
no_x_dims.append(combo_meta[f"no_x_dim_{num}"])
|
|
xnumels.append(combo_meta[f"xnumel_{num}"] or f"xnumel_{num}")
|
|
if f"ynumel_{num}" in combo_meta:
|
|
ynumels.append(combo_meta[f"ynumel_{num}"] or f"ynumel_{num}")
|
|
if f"znumel_{num}" in combo_meta:
|
|
znumels.append(combo_meta[f"znumel_{num}"] or f"znumel_{num}")
|
|
|
|
self.x_grid = self.combo_x_grid(xnumels, no_x_dims, meta)
|
|
if combo_meta["min_blocks"]:
|
|
self.x_grid = self.maximum([self.x_grid, combo_meta["min_blocks"]])
|
|
if ynumels:
|
|
self.y_grid = self.ceildiv(self.maximum(ynumels), meta.get("YBLOCK"))
|
|
if znumels:
|
|
self.z_grid = self.ceildiv(self.maximum(znumels), meta.get("ZBLOCK"))
|
|
|
|
def combo_x_grid(
|
|
self,
|
|
xnumels: list[Union[int, str]],
|
|
no_x_dims: list[bool],
|
|
meta: dict[str, int],
|
|
) -> Union[str, int]:
|
|
raise NotImplementedError
|
|
|
|
|
|
class SequentialComboKernelGrid(ComboKernelGrid):
|
|
def combo_x_grid(
|
|
self,
|
|
xnumels: list[Union[int, str]],
|
|
no_x_dims: list[bool],
|
|
meta: dict[str, int],
|
|
) -> Union[str, int]:
|
|
assert len(xnumels) == len(no_x_dims)
|
|
return self.summation(
|
|
[
|
|
self.ceildiv(x, 1 if no_x_dim else meta.get("XBLOCK"))
|
|
for x, no_x_dim in zip(xnumels, no_x_dims)
|
|
]
|
|
)
|
|
|
|
|
|
class RoundRobinComboKernelGrid(ComboKernelGrid):
|
|
def combo_x_grid(
|
|
self,
|
|
xnumels: list[Union[int, str]],
|
|
no_x_dims: list[bool],
|
|
meta: dict[str, int],
|
|
) -> str:
|
|
assert len(xnumels) == len(no_x_dims)
|
|
num_kernels = self.inductor_meta["combo_grid_meta"]["num_kernels"]
|
|
exprs = [x for x, no_x_dim in zip(xnumels, no_x_dims) if no_x_dim]
|
|
xnumels_x_dim = [x for x, no_x_dim in zip(xnumels, no_x_dims) if not no_x_dim]
|
|
if xnumels_x_dim:
|
|
exprs.append(self.ceildiv(self.maximum(xnumels_x_dim), meta.get("XBLOCK")))
|
|
return f"({self.maximum(exprs)}) * {num_kernels}"
|