Compare commits

...

8 Commits

Author SHA1 Message Date
80f16e90b8 Update on "inductor: enable pdl by default"
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben

[ghstack-poisoned]
2025-11-13 14:59:34 -08:00
2cc921f0ae Update base for Update on "inductor: enable pdl by default"
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben

[ghstack-poisoned]
2025-11-13 14:59:34 -08:00
593de1a81f Update on "inductor: enable pdl by default"
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben

[ghstack-poisoned]
2025-11-13 14:03:56 -08:00
00f8d9fea2 inductor: enable pdl by default
[ghstack-poisoned]
2025-11-12 15:34:27 -08:00
308e5ac650 Update on "inductor pdl: guard writes, too. do not wait if the tensor does not come from preceding kernel"
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben

[ghstack-poisoned]
2025-11-12 14:45:56 -08:00
bff2f0d675 Update base for Update on "inductor pdl: guard writes, too. do not wait if the tensor does not come from preceding kernel"
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben

[ghstack-poisoned]
2025-11-12 14:45:56 -08:00
014d6290cf Update on "inductor pdl: guard writes, too. do not wait if the tensor does not come from preceding kernel"
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben

[ghstack-poisoned]
2025-11-12 14:41:35 -08:00
35d883816b inductor pdl: guard writes, too. do not wait if the tensor does not come from preceding kernel
[ghstack-poisoned]
2025-10-31 11:05:26 -07:00
4 changed files with 171 additions and 29 deletions

View File

@ -14658,6 +14658,78 @@ def forward(self, arg0_1: "Sym(s77)", arg1_1: "Sym(s27)", arg2_1: "Sym(s53)", ar
),
)
@requires_cuda_and_triton
@skipCUDAIf(not SM90OrLater, "Requires sm90")
@config.patch({"triton.enable_pdl": True})
def test_pdl_mutation(self):
def fn(a, b, c):
b.copy_(c) # second kernel
return a**2 + b # first kernel
a, b, c = [
torch.randn(s, device=GPU_TYPE) for s in [(1024, 1024), (1024,), (1024,)]
]
self.common(fn, (a, b, c))
code = run_and_get_triton_code(torch.compile(fn), a, b, c)
(
FileCheck()
# first kernel
.check("'launch_pdl': True")
.check("gdc_wait")
.check("load")
.check("load")
.check("gdc_launch")
.check("store")
# second kernel, no need to wait before load
.check("'launch_pdl': True")
.check("load")
.check("gdc_wait")
.check("gdc_launch")
.check("store")
).run(code)
@requires_cuda_and_triton
@skipCUDAIf(not SM90OrLater, "Requires sm90")
@config.patch(
{
"triton.enable_pdl": True,
"max_autotune": True,
"coordinate_descent_tuning": True,
}
)
def test_pdl_template_and_delay(self):
def fn(a, b):
a = (a / (a**2).sum(-1, keepdim=True)) ** 2 # first kernel
b = (b / (b**2).sum(-1, keepdim=True)) ** 2 # second kernel
c = a @ b # fused-epilogue template
c = c**2
return c
a, b = [torch.randn(s, device=GPU_TYPE) for s in [(1024, 512), (512, 1024)]]
self.common(fn, (a, b))
code = run_and_get_triton_code(torch.compile(fn, mode="max-autotune"), a, b)
(
FileCheck()
# first kernel
.check("'launch_pdl': True")
.check("gdc_wait")
.check("load")
.check("gdc_launch")
.check("store")
# second kernel, no need to wait before load
.check("'launch_pdl': True")
.check("load")
.check("gdc_wait")
.check("gdc_launch")
.check("store")
# matmul template
.check_not("'launch_pdl': True")
.check_not("gdc_wait")
.check_not("gdc_launch")
.check("store")
).run(code)
# end of class CommonTemplate - add new tests here

View File

@ -44,7 +44,13 @@ from ..runtime.hints import (
TRITON_MAX_RSPLIT,
)
from ..runtime.runtime_utils import get_max_y_grid, next_power_of_2
from ..scheduler import BaseSchedulerNode, FusedSchedulerNode, Scheduler, SchedulerNode
from ..scheduler import (
BaseSchedulerNode,
FusedSchedulerNode,
NopKernelSchedulerNode,
Scheduler,
SchedulerNode,
)
from ..shape_propagation import get_broadcasted_shape
from ..utils import (
cache_on_self,
@ -726,13 +732,6 @@ def triton_reshape(
return f"{value}[{', '.join(expand)}]"
def enable_pdl_codegen():
if not torch._inductor.config.triton.enable_pdl:
return False
major, _ = torch.cuda.get_device_capability(torch.cuda.current_device())
return major >= 9
# NB: Inheriting from PythonPrinter is somewhat dangerous, because there are a
# number of operators which Triton "implements", but in a way that is
# inconsistent with Python semantics (and consistent with C semantics). We
@ -2296,7 +2295,8 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
self.tma_min_block_sizes = dict[str, int]()
self.hint_override = hint_override
self._load_counts: collections.Counter[str] = collections.Counter()
self._load_index = 0
self._pdl_load_index = 0
self._pdl_has_wait = False
# A set of autotuning hints to pass as part of triton_meta
self.autotune_hints = OrderedSet[AutotuneHint]()
@ -3094,26 +3094,86 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
else:
return self.loads
def _handle_pdl_before_load(self, wait_buffer):
def _enable_pdl_codegen(self):
if not torch._inductor.config.triton.enable_pdl:
return False
if isinstance(V.kernel, torch._inductor.select_algorithm.TritonTemplateKernel):
return False
major, _ = torch.cuda.get_device_capability(torch.cuda.current_device())
return major >= 9
def _handle_pdl_before_access(
self, wait_buffer, *dependencies, consider_reads=False
):
if not self._enable_pdl_codegen():
return
current_node = V.kernel.current_node
prev_node = (
V.graph.scheduler.previous_node if V.graph.scheduler is not None else None
)
def matching_dep(dep):
assert prev_node is not None
prev_deps = prev_node.read_writes.writes
if consider_reads:
prev_deps = itertools.chain(prev_deps, prev_node.read_writes.reads)
return any(
dep == current_node.mutation_renames.get(w.name, w.name)
for w in prev_deps
)
assert dependencies
need_wait = (
prev_node is None
or isinstance(prev_node, NopKernelSchedulerNode)
or any(matching_dep(d) for d in dependencies)
)
if not need_wait:
return
GDC_WAIT = "tl.extra.cuda.gdc_wait()"
self._load_index += 1
if self.inside_reduction:
# hoist before the loop
if self.inside_reduction and self.range_trees[-1].is_loop:
wait_buffer = self.body
if enable_pdl_codegen():
if self._load_index == 1:
wait_buffer.writeline(GDC_WAIT)
if self._pdl_has_wait:
return
self._pdl_has_wait = True
wait_buffer.writeline(GDC_WAIT)
def _handle_pdl_after_load(self, launch_buffer, result_var):
GDC_LAUNCH = "tl.extra.cuda.gdc_launch_dependents()"
if self.inside_reduction:
if not self._enable_pdl_codegen():
return
if result_var.use_count > 1: # we already went through this
return
# hoist after the loop
if self.inside_reduction and self.range_trees[-1].is_loop:
launch_buffer = self.post_loop_combine
if enable_pdl_codegen():
current_load_index = self._load_index
launch_if_last_load = DelayMaybeLine(
lambda: current_load_index == self._load_index,
f"0; {GDC_LAUNCH} # gdc launch for {result_var}",
)
self.cse.generate(launch_buffer, launch_if_last_load, dtype=torch.int32)
# always gdc_wait before gdc_launch
GDC_WAIT = "tl.extra.cuda.gdc_wait()"
def check_if_no_wait():
result = self._pdl_has_wait
self._pdl_has_wait = True
return not result
wait_if_no_load = DelayMaybeLine(
check_if_no_wait,
f"0; {GDC_WAIT} # gdc launch for {result_var}",
)
self.cse.generate(launch_buffer, wait_if_no_load, dtype=torch.int32)
self._pdl_load_index += 1
current_pdl_load_index = self._pdl_load_index
def check_if_last_load():
return self._pdl_load_index == current_pdl_load_index
GDC_LAUNCH = "tl.extra.cuda.gdc_launch_dependents()"
launch_if_last_load = DelayMaybeLine(
check_if_last_load,
f"0; {GDC_LAUNCH} # gdc launch for {result_var}",
)
self.cse.generate(launch_buffer, launch_if_last_load, dtype=torch.int32)
def partial_accumulate(
self, name: str, reduction_type, val, extra_meta: dict[str, Any]
@ -3269,7 +3329,7 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
dtype = torch.bool
load_buffer = self.get_load_buffer(indexing)
self._handle_pdl_before_load(load_buffer)
self._handle_pdl_before_access(load_buffer, name)
result_var = self.cse.generate(
load_buffer, make_line(line), dtype=dtype, shape=shape
)
@ -3388,6 +3448,7 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
if not self.inside_reduction and self.cooperative_reduction:
exit_stack.enter_context(self.guard_cooperative_store(name, self.stores))
self._handle_pdl_before_access(self.stores, name, consider_reads=True)
self.stores.writeline(DeferredLine(name, line))
if not self.inside_reduction:
@ -3455,7 +3516,9 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
"Bucketize only supports indexing with int32 and int64"
)
self._handle_pdl_before_load(self.compute)
self._handle_pdl_before_access(
self.compute, boundaries[0], *([sorter[0]] if sorter else [])
)
result = self.cse.generate(
self.compute,
f"triton_helpers.bucketize_binary_search({values}, "
@ -4225,6 +4288,8 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
self.guard_cooperative_store(name, self.post_loop_store)
)
self._handle_pdl_before_access(self.post_loop_store, var)
if isinstance(indexing, (BlockPtrOptions, TensorDescriptorOptions)):
self.post_loop_store.writeline(
DeferredLine(
@ -5163,8 +5228,7 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
triton_meta["configs"] = [config_of(signature)]
if enable_pdl_codegen():
triton_meta["launch_pdl"] = True
triton_meta["launch_pdl"] = self._enable_pdl_codegen()
# Triton compiler includes equal_to_1 args into constants even
# when they are not constexpr. otherwise there may be a segfault

View File

@ -1608,7 +1608,7 @@ class triton:
# Programmatic Dependent Launch improves launch latency on Nvidia Hopper+ devices
# If set to true, will generate PDL code on devices that support it.
# If set to false, will never generate PDL code.
enable_pdl = False
enable_pdl = True
mix_order_reduction = (
os.environ.get("TORCHINDUCTOR_MIX_ORDER_REDUCTION", "0" if is_fbcode() else "1")

View File

@ -2604,6 +2604,7 @@ class Scheduler:
]
)
self.nodes = [self.create_scheduler_node(n) for n in nodes]
self.previous_node: Optional[BaseSchedulerNode] = None
self.current_node: Optional[BaseSchedulerNode] = None
self.update_zero_dim_cpu_tensor()
# some new constants could have been created above
@ -2620,6 +2621,7 @@ class Scheduler:
self.name_to_node: dict[str, BaseSchedulerNode] = {
n.get_name(): n for n in self.nodes
}
self.name_to_buf: dict[str, SchedulerBuffer] = {
buf.get_name(): buf for node in self.nodes for buf in node.get_outputs()
}
@ -5985,6 +5987,7 @@ class Scheduler:
seen.add(key)
self.current_device = self.default_device_context
assert self.previous_node is None
# pyrefly: ignore [unbound-name]
if self.default_device_context and config.triton.autotune_at_compile_time:
@ -6076,6 +6079,8 @@ class Scheduler:
):
self.flush()
self.previous_node = node
if self.current_device != self.default_device_context:
# when default_device_context is not None, we are codegen
# for graph partitions and all nodes must be on
@ -6086,6 +6091,7 @@ class Scheduler:
# important for nested indentation codegen-ing.
V.graph.wrapper_code.codegen_device_guard_exit()
self.previous_node = None
self.flush()
def benchmark_combo_kernel(