mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[inductor] fuse for scalar shared data (#162311)
LOAF previously may skip these fusion opportunities and cause some tests fail. Test: - TORCHINDUCTOR_LOOP_ORDERING_AFTER_FUSION=1 python test/inductor/test_torchinductor_strided_blocks.py TritonBlockPointerTestGPU.test_2d_reduction_odd_shapes_view_size4_num_block_pointers_1_num_triton_kernels_1_reduction_op4_cuda Pull Request resolved: https://github.com/pytorch/pytorch/pull/162311 Approved by: https://github.com/jansel
This commit is contained in:
committed by
PyTorch MergeBot
parent
5793dd7875
commit
ebd29a13fe
@ -592,6 +592,23 @@ class LoopOrderingTest(TestCase):
|
||||
".run(", 1 + int(inductor_config.benchmark_kernel), exactly=True
|
||||
).run(code[0])
|
||||
|
||||
def test_fuse_with_scalar_shared_memory(self):
|
||||
"""
|
||||
Make sure if we can fuse two nodes sharing a scalar before,
|
||||
we can still do it with LOAF applied.
|
||||
|
||||
This is not really a big deal. But some tests rely on this and
|
||||
less number of kernels has some small benefits.
|
||||
"""
|
||||
|
||||
@torch.compile
|
||||
def f(x):
|
||||
return torch.mean(x)
|
||||
|
||||
x = torch.randn([5, 5], device=GPU_TYPE)
|
||||
out, code = run_and_get_code(f, x)
|
||||
FileCheck().check_count("@triton.jit", 1, exactly=True).run(code[0])
|
||||
|
||||
|
||||
@inductor_config.patch(
|
||||
{
|
||||
|
@ -309,8 +309,8 @@ class BaseSchedulerNode:
|
||||
|
||||
def reorder_loops_by_dep_pair(
|
||||
self, self_dep: MemoryDep, other_dep: MemoryDep
|
||||
) -> None:
|
||||
return
|
||||
) -> bool:
|
||||
return False
|
||||
|
||||
def update_mutated_names(self, renames: dict[str, str]) -> None:
|
||||
self.mutation_renames = {
|
||||
@ -1130,6 +1130,11 @@ class NopKernelSchedulerNode(BaseSchedulerNode):
|
||||
|
||||
|
||||
class SchedulerNode(BaseSchedulerNode):
|
||||
"""
|
||||
A SchedulerNode is a node for scheduling that encapsulates either
|
||||
a ComputedBuffer or a TemplateBuffer.
|
||||
"""
|
||||
|
||||
_sizes: tuple[Sequence[sympy.Expr], ...]
|
||||
_body: LoopBody
|
||||
|
||||
@ -1254,7 +1259,7 @@ class SchedulerNode(BaseSchedulerNode):
|
||||
|
||||
def reorder_loops_by_dep_pair(
|
||||
self, self_dep: MemoryDep, other_dep: MemoryDep
|
||||
) -> None:
|
||||
) -> bool:
|
||||
new_order = None
|
||||
self_sizes = self._sizes[0]
|
||||
if len(self_sizes) == self_dep.num_vars == other_dep.num_vars:
|
||||
@ -1266,11 +1271,13 @@ class SchedulerNode(BaseSchedulerNode):
|
||||
"Reorder loops for %s with order %s", self.get_name(), new_order
|
||||
)
|
||||
self.apply_new_loop_order(new_order)
|
||||
return True
|
||||
else:
|
||||
loop_ordering_log.debug(
|
||||
"Don't reordering %s because we can not decide the suitable loop order",
|
||||
self.get_name(),
|
||||
)
|
||||
return False
|
||||
|
||||
def debug_str_extra(self) -> str:
|
||||
name = self.get_name()
|
||||
@ -1527,10 +1534,13 @@ class FusedSchedulerNode(BaseSchedulerNode):
|
||||
|
||||
def reorder_loops_by_dep_pair(
|
||||
self, self_dep: MemoryDep, other_dep: MemoryDep
|
||||
) -> None:
|
||||
) -> bool:
|
||||
"""
|
||||
Return true if a loop reordering is performed.
|
||||
"""
|
||||
if self.is_template():
|
||||
# We can not really reorder loops for a triton template
|
||||
return
|
||||
return False
|
||||
self_sizes = None
|
||||
for snode in self.snodes:
|
||||
assert isinstance(snode, SchedulerNode)
|
||||
@ -1538,7 +1548,7 @@ class FusedSchedulerNode(BaseSchedulerNode):
|
||||
loop_ordering_log.debug(
|
||||
"Can not reorder fused node due to different sizes"
|
||||
)
|
||||
return
|
||||
return False
|
||||
self_sizes = snode._sizes[0]
|
||||
new_order = None
|
||||
|
||||
@ -1551,7 +1561,7 @@ class FusedSchedulerNode(BaseSchedulerNode):
|
||||
"Dont reordering fused node %s because we can not decide the suitable loop order",
|
||||
self.get_name(),
|
||||
)
|
||||
return
|
||||
return False
|
||||
metrics.num_loop_reordering += 1
|
||||
loop_ordering_log.debug(
|
||||
"Reorder loops for fused node %s with order %s", self.get_name(), new_order
|
||||
@ -1561,6 +1571,7 @@ class FusedSchedulerNode(BaseSchedulerNode):
|
||||
snode.apply_new_loop_order(new_order)
|
||||
|
||||
refresh_group_node_dependencies(self)
|
||||
return True
|
||||
|
||||
def __init__(self, scheduler: Scheduler, snodes: list[BaseSchedulerNode]) -> None:
|
||||
super().__init__(scheduler)
|
||||
@ -3900,6 +3911,11 @@ class Scheduler:
|
||||
Right now just greedily reorder the loop of node1 to be compatible with node2,
|
||||
but ideally we should have some heuristics to reorder the loop for node2
|
||||
to be compatible with node1 if that's more efficient.
|
||||
|
||||
Return the amount of shared data re-computed in this method.
|
||||
If no such recomputation happens, return -1 (not return 0 since 0 is a valid
|
||||
amount of shared data).
|
||||
|
||||
"""
|
||||
|
||||
# TODO Don't do loop reordering for CPU for now.
|
||||
@ -3907,14 +3923,14 @@ class Scheduler:
|
||||
if not config.loop_ordering_after_fusion or any(
|
||||
n.is_cpu() for n in [node1, node2]
|
||||
):
|
||||
return 0
|
||||
return -1
|
||||
|
||||
node1_buffer_names = node1.read_writes.buffer_names()
|
||||
node2_buffer_names = node2.read_writes.buffer_names()
|
||||
# Fast path: no common buffers.
|
||||
common_buffer_names = node1_buffer_names & node2_buffer_names
|
||||
if not common_buffer_names:
|
||||
return 0
|
||||
return -1
|
||||
|
||||
node1_name2dep = {dep.name: dep for dep in node1.read_writes.reads_and_writes()}
|
||||
node2_name2dep = {dep.name: dep for dep in node2.read_writes.reads_and_writes()}
|
||||
@ -3937,13 +3953,13 @@ class Scheduler:
|
||||
)
|
||||
|
||||
if len(candidates) == 0:
|
||||
return 0
|
||||
return -1
|
||||
|
||||
# Pick the largest buffer to guide the loop reordering
|
||||
_numel, lhs_dep, rhs_dep = max(candidates, key=operator.itemgetter(0))
|
||||
|
||||
if not isinstance(lhs_dep, MemoryDep) or not isinstance(rhs_dep, MemoryDep):
|
||||
return 0
|
||||
return -1
|
||||
|
||||
if lhs_dep.num_vars != rhs_dep.num_vars:
|
||||
# this can happen due to we don't merge loops.
|
||||
@ -3952,13 +3968,14 @@ class Scheduler:
|
||||
# normalization (merging loops)
|
||||
if lhs_dep.normalize() == rhs_dep.normalize():
|
||||
return self.dep_size_hint(lhs_dep)
|
||||
return 0
|
||||
return -1
|
||||
|
||||
reordered = False
|
||||
# Only reorder loops for pointwise for now
|
||||
if not node1.is_reduction():
|
||||
node1.reorder_loops_by_dep_pair(lhs_dep, rhs_dep)
|
||||
reordered = node1.reorder_loops_by_dep_pair(lhs_dep, rhs_dep)
|
||||
elif not node2.is_reduction():
|
||||
node2.reorder_loops_by_dep_pair(rhs_dep, lhs_dep)
|
||||
reordered = node2.reorder_loops_by_dep_pair(rhs_dep, lhs_dep)
|
||||
else:
|
||||
loop_ordering_log.debug(
|
||||
"Don't reorder loops since both nodes are reductions: %s v.s. %s",
|
||||
@ -3966,7 +3983,7 @@ class Scheduler:
|
||||
node2.get_name(),
|
||||
)
|
||||
|
||||
return self.score_fusion_memory(node1, node2)
|
||||
return self.score_fusion_memory(node1, node2) if reordered else -1
|
||||
|
||||
def unfusable_node(self, node: BaseSchedulerNode) -> bool:
|
||||
"""
|
||||
@ -4255,7 +4272,9 @@ class Scheduler:
|
||||
shared_data_score < config.score_fusion_memory_threshold
|
||||
and config.loop_ordering_after_fusion
|
||||
):
|
||||
shared_data_score = self.shared_data_after_reordering_loop(node1, node2)
|
||||
new_shared_data_score = self.shared_data_after_reordering_loop(node1, node2)
|
||||
if new_shared_data_score >= 0:
|
||||
shared_data_score = new_shared_data_score
|
||||
|
||||
if config.expand_dimension_for_pointwise_nodes and (
|
||||
expand_analysis := self.get_expand_dim_for_pointwise_nodes(node1, node2)
|
||||
|
Reference in New Issue
Block a user