mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-01 22:14:53 +08:00
Update (base update)
[ghstack-poisoned]
This commit is contained in:
@ -20,6 +20,30 @@ from torch.testing._internal.common_utils import (
|
||||
from torch.testing._internal.inductor_utils import HAS_CUDA
|
||||
|
||||
|
||||
class TestingHeuristics(InductorChoices):
|
||||
def __init__(self, *, cooperative: bool, persistent: bool, cfg: dict[str, int]):
|
||||
super().__init__()
|
||||
self.cooperative = cooperative
|
||||
self.persistent = persistent
|
||||
self.cfg = cfg
|
||||
self.call_count = 0
|
||||
|
||||
def triton_kernel_kwargs(
|
||||
self,
|
||||
kernel_cls: type[TritonKernel],
|
||||
features: SIMDKernelFeatures,
|
||||
groups: list[sympy.Expr],
|
||||
kernel_kwargs: dict[str, Any],
|
||||
) -> dict[str, Any]:
|
||||
self.call_count += 1
|
||||
return {
|
||||
**kernel_kwargs,
|
||||
"override_cooperative_reduction": self.cooperative,
|
||||
"override_persistent_reduction": self.persistent,
|
||||
"fixed_config": FixedTritonConfig(self.cfg),
|
||||
}
|
||||
|
||||
|
||||
@config.patch(
|
||||
{
|
||||
"triton.cooperative_reductions": True,
|
||||
@ -146,6 +170,19 @@ class MultiKernelCooperativeReductionTests(CooperativeReductionTests):
|
||||
)
|
||||
@instantiate_parametrized_tests
|
||||
class TestFixedConfigs(TestCase):
|
||||
def _check(self, fn, args, *, persistent=False, cooperative=True, cfg):
|
||||
expected = fn(*args)
|
||||
heuristic = TestingHeuristics(
|
||||
persistent=persistent, cooperative=cooperative, cfg=cfg
|
||||
)
|
||||
with torch._inductor.virtualized.V.set_choices_handler(heuristic):
|
||||
result, (source_code,) = run_and_get_code(
|
||||
torch.compile(fn, fullgraph=True), *args
|
||||
)
|
||||
self.assertEqual(result, expected)
|
||||
self.assertEqual(heuristic.call_count, 1)
|
||||
self.assertIn("@triton_heuristics.fixed_config(", source_code)
|
||||
|
||||
@parametrize(
|
||||
"persistent,cooperative,cfg",
|
||||
[
|
||||
@ -157,70 +194,61 @@ class TestFixedConfigs(TestCase):
|
||||
(False, True, {"XBLOCK": 2, "R0_BLOCK": 128, "RSPLIT": 16}),
|
||||
(True, True, {"XBLOCK": 1, "RSPLIT": 16}),
|
||||
(True, True, {"XBLOCK": 2, "RSPLIT": 16}),
|
||||
(False, True, {"XBLOCK": 1, "R0_BLOCK": 128, "RSPLIT": 17}),
|
||||
(False, True, {"XBLOCK": 2, "R0_BLOCK": 128, "RSPLIT": 17}),
|
||||
(True, True, {"XBLOCK": 1, "RSPLIT": 17}),
|
||||
(True, True, {"XBLOCK": 2, "RSPLIT": 17}),
|
||||
],
|
||||
)
|
||||
def test_fixed_configs(self, persistent, cooperative, cfg):
|
||||
class MyHeuristics(InductorChoices):
|
||||
def triton_kernel_kwargs(
|
||||
self,
|
||||
kernel_cls: type[TritonKernel],
|
||||
features: SIMDKernelFeatures,
|
||||
groups: list[sympy.Expr],
|
||||
kernel_kwargs: dict[str, Any],
|
||||
) -> dict[str, Any]:
|
||||
return {
|
||||
**kernel_kwargs,
|
||||
"override_cooperative_reduction": cooperative,
|
||||
"override_persistent_reduction": persistent,
|
||||
"fixed_config": FixedTritonConfig(cfg),
|
||||
}
|
||||
|
||||
def fn(x):
|
||||
return torch.softmax(x + 1, dim=-1) + x
|
||||
|
||||
args = [torch.randn(8, 8000, device="cuda")]
|
||||
with torch._inductor.virtualized.V.set_choices_handler(MyHeuristics()):
|
||||
expected = fn(*args)
|
||||
fn = torch.compile(fn, fullgraph=True)
|
||||
result, (source_code,) = run_and_get_code(fn, *args)
|
||||
self.assertEqual(result, expected)
|
||||
self.assertIn("@triton_heuristics.fixed_config(", source_code)
|
||||
self._check(fn, args, persistent=persistent, cooperative=cooperative, cfg=cfg)
|
||||
|
||||
def test_fixed_config_with_larger_xblock_than_xnumel(self):
|
||||
class MyHeuristics(InductorChoices):
|
||||
def triton_kernel_kwargs(
|
||||
self,
|
||||
kernel_cls: type[TritonKernel],
|
||||
features: SIMDKernelFeatures,
|
||||
groups: list[sympy.Expr],
|
||||
kernel_kwargs: dict[str, Any],
|
||||
) -> dict[str, Any]:
|
||||
return {
|
||||
**kernel_kwargs,
|
||||
"override_cooperative_reduction": True,
|
||||
"override_persistent_reduction": True,
|
||||
"fixed_config": FixedTritonConfig(
|
||||
{"XBLOCK": 128, "RSPLIT": 32, "num_warps": 16, "num_stages": 1}
|
||||
),
|
||||
}
|
||||
@parametrize(
|
||||
"persistent,x,r,rsplit",
|
||||
[
|
||||
(False, 1, 8000, 17),
|
||||
(False, 4, 8123, 33),
|
||||
(False, 9, 8000, 17),
|
||||
(False, 1, 8192, 33),
|
||||
(False, 3, 8192, 17),
|
||||
(True, 1, 7567, 17),
|
||||
(True, 4, 8000, 17),
|
||||
(True, 9, 8000, 37),
|
||||
(True, 1, 8192, 17),
|
||||
(True, 3, 8192, 40),
|
||||
],
|
||||
)
|
||||
def test_welford_non_power_of_2_rsplit(self, persistent, x, r, rsplit):
|
||||
def fn(x):
|
||||
return torch.var_mean(x, dim=-1)
|
||||
|
||||
cfg = {"XBLOCK": 64, "RSPLIT": rsplit, "num_warps": 8}
|
||||
if not persistent:
|
||||
cfg["R0_BLOCK"] = 64
|
||||
args = [torch.randn(x, r, device="cuda")]
|
||||
self._check(fn, args, persistent=persistent, cfg=cfg)
|
||||
|
||||
@parametrize("persistent", [False, True])
|
||||
@parametrize("rsplit", [32, 33])
|
||||
def test_fixed_config_with_larger_xblock_than_xnumel(self, persistent, rsplit):
|
||||
def fn(x, y):
|
||||
return [
|
||||
torch.any(x == y),
|
||||
torch.all(x == y),
|
||||
torch.any(x != y),
|
||||
torch.all(x != y),
|
||||
torch.any(x < y),
|
||||
torch.all(x > y),
|
||||
torch.mean(x + y),
|
||||
]
|
||||
|
||||
cfg = {"XBLOCK": 128, "RSPLIT": rsplit, "num_warps": 16, "num_stages": 1}
|
||||
if not persistent:
|
||||
cfg["R0_BLOCK"] = 64
|
||||
args = [torch.randn(1024, device="cuda") for _ in range(2)]
|
||||
with torch._inductor.virtualized.V.set_choices_handler(MyHeuristics()):
|
||||
expected = fn(*args)
|
||||
fn = torch.compile(fn, fullgraph=True)
|
||||
result, (source_code,) = run_and_get_code(fn, *args)
|
||||
self.assertEqual(result, expected)
|
||||
self.assertIn("@triton_heuristics.fixed_config(", source_code)
|
||||
self._check(fn, args, persistent=persistent, cfg=cfg)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@ -93,9 +93,12 @@ from .triton_utils import (
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from types import ModuleType
|
||||
from typing import TypeVar
|
||||
|
||||
from ..ir import IRNode
|
||||
|
||||
_T = TypeVar("_T")
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
perf_hint_log = torch._logging.getArtifactLogger(__name__, "perf_hints")
|
||||
schedule_log = torch._logging.getArtifactLogger(__name__, "schedule")
|
||||
@ -112,14 +115,14 @@ class OpDtypeSupport:
|
||||
convert_outputs: dict[str, bool] = {}
|
||||
|
||||
@classmethod
|
||||
def register_upcast(cls, func: Callable[..., str], convert_output: bool):
|
||||
def register_upcast(cls, func: Callable[..., str], convert_output: bool) -> None:
|
||||
op_name = func.__name__
|
||||
cls.supported_dtypes[op_name] = OrderedSet([torch.float32, torch.float64])
|
||||
cls.convert_outputs[op_name] = convert_output
|
||||
|
||||
|
||||
@lru_cache(None)
|
||||
def gen_attr_descriptor_import():
|
||||
def gen_attr_descriptor_import() -> str:
|
||||
"""
|
||||
import AttrsDescriptor if the triton version is new enough to have this
|
||||
class defined.
|
||||
@ -138,7 +141,7 @@ def gen_attr_descriptor_import():
|
||||
|
||||
|
||||
@lru_cache(None)
|
||||
def gen_common_triton_imports():
|
||||
def gen_common_triton_imports() -> str:
|
||||
imports = IndentedBuffer()
|
||||
imports.splice(
|
||||
"""
|
||||
@ -197,19 +200,19 @@ class IndexingOptions:
|
||||
_has_rindex: bool
|
||||
index: sympy.Expr
|
||||
|
||||
def has_mask(self):
|
||||
def has_mask(self) -> bool:
|
||||
return bool(self.mask_vars)
|
||||
|
||||
def has_indirect(self):
|
||||
def has_indirect(self) -> bool:
|
||||
return free_symbol_is_type(self.index, SymT.TMP)
|
||||
|
||||
def has_rindex(self):
|
||||
def has_rindex(self) -> bool:
|
||||
return self._has_rindex
|
||||
|
||||
def has_tmpmask(self):
|
||||
def has_tmpmask(self) -> bool:
|
||||
return "tmp" in self.mask_str
|
||||
|
||||
def has_rmask(self):
|
||||
def has_rmask(self) -> bool:
|
||||
return any(str(mask).startswith("r") for mask in self.mask_vars)
|
||||
|
||||
|
||||
@ -445,7 +448,7 @@ class BlockPtrOptions:
|
||||
)
|
||||
]
|
||||
|
||||
def boundary_check(self):
|
||||
def boundary_check(self) -> list[int]:
|
||||
assert self._boundary_check is not None
|
||||
return self._boundary_check
|
||||
|
||||
@ -468,7 +471,7 @@ class BlockPtrOptions:
|
||||
]
|
||||
return advance
|
||||
|
||||
def has_indirect(self):
|
||||
def has_indirect(self) -> bool:
|
||||
return False # block_ptr can't do indirect indexing
|
||||
|
||||
def has_rindex(self) -> bool:
|
||||
@ -477,19 +480,19 @@ class BlockPtrOptions:
|
||||
for expr in self.block_shape
|
||||
)
|
||||
|
||||
def has_rmask(self):
|
||||
def has_rmask(self) -> bool:
|
||||
return self.has_rindex()
|
||||
|
||||
def has_tmpmask(self):
|
||||
def has_tmpmask(self) -> bool:
|
||||
return False # block_ptr can't do indirect indexing
|
||||
|
||||
def has_mask(self):
|
||||
def has_mask(self) -> bool:
|
||||
return bool(self.boundary_check())
|
||||
|
||||
|
||||
def triton_reshape(
|
||||
value: str, old_shape: Sequence[sympy.Expr], new_shape: Sequence[sympy.Expr]
|
||||
):
|
||||
) -> str:
|
||||
"""Workaround https://github.com/openai/triton/issues/2836"""
|
||||
assert isinstance(old_shape, list) and isinstance(new_shape, list)
|
||||
|
||||
@ -519,25 +522,25 @@ def triton_reshape(
|
||||
# inconsistent with Python semantics (and consistent with C semantics). We
|
||||
# must override all of these, or it is potential silent correctness problem
|
||||
class TritonPrinter(PythonPrinter):
|
||||
def _print_TruncToInt(self, expr):
|
||||
def _print_TruncToInt(self, expr: sympy.Expr) -> str:
|
||||
assert len(expr.args) == 1
|
||||
return (
|
||||
f"libdevice.trunc({self._print(expr.args[0])}).to({V.kernel.index_dtype})"
|
||||
)
|
||||
|
||||
def _print_Float(self, expr):
|
||||
def _print_Float(self, expr: sympy.Expr) -> str:
|
||||
if config.is_fbcode() and torch.version.hip:
|
||||
ret = f"{expr}"
|
||||
else:
|
||||
ret = f"tl.full([], {expr}, tl.float64)"
|
||||
return ret
|
||||
|
||||
def _print_ToFloat(self, expr):
|
||||
def _print_ToFloat(self, expr: sympy.Expr) -> str:
|
||||
assert len(expr.args) == 1
|
||||
s = self.parenthesize(expr.args[0], PRECEDENCE["Atom"] - 0.5)
|
||||
return f"{s}.to(tl.float64)"
|
||||
|
||||
def _print_PythonMod(self, expr):
|
||||
def _print_PythonMod(self, expr: sympy.Expr) -> str:
|
||||
quot, div = expr.args
|
||||
if quot.is_nonnegative and div.is_nonnegative:
|
||||
return self.stringify(expr.args, " % ", PRECEDENCE["Atom"] - 0.5)
|
||||
@ -545,7 +548,7 @@ class TritonPrinter(PythonPrinter):
|
||||
div_s = self._print(div)
|
||||
return f"triton_helpers.remainder_integer({quot_s}, {div_s})"
|
||||
|
||||
def _print_FloorDiv(self, expr):
|
||||
def _print_FloorDiv(self, expr: sympy.Expr) -> str:
|
||||
assert expr.is_integer
|
||||
quot, div = expr.args
|
||||
if quot.is_nonnegative and div.is_nonnegative:
|
||||
@ -556,42 +559,42 @@ class TritonPrinter(PythonPrinter):
|
||||
|
||||
# TODO: This is wrong, when lhs, rhs > 2**53, Python does a higher
|
||||
# precision algorithm, which we would need to replicate here
|
||||
def _print_IntTrueDiv(self, expr):
|
||||
def _print_IntTrueDiv(self, expr: sympy.Expr) -> str:
|
||||
return self.stringify(expr.args, " / ", PRECEDENCE["Atom"] - 0.5)
|
||||
|
||||
# NB: sympy.floor/ceiling produce integers, so we have to do the
|
||||
# conversion to index dtype
|
||||
def _print_floor(self, expr):
|
||||
def _print_floor(self, expr: sympy.Expr) -> str:
|
||||
assert len(expr.args) == 1
|
||||
return (
|
||||
f"libdevice.floor({self._print(expr.args[0])}).to({V.kernel.index_dtype})"
|
||||
)
|
||||
|
||||
def _print_FloorToInt(self, expr):
|
||||
def _print_FloorToInt(self, expr: sympy.Expr) -> str:
|
||||
assert len(expr.args) == 1
|
||||
return (
|
||||
f"libdevice.floor({self._print(expr.args[0])}).to({V.kernel.index_dtype})"
|
||||
)
|
||||
|
||||
def _print_ceiling(self, expr):
|
||||
def _print_ceiling(self, expr: sympy.Expr) -> str:
|
||||
assert len(expr.args) == 1
|
||||
return f"libdevice.ceil({self._print(expr.args[0])}).to({V.kernel.index_dtype})"
|
||||
|
||||
def _print_CeilToInt(self, expr):
|
||||
def _print_CeilToInt(self, expr: sympy.Expr) -> str:
|
||||
assert len(expr.args) == 1
|
||||
return f"libdevice.ceil({self._print(expr.args[0])}).to({V.kernel.index_dtype})"
|
||||
|
||||
def _helper_sqrt(self, expr):
|
||||
def _helper_sqrt(self, expr: sympy.Expr) -> str:
|
||||
return f"libdevice.sqrt(({self._print(expr)}).to(tl.float32))"
|
||||
|
||||
def _print_FloatPow(self, expr):
|
||||
def _print_FloatPow(self, expr: sympy.Expr) -> str:
|
||||
return (
|
||||
f"libdevice.pow({self._print(expr.args[0])}, {self._print(expr.args[1])})"
|
||||
)
|
||||
|
||||
_print_PowByNatural = _print_FloatPow
|
||||
|
||||
def _print_Where(self, expr):
|
||||
def _print_Where(self, expr: sympy.Expr) -> str:
|
||||
c = self.doprint(expr.args[0])
|
||||
p = self.doprint(expr.args[1])
|
||||
q = self.doprint(expr.args[2])
|
||||
@ -616,59 +619,59 @@ class TritonPrinter(PythonPrinter):
|
||||
assert cmp in (">", "<"), f"Unexpected comparator: '{cmp}'"
|
||||
return f"({a} * ({a} {cmp}= {b}) + {b} * ({b} {cmp} {a}))"
|
||||
|
||||
def _print_Min(self, expr):
|
||||
def _print_Min(self, expr: sympy.Expr) -> str:
|
||||
return self._print_min_max_helper(expr, "<")
|
||||
|
||||
def _print_Max(self, expr):
|
||||
def _print_Max(self, expr: sympy.Expr) -> str:
|
||||
return self._print_min_max_helper(expr, ">")
|
||||
|
||||
def _print_Abs(self, expr):
|
||||
def _print_Abs(self, expr: sympy.Expr) -> str:
|
||||
assert len(expr.args) == 1
|
||||
return f"tl_math.abs({self._print(expr.args[0])})"
|
||||
|
||||
def _print_OpaqueUnaryFn_cos(self, expr):
|
||||
def _print_OpaqueUnaryFn_cos(self, expr: sympy.Expr) -> str:
|
||||
assert len(expr.args) == 1
|
||||
return f"libdevice.cos(({self._print(expr.args[0])}).to(tl.float32))"
|
||||
|
||||
def _print_OpaqueUnaryFn_cosh(self, expr):
|
||||
def _print_OpaqueUnaryFn_cosh(self, expr: sympy.Expr) -> str:
|
||||
assert len(expr.args) == 1
|
||||
return f"libdevice.cosh(({self._print(expr.args[0])}).to(tl.float32))"
|
||||
|
||||
def _print_OpaqueUnaryFn_acos(self, expr):
|
||||
def _print_OpaqueUnaryFn_acos(self, expr: sympy.Expr) -> str:
|
||||
assert len(expr.args) == 1
|
||||
return f"libdevice.acos(({self._print(expr.args[0])}).to(tl.float32))"
|
||||
|
||||
def _print_OpaqueUnaryFn_sin(self, expr):
|
||||
def _print_OpaqueUnaryFn_sin(self, expr: sympy.Expr) -> str:
|
||||
assert len(expr.args) == 1
|
||||
return f"libdevice.sin(({self._print(expr.args[0])}).to(tl.float32))"
|
||||
|
||||
def _print_OpaqueUnaryFn_sinh(self, expr):
|
||||
def _print_OpaqueUnaryFn_sinh(self, expr: sympy.Expr) -> str:
|
||||
assert len(expr.args) == 1
|
||||
return f"libdevice.sinh(({self._print(expr.args[0])}).to(tl.float32))"
|
||||
|
||||
def _print_OpaqueUnaryFn_asin(self, expr):
|
||||
def _print_OpaqueUnaryFn_asin(self, expr: sympy.Expr) -> str:
|
||||
assert len(expr.args) == 1
|
||||
return f"libdevice.asin(({self._print(expr.args[0])}).to(tl.float32))"
|
||||
|
||||
def _print_OpaqueUnaryFn_tan(self, expr):
|
||||
def _print_OpaqueUnaryFn_tan(self, expr: sympy.Expr) -> str:
|
||||
assert len(expr.args) == 1
|
||||
return f"libdevice.tan(({self._print(expr.args[0])}).to(tl.float32))"
|
||||
|
||||
def _print_OpaqueUnaryFn_tanh(self, expr):
|
||||
def _print_OpaqueUnaryFn_tanh(self, expr: sympy.Expr) -> str:
|
||||
assert len(expr.args) == 1
|
||||
return f"libdevice.tanh(({self._print(expr.args[0])}).to(tl.float32))"
|
||||
|
||||
def _print_OpaqueUnaryFn_atan(self, expr):
|
||||
def _print_OpaqueUnaryFn_atan(self, expr: sympy.Expr) -> str:
|
||||
assert len(expr.args) == 1
|
||||
return f"libdevice.atan(({self._print(expr.args[0])}).to(tl.float32))"
|
||||
|
||||
def _print_RoundToInt(self, expr):
|
||||
def _print_RoundToInt(self, expr: sympy.Expr) -> str:
|
||||
assert len(expr.args) == 1
|
||||
return (
|
||||
f"libdevice.llrint({self._print(expr.args[0])}).to({V.kernel.index_dtype})"
|
||||
)
|
||||
|
||||
def _print_RoundDecimal(self, expr):
|
||||
def _print_RoundDecimal(self, expr: sympy.Expr) -> str:
|
||||
assert len(expr.args) == 2
|
||||
number, ndigits = expr.args
|
||||
if number.is_integer:
|
||||
@ -740,7 +743,7 @@ class TritonCSEVariable(CSEVariable):
|
||||
break
|
||||
|
||||
|
||||
def maybe_upcast_float32(convert_output: bool = True):
|
||||
def maybe_upcast_float32(convert_output: bool = True) -> Callable[[_T], _T]:
|
||||
"""
|
||||
Codegen helper to upcast arguments to float32, depending on the config and dtype.
|
||||
This decorates tl.math/libdevice codegen functions.
|
||||
@ -796,7 +799,7 @@ def maybe_upcast_float32(convert_output: bool = True):
|
||||
|
||||
return wrapped
|
||||
|
||||
return decorator
|
||||
return decorator # type: ignore[return-value]
|
||||
|
||||
|
||||
class TritonOverrides(OpOverrides):
|
||||
@ -1568,6 +1571,9 @@ class TritonKernel(SIMDKernel):
|
||||
|
||||
self.codegen_range_tree()
|
||||
|
||||
if self.cooperative_reduction:
|
||||
self.init_cooperative_reduction_mask()
|
||||
|
||||
def dtype_to_str(self, dtype: torch.dtype) -> str:
|
||||
return triton_type(dtype)
|
||||
|
||||
@ -1593,14 +1599,16 @@ class TritonKernel(SIMDKernel):
|
||||
self.args
|
||||
)
|
||||
self.body.splice(
|
||||
"""
|
||||
"""\
|
||||
RSPLIT_NEXT_POWER_OF_2: tl.constexpr = triton_helpers.constexpr_next_power_of_2(RSPLIT)
|
||||
RSPLIT_IS_POWER_OF_2: tl.constexpr = RSPLIT == RSPLIT_NEXT_POWER_OF_2
|
||||
HAS_RSPLIT: tl.constexpr = RSPLIT > 1
|
||||
rsplit_id = tl.program_id(0)
|
||||
num_rblocks = (rnumel + RBLOCK - 1) // RBLOCK
|
||||
rsplit_chunk = (num_rblocks + RSPLIT - 1) // RSPLIT * RBLOCK
|
||||
rsplit_start = rsplit_chunk * rsplit_id
|
||||
rsplit_end = rsplit_chunk * (rsplit_id + 1)
|
||||
""",
|
||||
strip=True,
|
||||
)
|
||||
if any(
|
||||
not self._has_constant_mask(tree)
|
||||
@ -1611,6 +1619,27 @@ class TritonKernel(SIMDKernel):
|
||||
"rsplit_end = tl.where(rsplit_end < rnumel, rsplit_end, rnumel)"
|
||||
)
|
||||
|
||||
def init_cooperative_reduction_mask(self):
|
||||
rsplit_arange = "tl.arange(0, RSPLIT_NEXT_POWER_OF_2)"
|
||||
if not self.no_x_dim:
|
||||
rsplit_arange = f"{rsplit_arange}[None, :]"
|
||||
self.body.writeline(f"rsplit_arange = {rsplit_arange}")
|
||||
|
||||
if self._has_constant_xmask():
|
||||
self.body.splice(
|
||||
"""\
|
||||
if RSPLIT_IS_POWER_OF_2:
|
||||
rsplit_mask: tl.constexpr = None
|
||||
else:
|
||||
rsplit_mask = rsplit_arange < RSPLIT
|
||||
"""
|
||||
)
|
||||
else:
|
||||
assert not self.no_x_dim
|
||||
self.body.writeline(
|
||||
"rsplit_mask = xmask if RSPLIT_IS_POWER_OF_2 else ((rsplit_arange < RSPLIT) & xmask)"
|
||||
)
|
||||
|
||||
def codegen_range_tree(self):
|
||||
for tree in self.range_trees:
|
||||
# reduction indexing goes inside a loop
|
||||
@ -1934,9 +1963,10 @@ class TritonKernel(SIMDKernel):
|
||||
if isinstance(index, sympy.Integer):
|
||||
expand_str = f"{copy_shape}.shape" if copy_shape else self.dense_size_str()
|
||||
index_str = f"tl.full({expand_str}, {index_str}, tl.int32)"
|
||||
mask_vars = OrderedSet()
|
||||
mask_vars = dense_mask_vars
|
||||
if self._load_mask:
|
||||
mask_vars.add(self._load_mask)
|
||||
self.filter_masks(mask_vars)
|
||||
mask_str = " & ".join(sorted(map(str, mask_vars))) if mask_vars else "None"
|
||||
return IndexingOptions(
|
||||
index_str, mask_vars, mask_str, expand_str, has_rindex, index
|
||||
@ -2544,7 +2574,7 @@ class TritonKernel(SIMDKernel):
|
||||
exit_stack = contextlib.ExitStack()
|
||||
for buf in (self.post_loop_combine, self.post_loop_store):
|
||||
# only do cooperative reduction combines if we have more than one thread block
|
||||
buf.writeline("if RSPLIT > 1:")
|
||||
buf.writeline("if HAS_RSPLIT:")
|
||||
exit_stack.enter_context(buf.indent())
|
||||
|
||||
if reduction_type in ("argmax", "argmin"):
|
||||
@ -2728,7 +2758,6 @@ class TritonKernel(SIMDKernel):
|
||||
"""
|
||||
xnumel = self.numels["x"]
|
||||
mask = "xindex < xnumel" if not self._has_constant_xmask() else None
|
||||
expand = "" if self.no_x_dim else "[None,:]"
|
||||
|
||||
nbytes = xnumel * dtype.itemsize * self.max_rsplit()
|
||||
ws_name, ws_offset = self.cooperative_reduction_workspace_cache.allocate(nbytes)
|
||||
@ -2741,8 +2770,8 @@ class TritonKernel(SIMDKernel):
|
||||
strip=True,
|
||||
)
|
||||
self.post_loop_store.writeline(
|
||||
f"{result_var}_peers = tl.load({result_var}_ws + (xindex * RSPLIT + tl.arange(0, RSPLIT){expand}), "
|
||||
f"{mask}, eviction_policy='evict_first')"
|
||||
f"{result_var}_peers = tl.load({result_var}_ws + (xindex * RSPLIT + rsplit_arange), "
|
||||
f"rsplit_mask, eviction_policy='evict_first', other=triton_helpers.zero_other({result_var}_ws, rsplit_mask))"
|
||||
)
|
||||
return f"{result_var}_peers"
|
||||
|
||||
@ -3095,7 +3124,7 @@ class TritonKernel(SIMDKernel):
|
||||
sem_ptr = f"{self.semaphores_name} + tl.program_id(1)"
|
||||
self.body.splice(
|
||||
f"""
|
||||
if RSPLIT > 1:
|
||||
if HAS_RSPLIT:
|
||||
triton_helpers.x_grid_barrier({sem_ptr})
|
||||
""",
|
||||
strip=True,
|
||||
@ -3539,9 +3568,11 @@ class TritonKernel(SIMDKernel):
|
||||
code.writeline(f"{tree.prefix}numel = {int(simplified_tree_numel)}")
|
||||
|
||||
if tree.is_reduction and self.persistent_reduction:
|
||||
val = self._get_persistent_RBLOCK(tree.numel)
|
||||
if self.cooperative_reduction:
|
||||
val = f"{val} // RSPLIT"
|
||||
numel = self.kexpr(self.rename_indexing(tree.numel))
|
||||
val = f"triton_helpers.constexpr_next_power_of_2(({numel} + RSPLIT - 1) // RSPLIT)"
|
||||
else:
|
||||
val = self._get_persistent_RBLOCK(tree.numel)
|
||||
code.writeline(f"{tree.prefix.upper()}BLOCK: tl.constexpr = {val}")
|
||||
|
||||
if tree.prefix == "x" and self.no_x_dim:
|
||||
@ -3598,7 +3629,7 @@ class TritonKernel(SIMDKernel):
|
||||
for ws in reversed(self.args.workspace_args):
|
||||
wrapper.generate_workspace_deallocation(ws)
|
||||
|
||||
def codegen_nan_check(self):
|
||||
def codegen_nan_check(self) -> None:
|
||||
wrapper = V.graph.wrapper_code
|
||||
_, call_args, arg_signatures, _ = self.args.python_argdefs()
|
||||
for arg, arg_signature in zip(call_args, arg_signatures):
|
||||
@ -3613,7 +3644,7 @@ class TritonKernel(SIMDKernel):
|
||||
line = f"assert not {arg}.isinf().any().item()"
|
||||
wrapper.writeline(line)
|
||||
|
||||
def create_cse_var(self, *args, **kwargs):
|
||||
def create_cse_var(self, *args, **kwargs) -> TritonCSEVariable:
|
||||
return TritonCSEVariable(*args, **kwargs)
|
||||
|
||||
def codegen_iteration_ranges_entry(self, entry: IterationRangesEntry):
|
||||
@ -3624,7 +3655,7 @@ class TritonKernel(SIMDKernel):
|
||||
# lift non-reduction stores outside loop
|
||||
self.body.writeline(line)
|
||||
|
||||
def iteration_ranges_ranges_code(self, entry):
|
||||
def iteration_ranges_ranges_code(self, entry: IterationRangesRoot) -> str:
|
||||
assert entry.tensor_dim is not None
|
||||
size = self.indexing_size_str(entry.tensor_dim)
|
||||
index_dtype = self.index_dtype
|
||||
@ -3637,13 +3668,15 @@ class TritonKernel(SIMDKernel):
|
||||
suffix = f"{suffix} + rsplit_start"
|
||||
return f"tl.arange(0, {entry.prefix.upper()}BLOCK){size}{suffix}"
|
||||
|
||||
def iteration_ranges_scalar_code(self, entry, value):
|
||||
def iteration_ranges_scalar_code(
|
||||
self, entry: IterationRangesRoot, value: Any
|
||||
) -> str:
|
||||
index_dtype = self.index_dtype
|
||||
ndim = self.triton_tensor_ndim()
|
||||
size = [1] * ndim
|
||||
return f"tl.full({size}, {value}, {index_dtype})"
|
||||
|
||||
def iteration_ranges_get_pid(self, entry):
|
||||
def iteration_ranges_get_pid(self, entry: IterationRangesRoot) -> str:
|
||||
assert entry.grid_dim is not None
|
||||
key = f"tl.program_id({entry.grid_dim})"
|
||||
# y_grid has a limit, so express it in terms of y and z in case of overflow.
|
||||
@ -3663,12 +3696,12 @@ class TritonKernel(SIMDKernel):
|
||||
return f"{pid}.to({self.index_dtype})"
|
||||
return pid
|
||||
|
||||
def max_block(self, prefix):
|
||||
def max_block(self, prefix: str) -> int:
|
||||
if self.fixed_config:
|
||||
return self.fixed_config[f"{prefix.upper()}BLOCK"]
|
||||
return TRITON_MAX_BLOCK[prefix.upper()]
|
||||
|
||||
def _has_constant_mask(self, tree: IterationRangesRoot):
|
||||
def _has_constant_mask(self, tree: IterationRangesRoot) -> bool:
|
||||
if not self.optimize_mask:
|
||||
return False
|
||||
|
||||
@ -3710,12 +3743,12 @@ class TritonKernel(SIMDKernel):
|
||||
|
||||
return False
|
||||
|
||||
def _has_constant_xmask(self):
|
||||
def _has_constant_xmask(self) -> bool:
|
||||
xtree = self.range_trees[0]
|
||||
assert xtree.prefix == "x"
|
||||
return self._has_constant_mask(xtree)
|
||||
|
||||
def filter_masks(self, mask_vars):
|
||||
def filter_masks(self, mask_vars: OrderedSet[str]) -> None:
|
||||
for tree in self.range_trees:
|
||||
if self._has_constant_mask(tree):
|
||||
mask_vars.discard(f"{tree.prefix}mask")
|
||||
@ -3730,7 +3763,7 @@ class TritonKernel(SIMDKernel):
|
||||
for symt in list(TritonSymbols.reduction_types)[: self.num_reduction_dims]
|
||||
]
|
||||
|
||||
def codegen_reduction_numels(self, buffer) -> None:
|
||||
def codegen_reduction_numels(self, buffer: IndentedBuffer) -> None:
|
||||
"""
|
||||
Generates code that flattens ND reduction numels, block sizes, etc. into 1D.
|
||||
"""
|
||||
@ -3775,7 +3808,7 @@ class TritonKernel(SIMDKernel):
|
||||
coeffs = self._get_reduction_index_coeffs()
|
||||
return sympy_dot(coeffs, multi_inds)
|
||||
|
||||
def codegen_reduction_indices(self, buffer) -> None:
|
||||
def codegen_reduction_indices(self, buffer: IndentedBuffer) -> None:
|
||||
"""
|
||||
Generates code that converts ND reduction indices into linear indices.
|
||||
"""
|
||||
@ -3791,7 +3824,9 @@ class TritonKernel(SIMDKernel):
|
||||
rindex = self._flatten_reduction_indices(rn_inds)
|
||||
buffer.splice(f"rindex = {self.index_to_str(rindex)}")
|
||||
|
||||
def iteration_ranges_codegen_header(self, entry, code):
|
||||
def iteration_ranges_codegen_header(
|
||||
self, entry: IterationRangesRoot, code: IndentedBuffer
|
||||
) -> None:
|
||||
x = entry.prefix
|
||||
if entry.is_loop:
|
||||
code.writeline(f"{entry.name} = {x}offset + {x}base")
|
||||
|
||||
@ -1,10 +1,14 @@
|
||||
# mypy: allow-untyped-decorators
|
||||
# mypy: allow-untyped-defs
|
||||
import warnings
|
||||
from typing import Any, TypeVar
|
||||
|
||||
from .triton_compat import _log2, libdevice, math, tl, triton # noqa: F401
|
||||
|
||||
|
||||
_T = TypeVar("_T")
|
||||
|
||||
|
||||
def set_driver_to_cpu():
|
||||
driver = triton.runtime.driver
|
||||
if backend := triton.backends.backends.get("cpu", None):
|
||||
@ -627,3 +631,42 @@ def x_grid_barrier(sem):
|
||||
|
||||
# TODO(jansel): is this needed?
|
||||
tl.debug_barrier()
|
||||
|
||||
|
||||
def triton_builtin(f: _T) -> _T:
|
||||
"""
|
||||
Decorator to mark a function as a Triton built-in function. These functions
|
||||
are evaluated at compile time.
|
||||
|
||||
Args:
|
||||
f (function): The function to be marked as a Triton built-in.
|
||||
|
||||
Returns:
|
||||
function: The same function, marked as a Triton built-in.
|
||||
"""
|
||||
f.__triton_builtin__ = True # type: ignore[attr-defined]
|
||||
return f
|
||||
|
||||
|
||||
@triton_builtin
|
||||
def constexpr_next_power_of_2(
|
||||
n: tl.constexpr, *, _builder: object = None
|
||||
) -> tl.constexpr:
|
||||
"""
|
||||
A version triton.next_power_of_two that can be used within a kernel on constants.
|
||||
"""
|
||||
assert isinstance(n, tl.constexpr)
|
||||
return tl.constexpr(triton.next_power_of_2(n.value))
|
||||
|
||||
|
||||
@triton_builtin
|
||||
def zero_other(ptr: Any, mask: Any, *, _builder: object = None) -> tl.constexpr:
|
||||
"""
|
||||
Work around triton compile error: `ValueError: `other` cannot be provided without `mask``
|
||||
A compile-time to check to return either 0 or None depending on the value of mask.
|
||||
"""
|
||||
if isinstance(mask, tl.constexpr) and mask.value is None:
|
||||
return tl.constexpr(None)
|
||||
if isinstance(ptr, tl.core.tensor) and ptr.dtype.is_floating():
|
||||
return tl.constexpr(0.0)
|
||||
return tl.constexpr(0)
|
||||
|
||||
Reference in New Issue
Block a user