[inductor] fuse for scalar shared data (#162311)

LOAF previously may skip these fusion opportunities and cause some tests fail.

Test:
- TORCHINDUCTOR_LOOP_ORDERING_AFTER_FUSION=1 python test/inductor/test_torchinductor_strided_blocks.py TritonBlockPointerTestGPU.test_2d_reduction_odd_shapes_view_size4_num_block_pointers_1_num_triton_kernels_1_reduction_op4_cuda

Pull Request resolved: https://github.com/pytorch/pytorch/pull/162311
Approved by: https://github.com/jansel
ghstack dependencies: #162028, #162221, #162303
This commit is contained in:
Shunting Zhang
2025-09-06 00:10:26 -07:00
committed by PyTorch MergeBot
parent b919560c4a
commit 2a45837e98
2 changed files with 47 additions and 16 deletions

View File

@ -589,6 +589,23 @@ class LoopOrderingTest(TestCase):
".run(", 1 + int(inductor_config.benchmark_kernel), exactly=True
).run(code[0])
def test_fuse_with_scalar_shared_memory(self):
"""
Make sure if we can fuse two nodes sharing a scalar before,
we can still do it with LOAF applied.
This is not really a big deal. But some tests rely on this and
less number of kernels has some small benefits.
"""
@torch.compile
def f(x):
return torch.mean(x)
x = torch.randn([5, 5], device=GPU_TYPE)
out, code = run_and_get_code(f, x)
FileCheck().check_count("@triton.jit", 1, exactly=True).run(code[0])
@inductor_config.patch(
{

View File

@ -304,8 +304,8 @@ class BaseSchedulerNode:
def reorder_loops_by_dep_pair(
self, self_dep: MemoryDep, other_dep: MemoryDep
) -> None:
return
) -> bool:
return False
def update_mutated_names(self, renames: dict[str, str]) -> None:
self.mutation_renames = {
@ -1149,7 +1149,7 @@ class SchedulerNode(BaseSchedulerNode):
def reorder_loops_by_dep_pair(
self, self_dep: MemoryDep, other_dep: MemoryDep
) -> None:
) -> bool:
new_order = None
self_sizes = self._sizes[0]
if len(self_sizes) == self_dep.num_vars == other_dep.num_vars:
@ -1161,11 +1161,13 @@ class SchedulerNode(BaseSchedulerNode):
"Reorder loops for %s with order %s", self.get_name(), new_order
)
self.apply_new_loop_order(new_order)
return True
else:
loop_ordering_log.debug(
"Don't reordering %s because we can not decide the suitable loop order",
self.get_name(),
)
return False
def debug_str_extra(self) -> str:
name = self.get_name()
@ -1422,10 +1424,13 @@ class FusedSchedulerNode(BaseSchedulerNode):
def reorder_loops_by_dep_pair(
self, self_dep: MemoryDep, other_dep: MemoryDep
) -> None:
) -> bool:
"""
Return true if a loop reordering is performed.
"""
if self.is_template():
# We can not really reorder loops for a triton template
return
return False
self_sizes = None
for snode in self.snodes:
assert isinstance(snode, SchedulerNode)
@ -1433,7 +1438,7 @@ class FusedSchedulerNode(BaseSchedulerNode):
loop_ordering_log.debug(
"Can not reorder fused node due to different sizes"
)
return
return False
self_sizes = snode._sizes[0]
new_order = None
@ -1446,7 +1451,7 @@ class FusedSchedulerNode(BaseSchedulerNode):
"Dont reordering fused node %s because we can not decide the suitable loop order",
self.get_name(),
)
return
return False
metrics.num_loop_reordering += 1
loop_ordering_log.debug(
"Reorder loops for fused node %s with order %s", self.get_name(), new_order
@ -1456,6 +1461,7 @@ class FusedSchedulerNode(BaseSchedulerNode):
snode.apply_new_loop_order(new_order)
refresh_group_node_dependencies(self)
return True
def __init__(self, scheduler: Scheduler, snodes: list[BaseSchedulerNode]) -> None:
super().__init__(scheduler)
@ -3780,6 +3786,11 @@ class Scheduler:
Right now just greedily reorder the loop of node1 to be compatible with node2,
but ideally we should have some heuristics to reorder the loop for node2
to be compatible with node1 if that's more efficient.
Return the amount of shared data re-computed in this method.
If no such recomputation happens, return -1 (not return 0 since 0 is a valid
amount of shared data).
"""
# TODO Don't do loop reordering for CPU for now.
@ -3787,14 +3798,14 @@ class Scheduler:
if not config.loop_ordering_after_fusion or any(
n.is_cpu() for n in [node1, node2]
):
return 0
return -1
node1_buffer_names = node1.read_writes.buffer_names()
node2_buffer_names = node2.read_writes.buffer_names()
# Fast path: no common buffers.
common_buffer_names = node1_buffer_names & node2_buffer_names
if not common_buffer_names:
return 0
return -1
node1_name2dep = {dep.name: dep for dep in node1.read_writes.reads_and_writes()}
node2_name2dep = {dep.name: dep for dep in node2.read_writes.reads_and_writes()}
@ -3817,13 +3828,13 @@ class Scheduler:
)
if len(candidates) == 0:
return 0
return -1
# Pick the largest buffer to guide the loop reordering
_numel, lhs_dep, rhs_dep = max(candidates, key=operator.itemgetter(0))
if not isinstance(lhs_dep, MemoryDep) or not isinstance(rhs_dep, MemoryDep):
return 0
return -1
if lhs_dep.num_vars != rhs_dep.num_vars:
# this can happen due to we don't merge loops.
@ -3832,13 +3843,14 @@ class Scheduler:
# normalization (merging loops)
if lhs_dep.normalize() == rhs_dep.normalize():
return self.dep_size_hint(lhs_dep)
return 0
return -1
reordered = False
# Only reorder loops for pointwise for now
if not node1.is_reduction():
node1.reorder_loops_by_dep_pair(lhs_dep, rhs_dep)
reordered = node1.reorder_loops_by_dep_pair(lhs_dep, rhs_dep)
elif not node2.is_reduction():
node2.reorder_loops_by_dep_pair(rhs_dep, lhs_dep)
reordered = node2.reorder_loops_by_dep_pair(rhs_dep, lhs_dep)
else:
loop_ordering_log.debug(
"Don't reorder loops since both nodes are reductions: %s v.s. %s",
@ -3846,7 +3858,7 @@ class Scheduler:
node2.get_name(),
)
return self.score_fusion_memory(node1, node2)
return self.score_fusion_memory(node1, node2) if reordered else -1
def unfusable_node(self, node: BaseSchedulerNode) -> bool:
"""
@ -4135,7 +4147,9 @@ class Scheduler:
shared_data_score < config.score_fusion_memory_threshold
and config.loop_ordering_after_fusion
):
shared_data_score = self.shared_data_after_reordering_loop(node1, node2)
new_shared_data_score = self.shared_data_after_reordering_loop(node1, node2)
if new_shared_data_score >= 0:
shared_data_score = new_shared_data_score
if config.expand_dimension_for_pointwise_nodes and (
expand_analysis := self.get_expand_dim_for_pointwise_nodes(node1, node2)