Compare commits

...

7 Commits

Author SHA1 Message Date
ff32d9851a Update autotune decision logic correctly 2025-10-29 11:56:44 +00:00
3569d4b5ea Lint 2025-10-25 20:33:22 +00:00
7af620c692 Remove outdated api 2025-10-23 14:59:26 +00:00
aa7f6e4cb1 Linting 2025-10-23 14:59:26 +00:00
39f2a2d67a Update triton_heuristics.py 2025-10-23 14:59:26 +00:00
d9ad880f90 Update triton_heuristics.py 2025-10-23 14:59:26 +00:00
da55a7d513 Naive foreach autotune support 2025-10-23 14:59:26 +00:00
2 changed files with 14 additions and 3 deletions

View File

@ -628,7 +628,7 @@ class ComboKernel(Kernel):
if heuristics == "foreach":
heuristics_line = f"""
@triton_heuristics.foreach(
num_warps={self.num_warps},
filename=__file__,
triton_meta={triton_meta!r},
inductor_meta={inductor_meta!r},
)

View File

@ -3530,13 +3530,24 @@ def user_autotune(
)
def foreach(triton_meta, num_warps, filename=None, inductor_meta=None):
def foreach(triton_meta, filename=None, inductor_meta=None):
"""
Compile a triton foreach kernel
"""
configs = []
# Naive autotuning path for num_warps
if not (
inductor_meta.get("max_autotune") or inductor_meta.get("max_autotune_pointwise")
):
configs.append(triton.Config({}, num_stages=1, num_warps=8))
else:
for warps in [1, 2, 4, 8]:
configs.append(triton.Config({}, num_stages=1, num_warps=warps))
return cached_autotune(
None,
[triton.Config({}, num_stages=1, num_warps=num_warps)],
configs,
triton_meta=triton_meta,
inductor_meta=inductor_meta,
heuristic_type=HeuristicType.TEMPLATE,