mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[inductor] Support fixed triton configs defined at compile time (#140217)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/140217 Approved by: https://github.com/shunting314 ghstack dependencies: #139585
This commit is contained in:
committed by
PyTorch MergeBot
parent
318eaa2be7
commit
2c6bd9f6f6
@ -1,7 +1,14 @@
|
||||
# Owner(s): ["module: inductor"]
|
||||
from typing import Any, Dict, List, Type
|
||||
|
||||
import sympy
|
||||
|
||||
import torch
|
||||
import torch._inductor
|
||||
from torch._inductor import config
|
||||
from torch._inductor.choices import InductorChoices
|
||||
from torch._inductor.codegen.simd_kernel_features import SIMDKernelFeatures
|
||||
from torch._inductor.codegen.triton import FixedTritonConfig, TritonKernel
|
||||
from torch._inductor.test_case import TestCase
|
||||
from torch._inductor.utils import run_and_get_code
|
||||
from torch.testing._internal.common_utils import (
|
||||
@ -124,6 +131,54 @@ class MultiKernelCooperativeReductionTests(CooperativeReductionTests):
|
||||
pass
|
||||
|
||||
|
||||
@config.patch(
|
||||
{
|
||||
"triton.cooperative_reductions": True,
|
||||
}
|
||||
)
|
||||
@instantiate_parametrized_tests
|
||||
class TestFixedConfigs(TestCase):
|
||||
@parametrize(
|
||||
"persistent,cooperative,cfg",
|
||||
[
|
||||
(False, False, {"XBLOCK": 1, "RBLOCK": 128}),
|
||||
(False, False, {"XBLOCK": 2, "RBLOCK": 128}),
|
||||
(True, False, {"XBLOCK": 1}),
|
||||
(True, False, {"XBLOCK": 2}),
|
||||
(False, True, {"XBLOCK": 1, "RBLOCK": 128, "RSPLIT": 16}),
|
||||
(False, True, {"XBLOCK": 2, "RBLOCK": 128, "RSPLIT": 16}),
|
||||
(True, True, {"XBLOCK": 1, "RSPLIT": 16}),
|
||||
(True, True, {"XBLOCK": 2, "RSPLIT": 16}),
|
||||
],
|
||||
)
|
||||
def test_fixed_configs(self, persistent, cooperative, cfg):
|
||||
class MyHeuristics(InductorChoices):
|
||||
def triton_kernel_kwargs(
|
||||
self,
|
||||
kernel_cls: Type[TritonKernel],
|
||||
features: SIMDKernelFeatures,
|
||||
groups: List[sympy.Expr],
|
||||
kernel_kwargs: Dict[str, Any],
|
||||
) -> Dict[str, Any]:
|
||||
return {
|
||||
**kernel_kwargs,
|
||||
"override_cooperative_reduction": cooperative,
|
||||
"override_persistent_reduction": persistent,
|
||||
"fixed_config": FixedTritonConfig(cfg),
|
||||
}
|
||||
|
||||
def fn(x):
|
||||
return torch.softmax(x + 1, dim=-1) + x
|
||||
|
||||
args = [torch.randn(8, 8000, device="cuda")]
|
||||
with torch._inductor.virtualized.V.set_choices_handler(MyHeuristics()):
|
||||
expected = fn(*args)
|
||||
fn = torch.compile(fn, fullgraph=True)
|
||||
result, (source_code,) = run_and_get_code(fn, *args)
|
||||
self.assertEqual(result, expected)
|
||||
self.assertIn("@triton_heuristics.fixed_config(", source_code)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._dynamo.test_case import run_tests
|
||||
|
||||
|
@ -1,6 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import Any, Dict, List, Type, TYPE_CHECKING
|
||||
|
||||
from . import config
|
||||
from .runtime.hints import ReductionHint
|
||||
@ -8,7 +8,10 @@ from .virtualized import V
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import sympy
|
||||
|
||||
from .codegen.simd_kernel_features import SIMDKernelFeatures
|
||||
from .codegen.triton import TritonKernel
|
||||
|
||||
|
||||
class InductorChoices:
|
||||
@ -24,6 +27,16 @@ class InductorChoices:
|
||||
torch._inductor.virtualized.V.set_choices_handler(MyHeuristics())
|
||||
"""
|
||||
|
||||
def triton_kernel_kwargs(
|
||||
self,
|
||||
kernel_cls: Type[TritonKernel],
|
||||
features: SIMDKernelFeatures,
|
||||
groups: List[sympy.Expr],
|
||||
kernel_kwargs: Dict[str, Any],
|
||||
) -> Dict[str, Any]:
|
||||
"""Hook to change the kwargs passed to TritonKernel, used to apply fixed configurations"""
|
||||
return kernel_kwargs
|
||||
|
||||
@staticmethod
|
||||
def should_use_cooperative_reduction(features: SIMDKernelFeatures) -> bool:
|
||||
"""Heuristic to decide if a cooperative reduction should be used."""
|
||||
|
@ -29,18 +29,8 @@ from typing import (
|
||||
import sympy
|
||||
|
||||
import torch
|
||||
import torch._inductor.metrics as metrics
|
||||
import torch._logging
|
||||
from torch._dynamo.utils import identity, preserve_rng_state
|
||||
from torch._inductor.runtime.hints import (
|
||||
AutotuneHint,
|
||||
DeviceProperties,
|
||||
TRITON_MAX_RSPLIT,
|
||||
)
|
||||
from torch._inductor.runtime.triton_heuristics import (
|
||||
cooperative_reduction_grid,
|
||||
grid as default_grid_fn,
|
||||
)
|
||||
from torch._prims_common import is_integer_dtype
|
||||
from torch.utils._ordered_set import OrderedSet
|
||||
from torch.utils._sympy.functions import CeilDiv, FloorDiv, ModularIndexing
|
||||
@ -48,14 +38,22 @@ from torch.utils._triton import has_triton_package
|
||||
|
||||
from ...utils._sympy.symbol import free_symbol_is_type, prefix_str, symbol_is_type, SymT
|
||||
from ...utils._sympy.value_ranges import ValueRanges
|
||||
from .. import config, ir
|
||||
from .. import config, ir, metrics
|
||||
from ..codecache import code_hash, get_path, PyCodeCache
|
||||
from ..runtime.benchmarking import benchmarker
|
||||
from ..runtime.hints import TRITON_MAX_BLOCK
|
||||
from ..runtime.hints import (
|
||||
AutotuneHint,
|
||||
DeviceProperties,
|
||||
TRITON_MAX_BLOCK,
|
||||
TRITON_MAX_RSPLIT,
|
||||
)
|
||||
from ..runtime.runtime_utils import get_max_y_grid, next_power_of_2
|
||||
from ..runtime.triton_heuristics import (
|
||||
cooperative_reduction_grid,
|
||||
grid as default_grid_fn,
|
||||
)
|
||||
from ..scheduler import BaseSchedulerNode, FusedSchedulerNode, Scheduler, SchedulerNode
|
||||
from ..utils import (
|
||||
cache_on_self,
|
||||
DelayReplaceLine,
|
||||
get_bounds_index_expr,
|
||||
get_fused_kernel_name,
|
||||
@ -172,10 +170,6 @@ class TritonSymbols:
|
||||
def get_block_offset(cls, tree: IterationRanges) -> sympy.Symbol:
|
||||
return cls.block_offsets[tree.symt]
|
||||
|
||||
@classmethod
|
||||
def max_block_size(cls, tree: IterationRanges) -> int:
|
||||
return TRITON_MAX_BLOCK[tree.prefix.upper()]
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class IndexingOptions:
|
||||
@ -211,6 +205,7 @@ class BlockPtrOptions:
|
||||
broadcast_shape: Sequence[sympy.Expr]
|
||||
broadcasting_dims: List[bool]
|
||||
final_shape: Sequence[sympy.Expr]
|
||||
_boundary_check: Optional[List[int]] = None
|
||||
|
||||
@property
|
||||
def shape(self) -> List[sympy.Expr]:
|
||||
@ -279,6 +274,7 @@ class BlockPtrOptions:
|
||||
constant_offset: sympy.Expr,
|
||||
range_trees: List[IterationRangesEntry],
|
||||
mask_vars: OrderedSet[str],
|
||||
get_max_block: Callable[[str], int],
|
||||
) -> BlockPtrOptions:
|
||||
"""Helper to create a BlockPtrOptions instance"""
|
||||
|
||||
@ -345,7 +341,7 @@ class BlockPtrOptions:
|
||||
# Need to expand rank by 1 to match rank when self.inside_reduction=True
|
||||
final_shape.append(sympy.S.One)
|
||||
|
||||
return BlockPtrOptions(
|
||||
result = BlockPtrOptions(
|
||||
params=params,
|
||||
constant_offset=V.graph.sizevars.lookup_precomputed_size(constant_offset),
|
||||
order=list(reversed(range(len(params.shape)))),
|
||||
@ -354,6 +350,8 @@ class BlockPtrOptions:
|
||||
broadcast_shape=broadcast_shape,
|
||||
broadcasting_dims=broadcasting_dims,
|
||||
)
|
||||
result.compute_boundary_check(get_max_block)
|
||||
return result
|
||||
|
||||
def replace_roffset(self, expr: sympy.Expr, replacement: sympy.Expr) -> sympy.Expr:
|
||||
"""
|
||||
@ -391,19 +389,18 @@ class BlockPtrOptions:
|
||||
]
|
||||
return f"tl.make_block_ptr({', '.join(args)})"
|
||||
|
||||
@cache_on_self
|
||||
def boundary_check(self) -> List[int]:
|
||||
def compute_boundary_check(self, get_max_block: Callable[[str], int]) -> None:
|
||||
"""List of indices to pass to tl.load(boundary_check=...)"""
|
||||
sizevars = V.graph.sizevars
|
||||
|
||||
# Substitute maximum block sizes in shape expressions.
|
||||
# This works in multiple_of checks because block sizes are powers of 2.
|
||||
block_to_max: Dict[sympy.Expr, Any] = {
|
||||
block_size: TRITON_MAX_BLOCK[prefix_str[symt].upper()]
|
||||
block_size: get_max_block(prefix_str[symt])
|
||||
for symt, block_size in TritonSymbols.block_sizes.items()
|
||||
}
|
||||
|
||||
return [
|
||||
self._boundary_check = [
|
||||
idx
|
||||
for idx in range(len(self.shape))
|
||||
if (
|
||||
@ -421,6 +418,10 @@ class BlockPtrOptions:
|
||||
)
|
||||
]
|
||||
|
||||
def boundary_check(self):
|
||||
assert self._boundary_check is not None
|
||||
return self._boundary_check
|
||||
|
||||
def advance_roffset(self):
|
||||
"""
|
||||
Codegen string to pass to tl.advance(name, ...).
|
||||
@ -1361,6 +1362,14 @@ class CooperativeReductionWorkspaceCache:
|
||||
return prior
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class FixedTritonConfig:
|
||||
config: Dict[str, int]
|
||||
|
||||
def __getitem__(self, item):
|
||||
return self.config[item]
|
||||
|
||||
|
||||
class TritonKernel(SIMDKernel):
|
||||
overrides = TritonKernelOverrides # type: ignore[assignment]
|
||||
helper_functions: HelperFunctions
|
||||
@ -1372,9 +1381,11 @@ class TritonKernel(SIMDKernel):
|
||||
*groups,
|
||||
min_elem_per_thread=0,
|
||||
optimize_mask=True,
|
||||
fixed_config: Optional[FixedTritonConfig] = None,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
self.optimize_mask: bool = optimize_mask
|
||||
self.fixed_config = fixed_config
|
||||
super().__init__(*groups, **kwargs)
|
||||
self.post_loop_combine: IndentedBuffer = IndentedBuffer()
|
||||
self.post_loop_store: IndentedBuffer = IndentedBuffer()
|
||||
@ -1410,8 +1421,10 @@ class TritonKernel(SIMDKernel):
|
||||
if tree.grid_dim is not None:
|
||||
tree.grid_dim += 1
|
||||
|
||||
xnumel, rnumel = self.numels
|
||||
self.semaphores_name = self.args.semaphores(xnumel)
|
||||
sem_count, _ = self.numels
|
||||
if self.fixed_config:
|
||||
sem_count = CeilDiv(sem_count, self.fixed_config["XBLOCK"])
|
||||
self.semaphores_name = self.args.semaphores(sem_count)
|
||||
self.cooperative_reduction_workspace_cache = CooperativeReductionWorkspaceCache(
|
||||
self.args
|
||||
)
|
||||
@ -1458,11 +1471,11 @@ class TritonKernel(SIMDKernel):
|
||||
)
|
||||
|
||||
def want_no_x_dim(self):
|
||||
return (
|
||||
self.persistent_reduction
|
||||
and len(self.numels) == 2
|
||||
and V.choices.want_no_x_dim(self.features)
|
||||
)
|
||||
if self.persistent_reduction and len(self.numels) == 2:
|
||||
if self.fixed_config:
|
||||
return self.fixed_config["XBLOCK"] == 1
|
||||
return V.choices.want_no_x_dim(self.features)
|
||||
return False
|
||||
|
||||
@property
|
||||
def assert_function(self) -> str:
|
||||
@ -1673,7 +1686,7 @@ class TritonKernel(SIMDKernel):
|
||||
# with n and m integers, then either numel is a multiple of XBLOCK, or numel
|
||||
# is less than XBLOCK. (If numel is less than XBLOCK, we round up to 1 below.)
|
||||
# 2. Numels are multiples of the maximum possible block size.
|
||||
max_block = TritonSymbols.max_block_size(range_tree)
|
||||
max_block = self.max_block(range_tree.prefix)
|
||||
if any(
|
||||
not sizevars.statically_known_multiple_of(numel, max_block)
|
||||
and not sizevars.statically_known_power_of_2(numel)
|
||||
@ -1769,6 +1782,7 @@ class TritonKernel(SIMDKernel):
|
||||
constant_offset=offset,
|
||||
range_trees=range_trees,
|
||||
mask_vars=mask_vars,
|
||||
get_max_block=self.max_block,
|
||||
)
|
||||
|
||||
# Return a block pointer, if indexing matches the pattern.
|
||||
@ -1869,9 +1883,6 @@ class TritonKernel(SIMDKernel):
|
||||
index_str, "0" if lower else None, size_str, mask_str
|
||||
)
|
||||
|
||||
indirect = self.is_indirect_indexing(expr) or any(
|
||||
isinstance(m, TritonCSEVariable) for m in indexing.mask_vars
|
||||
)
|
||||
buffer = self.get_load_buffer(indexing)
|
||||
self.cse.generate(buffer, line, assignment=False, dtype=torch.int32)
|
||||
|
||||
@ -2430,6 +2441,11 @@ class TritonKernel(SIMDKernel):
|
||||
)
|
||||
return result_mean, result_m2, result_weight
|
||||
|
||||
def max_rsplit(self):
|
||||
if self.fixed_config:
|
||||
return self.fixed_config["RSPLIT"]
|
||||
return TRITON_MAX_RSPLIT
|
||||
|
||||
def codegen_cooperative_reduction_peer_combine(self, result_var, dtype):
|
||||
"""
|
||||
Generate code to save a [XBLOCK, RSPLIT] temporary workspace, where each thread block writes a different
|
||||
@ -2440,7 +2456,7 @@ class TritonKernel(SIMDKernel):
|
||||
mask = "xindex < xnumel" if xnumel != 1 and not self.no_x_dim else None
|
||||
expand = "" if self.no_x_dim else "[None,:]"
|
||||
|
||||
nbytes = xnumel * dtype.itemsize * TRITON_MAX_RSPLIT
|
||||
nbytes = xnumel * dtype.itemsize * self.max_rsplit()
|
||||
ws_name, ws_offset = self.cooperative_reduction_workspace_cache.allocate(nbytes)
|
||||
|
||||
self.post_loop_combine.splice(
|
||||
@ -2887,7 +2903,9 @@ class TritonKernel(SIMDKernel):
|
||||
)
|
||||
|
||||
def _get_heuristic(self):
|
||||
if self.cooperative_reduction:
|
||||
if self.fixed_config:
|
||||
return "fixed_config"
|
||||
elif self.cooperative_reduction:
|
||||
return "cooperative_reduction"
|
||||
elif self.persistent_reduction:
|
||||
assert self.inside_reduction
|
||||
@ -2962,8 +2980,6 @@ class TritonKernel(SIMDKernel):
|
||||
if not self.inside_reduction:
|
||||
size_hints.pop()
|
||||
|
||||
heuristics = self._get_heuristic()
|
||||
|
||||
if name is None:
|
||||
code.splice(gen_common_triton_imports())
|
||||
device_type = V.graph.get_current_device_or_throw().type
|
||||
@ -3094,10 +3110,20 @@ class TritonKernel(SIMDKernel):
|
||||
code.writeline("")
|
||||
code.splice(helper)
|
||||
|
||||
if self.inside_reduction:
|
||||
if self.fixed_config:
|
||||
heuristics_line = f"""
|
||||
@triton_heuristics.{self._get_heuristic()}(
|
||||
config={self.fixed_config.config!r},
|
||||
filename=__file__,
|
||||
triton_meta={triton_meta!r},
|
||||
inductor_meta={inductor_meta!r}
|
||||
)
|
||||
@triton.jit
|
||||
"""
|
||||
elif self.inside_reduction:
|
||||
reduction_hint = self.features.get_reduction_hint()
|
||||
heuristics_line = f"""
|
||||
@triton_heuristics.{heuristics}(
|
||||
@triton_heuristics.{self._get_heuristic()}(
|
||||
size_hints={size_hints!r},
|
||||
reduction_hint={reduction_hint},
|
||||
filename=__file__,
|
||||
@ -3114,7 +3140,7 @@ class TritonKernel(SIMDKernel):
|
||||
else:
|
||||
tile_hint = "tile_hint=TileHint.DEFAULT,"
|
||||
heuristics_line = f"""
|
||||
@triton_heuristics.{heuristics}(
|
||||
@triton_heuristics.{self._get_heuristic()}(
|
||||
size_hints={size_hints!r}, {tile_hint}
|
||||
filename=__file__,
|
||||
triton_meta={triton_meta!r},
|
||||
@ -3298,6 +3324,11 @@ class TritonKernel(SIMDKernel):
|
||||
return f"{pid}.to({self.index_dtype})"
|
||||
return pid
|
||||
|
||||
def max_block(self, prefix):
|
||||
if self.fixed_config:
|
||||
return self.fixed_config[f"{prefix.upper()}BLOCK"]
|
||||
return TRITON_MAX_BLOCK[prefix.upper()]
|
||||
|
||||
def _has_constant_mask(self, tree: IterationRangesRoot):
|
||||
if not self.optimize_mask:
|
||||
return False
|
||||
@ -3311,12 +3342,10 @@ class TritonKernel(SIMDKernel):
|
||||
elif tree.prefix == "x" and self.no_x_dim:
|
||||
max_block = 1
|
||||
else:
|
||||
if tree.prefix.upper() not in TRITON_MAX_BLOCK:
|
||||
return False
|
||||
max_block = TRITON_MAX_BLOCK[tree.prefix.upper()]
|
||||
max_block = self.max_block(tree.prefix)
|
||||
|
||||
if tree.prefix == "r" and self.cooperative_reduction:
|
||||
max_block = max_block * TRITON_MAX_RSPLIT
|
||||
max_block = max_block * self.max_rsplit()
|
||||
|
||||
# Optional optimization: if block divides numel exactly, we will
|
||||
# never need to do a masked load to handle stragglers at the end.
|
||||
@ -3577,7 +3606,11 @@ class TritonScheduling(SIMDScheduling):
|
||||
# so taking the hit of non-coalesced loads is okay
|
||||
if kernel_features.contains_op("sort"):
|
||||
kernel_kwargs["override_persistent_reduction"] = True
|
||||
kernel_kwargs["override_cooperative_reduction"] = False
|
||||
|
||||
kernel_kwargs = V.choices.triton_kernel_kwargs(
|
||||
kernel_type, kernel_features, kernel_args, kernel_kwargs
|
||||
)
|
||||
kernel = kernel_type(*kernel_args, **kernel_kwargs)
|
||||
return self.add_multi_kernel_choices(kernel, kernel_args, kernel_kwargs)
|
||||
|
||||
|
@ -29,9 +29,11 @@ class TritonSplitScanKernel(TritonKernel):
|
||||
self,
|
||||
*groups,
|
||||
pid_cache=None,
|
||||
fixed_config=None,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
assert pid_cache is None, "not supported"
|
||||
assert fixed_config is None, "not supported"
|
||||
super().__init__(
|
||||
*groups,
|
||||
**kwargs,
|
||||
|
@ -5,6 +5,7 @@ from enum import auto, Enum
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
|
||||
# The following maximums only apply to runtime autotuning, when using FixedTritonConfig one may see larger values
|
||||
# NOTE: if these fail asserts submit a PR to increase them
|
||||
TRITON_MAX_BLOCK = {
|
||||
"X": 4096,
|
||||
@ -94,6 +95,7 @@ class HeuristicType(Enum):
|
||||
SPLIT_SCAN = auto()
|
||||
TEMPLATE = auto()
|
||||
USER_AUTOTUNE = auto()
|
||||
FIXED = auto()
|
||||
|
||||
|
||||
class AutotuneHint(Enum):
|
||||
|
@ -1854,29 +1854,42 @@ def template(num_stages, num_warps, triton_meta, filename=None, inductor_meta=No
|
||||
)
|
||||
|
||||
|
||||
def _pop_config_kwargs(config: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Extract triton.Config options that should become kwargs"""
|
||||
popped = {}
|
||||
for key in ("num_warps", "num_stages", "num_ctas", "maxnreg"):
|
||||
val = config.pop(key, None)
|
||||
if val is not None:
|
||||
popped[key] = val
|
||||
return popped
|
||||
|
||||
|
||||
def fixed_config(config, filename, triton_meta, inductor_meta):
|
||||
"""
|
||||
Used when the configuration is already decided at compile time
|
||||
"""
|
||||
config = {**config}
|
||||
return cached_autotune(
|
||||
None,
|
||||
[triton.Config(config, **_pop_config_kwargs(config))],
|
||||
triton_meta=triton_meta,
|
||||
inductor_meta=inductor_meta,
|
||||
heuristic_type=HeuristicType.FIXED,
|
||||
filename=filename,
|
||||
)
|
||||
|
||||
|
||||
def user_autotune(
|
||||
configs, triton_meta, filename=None, inductor_meta=None, custom_kernel=False
|
||||
):
|
||||
"""
|
||||
Compile a user defined triton kernel
|
||||
"""
|
||||
defaults = inspect.signature(triton.Config).parameters
|
||||
default_num_stages = defaults["num_stages"].default
|
||||
default_num_warps = defaults["num_warps"].default
|
||||
|
||||
if len(configs) == 0:
|
||||
configs = [
|
||||
triton.Config(
|
||||
{}, num_stages=default_num_stages, num_warps=default_num_warps
|
||||
)
|
||||
]
|
||||
configs = [triton.Config({})]
|
||||
else:
|
||||
configs = [
|
||||
triton.Config(
|
||||
c.get("kwargs", {}),
|
||||
num_stages=c.get("num_stages", default_num_stages),
|
||||
num_warps=c.get("num_warps", default_num_warps),
|
||||
)
|
||||
triton.Config(c.get("kwargs", {}), **_pop_config_kwargs({**c}))
|
||||
for c in configs
|
||||
]
|
||||
return cached_autotune(
|
||||
|
Reference in New Issue
Block a user