mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Revert "[inductor] fuse for scalar shared data (#162311)"
This reverts commit 2a45837e98c63cae9d1a2e2133a727b829e549d5. Reverted https://github.com/pytorch/pytorch/pull/162311 on behalf of https://github.com/huydhn due to Sorry for reverting your change, but it is breaking lint ([comment](https://github.com/pytorch/pytorch/pull/162311#issuecomment-3263511162))
This commit is contained in:
@ -304,8 +304,8 @@ class BaseSchedulerNode:
|
||||
|
||||
def reorder_loops_by_dep_pair(
|
||||
self, self_dep: MemoryDep, other_dep: MemoryDep
|
||||
) -> bool:
|
||||
return False
|
||||
) -> None:
|
||||
return
|
||||
|
||||
def update_mutated_names(self, renames: dict[str, str]) -> None:
|
||||
self.mutation_renames = {
|
||||
@ -1149,7 +1149,7 @@ class SchedulerNode(BaseSchedulerNode):
|
||||
|
||||
def reorder_loops_by_dep_pair(
|
||||
self, self_dep: MemoryDep, other_dep: MemoryDep
|
||||
) -> bool:
|
||||
) -> None:
|
||||
new_order = None
|
||||
self_sizes = self._sizes[0]
|
||||
if len(self_sizes) == self_dep.num_vars == other_dep.num_vars:
|
||||
@ -1161,13 +1161,11 @@ class SchedulerNode(BaseSchedulerNode):
|
||||
"Reorder loops for %s with order %s", self.get_name(), new_order
|
||||
)
|
||||
self.apply_new_loop_order(new_order)
|
||||
return True
|
||||
else:
|
||||
loop_ordering_log.debug(
|
||||
"Don't reordering %s because we can not decide the suitable loop order",
|
||||
self.get_name(),
|
||||
)
|
||||
return False
|
||||
|
||||
def debug_str_extra(self) -> str:
|
||||
name = self.get_name()
|
||||
@ -1424,13 +1422,10 @@ class FusedSchedulerNode(BaseSchedulerNode):
|
||||
|
||||
def reorder_loops_by_dep_pair(
|
||||
self, self_dep: MemoryDep, other_dep: MemoryDep
|
||||
) -> bool:
|
||||
"""
|
||||
Return true if a loop reordering is performed.
|
||||
"""
|
||||
) -> None:
|
||||
if self.is_template():
|
||||
# We can not really reorder loops for a triton template
|
||||
return False
|
||||
return
|
||||
self_sizes = None
|
||||
for snode in self.snodes:
|
||||
assert isinstance(snode, SchedulerNode)
|
||||
@ -1438,7 +1433,7 @@ class FusedSchedulerNode(BaseSchedulerNode):
|
||||
loop_ordering_log.debug(
|
||||
"Can not reorder fused node due to different sizes"
|
||||
)
|
||||
return False
|
||||
return
|
||||
self_sizes = snode._sizes[0]
|
||||
new_order = None
|
||||
|
||||
@ -1451,7 +1446,7 @@ class FusedSchedulerNode(BaseSchedulerNode):
|
||||
"Dont reordering fused node %s because we can not decide the suitable loop order",
|
||||
self.get_name(),
|
||||
)
|
||||
return False
|
||||
return
|
||||
metrics.num_loop_reordering += 1
|
||||
loop_ordering_log.debug(
|
||||
"Reorder loops for fused node %s with order %s", self.get_name(), new_order
|
||||
@ -1461,7 +1456,6 @@ class FusedSchedulerNode(BaseSchedulerNode):
|
||||
snode.apply_new_loop_order(new_order)
|
||||
|
||||
refresh_group_node_dependencies(self)
|
||||
return True
|
||||
|
||||
def __init__(self, scheduler: Scheduler, snodes: list[BaseSchedulerNode]) -> None:
|
||||
super().__init__(scheduler)
|
||||
@ -3786,11 +3780,6 @@ class Scheduler:
|
||||
Right now just greedily reorder the loop of node1 to be compatible with node2,
|
||||
but ideally we should have some heuristics to reorder the loop for node2
|
||||
to be compatible with node1 if that's more efficient.
|
||||
|
||||
Return the amount of shared data re-computed in this method.
|
||||
If no such recomputation happens, return -1 (not return 0 since 0 is a valid
|
||||
amount of shared data).
|
||||
|
||||
"""
|
||||
|
||||
# TODO Don't do loop reordering for CPU for now.
|
||||
@ -3798,14 +3787,14 @@ class Scheduler:
|
||||
if not config.loop_ordering_after_fusion or any(
|
||||
n.is_cpu() for n in [node1, node2]
|
||||
):
|
||||
return -1
|
||||
return 0
|
||||
|
||||
node1_buffer_names = node1.read_writes.buffer_names()
|
||||
node2_buffer_names = node2.read_writes.buffer_names()
|
||||
# Fast path: no common buffers.
|
||||
common_buffer_names = node1_buffer_names & node2_buffer_names
|
||||
if not common_buffer_names:
|
||||
return -1
|
||||
return 0
|
||||
|
||||
node1_name2dep = {dep.name: dep for dep in node1.read_writes.reads_and_writes()}
|
||||
node2_name2dep = {dep.name: dep for dep in node2.read_writes.reads_and_writes()}
|
||||
@ -3828,13 +3817,13 @@ class Scheduler:
|
||||
)
|
||||
|
||||
if len(candidates) == 0:
|
||||
return -1
|
||||
return 0
|
||||
|
||||
# Pick the largest buffer to guide the loop reordering
|
||||
_numel, lhs_dep, rhs_dep = max(candidates, key=operator.itemgetter(0))
|
||||
|
||||
if not isinstance(lhs_dep, MemoryDep) or not isinstance(rhs_dep, MemoryDep):
|
||||
return -1
|
||||
return 0
|
||||
|
||||
if lhs_dep.num_vars != rhs_dep.num_vars:
|
||||
# this can happen due to we don't merge loops.
|
||||
@ -3843,14 +3832,13 @@ class Scheduler:
|
||||
# normalization (merging loops)
|
||||
if lhs_dep.normalize() == rhs_dep.normalize():
|
||||
return self.dep_size_hint(lhs_dep)
|
||||
return -1
|
||||
return 0
|
||||
|
||||
reordered = False
|
||||
# Only reorder loops for pointwise for now
|
||||
if not node1.is_reduction():
|
||||
reordered = node1.reorder_loops_by_dep_pair(lhs_dep, rhs_dep)
|
||||
node1.reorder_loops_by_dep_pair(lhs_dep, rhs_dep)
|
||||
elif not node2.is_reduction():
|
||||
reordered = node2.reorder_loops_by_dep_pair(rhs_dep, lhs_dep)
|
||||
node2.reorder_loops_by_dep_pair(rhs_dep, lhs_dep)
|
||||
else:
|
||||
loop_ordering_log.debug(
|
||||
"Don't reorder loops since both nodes are reductions: %s v.s. %s",
|
||||
@ -3858,7 +3846,7 @@ class Scheduler:
|
||||
node2.get_name(),
|
||||
)
|
||||
|
||||
return self.score_fusion_memory(node1, node2) if reordered else -1
|
||||
return self.score_fusion_memory(node1, node2)
|
||||
|
||||
def unfusable_node(self, node: BaseSchedulerNode) -> bool:
|
||||
"""
|
||||
@ -4147,9 +4135,7 @@ class Scheduler:
|
||||
shared_data_score < config.score_fusion_memory_threshold
|
||||
and config.loop_ordering_after_fusion
|
||||
):
|
||||
new_shared_data_score = self.shared_data_after_reordering_loop(node1, node2)
|
||||
if new_shared_data_score >= 0:
|
||||
shared_data_score = new_shared_data_score
|
||||
shared_data_score = self.shared_data_after_reordering_loop(node1, node2)
|
||||
|
||||
if config.expand_dimension_for_pointwise_nodes and (
|
||||
expand_analysis := self.get_expand_dim_for_pointwise_nodes(node1, node2)
|
||||
|
Reference in New Issue
Block a user