add more restriction to fusion with large accumulate reads (#163163)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163163
Approved by: https://github.com/yf225
This commit is contained in:
Xuan Zhang
2025-09-17 08:40:15 -07:00
committed by PyTorch MergeBot
parent 3c9e220f34
commit 6e680ae8de
3 changed files with 19 additions and 10 deletions

View File

@ -408,13 +408,9 @@ class TestOperatorReorderForPeakMemory(TestCase):
code = run_and_get_triton_code(f_compiled, x, y, z)
(
FileCheck()
.check("triton_poi_fused_add_0.run(buf1, arg2_1,")
.check("triton_poi_fused_add_0.run(buf3, arg2_1,")
.check("triton_poi_fused_add_0.run(buf4, buf3,")
.check("triton_poi_fused_add_0.run(buf6, arg2_1,")
.check("triton_poi_fused_add_0.run(buf7, buf6,")
.check("triton_poi_fused_add_0.run(buf9, arg2_1,")
.check("triton_poi_fused_add_0.run(buf10, buf9,")
.check("triton_poi_fused_add_0.run(buf2, arg2_1, buf1,")
.check("triton_poi_fused_add_1.run(buf4, buf3, arg2_1")
.check("triton_poi_fused_add_1.run(buf6, buf5, arg2_1,")
.run(code)
)

View File

@ -8199,12 +8199,23 @@ class StorageBox(MutableBox):
self.realize()
def has_accumulated_enough_reads_by_size(self, threshold: int) -> bool:
size_of_reads = [V.graph.get_dep_size_hint(dep) for dep in self.get_reads()]
from torch._inductor.utils import is_nonfreeable_buffers
size_of_reads = [
V.graph.get_dep_size_hint(dep)
for dep in self.get_reads()
if not is_nonfreeable_buffers(dep)
]
if not size_of_reads:
return False
total_size = sum(size_of_reads)
max_size = max(size_of_reads)
return total_size > threshold and total_size / max_size >= 2
min_size = min(size_of_reads)
return (
total_size >= threshold
and total_size / max_size >= 2
and max_size == min_size
)
def has_exceeded_max_reads(self) -> bool:
return isinstance(self.data, Pointwise) and (

View File

@ -3749,4 +3749,6 @@ def is_nonfreeable_buffers(dep: Dep) -> bool:
# before checking for known strings.
if V.graph.name:
dep_name = dep_name.removeprefix(V.graph.name + "_")
return dep_name.startswith(("primals_", "arg", "fwd_rng_state", "bwd_rng_state"))
return dep_name.startswith(
("primals_", "arg", "fwd_rng_state", "bwd_rng_state", "tangents")
)