Revert "[inductor] fuse for scalar shared data (#162311)"

This reverts commit 2a45837e98c63cae9d1a2e2133a727b829e549d5.

Reverted https://github.com/pytorch/pytorch/pull/162311 on behalf of https://github.com/huydhn due to Sorry for reverting your change, but it is breaking lint ([comment](https://github.com/pytorch/pytorch/pull/162311#issuecomment-3263511162))
This commit is contained in:
PyTorch MergeBot
2025-09-07 05:57:43 +00:00
parent fea20775ad
commit eac3d6f04c
2 changed files with 16 additions and 47 deletions

View File

@ -304,8 +304,8 @@ class BaseSchedulerNode:
def reorder_loops_by_dep_pair(
self, self_dep: MemoryDep, other_dep: MemoryDep
) -> bool:
return False
) -> None:
return
def update_mutated_names(self, renames: dict[str, str]) -> None:
self.mutation_renames = {
@ -1149,7 +1149,7 @@ class SchedulerNode(BaseSchedulerNode):
def reorder_loops_by_dep_pair(
self, self_dep: MemoryDep, other_dep: MemoryDep
) -> bool:
) -> None:
new_order = None
self_sizes = self._sizes[0]
if len(self_sizes) == self_dep.num_vars == other_dep.num_vars:
@ -1161,13 +1161,11 @@ class SchedulerNode(BaseSchedulerNode):
"Reorder loops for %s with order %s", self.get_name(), new_order
)
self.apply_new_loop_order(new_order)
return True
else:
loop_ordering_log.debug(
"Don't reordering %s because we can not decide the suitable loop order",
self.get_name(),
)
return False
def debug_str_extra(self) -> str:
name = self.get_name()
@ -1424,13 +1422,10 @@ class FusedSchedulerNode(BaseSchedulerNode):
def reorder_loops_by_dep_pair(
self, self_dep: MemoryDep, other_dep: MemoryDep
) -> bool:
"""
Return true if a loop reordering is performed.
"""
) -> None:
if self.is_template():
# We can not really reorder loops for a triton template
return False
return
self_sizes = None
for snode in self.snodes:
assert isinstance(snode, SchedulerNode)
@ -1438,7 +1433,7 @@ class FusedSchedulerNode(BaseSchedulerNode):
loop_ordering_log.debug(
"Can not reorder fused node due to different sizes"
)
return False
return
self_sizes = snode._sizes[0]
new_order = None
@ -1451,7 +1446,7 @@ class FusedSchedulerNode(BaseSchedulerNode):
"Dont reordering fused node %s because we can not decide the suitable loop order",
self.get_name(),
)
return False
return
metrics.num_loop_reordering += 1
loop_ordering_log.debug(
"Reorder loops for fused node %s with order %s", self.get_name(), new_order
@ -1461,7 +1456,6 @@ class FusedSchedulerNode(BaseSchedulerNode):
snode.apply_new_loop_order(new_order)
refresh_group_node_dependencies(self)
return True
def __init__(self, scheduler: Scheduler, snodes: list[BaseSchedulerNode]) -> None:
super().__init__(scheduler)
@ -3786,11 +3780,6 @@ class Scheduler:
Right now just greedily reorder the loop of node1 to be compatible with node2,
but ideally we should have some heuristics to reorder the loop for node2
to be compatible with node1 if that's more efficient.
Return the amount of shared data re-computed in this method.
If no such recomputation happens, return -1 (not return 0 since 0 is a valid
amount of shared data).
"""
# TODO Don't do loop reordering for CPU for now.
@ -3798,14 +3787,14 @@ class Scheduler:
if not config.loop_ordering_after_fusion or any(
n.is_cpu() for n in [node1, node2]
):
return -1
return 0
node1_buffer_names = node1.read_writes.buffer_names()
node2_buffer_names = node2.read_writes.buffer_names()
# Fast path: no common buffers.
common_buffer_names = node1_buffer_names & node2_buffer_names
if not common_buffer_names:
return -1
return 0
node1_name2dep = {dep.name: dep for dep in node1.read_writes.reads_and_writes()}
node2_name2dep = {dep.name: dep for dep in node2.read_writes.reads_and_writes()}
@ -3828,13 +3817,13 @@ class Scheduler:
)
if len(candidates) == 0:
return -1
return 0
# Pick the largest buffer to guide the loop reordering
_numel, lhs_dep, rhs_dep = max(candidates, key=operator.itemgetter(0))
if not isinstance(lhs_dep, MemoryDep) or not isinstance(rhs_dep, MemoryDep):
return -1
return 0
if lhs_dep.num_vars != rhs_dep.num_vars:
# this can happen due to we don't merge loops.
@ -3843,14 +3832,13 @@ class Scheduler:
# normalization (merging loops)
if lhs_dep.normalize() == rhs_dep.normalize():
return self.dep_size_hint(lhs_dep)
return -1
return 0
reordered = False
# Only reorder loops for pointwise for now
if not node1.is_reduction():
reordered = node1.reorder_loops_by_dep_pair(lhs_dep, rhs_dep)
node1.reorder_loops_by_dep_pair(lhs_dep, rhs_dep)
elif not node2.is_reduction():
reordered = node2.reorder_loops_by_dep_pair(rhs_dep, lhs_dep)
node2.reorder_loops_by_dep_pair(rhs_dep, lhs_dep)
else:
loop_ordering_log.debug(
"Don't reorder loops since both nodes are reductions: %s v.s. %s",
@ -3858,7 +3846,7 @@ class Scheduler:
node2.get_name(),
)
return self.score_fusion_memory(node1, node2) if reordered else -1
return self.score_fusion_memory(node1, node2)
def unfusable_node(self, node: BaseSchedulerNode) -> bool:
"""
@ -4147,9 +4135,7 @@ class Scheduler:
shared_data_score < config.score_fusion_memory_threshold
and config.loop_ordering_after_fusion
):
new_shared_data_score = self.shared_data_after_reordering_loop(node1, node2)
if new_shared_data_score >= 0:
shared_data_score = new_shared_data_score
shared_data_score = self.shared_data_after_reordering_loop(node1, node2)
if config.expand_dimension_for_pointwise_nodes and (
expand_analysis := self.get_expand_dim_for_pointwise_nodes(node1, node2)