mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
add more restriction to fusion with large accumulate reads (#163163)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/163163 Approved by: https://github.com/yf225
This commit is contained in:
committed by
PyTorch MergeBot
parent
3c9e220f34
commit
6e680ae8de
@ -408,13 +408,9 @@ class TestOperatorReorderForPeakMemory(TestCase):
|
||||
code = run_and_get_triton_code(f_compiled, x, y, z)
|
||||
(
|
||||
FileCheck()
|
||||
.check("triton_poi_fused_add_0.run(buf1, arg2_1,")
|
||||
.check("triton_poi_fused_add_0.run(buf3, arg2_1,")
|
||||
.check("triton_poi_fused_add_0.run(buf4, buf3,")
|
||||
.check("triton_poi_fused_add_0.run(buf6, arg2_1,")
|
||||
.check("triton_poi_fused_add_0.run(buf7, buf6,")
|
||||
.check("triton_poi_fused_add_0.run(buf9, arg2_1,")
|
||||
.check("triton_poi_fused_add_0.run(buf10, buf9,")
|
||||
.check("triton_poi_fused_add_0.run(buf2, arg2_1, buf1,")
|
||||
.check("triton_poi_fused_add_1.run(buf4, buf3, arg2_1")
|
||||
.check("triton_poi_fused_add_1.run(buf6, buf5, arg2_1,")
|
||||
.run(code)
|
||||
)
|
||||
|
||||
|
@ -8199,12 +8199,23 @@ class StorageBox(MutableBox):
|
||||
self.realize()
|
||||
|
||||
def has_accumulated_enough_reads_by_size(self, threshold: int) -> bool:
|
||||
size_of_reads = [V.graph.get_dep_size_hint(dep) for dep in self.get_reads()]
|
||||
from torch._inductor.utils import is_nonfreeable_buffers
|
||||
|
||||
size_of_reads = [
|
||||
V.graph.get_dep_size_hint(dep)
|
||||
for dep in self.get_reads()
|
||||
if not is_nonfreeable_buffers(dep)
|
||||
]
|
||||
if not size_of_reads:
|
||||
return False
|
||||
total_size = sum(size_of_reads)
|
||||
max_size = max(size_of_reads)
|
||||
return total_size > threshold and total_size / max_size >= 2
|
||||
min_size = min(size_of_reads)
|
||||
return (
|
||||
total_size >= threshold
|
||||
and total_size / max_size >= 2
|
||||
and max_size == min_size
|
||||
)
|
||||
|
||||
def has_exceeded_max_reads(self) -> bool:
|
||||
return isinstance(self.data, Pointwise) and (
|
||||
|
@ -3749,4 +3749,6 @@ def is_nonfreeable_buffers(dep: Dep) -> bool:
|
||||
# before checking for known strings.
|
||||
if V.graph.name:
|
||||
dep_name = dep_name.removeprefix(V.graph.name + "_")
|
||||
return dep_name.startswith(("primals_", "arg", "fwd_rng_state", "bwd_rng_state"))
|
||||
return dep_name.startswith(
|
||||
("primals_", "arg", "fwd_rng_state", "bwd_rng_state", "tangents")
|
||||
)
|
||||
|
Reference in New Issue
Block a user