[inductor] bugfix: keep WeakDeps (WAR deps) during fusion (#162316)

fixes #159855, was not triggered in other tests since it took
more than one round of fusion to get to the problematic code
which prunes WeakDeps. The WeakDeps are important to inhibit
fusion of kernels that read/write data into mutated buffers
with different indexing.

We modify the code to a) always prune before fusion, rather
than after, which improves its coverage and makes our basic
vertical fusion tests surface this issue as well and b)
check whether the weak dep is fusable before eliminating it
(which basically means checking that the producing code and
the consuming code are sufficiently compatible).

The tests that trigger this with change (a) is:
test_fusing_write_into_disjoint_read introduced in #118210.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/162316
Approved by: https://github.com/eellison, https://github.com/mlazos, https://github.com/shunting314
This commit is contained in:
Markus Hoehnerbach
2025-09-19 10:52:19 -07:00
committed by PyTorch MergeBot
parent 5d8a226e23
commit f34744d2a5
2 changed files with 40 additions and 12 deletions

View File

@ -5869,6 +5869,16 @@ class CommonTemplate:
a = torch.rand((1, 1000000), device=self.device)
self.common(f, (a,))
def test_inplace_flip(self):
def f(x, y):
x.copy_(x.flip(1))
y = y.sum(dim=1, keepdim=True) + y
return x + y
x = torch.randn(20, 1024 * 1024)
y = torch.randn(20, 1024 * 1024)
self.common(f, (x, y), atol=1e-3, rtol=1e-3)
def test_gather_scatter(self):
def fn(node_feat, edge_index):
src_node_feat = node_feat[edge_index[0]]

View File

@ -1112,7 +1112,11 @@ def _prune_redundant_deps(
def should_prune(dep: Dep) -> bool:
if isinstance(dep, WeakDep):
op_name = name_to_buf[dep.name].defining_op_name()
is_redundant = name_to_dep_count[name_to_fused_node[op_name].get_name()] > 0
is_redundant = name_to_dep_count[
name_to_fused_node[op_name].get_name()
] > 0 and node.scheduler.fusable_weak_dep(
dep, name_to_fused_node[op_name], node
)
# These can occur because fused nodes always gather deps from their snodes
# If B has a weakdep on A
# B gets fused with C, then any time BC is fused, the weakdep will reappear
@ -3535,6 +3539,7 @@ class Scheduler:
- self.can_fuse(): checks if a fusion is legal
- self.score_fusion(): assigns priority to a given fusion
"""
self.prune_redundant_deps(nodes)
fused_nodes = OrderedSet(nodes)
if fusion_log.isEnabledFor(logging.DEBUG):
fusion_log.debug("fuse_nodes_once, candidates:")
@ -3628,7 +3633,6 @@ class Scheduler:
nodes = sorted(fused_nodes, key=lambda x: x.min_order)
nodes = self.topological_sort_schedule(nodes)
self.prune_redundant_deps(nodes)
return nodes
def create_combo_kernel_nodes(self, num_ck_nodes: Optional[int] = None) -> None:
@ -4408,22 +4412,36 @@ class Scheduler:
if len(mutating_writes) != 1:
return False
write = mutating_writes[0]
if isinstance(write, StarDep):
return False
assert isinstance(write, MemoryDep)
if free_symbol_is_type(write.index, SymT.TMP):
return False
real_name = self.mutation_real_name[weak_dep.mutating_buf]
relevant_reads = [
read for read in node1.read_writes.reads if read.name == real_name
]
return all(
isinstance(read, MemoryDep)
and not free_symbol_is_type(read.index, SymT.TMP)
and read.index == write.index
and read.size == write.size
for read in relevant_reads
)
relevant_reading_nodes = [node1]
if isinstance(node1, ForeachKernelSchedulerNode):
relevant_reading_nodes = node1.snodes
num_concurrent_reads = 0
for reading_node in relevant_reading_nodes:
relevant_reads = [
read
for read in reading_node.read_writes.reads
if read.name == real_name
]
if not relevant_reads:
continue
num_concurrent_reads += 1
if not all(
isinstance(read, MemoryDep)
and not free_symbol_is_type(read.index, SymT.TMP)
and read.index == write.index
and read.size == write.size
for read in relevant_reads
):
return False
return num_concurrent_reads <= 1
# StarDep doesn't match MemoryDep, different indices don't match
# However, broadcasting sometimes strips dimensions, and if that's the case