mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[inductor] bugfix: keep WeakDeps (WAR deps) during fusion (#162316)
fixes #159855, was not triggered in other tests since it took more than one round of fusion to get to the problematic code which prunes WeakDeps. The WeakDeps are important to inhibit fusion of kernels that read/write data into mutated buffers with different indexing. We modify the code to a) always prune before fusion, rather than after, which improves its coverage and makes our basic vertical fusion tests surface this issue as well and b) check whether the weak dep is fusable before eliminating it (which basically means checking that the producing code and the consuming code are sufficiently compatible). The tests that trigger this with change (a) is: test_fusing_write_into_disjoint_read introduced in #118210. Pull Request resolved: https://github.com/pytorch/pytorch/pull/162316 Approved by: https://github.com/eellison, https://github.com/mlazos, https://github.com/shunting314
This commit is contained in:
committed by
PyTorch MergeBot
parent
5d8a226e23
commit
f34744d2a5
@ -5869,6 +5869,16 @@ class CommonTemplate:
|
||||
a = torch.rand((1, 1000000), device=self.device)
|
||||
self.common(f, (a,))
|
||||
|
||||
def test_inplace_flip(self):
|
||||
def f(x, y):
|
||||
x.copy_(x.flip(1))
|
||||
y = y.sum(dim=1, keepdim=True) + y
|
||||
return x + y
|
||||
|
||||
x = torch.randn(20, 1024 * 1024)
|
||||
y = torch.randn(20, 1024 * 1024)
|
||||
self.common(f, (x, y), atol=1e-3, rtol=1e-3)
|
||||
|
||||
def test_gather_scatter(self):
|
||||
def fn(node_feat, edge_index):
|
||||
src_node_feat = node_feat[edge_index[0]]
|
||||
|
@ -1112,7 +1112,11 @@ def _prune_redundant_deps(
|
||||
def should_prune(dep: Dep) -> bool:
|
||||
if isinstance(dep, WeakDep):
|
||||
op_name = name_to_buf[dep.name].defining_op_name()
|
||||
is_redundant = name_to_dep_count[name_to_fused_node[op_name].get_name()] > 0
|
||||
is_redundant = name_to_dep_count[
|
||||
name_to_fused_node[op_name].get_name()
|
||||
] > 0 and node.scheduler.fusable_weak_dep(
|
||||
dep, name_to_fused_node[op_name], node
|
||||
)
|
||||
# These can occur because fused nodes always gather deps from their snodes
|
||||
# If B has a weakdep on A
|
||||
# B gets fused with C, then any time BC is fused, the weakdep will reappear
|
||||
@ -3535,6 +3539,7 @@ class Scheduler:
|
||||
- self.can_fuse(): checks if a fusion is legal
|
||||
- self.score_fusion(): assigns priority to a given fusion
|
||||
"""
|
||||
self.prune_redundant_deps(nodes)
|
||||
fused_nodes = OrderedSet(nodes)
|
||||
if fusion_log.isEnabledFor(logging.DEBUG):
|
||||
fusion_log.debug("fuse_nodes_once, candidates:")
|
||||
@ -3628,7 +3633,6 @@ class Scheduler:
|
||||
|
||||
nodes = sorted(fused_nodes, key=lambda x: x.min_order)
|
||||
nodes = self.topological_sort_schedule(nodes)
|
||||
self.prune_redundant_deps(nodes)
|
||||
return nodes
|
||||
|
||||
def create_combo_kernel_nodes(self, num_ck_nodes: Optional[int] = None) -> None:
|
||||
@ -4408,22 +4412,36 @@ class Scheduler:
|
||||
if len(mutating_writes) != 1:
|
||||
return False
|
||||
write = mutating_writes[0]
|
||||
if isinstance(write, StarDep):
|
||||
return False
|
||||
assert isinstance(write, MemoryDep)
|
||||
|
||||
if free_symbol_is_type(write.index, SymT.TMP):
|
||||
return False
|
||||
|
||||
real_name = self.mutation_real_name[weak_dep.mutating_buf]
|
||||
relevant_reads = [
|
||||
read for read in node1.read_writes.reads if read.name == real_name
|
||||
]
|
||||
return all(
|
||||
isinstance(read, MemoryDep)
|
||||
and not free_symbol_is_type(read.index, SymT.TMP)
|
||||
and read.index == write.index
|
||||
and read.size == write.size
|
||||
for read in relevant_reads
|
||||
)
|
||||
relevant_reading_nodes = [node1]
|
||||
if isinstance(node1, ForeachKernelSchedulerNode):
|
||||
relevant_reading_nodes = node1.snodes
|
||||
num_concurrent_reads = 0
|
||||
for reading_node in relevant_reading_nodes:
|
||||
relevant_reads = [
|
||||
read
|
||||
for read in reading_node.read_writes.reads
|
||||
if read.name == real_name
|
||||
]
|
||||
if not relevant_reads:
|
||||
continue
|
||||
num_concurrent_reads += 1
|
||||
if not all(
|
||||
isinstance(read, MemoryDep)
|
||||
and not free_symbol_is_type(read.index, SymT.TMP)
|
||||
and read.index == write.index
|
||||
and read.size == write.size
|
||||
for read in relevant_reads
|
||||
):
|
||||
return False
|
||||
return num_concurrent_reads <= 1
|
||||
|
||||
# StarDep doesn't match MemoryDep, different indices don't match
|
||||
# However, broadcasting sometimes strips dimensions, and if that's the case
|
||||
|
Reference in New Issue
Block a user