[Inductor] do loop reordering in a separate final round (#162355)

Previous LOAF after fusion algorithm is not guaranteed to create more fusion opportunities even if loop reordering happens. I can not find an example that LOAF reduce the amount of fusion, but here is an example that reordering loops does not add more fusions:

a1f7639922/test/inductor/test_loop_ordering.py (L612-L641)

Move LOAF to a separate final round of fusion so that we are guaranteed to not reducing the amount of fusions. Hopefully this also helps compilation time since LOAF kicks in when there are less nodes.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/162355
Approved by: https://github.com/eellison, https://github.com/jansel
ghstack dependencies: #162101, #162126
This commit is contained in:
Shunting Zhang
2025-09-10 18:10:42 -07:00
committed by PyTorch MergeBot
parent e88460f453
commit 248156ed06
2 changed files with 55 additions and 11 deletions

View File

@ -606,6 +606,37 @@ class LoopOrderingTest(TestCase):
out, code = run_and_get_code(f, x)
FileCheck().check_count("@triton.jit", 1, exactly=True).run(code[0])
def test_3dred_pw_2d_outer_red(self):
"""
Test a pattern as follows. We have a 3d contiguous tensor [m, n, k] as input.
1. do reduction on the k dimension and get a [m, n] tensor
2. do a pointwise operation on this [m, n] tensor (and realize the computation)
3. do a outer reduction on the output of step 2 on the m dimension.
Each of these step generate a kernel before fusion.
Without any loop reorder, kernel 1 and kernel 2 will get fused. And kernel 3 will be separeate.
But if we reorder the loop for kernel 2, then kernel 2 will get fused with kernel 3.
And the fused kernel-2-3 can not be fused with kernel 1.
The older version of LOAF algorithm will do reorder in this case. But there is no real
benefits. There are even some slight downsides
1. the original fusion without loop reordering is more natural
2. fusion kernel 1 with kernel 2 may help precision when the output of kernel 1 is in low precision.
By fusion kernel 1 and kernel 2, the pointwise operation will operate on fp32 precision thanks
to fusion.
"""
M, N, K = 64, 64, 64
def f(x):
x = x.sum(dim=-1)
x = x + 1 # can be more complex like sigmoid or other ops
return x, x.sum(dim=0)
x = torch.randn(M, N, K, device=GPU_TYPE)
self.do_acc_test(f, x)
self.assertEqual(0, metrics.num_loop_reordering)
@inductor_config.patch(
{

View File

@ -2978,7 +2978,7 @@ class Scheduler:
i + 1,
old_len,
)
nodes = self.fuse_nodes_once(nodes)
nodes = self.fuse_nodes_once(nodes, is_reorder_round=False)
new_len = len(nodes)
fusion_log.debug(
"completed fusion round (%d/10): fused %d nodes into %d nodes\n",
@ -2991,6 +2991,9 @@ class Scheduler:
"===== fusion complete (%d iterations) =====", i + 1
)
break
if config.loop_ordering_after_fusion:
nodes = self.fuse_nodes_once(nodes, is_reorder_round=True)
return nodes
def process_grouped_nodes(self) -> None:
@ -3497,7 +3500,9 @@ class Scheduler:
return self.name_to_fused_node[node.get_first_name()]
def fuse_nodes_once(
self, nodes: list[BaseSchedulerNode]
self,
nodes: list[BaseSchedulerNode],
is_reorder_round: bool,
) -> list[BaseSchedulerNode]:
"""
Combine eligible nodes into FusedSchedulerNodes.
@ -3560,7 +3565,7 @@ class Scheduler:
fuse_two_nodes(node_key1, node_key2)
for node1, node2 in self.get_possible_fusions(nodes):
for node1, node2 in self.get_possible_fusions(nodes, is_reorder_round):
# if either node is in a pending fusion, resolve it.
# since we iterate on potential fusions based on profitability
# the first potential fusion should take precedence.
@ -3568,9 +3573,9 @@ class Scheduler:
node1 = self.get_fused_node(node1)
node2 = self.get_fused_node(node2)
if self.can_fuse(node1, node2) and not self.will_fusion_create_cycle(
node1, node2
):
if self.can_fuse(
node1, node2, is_reorder_round
) and not self.will_fusion_create_cycle(node1, node2):
speedup = self.speedup_by_fusion(node1, node2)
if callable(speedup):
pending_fusions[node1] = (speedup, node1, node2)
@ -3655,7 +3660,9 @@ class Scheduler:
node.prune_redundant_deps(self.name_to_fused_node)
def get_possible_fusions(
self, nodes: list[BaseSchedulerNode]
self,
nodes: list[BaseSchedulerNode],
is_reorder_round: bool,
) -> list[tuple[BaseSchedulerNode, BaseSchedulerNode]]:
"""
Helper to find all legal fusion opportunities, sorted by self.score_fusion()
@ -3675,10 +3682,10 @@ class Scheduler:
continue
seen.add(key)
if self.can_fuse(node1, node2):
if self.can_fuse(node1, node2, is_reorder_round):
possible_fusions.append(key)
elif (node2.is_template() or node2.is_foreach()) and self.can_fuse(
node2, node1
node2, node1, is_reorder_round
):
# foreach fusions and epilogue fusions are order dependent
possible_fusions.append((node2, node1))
@ -4148,7 +4155,12 @@ class Scheduler:
else:
return None
def can_fuse(self, node1: BaseSchedulerNode, node2: BaseSchedulerNode) -> bool:
def can_fuse(
self,
node1: BaseSchedulerNode,
node2: BaseSchedulerNode,
can_reorder: bool = False,
) -> bool:
"""
Determine if it is possible to combine node1 and node2 into a
single fused node.
@ -4265,7 +4277,8 @@ class Scheduler:
shared_data_score = self.score_fusion_memory(node1, node2)
if (
shared_data_score < config.score_fusion_memory_threshold
can_reorder
and shared_data_score < config.score_fusion_memory_threshold
and config.loop_ordering_after_fusion
):
new_shared_data_score = self.shared_data_after_reordering_loop(node1, node2)