Update (base update)

[ghstack-poisoned]
This commit is contained in:
Jason Ansel
2025-01-25 15:37:32 -08:00
parent c1f3dc86f9
commit 690c47a8ee
3 changed files with 217 additions and 111 deletions

View File

@ -20,6 +20,30 @@ from torch.testing._internal.common_utils import (
from torch.testing._internal.inductor_utils import HAS_CUDA
class TestingHeuristics(InductorChoices):
def __init__(self, *, cooperative: bool, persistent: bool, cfg: dict[str, int]):
super().__init__()
self.cooperative = cooperative
self.persistent = persistent
self.cfg = cfg
self.call_count = 0
def triton_kernel_kwargs(
self,
kernel_cls: type[TritonKernel],
features: SIMDKernelFeatures,
groups: list[sympy.Expr],
kernel_kwargs: dict[str, Any],
) -> dict[str, Any]:
self.call_count += 1
return {
**kernel_kwargs,
"override_cooperative_reduction": self.cooperative,
"override_persistent_reduction": self.persistent,
"fixed_config": FixedTritonConfig(self.cfg),
}
@config.patch(
{
"triton.cooperative_reductions": True,
@ -146,6 +170,19 @@ class MultiKernelCooperativeReductionTests(CooperativeReductionTests):
)
@instantiate_parametrized_tests
class TestFixedConfigs(TestCase):
def _check(self, fn, args, *, persistent=False, cooperative=True, cfg):
expected = fn(*args)
heuristic = TestingHeuristics(
persistent=persistent, cooperative=cooperative, cfg=cfg
)
with torch._inductor.virtualized.V.set_choices_handler(heuristic):
result, (source_code,) = run_and_get_code(
torch.compile(fn, fullgraph=True), *args
)
self.assertEqual(result, expected)
self.assertEqual(heuristic.call_count, 1)
self.assertIn("@triton_heuristics.fixed_config(", source_code)
@parametrize(
"persistent,cooperative,cfg",
[
@ -157,70 +194,61 @@ class TestFixedConfigs(TestCase):
(False, True, {"XBLOCK": 2, "R0_BLOCK": 128, "RSPLIT": 16}),
(True, True, {"XBLOCK": 1, "RSPLIT": 16}),
(True, True, {"XBLOCK": 2, "RSPLIT": 16}),
(False, True, {"XBLOCK": 1, "R0_BLOCK": 128, "RSPLIT": 17}),
(False, True, {"XBLOCK": 2, "R0_BLOCK": 128, "RSPLIT": 17}),
(True, True, {"XBLOCK": 1, "RSPLIT": 17}),
(True, True, {"XBLOCK": 2, "RSPLIT": 17}),
],
)
def test_fixed_configs(self, persistent, cooperative, cfg):
class MyHeuristics(InductorChoices):
def triton_kernel_kwargs(
self,
kernel_cls: type[TritonKernel],
features: SIMDKernelFeatures,
groups: list[sympy.Expr],
kernel_kwargs: dict[str, Any],
) -> dict[str, Any]:
return {
**kernel_kwargs,
"override_cooperative_reduction": cooperative,
"override_persistent_reduction": persistent,
"fixed_config": FixedTritonConfig(cfg),
}
def fn(x):
return torch.softmax(x + 1, dim=-1) + x
args = [torch.randn(8, 8000, device="cuda")]
with torch._inductor.virtualized.V.set_choices_handler(MyHeuristics()):
expected = fn(*args)
fn = torch.compile(fn, fullgraph=True)
result, (source_code,) = run_and_get_code(fn, *args)
self.assertEqual(result, expected)
self.assertIn("@triton_heuristics.fixed_config(", source_code)
self._check(fn, args, persistent=persistent, cooperative=cooperative, cfg=cfg)
def test_fixed_config_with_larger_xblock_than_xnumel(self):
class MyHeuristics(InductorChoices):
def triton_kernel_kwargs(
self,
kernel_cls: type[TritonKernel],
features: SIMDKernelFeatures,
groups: list[sympy.Expr],
kernel_kwargs: dict[str, Any],
) -> dict[str, Any]:
return {
**kernel_kwargs,
"override_cooperative_reduction": True,
"override_persistent_reduction": True,
"fixed_config": FixedTritonConfig(
{"XBLOCK": 128, "RSPLIT": 32, "num_warps": 16, "num_stages": 1}
),
}
@parametrize(
"persistent,x,r,rsplit",
[
(False, 1, 8000, 17),
(False, 4, 8123, 33),
(False, 9, 8000, 17),
(False, 1, 8192, 33),
(False, 3, 8192, 17),
(True, 1, 7567, 17),
(True, 4, 8000, 17),
(True, 9, 8000, 37),
(True, 1, 8192, 17),
(True, 3, 8192, 40),
],
)
def test_welford_non_power_of_2_rsplit(self, persistent, x, r, rsplit):
def fn(x):
return torch.var_mean(x, dim=-1)
cfg = {"XBLOCK": 64, "RSPLIT": rsplit, "num_warps": 8}
if not persistent:
cfg["R0_BLOCK"] = 64
args = [torch.randn(x, r, device="cuda")]
self._check(fn, args, persistent=persistent, cfg=cfg)
@parametrize("persistent", [False, True])
@parametrize("rsplit", [32, 33])
def test_fixed_config_with_larger_xblock_than_xnumel(self, persistent, rsplit):
def fn(x, y):
return [
torch.any(x == y),
torch.all(x == y),
torch.any(x != y),
torch.all(x != y),
torch.any(x < y),
torch.all(x > y),
torch.mean(x + y),
]
cfg = {"XBLOCK": 128, "RSPLIT": rsplit, "num_warps": 16, "num_stages": 1}
if not persistent:
cfg["R0_BLOCK"] = 64
args = [torch.randn(1024, device="cuda") for _ in range(2)]
with torch._inductor.virtualized.V.set_choices_handler(MyHeuristics()):
expected = fn(*args)
fn = torch.compile(fn, fullgraph=True)
result, (source_code,) = run_and_get_code(fn, *args)
self.assertEqual(result, expected)
self.assertIn("@triton_heuristics.fixed_config(", source_code)
self._check(fn, args, persistent=persistent, cfg=cfg)
if __name__ == "__main__":

View File

@ -93,9 +93,12 @@ from .triton_utils import (
if TYPE_CHECKING:
from types import ModuleType
from typing import TypeVar
from ..ir import IRNode
_T = TypeVar("_T")
log = logging.getLogger(__name__)
perf_hint_log = torch._logging.getArtifactLogger(__name__, "perf_hints")
schedule_log = torch._logging.getArtifactLogger(__name__, "schedule")
@ -112,14 +115,14 @@ class OpDtypeSupport:
convert_outputs: dict[str, bool] = {}
@classmethod
def register_upcast(cls, func: Callable[..., str], convert_output: bool):
def register_upcast(cls, func: Callable[..., str], convert_output: bool) -> None:
op_name = func.__name__
cls.supported_dtypes[op_name] = OrderedSet([torch.float32, torch.float64])
cls.convert_outputs[op_name] = convert_output
@lru_cache(None)
def gen_attr_descriptor_import():
def gen_attr_descriptor_import() -> str:
"""
import AttrsDescriptor if the triton version is new enough to have this
class defined.
@ -138,7 +141,7 @@ def gen_attr_descriptor_import():
@lru_cache(None)
def gen_common_triton_imports():
def gen_common_triton_imports() -> str:
imports = IndentedBuffer()
imports.splice(
"""
@ -197,19 +200,19 @@ class IndexingOptions:
_has_rindex: bool
index: sympy.Expr
def has_mask(self):
def has_mask(self) -> bool:
return bool(self.mask_vars)
def has_indirect(self):
def has_indirect(self) -> bool:
return free_symbol_is_type(self.index, SymT.TMP)
def has_rindex(self):
def has_rindex(self) -> bool:
return self._has_rindex
def has_tmpmask(self):
def has_tmpmask(self) -> bool:
return "tmp" in self.mask_str
def has_rmask(self):
def has_rmask(self) -> bool:
return any(str(mask).startswith("r") for mask in self.mask_vars)
@ -445,7 +448,7 @@ class BlockPtrOptions:
)
]
def boundary_check(self):
def boundary_check(self) -> list[int]:
assert self._boundary_check is not None
return self._boundary_check
@ -468,7 +471,7 @@ class BlockPtrOptions:
]
return advance
def has_indirect(self):
def has_indirect(self) -> bool:
return False # block_ptr can't do indirect indexing
def has_rindex(self) -> bool:
@ -477,19 +480,19 @@ class BlockPtrOptions:
for expr in self.block_shape
)
def has_rmask(self):
def has_rmask(self) -> bool:
return self.has_rindex()
def has_tmpmask(self):
def has_tmpmask(self) -> bool:
return False # block_ptr can't do indirect indexing
def has_mask(self):
def has_mask(self) -> bool:
return bool(self.boundary_check())
def triton_reshape(
value: str, old_shape: Sequence[sympy.Expr], new_shape: Sequence[sympy.Expr]
):
) -> str:
"""Workaround https://github.com/openai/triton/issues/2836"""
assert isinstance(old_shape, list) and isinstance(new_shape, list)
@ -519,25 +522,25 @@ def triton_reshape(
# inconsistent with Python semantics (and consistent with C semantics). We
# must override all of these, or it is potential silent correctness problem
class TritonPrinter(PythonPrinter):
def _print_TruncToInt(self, expr):
def _print_TruncToInt(self, expr: sympy.Expr) -> str:
assert len(expr.args) == 1
return (
f"libdevice.trunc({self._print(expr.args[0])}).to({V.kernel.index_dtype})"
)
def _print_Float(self, expr):
def _print_Float(self, expr: sympy.Expr) -> str:
if config.is_fbcode() and torch.version.hip:
ret = f"{expr}"
else:
ret = f"tl.full([], {expr}, tl.float64)"
return ret
def _print_ToFloat(self, expr):
def _print_ToFloat(self, expr: sympy.Expr) -> str:
assert len(expr.args) == 1
s = self.parenthesize(expr.args[0], PRECEDENCE["Atom"] - 0.5)
return f"{s}.to(tl.float64)"
def _print_PythonMod(self, expr):
def _print_PythonMod(self, expr: sympy.Expr) -> str:
quot, div = expr.args
if quot.is_nonnegative and div.is_nonnegative:
return self.stringify(expr.args, " % ", PRECEDENCE["Atom"] - 0.5)
@ -545,7 +548,7 @@ class TritonPrinter(PythonPrinter):
div_s = self._print(div)
return f"triton_helpers.remainder_integer({quot_s}, {div_s})"
def _print_FloorDiv(self, expr):
def _print_FloorDiv(self, expr: sympy.Expr) -> str:
assert expr.is_integer
quot, div = expr.args
if quot.is_nonnegative and div.is_nonnegative:
@ -556,42 +559,42 @@ class TritonPrinter(PythonPrinter):
# TODO: This is wrong, when lhs, rhs > 2**53, Python does a higher
# precision algorithm, which we would need to replicate here
def _print_IntTrueDiv(self, expr):
def _print_IntTrueDiv(self, expr: sympy.Expr) -> str:
return self.stringify(expr.args, " / ", PRECEDENCE["Atom"] - 0.5)
# NB: sympy.floor/ceiling produce integers, so we have to do the
# conversion to index dtype
def _print_floor(self, expr):
def _print_floor(self, expr: sympy.Expr) -> str:
assert len(expr.args) == 1
return (
f"libdevice.floor({self._print(expr.args[0])}).to({V.kernel.index_dtype})"
)
def _print_FloorToInt(self, expr):
def _print_FloorToInt(self, expr: sympy.Expr) -> str:
assert len(expr.args) == 1
return (
f"libdevice.floor({self._print(expr.args[0])}).to({V.kernel.index_dtype})"
)
def _print_ceiling(self, expr):
def _print_ceiling(self, expr: sympy.Expr) -> str:
assert len(expr.args) == 1
return f"libdevice.ceil({self._print(expr.args[0])}).to({V.kernel.index_dtype})"
def _print_CeilToInt(self, expr):
def _print_CeilToInt(self, expr: sympy.Expr) -> str:
assert len(expr.args) == 1
return f"libdevice.ceil({self._print(expr.args[0])}).to({V.kernel.index_dtype})"
def _helper_sqrt(self, expr):
def _helper_sqrt(self, expr: sympy.Expr) -> str:
return f"libdevice.sqrt(({self._print(expr)}).to(tl.float32))"
def _print_FloatPow(self, expr):
def _print_FloatPow(self, expr: sympy.Expr) -> str:
return (
f"libdevice.pow({self._print(expr.args[0])}, {self._print(expr.args[1])})"
)
_print_PowByNatural = _print_FloatPow
def _print_Where(self, expr):
def _print_Where(self, expr: sympy.Expr) -> str:
c = self.doprint(expr.args[0])
p = self.doprint(expr.args[1])
q = self.doprint(expr.args[2])
@ -616,59 +619,59 @@ class TritonPrinter(PythonPrinter):
assert cmp in (">", "<"), f"Unexpected comparator: '{cmp}'"
return f"({a} * ({a} {cmp}= {b}) + {b} * ({b} {cmp} {a}))"
def _print_Min(self, expr):
def _print_Min(self, expr: sympy.Expr) -> str:
return self._print_min_max_helper(expr, "<")
def _print_Max(self, expr):
def _print_Max(self, expr: sympy.Expr) -> str:
return self._print_min_max_helper(expr, ">")
def _print_Abs(self, expr):
def _print_Abs(self, expr: sympy.Expr) -> str:
assert len(expr.args) == 1
return f"tl_math.abs({self._print(expr.args[0])})"
def _print_OpaqueUnaryFn_cos(self, expr):
def _print_OpaqueUnaryFn_cos(self, expr: sympy.Expr) -> str:
assert len(expr.args) == 1
return f"libdevice.cos(({self._print(expr.args[0])}).to(tl.float32))"
def _print_OpaqueUnaryFn_cosh(self, expr):
def _print_OpaqueUnaryFn_cosh(self, expr: sympy.Expr) -> str:
assert len(expr.args) == 1
return f"libdevice.cosh(({self._print(expr.args[0])}).to(tl.float32))"
def _print_OpaqueUnaryFn_acos(self, expr):
def _print_OpaqueUnaryFn_acos(self, expr: sympy.Expr) -> str:
assert len(expr.args) == 1
return f"libdevice.acos(({self._print(expr.args[0])}).to(tl.float32))"
def _print_OpaqueUnaryFn_sin(self, expr):
def _print_OpaqueUnaryFn_sin(self, expr: sympy.Expr) -> str:
assert len(expr.args) == 1
return f"libdevice.sin(({self._print(expr.args[0])}).to(tl.float32))"
def _print_OpaqueUnaryFn_sinh(self, expr):
def _print_OpaqueUnaryFn_sinh(self, expr: sympy.Expr) -> str:
assert len(expr.args) == 1
return f"libdevice.sinh(({self._print(expr.args[0])}).to(tl.float32))"
def _print_OpaqueUnaryFn_asin(self, expr):
def _print_OpaqueUnaryFn_asin(self, expr: sympy.Expr) -> str:
assert len(expr.args) == 1
return f"libdevice.asin(({self._print(expr.args[0])}).to(tl.float32))"
def _print_OpaqueUnaryFn_tan(self, expr):
def _print_OpaqueUnaryFn_tan(self, expr: sympy.Expr) -> str:
assert len(expr.args) == 1
return f"libdevice.tan(({self._print(expr.args[0])}).to(tl.float32))"
def _print_OpaqueUnaryFn_tanh(self, expr):
def _print_OpaqueUnaryFn_tanh(self, expr: sympy.Expr) -> str:
assert len(expr.args) == 1
return f"libdevice.tanh(({self._print(expr.args[0])}).to(tl.float32))"
def _print_OpaqueUnaryFn_atan(self, expr):
def _print_OpaqueUnaryFn_atan(self, expr: sympy.Expr) -> str:
assert len(expr.args) == 1
return f"libdevice.atan(({self._print(expr.args[0])}).to(tl.float32))"
def _print_RoundToInt(self, expr):
def _print_RoundToInt(self, expr: sympy.Expr) -> str:
assert len(expr.args) == 1
return (
f"libdevice.llrint({self._print(expr.args[0])}).to({V.kernel.index_dtype})"
)
def _print_RoundDecimal(self, expr):
def _print_RoundDecimal(self, expr: sympy.Expr) -> str:
assert len(expr.args) == 2
number, ndigits = expr.args
if number.is_integer:
@ -740,7 +743,7 @@ class TritonCSEVariable(CSEVariable):
break
def maybe_upcast_float32(convert_output: bool = True):
def maybe_upcast_float32(convert_output: bool = True) -> Callable[[_T], _T]:
"""
Codegen helper to upcast arguments to float32, depending on the config and dtype.
This decorates tl.math/libdevice codegen functions.
@ -796,7 +799,7 @@ def maybe_upcast_float32(convert_output: bool = True):
return wrapped
return decorator
return decorator # type: ignore[return-value]
class TritonOverrides(OpOverrides):
@ -1568,6 +1571,9 @@ class TritonKernel(SIMDKernel):
self.codegen_range_tree()
if self.cooperative_reduction:
self.init_cooperative_reduction_mask()
def dtype_to_str(self, dtype: torch.dtype) -> str:
return triton_type(dtype)
@ -1593,14 +1599,16 @@ class TritonKernel(SIMDKernel):
self.args
)
self.body.splice(
"""
"""\
RSPLIT_NEXT_POWER_OF_2: tl.constexpr = triton_helpers.constexpr_next_power_of_2(RSPLIT)
RSPLIT_IS_POWER_OF_2: tl.constexpr = RSPLIT == RSPLIT_NEXT_POWER_OF_2
HAS_RSPLIT: tl.constexpr = RSPLIT > 1
rsplit_id = tl.program_id(0)
num_rblocks = (rnumel + RBLOCK - 1) // RBLOCK
rsplit_chunk = (num_rblocks + RSPLIT - 1) // RSPLIT * RBLOCK
rsplit_start = rsplit_chunk * rsplit_id
rsplit_end = rsplit_chunk * (rsplit_id + 1)
""",
strip=True,
)
if any(
not self._has_constant_mask(tree)
@ -1611,6 +1619,27 @@ class TritonKernel(SIMDKernel):
"rsplit_end = tl.where(rsplit_end < rnumel, rsplit_end, rnumel)"
)
def init_cooperative_reduction_mask(self):
rsplit_arange = "tl.arange(0, RSPLIT_NEXT_POWER_OF_2)"
if not self.no_x_dim:
rsplit_arange = f"{rsplit_arange}[None, :]"
self.body.writeline(f"rsplit_arange = {rsplit_arange}")
if self._has_constant_xmask():
self.body.splice(
"""\
if RSPLIT_IS_POWER_OF_2:
rsplit_mask: tl.constexpr = None
else:
rsplit_mask = rsplit_arange < RSPLIT
"""
)
else:
assert not self.no_x_dim
self.body.writeline(
"rsplit_mask = xmask if RSPLIT_IS_POWER_OF_2 else ((rsplit_arange < RSPLIT) & xmask)"
)
def codegen_range_tree(self):
for tree in self.range_trees:
# reduction indexing goes inside a loop
@ -1934,9 +1963,10 @@ class TritonKernel(SIMDKernel):
if isinstance(index, sympy.Integer):
expand_str = f"{copy_shape}.shape" if copy_shape else self.dense_size_str()
index_str = f"tl.full({expand_str}, {index_str}, tl.int32)"
mask_vars = OrderedSet()
mask_vars = dense_mask_vars
if self._load_mask:
mask_vars.add(self._load_mask)
self.filter_masks(mask_vars)
mask_str = " & ".join(sorted(map(str, mask_vars))) if mask_vars else "None"
return IndexingOptions(
index_str, mask_vars, mask_str, expand_str, has_rindex, index
@ -2544,7 +2574,7 @@ class TritonKernel(SIMDKernel):
exit_stack = contextlib.ExitStack()
for buf in (self.post_loop_combine, self.post_loop_store):
# only do cooperative reduction combines if we have more than one thread block
buf.writeline("if RSPLIT > 1:")
buf.writeline("if HAS_RSPLIT:")
exit_stack.enter_context(buf.indent())
if reduction_type in ("argmax", "argmin"):
@ -2728,7 +2758,6 @@ class TritonKernel(SIMDKernel):
"""
xnumel = self.numels["x"]
mask = "xindex < xnumel" if not self._has_constant_xmask() else None
expand = "" if self.no_x_dim else "[None,:]"
nbytes = xnumel * dtype.itemsize * self.max_rsplit()
ws_name, ws_offset = self.cooperative_reduction_workspace_cache.allocate(nbytes)
@ -2741,8 +2770,8 @@ class TritonKernel(SIMDKernel):
strip=True,
)
self.post_loop_store.writeline(
f"{result_var}_peers = tl.load({result_var}_ws + (xindex * RSPLIT + tl.arange(0, RSPLIT){expand}), "
f"{mask}, eviction_policy='evict_first')"
f"{result_var}_peers = tl.load({result_var}_ws + (xindex * RSPLIT + rsplit_arange), "
f"rsplit_mask, eviction_policy='evict_first', other=triton_helpers.zero_other({result_var}_ws, rsplit_mask))"
)
return f"{result_var}_peers"
@ -3095,7 +3124,7 @@ class TritonKernel(SIMDKernel):
sem_ptr = f"{self.semaphores_name} + tl.program_id(1)"
self.body.splice(
f"""
if RSPLIT > 1:
if HAS_RSPLIT:
triton_helpers.x_grid_barrier({sem_ptr})
""",
strip=True,
@ -3539,9 +3568,11 @@ class TritonKernel(SIMDKernel):
code.writeline(f"{tree.prefix}numel = {int(simplified_tree_numel)}")
if tree.is_reduction and self.persistent_reduction:
val = self._get_persistent_RBLOCK(tree.numel)
if self.cooperative_reduction:
val = f"{val} // RSPLIT"
numel = self.kexpr(self.rename_indexing(tree.numel))
val = f"triton_helpers.constexpr_next_power_of_2(({numel} + RSPLIT - 1) // RSPLIT)"
else:
val = self._get_persistent_RBLOCK(tree.numel)
code.writeline(f"{tree.prefix.upper()}BLOCK: tl.constexpr = {val}")
if tree.prefix == "x" and self.no_x_dim:
@ -3598,7 +3629,7 @@ class TritonKernel(SIMDKernel):
for ws in reversed(self.args.workspace_args):
wrapper.generate_workspace_deallocation(ws)
def codegen_nan_check(self):
def codegen_nan_check(self) -> None:
wrapper = V.graph.wrapper_code
_, call_args, arg_signatures, _ = self.args.python_argdefs()
for arg, arg_signature in zip(call_args, arg_signatures):
@ -3613,7 +3644,7 @@ class TritonKernel(SIMDKernel):
line = f"assert not {arg}.isinf().any().item()"
wrapper.writeline(line)
def create_cse_var(self, *args, **kwargs):
def create_cse_var(self, *args, **kwargs) -> TritonCSEVariable:
return TritonCSEVariable(*args, **kwargs)
def codegen_iteration_ranges_entry(self, entry: IterationRangesEntry):
@ -3624,7 +3655,7 @@ class TritonKernel(SIMDKernel):
# lift non-reduction stores outside loop
self.body.writeline(line)
def iteration_ranges_ranges_code(self, entry):
def iteration_ranges_ranges_code(self, entry: IterationRangesRoot) -> str:
assert entry.tensor_dim is not None
size = self.indexing_size_str(entry.tensor_dim)
index_dtype = self.index_dtype
@ -3637,13 +3668,15 @@ class TritonKernel(SIMDKernel):
suffix = f"{suffix} + rsplit_start"
return f"tl.arange(0, {entry.prefix.upper()}BLOCK){size}{suffix}"
def iteration_ranges_scalar_code(self, entry, value):
def iteration_ranges_scalar_code(
self, entry: IterationRangesRoot, value: Any
) -> str:
index_dtype = self.index_dtype
ndim = self.triton_tensor_ndim()
size = [1] * ndim
return f"tl.full({size}, {value}, {index_dtype})"
def iteration_ranges_get_pid(self, entry):
def iteration_ranges_get_pid(self, entry: IterationRangesRoot) -> str:
assert entry.grid_dim is not None
key = f"tl.program_id({entry.grid_dim})"
# y_grid has a limit, so express it in terms of y and z in case of overflow.
@ -3663,12 +3696,12 @@ class TritonKernel(SIMDKernel):
return f"{pid}.to({self.index_dtype})"
return pid
def max_block(self, prefix):
def max_block(self, prefix: str) -> int:
if self.fixed_config:
return self.fixed_config[f"{prefix.upper()}BLOCK"]
return TRITON_MAX_BLOCK[prefix.upper()]
def _has_constant_mask(self, tree: IterationRangesRoot):
def _has_constant_mask(self, tree: IterationRangesRoot) -> bool:
if not self.optimize_mask:
return False
@ -3710,12 +3743,12 @@ class TritonKernel(SIMDKernel):
return False
def _has_constant_xmask(self):
def _has_constant_xmask(self) -> bool:
xtree = self.range_trees[0]
assert xtree.prefix == "x"
return self._has_constant_mask(xtree)
def filter_masks(self, mask_vars):
def filter_masks(self, mask_vars: OrderedSet[str]) -> None:
for tree in self.range_trees:
if self._has_constant_mask(tree):
mask_vars.discard(f"{tree.prefix}mask")
@ -3730,7 +3763,7 @@ class TritonKernel(SIMDKernel):
for symt in list(TritonSymbols.reduction_types)[: self.num_reduction_dims]
]
def codegen_reduction_numels(self, buffer) -> None:
def codegen_reduction_numels(self, buffer: IndentedBuffer) -> None:
"""
Generates code that flattens ND reduction numels, block sizes, etc. into 1D.
"""
@ -3775,7 +3808,7 @@ class TritonKernel(SIMDKernel):
coeffs = self._get_reduction_index_coeffs()
return sympy_dot(coeffs, multi_inds)
def codegen_reduction_indices(self, buffer) -> None:
def codegen_reduction_indices(self, buffer: IndentedBuffer) -> None:
"""
Generates code that converts ND reduction indices into linear indices.
"""
@ -3791,7 +3824,9 @@ class TritonKernel(SIMDKernel):
rindex = self._flatten_reduction_indices(rn_inds)
buffer.splice(f"rindex = {self.index_to_str(rindex)}")
def iteration_ranges_codegen_header(self, entry, code):
def iteration_ranges_codegen_header(
self, entry: IterationRangesRoot, code: IndentedBuffer
) -> None:
x = entry.prefix
if entry.is_loop:
code.writeline(f"{entry.name} = {x}offset + {x}base")

View File

@ -1,10 +1,14 @@
# mypy: allow-untyped-decorators
# mypy: allow-untyped-defs
import warnings
from typing import Any, TypeVar
from .triton_compat import _log2, libdevice, math, tl, triton # noqa: F401
_T = TypeVar("_T")
def set_driver_to_cpu():
driver = triton.runtime.driver
if backend := triton.backends.backends.get("cpu", None):
@ -627,3 +631,42 @@ def x_grid_barrier(sem):
# TODO(jansel): is this needed?
tl.debug_barrier()
def triton_builtin(f: _T) -> _T:
"""
Decorator to mark a function as a Triton built-in function. These functions
are evaluated at compile time.
Args:
f (function): The function to be marked as a Triton built-in.
Returns:
function: The same function, marked as a Triton built-in.
"""
f.__triton_builtin__ = True # type: ignore[attr-defined]
return f
@triton_builtin
def constexpr_next_power_of_2(
n: tl.constexpr, *, _builder: object = None
) -> tl.constexpr:
"""
A version triton.next_power_of_two that can be used within a kernel on constants.
"""
assert isinstance(n, tl.constexpr)
return tl.constexpr(triton.next_power_of_2(n.value))
@triton_builtin
def zero_other(ptr: Any, mask: Any, *, _builder: object = None) -> tl.constexpr:
"""
Work around triton compile error: `ValueError: `other` cannot be provided without `mask``
A compile-time to check to return either 0 or None depending on the value of mask.
"""
if isinstance(mask, tl.constexpr) and mask.value is None:
return tl.constexpr(None)
if isinstance(ptr, tl.core.tensor) and ptr.dtype.is_floating():
return tl.constexpr(0.0)
return tl.constexpr(0)