[inductor] Support fixed triton configs defined at compile time (#140217)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/140217
Approved by: https://github.com/shunting314
ghstack dependencies: #139585
This commit is contained in:
Jason Ansel
2024-11-16 10:23:32 -08:00
committed by PyTorch MergeBot
parent 318eaa2be7
commit 2c6bd9f6f6
6 changed files with 177 additions and 59 deletions

View File

@ -1,7 +1,14 @@
# Owner(s): ["module: inductor"]
from typing import Any, Dict, List, Type
import sympy
import torch
import torch._inductor
from torch._inductor import config
from torch._inductor.choices import InductorChoices
from torch._inductor.codegen.simd_kernel_features import SIMDKernelFeatures
from torch._inductor.codegen.triton import FixedTritonConfig, TritonKernel
from torch._inductor.test_case import TestCase
from torch._inductor.utils import run_and_get_code
from torch.testing._internal.common_utils import (
@ -124,6 +131,54 @@ class MultiKernelCooperativeReductionTests(CooperativeReductionTests):
pass
@config.patch(
{
"triton.cooperative_reductions": True,
}
)
@instantiate_parametrized_tests
class TestFixedConfigs(TestCase):
@parametrize(
"persistent,cooperative,cfg",
[
(False, False, {"XBLOCK": 1, "RBLOCK": 128}),
(False, False, {"XBLOCK": 2, "RBLOCK": 128}),
(True, False, {"XBLOCK": 1}),
(True, False, {"XBLOCK": 2}),
(False, True, {"XBLOCK": 1, "RBLOCK": 128, "RSPLIT": 16}),
(False, True, {"XBLOCK": 2, "RBLOCK": 128, "RSPLIT": 16}),
(True, True, {"XBLOCK": 1, "RSPLIT": 16}),
(True, True, {"XBLOCK": 2, "RSPLIT": 16}),
],
)
def test_fixed_configs(self, persistent, cooperative, cfg):
class MyHeuristics(InductorChoices):
def triton_kernel_kwargs(
self,
kernel_cls: Type[TritonKernel],
features: SIMDKernelFeatures,
groups: List[sympy.Expr],
kernel_kwargs: Dict[str, Any],
) -> Dict[str, Any]:
return {
**kernel_kwargs,
"override_cooperative_reduction": cooperative,
"override_persistent_reduction": persistent,
"fixed_config": FixedTritonConfig(cfg),
}
def fn(x):
return torch.softmax(x + 1, dim=-1) + x
args = [torch.randn(8, 8000, device="cuda")]
with torch._inductor.virtualized.V.set_choices_handler(MyHeuristics()):
expected = fn(*args)
fn = torch.compile(fn, fullgraph=True)
result, (source_code,) = run_and_get_code(fn, *args)
self.assertEqual(result, expected)
self.assertIn("@triton_heuristics.fixed_config(", source_code)
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests

View File

@ -1,6 +1,6 @@
from __future__ import annotations
from typing import TYPE_CHECKING
from typing import Any, Dict, List, Type, TYPE_CHECKING
from . import config
from .runtime.hints import ReductionHint
@ -8,7 +8,10 @@ from .virtualized import V
if TYPE_CHECKING:
import sympy
from .codegen.simd_kernel_features import SIMDKernelFeatures
from .codegen.triton import TritonKernel
class InductorChoices:
@ -24,6 +27,16 @@ class InductorChoices:
torch._inductor.virtualized.V.set_choices_handler(MyHeuristics())
"""
def triton_kernel_kwargs(
self,
kernel_cls: Type[TritonKernel],
features: SIMDKernelFeatures,
groups: List[sympy.Expr],
kernel_kwargs: Dict[str, Any],
) -> Dict[str, Any]:
"""Hook to change the kwargs passed to TritonKernel, used to apply fixed configurations"""
return kernel_kwargs
@staticmethod
def should_use_cooperative_reduction(features: SIMDKernelFeatures) -> bool:
"""Heuristic to decide if a cooperative reduction should be used."""

View File

@ -29,18 +29,8 @@ from typing import (
import sympy
import torch
import torch._inductor.metrics as metrics
import torch._logging
from torch._dynamo.utils import identity, preserve_rng_state
from torch._inductor.runtime.hints import (
AutotuneHint,
DeviceProperties,
TRITON_MAX_RSPLIT,
)
from torch._inductor.runtime.triton_heuristics import (
cooperative_reduction_grid,
grid as default_grid_fn,
)
from torch._prims_common import is_integer_dtype
from torch.utils._ordered_set import OrderedSet
from torch.utils._sympy.functions import CeilDiv, FloorDiv, ModularIndexing
@ -48,14 +38,22 @@ from torch.utils._triton import has_triton_package
from ...utils._sympy.symbol import free_symbol_is_type, prefix_str, symbol_is_type, SymT
from ...utils._sympy.value_ranges import ValueRanges
from .. import config, ir
from .. import config, ir, metrics
from ..codecache import code_hash, get_path, PyCodeCache
from ..runtime.benchmarking import benchmarker
from ..runtime.hints import TRITON_MAX_BLOCK
from ..runtime.hints import (
AutotuneHint,
DeviceProperties,
TRITON_MAX_BLOCK,
TRITON_MAX_RSPLIT,
)
from ..runtime.runtime_utils import get_max_y_grid, next_power_of_2
from ..runtime.triton_heuristics import (
cooperative_reduction_grid,
grid as default_grid_fn,
)
from ..scheduler import BaseSchedulerNode, FusedSchedulerNode, Scheduler, SchedulerNode
from ..utils import (
cache_on_self,
DelayReplaceLine,
get_bounds_index_expr,
get_fused_kernel_name,
@ -172,10 +170,6 @@ class TritonSymbols:
def get_block_offset(cls, tree: IterationRanges) -> sympy.Symbol:
return cls.block_offsets[tree.symt]
@classmethod
def max_block_size(cls, tree: IterationRanges) -> int:
return TRITON_MAX_BLOCK[tree.prefix.upper()]
@dataclasses.dataclass
class IndexingOptions:
@ -211,6 +205,7 @@ class BlockPtrOptions:
broadcast_shape: Sequence[sympy.Expr]
broadcasting_dims: List[bool]
final_shape: Sequence[sympy.Expr]
_boundary_check: Optional[List[int]] = None
@property
def shape(self) -> List[sympy.Expr]:
@ -279,6 +274,7 @@ class BlockPtrOptions:
constant_offset: sympy.Expr,
range_trees: List[IterationRangesEntry],
mask_vars: OrderedSet[str],
get_max_block: Callable[[str], int],
) -> BlockPtrOptions:
"""Helper to create a BlockPtrOptions instance"""
@ -345,7 +341,7 @@ class BlockPtrOptions:
# Need to expand rank by 1 to match rank when self.inside_reduction=True
final_shape.append(sympy.S.One)
return BlockPtrOptions(
result = BlockPtrOptions(
params=params,
constant_offset=V.graph.sizevars.lookup_precomputed_size(constant_offset),
order=list(reversed(range(len(params.shape)))),
@ -354,6 +350,8 @@ class BlockPtrOptions:
broadcast_shape=broadcast_shape,
broadcasting_dims=broadcasting_dims,
)
result.compute_boundary_check(get_max_block)
return result
def replace_roffset(self, expr: sympy.Expr, replacement: sympy.Expr) -> sympy.Expr:
"""
@ -391,19 +389,18 @@ class BlockPtrOptions:
]
return f"tl.make_block_ptr({', '.join(args)})"
@cache_on_self
def boundary_check(self) -> List[int]:
def compute_boundary_check(self, get_max_block: Callable[[str], int]) -> None:
"""List of indices to pass to tl.load(boundary_check=...)"""
sizevars = V.graph.sizevars
# Substitute maximum block sizes in shape expressions.
# This works in multiple_of checks because block sizes are powers of 2.
block_to_max: Dict[sympy.Expr, Any] = {
block_size: TRITON_MAX_BLOCK[prefix_str[symt].upper()]
block_size: get_max_block(prefix_str[symt])
for symt, block_size in TritonSymbols.block_sizes.items()
}
return [
self._boundary_check = [
idx
for idx in range(len(self.shape))
if (
@ -421,6 +418,10 @@ class BlockPtrOptions:
)
]
def boundary_check(self):
assert self._boundary_check is not None
return self._boundary_check
def advance_roffset(self):
"""
Codegen string to pass to tl.advance(name, ...).
@ -1361,6 +1362,14 @@ class CooperativeReductionWorkspaceCache:
return prior
@dataclasses.dataclass
class FixedTritonConfig:
config: Dict[str, int]
def __getitem__(self, item):
return self.config[item]
class TritonKernel(SIMDKernel):
overrides = TritonKernelOverrides # type: ignore[assignment]
helper_functions: HelperFunctions
@ -1372,9 +1381,11 @@ class TritonKernel(SIMDKernel):
*groups,
min_elem_per_thread=0,
optimize_mask=True,
fixed_config: Optional[FixedTritonConfig] = None,
**kwargs,
) -> None:
self.optimize_mask: bool = optimize_mask
self.fixed_config = fixed_config
super().__init__(*groups, **kwargs)
self.post_loop_combine: IndentedBuffer = IndentedBuffer()
self.post_loop_store: IndentedBuffer = IndentedBuffer()
@ -1410,8 +1421,10 @@ class TritonKernel(SIMDKernel):
if tree.grid_dim is not None:
tree.grid_dim += 1
xnumel, rnumel = self.numels
self.semaphores_name = self.args.semaphores(xnumel)
sem_count, _ = self.numels
if self.fixed_config:
sem_count = CeilDiv(sem_count, self.fixed_config["XBLOCK"])
self.semaphores_name = self.args.semaphores(sem_count)
self.cooperative_reduction_workspace_cache = CooperativeReductionWorkspaceCache(
self.args
)
@ -1458,11 +1471,11 @@ class TritonKernel(SIMDKernel):
)
def want_no_x_dim(self):
return (
self.persistent_reduction
and len(self.numels) == 2
and V.choices.want_no_x_dim(self.features)
)
if self.persistent_reduction and len(self.numels) == 2:
if self.fixed_config:
return self.fixed_config["XBLOCK"] == 1
return V.choices.want_no_x_dim(self.features)
return False
@property
def assert_function(self) -> str:
@ -1673,7 +1686,7 @@ class TritonKernel(SIMDKernel):
# with n and m integers, then either numel is a multiple of XBLOCK, or numel
# is less than XBLOCK. (If numel is less than XBLOCK, we round up to 1 below.)
# 2. Numels are multiples of the maximum possible block size.
max_block = TritonSymbols.max_block_size(range_tree)
max_block = self.max_block(range_tree.prefix)
if any(
not sizevars.statically_known_multiple_of(numel, max_block)
and not sizevars.statically_known_power_of_2(numel)
@ -1769,6 +1782,7 @@ class TritonKernel(SIMDKernel):
constant_offset=offset,
range_trees=range_trees,
mask_vars=mask_vars,
get_max_block=self.max_block,
)
# Return a block pointer, if indexing matches the pattern.
@ -1869,9 +1883,6 @@ class TritonKernel(SIMDKernel):
index_str, "0" if lower else None, size_str, mask_str
)
indirect = self.is_indirect_indexing(expr) or any(
isinstance(m, TritonCSEVariable) for m in indexing.mask_vars
)
buffer = self.get_load_buffer(indexing)
self.cse.generate(buffer, line, assignment=False, dtype=torch.int32)
@ -2430,6 +2441,11 @@ class TritonKernel(SIMDKernel):
)
return result_mean, result_m2, result_weight
def max_rsplit(self):
if self.fixed_config:
return self.fixed_config["RSPLIT"]
return TRITON_MAX_RSPLIT
def codegen_cooperative_reduction_peer_combine(self, result_var, dtype):
"""
Generate code to save a [XBLOCK, RSPLIT] temporary workspace, where each thread block writes a different
@ -2440,7 +2456,7 @@ class TritonKernel(SIMDKernel):
mask = "xindex < xnumel" if xnumel != 1 and not self.no_x_dim else None
expand = "" if self.no_x_dim else "[None,:]"
nbytes = xnumel * dtype.itemsize * TRITON_MAX_RSPLIT
nbytes = xnumel * dtype.itemsize * self.max_rsplit()
ws_name, ws_offset = self.cooperative_reduction_workspace_cache.allocate(nbytes)
self.post_loop_combine.splice(
@ -2887,7 +2903,9 @@ class TritonKernel(SIMDKernel):
)
def _get_heuristic(self):
if self.cooperative_reduction:
if self.fixed_config:
return "fixed_config"
elif self.cooperative_reduction:
return "cooperative_reduction"
elif self.persistent_reduction:
assert self.inside_reduction
@ -2962,8 +2980,6 @@ class TritonKernel(SIMDKernel):
if not self.inside_reduction:
size_hints.pop()
heuristics = self._get_heuristic()
if name is None:
code.splice(gen_common_triton_imports())
device_type = V.graph.get_current_device_or_throw().type
@ -3094,10 +3110,20 @@ class TritonKernel(SIMDKernel):
code.writeline("")
code.splice(helper)
if self.inside_reduction:
if self.fixed_config:
heuristics_line = f"""
@triton_heuristics.{self._get_heuristic()}(
config={self.fixed_config.config!r},
filename=__file__,
triton_meta={triton_meta!r},
inductor_meta={inductor_meta!r}
)
@triton.jit
"""
elif self.inside_reduction:
reduction_hint = self.features.get_reduction_hint()
heuristics_line = f"""
@triton_heuristics.{heuristics}(
@triton_heuristics.{self._get_heuristic()}(
size_hints={size_hints!r},
reduction_hint={reduction_hint},
filename=__file__,
@ -3114,7 +3140,7 @@ class TritonKernel(SIMDKernel):
else:
tile_hint = "tile_hint=TileHint.DEFAULT,"
heuristics_line = f"""
@triton_heuristics.{heuristics}(
@triton_heuristics.{self._get_heuristic()}(
size_hints={size_hints!r}, {tile_hint}
filename=__file__,
triton_meta={triton_meta!r},
@ -3298,6 +3324,11 @@ class TritonKernel(SIMDKernel):
return f"{pid}.to({self.index_dtype})"
return pid
def max_block(self, prefix):
if self.fixed_config:
return self.fixed_config[f"{prefix.upper()}BLOCK"]
return TRITON_MAX_BLOCK[prefix.upper()]
def _has_constant_mask(self, tree: IterationRangesRoot):
if not self.optimize_mask:
return False
@ -3311,12 +3342,10 @@ class TritonKernel(SIMDKernel):
elif tree.prefix == "x" and self.no_x_dim:
max_block = 1
else:
if tree.prefix.upper() not in TRITON_MAX_BLOCK:
return False
max_block = TRITON_MAX_BLOCK[tree.prefix.upper()]
max_block = self.max_block(tree.prefix)
if tree.prefix == "r" and self.cooperative_reduction:
max_block = max_block * TRITON_MAX_RSPLIT
max_block = max_block * self.max_rsplit()
# Optional optimization: if block divides numel exactly, we will
# never need to do a masked load to handle stragglers at the end.
@ -3577,7 +3606,11 @@ class TritonScheduling(SIMDScheduling):
# so taking the hit of non-coalesced loads is okay
if kernel_features.contains_op("sort"):
kernel_kwargs["override_persistent_reduction"] = True
kernel_kwargs["override_cooperative_reduction"] = False
kernel_kwargs = V.choices.triton_kernel_kwargs(
kernel_type, kernel_features, kernel_args, kernel_kwargs
)
kernel = kernel_type(*kernel_args, **kernel_kwargs)
return self.add_multi_kernel_choices(kernel, kernel_args, kernel_kwargs)

View File

@ -29,9 +29,11 @@ class TritonSplitScanKernel(TritonKernel):
self,
*groups,
pid_cache=None,
fixed_config=None,
**kwargs,
) -> None:
assert pid_cache is None, "not supported"
assert fixed_config is None, "not supported"
super().__init__(
*groups,
**kwargs,

View File

@ -5,6 +5,7 @@ from enum import auto, Enum
from typing import Dict, List, Optional, Union
# The following maximums only apply to runtime autotuning, when using FixedTritonConfig one may see larger values
# NOTE: if these fail asserts submit a PR to increase them
TRITON_MAX_BLOCK = {
"X": 4096,
@ -94,6 +95,7 @@ class HeuristicType(Enum):
SPLIT_SCAN = auto()
TEMPLATE = auto()
USER_AUTOTUNE = auto()
FIXED = auto()
class AutotuneHint(Enum):

View File

@ -1854,29 +1854,42 @@ def template(num_stages, num_warps, triton_meta, filename=None, inductor_meta=No
)
def _pop_config_kwargs(config: Dict[str, Any]) -> Dict[str, Any]:
"""Extract triton.Config options that should become kwargs"""
popped = {}
for key in ("num_warps", "num_stages", "num_ctas", "maxnreg"):
val = config.pop(key, None)
if val is not None:
popped[key] = val
return popped
def fixed_config(config, filename, triton_meta, inductor_meta):
"""
Used when the configuration is already decided at compile time
"""
config = {**config}
return cached_autotune(
None,
[triton.Config(config, **_pop_config_kwargs(config))],
triton_meta=triton_meta,
inductor_meta=inductor_meta,
heuristic_type=HeuristicType.FIXED,
filename=filename,
)
def user_autotune(
configs, triton_meta, filename=None, inductor_meta=None, custom_kernel=False
):
"""
Compile a user defined triton kernel
"""
defaults = inspect.signature(triton.Config).parameters
default_num_stages = defaults["num_stages"].default
default_num_warps = defaults["num_warps"].default
if len(configs) == 0:
configs = [
triton.Config(
{}, num_stages=default_num_stages, num_warps=default_num_warps
)
]
configs = [triton.Config({})]
else:
configs = [
triton.Config(
c.get("kwargs", {}),
num_stages=c.get("num_stages", default_num_stages),
num_warps=c.get("num_warps", default_num_warps),
)
triton.Config(c.get("kwargs", {}), **_pop_config_kwargs({**c}))
for c in configs
]
return cached_autotune(