Update on "[inductor] track reduction before splitting"

Keep tracking of the reduction before splitting.

In the mix-order reduction context, if one of the reduction is split, it makes it much harder to fuse with the other reduction. Tracking the metadata of the reduction before splitting to make the fusion possible.


cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben

[ghstack-poisoned]
This commit is contained in:
Shunting Zhang
2025-10-27 17:32:24 -07:00
3 changed files with 33 additions and 4 deletions

View File

@ -131,6 +131,20 @@ class MixOrderReductionTest(TestBase):
metrics.codegen_mix_order_reduction,
)
@inductor_config.patch(split_reductions=False)
def test_non_contiguous_input(self):
def f(x):
return x.sum(dim=-1), x.sum(dim=[0, 1])
x = torch.randn(1024, 32, 768, dtype=torch.float, device=GPU_TYPE).permute(
1, 0, 2
)
self.check_numeric(f, (x,))
self.assertEqual(
inductor_config.triton.mix_order_reduction,
metrics.codegen_mix_order_reduction,
)
@inductor_config.patch(split_reductions=False)
def test_multi_workspace_allocation(self):
def f(x, y):

View File

@ -4547,14 +4547,16 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
)
accumname2var[name] = self.cse.namedvar(name, dtype=torch.float)
self.body.writeline("split_size = min(RSPLIT_SIZE, xnumel - xoffset)")
self.body.writeline("for suboff in range(0, split_size, XBLOCK):")
self.body.writeline("for _ in range(0, split_size, XBLOCK):")
with self.body.indent(offset=1):
self.body.splice(self.indexing_code)
self.body.writelines(
[
"x0 = xindex + suboff",
"xindex += XBLOCK",
# TODO we force XBLOCK==1 for now so there is
# no need to update the xmask
]
)
self.body.splice(self.indexing_code)
self.body.splice(self.loads)
self.body.splice(self.compute)
self.body.splice(self.stores)
@ -5345,7 +5347,10 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
def codegen_iteration_ranges_entry(self, entry: IterationRangesEntry):
line = f"{entry.name} = {self.kexpr(self.rename_indexing(entry.expr))}"
if entry.root.is_loop:
# mix order reduction introduces an extra loop across the x
# dimension
if entry.root.is_loop or (self.mix_order_reduction and entry.prefix == "x"):
self.indexing_code.writeline(line)
else:
# lift non-reduction stores outside loop

View File

@ -243,6 +243,7 @@ class MixOrderReduction:
def can_fuse(cls, node1: BaseSchedulerNode, node2: BaseSchedulerNode) -> bool:
if not config.triton.mix_order_reduction:
return False
if not node1.is_gpu() or not node2.is_gpu():
return False
if node1.get_device().type != "cuda" or config.cuda_backend != "triton": # type: ignore[union-attr]
@ -4516,6 +4517,15 @@ class Scheduler:
if node1 is node2:
return False
# We don't further fuse with FusedMixOrderReductions for now.
# It's not a big deal since the score for fusion with
# mix order reduction is low. When we do this kind of fusion,
# the participants should have already been well fused.
if isinstance(node1, FusedMixOrderReductions) or isinstance(
node2, FusedMixOrderReductions
):
return False
why = WhyNoFuse(node1, node2)
if node1.is_template() and self.get_backend(