[PT2][fusion] ban fusions with large accumulated reads (#157563)

**Problem:**
Fusion can accumulate large amount of reads, which leads to significant increase in peak memory utilization. Imagine we have the following code snippet
```
total = torch.rand(N, N)
for _ in range(r):
    x = torch.rand(N, N)
    total = total + x
```
The default execution is memory efficient as only two tensors of size N-by-N is in memory at any given time. However, with fusion, the additions are fused into a single operation and the execution becomes something like:
```
x_1 = torch.rand(N, N)
x_2 =  torch.rand(N, N)
...
x_r = torch.rand(N, N)
total = x_1 + x_2 + ... + x_r
```
Though this is run-time efficient, in the case of large `N` and/or large `r`, this is not memory efficient.

[internal only] see [post](https://fb.workplace.com/groups/1075192433118967/permalink/1703374333634104/) for additional details

**Solution:**
Our proposed solution is to ban fusions in case where a large amount of reads are accumulated. This is in addition to some existing logics during torch compile.
* During lowering (i.e., `ir.py`), the config `realize_acc_reads_threshold`, which is default to be 8, controls _the number of_ buffers can be accumulated for a single operator. However, this is oblivious to the size of the buffers. Hence, we additionally introduce a config `realize_acc_reads_size_threshold` to control _the amount of buffers_ in size that can be accumulated.
* During scheduling (i.e., `scheduler.py`), additional fusion will be performed and thus we also need to capture such pattern there. The decisions are implemented under `choices.py`.

**Results:**
For a small example similar to be one in the test case (but with larger `N` and higher number of loop repeats), the memory snapshot before and after are shown below. Note the snapshot on the right is zoomed out so that the y-axis of the two snapshots match.

<img width="1328" alt="image" src="https://github.com/user-attachments/assets/670b5961-8454-4379-ae0f-62d4e7946c64" />

Pull Request resolved: https://github.com/pytorch/pytorch/pull/157563
Approved by: https://github.com/jansel, https://github.com/mlazos
This commit is contained in:
Xuan Zhang
2025-07-15 13:13:49 -07:00
committed by PyTorch MergeBot
parent 651b4a68f2
commit 8554c8007d
10 changed files with 131 additions and 49 deletions

View File

@ -1,4 +1,4 @@
add_loop_eager,compile_time_instruction_count,3017000000,0.015
add_loop_eager,compile_time_instruction_count,2994000000,0.015
@ -6,15 +6,15 @@ add_loop_eager_dynamic,compile_time_instruction_count,4352000000,0.025
add_loop_inductor,compile_time_instruction_count,29490000000,0.015
add_loop_inductor,compile_time_instruction_count,33260000000,0.015
add_loop_inductor_dynamic_gpu,compile_time_instruction_count,38760000000,0.025
add_loop_inductor_dynamic_gpu,compile_time_instruction_count,42900000000,0.025
add_loop_inductor_gpu,compile_time_instruction_count,26000000000,0.015
add_loop_inductor_gpu,compile_time_instruction_count,29880000000,0.015
@ -22,51 +22,51 @@ basic_modules_ListOfLinears_eager,compile_time_instruction_count,947600000,0.015
basic_modules_ListOfLinears_inductor,compile_time_instruction_count,18490000000,0.015
basic_modules_ListOfLinears_inductor,compile_time_instruction_count,17940000000,0.015
basic_modules_ListOfLinears_inductor_gpu_force_shape_pad,compile_time_instruction_count,17020000000,0.015
basic_modules_ListOfLinears_inductor_gpu_force_shape_pad,compile_time_instruction_count,17210000000,0.015
basic_modules_ListOfLinears_inductor_gpu,compile_time_instruction_count,10297683772,0.2
basic_modules_ListOfLinears_inductor_gpu,compile_time_instruction_count,10980000000,0.2
update_hint_regression,compile_time_instruction_count,1673000000,0.02
update_hint_regression,compile_time_instruction_count,1688000000,0.02
sum_floordiv_regression,compile_time_instruction_count,986800000,0.015
sum_floordiv_regression,compile_time_instruction_count,992700000,0.015
symint_sum,compile_time_instruction_count,3166000000,0.015
symint_sum,compile_time_instruction_count,3187000000,0.015
symint_sum_loop,compile_time_instruction_count,4202000000,0.015
symint_sum_loop,compile_time_instruction_count,4225000000,0.015
aotdispatcher_inference_nosubclass_cpu,compile_time_instruction_count,2103000000,0.015
aotdispatcher_inference_nosubclass_cpu,compile_time_instruction_count,2122000000,0.015
aotdispatcher_inference_subclass_cpu,compile_time_instruction_count,6004000000,0.015
aotdispatcher_inference_subclass_cpu,compile_time_instruction_count,6040000000,0.015
aotdispatcher_partitioner_cpu,compile_time_instruction_count,8783000000,0.015
aotdispatcher_partitioner_cpu,compile_time_instruction_count,8894000000,0.015
aotdispatcher_partitioner_cpu2,compile_time_instruction_count,1940000000,0.015
aotdispatcher_partitioner_cpu2,compile_time_instruction_count,1952000000,0.015
aotdispatcher_training_nosubclass_cpu,compile_time_instruction_count,3885000000,0.015
aotdispatcher_training_nosubclass_cpu,compile_time_instruction_count,3905000000,0.015
@ -74,15 +74,15 @@ aotdispatcher_training_subclass_cpu,compile_time_instruction_count,10470000000,0
mm_loop_inductor_gpu,compile_time_instruction_count,4324000000,0.015
mm_loop_inductor_gpu,compile_time_instruction_count,4406000000,0.015
mm_loop_inductor_dynamic_gpu,compile_time_instruction_count,8116000000,0.015
mm_loop_inductor_dynamic_gpu,compile_time_instruction_count,8274000000,0.015
basic_NestedModule_eager,compile_time_instruction_count,8152524390,0.015
basic_NestedModule_eager,compile_time_instruction_count,8193000000,0.015

1 add_loop_eager compile_time_instruction_count 3017000000 2994000000 0.015
2 add_loop_eager_dynamic compile_time_instruction_count 4352000000 4352000000 0.025
3 add_loop_inductor compile_time_instruction_count 29490000000 33260000000 0.015
4 add_loop_inductor_dynamic_gpu compile_time_instruction_count 38760000000 42900000000 0.025
6 basic_modules_ListOfLinears_eager compile_time_instruction_count 947600000 947600000 0.015
7 basic_modules_ListOfLinears_inductor compile_time_instruction_count 18490000000 17940000000 0.015
8 basic_modules_ListOfLinears_inductor_gpu_force_shape_pad compile_time_instruction_count 17020000000 17210000000 0.015
9 basic_modules_ListOfLinears_inductor_gpu compile_time_instruction_count 10297683772 10980000000 0.2
10 update_hint_regression compile_time_instruction_count 1673000000 1688000000 0.02
11 sum_floordiv_regression compile_time_instruction_count 986800000 992700000 0.015
12 symint_sum compile_time_instruction_count 3166000000 3187000000 0.015
13 symint_sum_loop compile_time_instruction_count 4202000000 4225000000 0.015
14 aotdispatcher_inference_nosubclass_cpu compile_time_instruction_count 2103000000 2122000000 0.015
15 aotdispatcher_inference_subclass_cpu compile_time_instruction_count 6004000000 6040000000 0.015
16 aotdispatcher_partitioner_cpu compile_time_instruction_count 8783000000 8894000000 0.015
17 aotdispatcher_partitioner_cpu2 compile_time_instruction_count 1940000000 1952000000 0.015
18 aotdispatcher_training_nosubclass_cpu compile_time_instruction_count 3885000000 3905000000 0.015
19 aotdispatcher_training_subclass_cpu compile_time_instruction_count 10470000000 10470000000 0.015
20 mm_loop_inductor_gpu compile_time_instruction_count 4324000000 4406000000 0.015
22 basic_NestedModule_eager compile_time_instruction_count 8152524390 8193000000 0.015
23 basic_InlineMod_eager compile_time_instruction_count 7255000000 7255000000 0.015
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88

View File

@ -9,6 +9,7 @@ from torch._dynamo.utils import same
from torch._inductor.test_case import run_tests, TestCase
from torch._inductor.utils import run_and_get_code
from torch.testing import FileCheck
from torch.testing._internal.common_utils import serialTest
from torch.testing._internal.inductor_utils import (
GPU_TYPE,
HAS_GPU,
@ -209,6 +210,7 @@ class InplacePaddingTest(TestCase):
self.assertEqual(num_inplace_padding(), 0)
@serialTest()
@requires_cuda_with_enough_memory(2e10)
@inductor_config.patch(force_shape_pad=True)
def test_linear_and_cel(self):

View File

@ -8,6 +8,7 @@ from torch._dynamo.utils import same
from torch._inductor import config, memory
from torch._inductor.test_case import TestCase
from torch._inductor.utils import run_and_get_triton_code
from torch.testing._internal.common_utils import serialTest
from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU
@ -306,6 +307,58 @@ class TestOperatorReorderForPeakMemory(TestCase):
expected_bound = a.size(0) * c.size(1) * a.dtype.itemsize * 2
self.assertLess(peak_mem, expected_bound)
@serialTest()
def test_fusion_acc_large_reads(self):
def f(x, y, z):
res = torch.zeros_like(x[0])
for i in range(4):
temp = torch.matmul(x, y) + z
res = res + temp
return res
N = 128
x = torch.rand(N, N, dtype=torch.float32, device=GPU_TYPE)
y = torch.rand(N, N, dtype=torch.float32, device=GPU_TYPE)
z = torch.rand(N, N, dtype=torch.float32, device=GPU_TYPE)
# CASE 1: no restriction on the amount of accumulation
with config.patch({"realize_acc_reads_size_threshold": float("inf")}):
f_compiled = torch.compile(f)
code = run_and_get_triton_code(f_compiled, x, y, z)
(
FileCheck()
.check("triton_poi_fused_add_0.run(buf4, arg2_1, buf1, buf2, buf3")
.run(code)
)
# CASE 2: for tensors with the same size as x (which is 4 * N**2 bytes)
# at most 12 / 4 = 3 reads can be accumulated during fusion
with config.patch({"realize_acc_reads_size_threshold": 12 * N**2}):
f_compiled = torch.compile(f)
code = run_and_get_triton_code(f_compiled, x, y, z)
(
FileCheck()
.check("triton_poi_fused_add_0.run(buf3, arg2_1, buf1, buf2,")
.check("triton_poi_fused_add_1.run(buf5, buf4, arg2_1,")
.run(code)
)
# CASE 3: no such fusion allowed
with config.patch({"realize_acc_reads_size_threshold": N**2}):
f_compiled = torch.compile(f)
code = run_and_get_triton_code(f_compiled, x, y, z)
(
FileCheck()
.check("triton_poi_fused_add_0.run(buf1, arg2_1,")
.check("triton_poi_fused_add_0.run(buf3, arg2_1,")
.check("triton_poi_fused_add_0.run(buf4, buf3,")
.check("triton_poi_fused_add_0.run(buf6, arg2_1,")
.check("triton_poi_fused_add_0.run(buf7, buf6,")
.check("triton_poi_fused_add_0.run(buf9, arg2_1,")
.check("triton_poi_fused_add_0.run(buf10, buf9,")
.run(code)
)
if __name__ == "__main__":
from torch._inductor.test_case import run_tests

View File

@ -13,6 +13,7 @@ from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
IS_LINUX,
parametrize,
serialTest,
)
from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_CUDA
@ -77,12 +78,17 @@ class TestOnlineSoftmax(TestCase):
out, source_codes = run_and_get_code(f, x)
return source_codes[0]
@serialTest()
def test_codegen_3pass_softmax_due_to_disable(self):
with inductor_config.patch(online_softmax=False):
with inductor_config.patch(
online_softmax=False,
realize_acc_reads_size_threshold=float("inf"),
):
wrapper_code = self.get_softmax_wrapper()
self.assertEqual(wrapper_code.count("for r0_offset in"), 3)
@serialTest()
@parametrize("V", [2048, 50304])
@parametrize("use_log_softmax", [False, True])
def test_codegen_online_softmax(self, use_log_softmax, V):

View File

@ -365,6 +365,10 @@ class InductorChoices:
WhyNoFuse(node1, node2)("Fusion will increase peak memory")
return False
if scheduler.fusion_accumulate_large_reads(node1, node2):
WhyNoFuse(node1, node2)("Fusion accumulate large amount of reads")
return False
return True
@staticmethod

View File

@ -574,6 +574,7 @@ realize_opcount_threshold = 30
# Threshold to prevent excessive accumulation of ops in one buffer during lowering
realize_acc_reads_threshold = 8
realize_acc_reads_size_threshold = 3 * (1024**3)
# fallback to eager for random/dropout, this is slow but useful for debugging
fallback_random = False

View File

@ -123,6 +123,7 @@ if TYPE_CHECKING:
from torch.fx.graph import Graph
from .codegen.wrapper import PythonWrapperCodegen
from .dependencies import Dep
from .scheduler import BaseSchedulerNode
CompiledModule = Union[ModuleType, FileBackedGraphModule]
@ -485,6 +486,9 @@ class GraphLowering(torch.fx.Interpreter):
self.bw_donated_idxs = get_donated_idxs()
# Cache for dep size hints to avoid expensive recomputation
self.dep_size_hint_cache: dict[Dep, int] = {}
def freeze_runtime_asserts(self) -> None:
self._shape_env.freeze_runtime_asserts()
@ -570,6 +574,23 @@ class GraphLowering(torch.fx.Interpreter):
assert isinstance(feature, BackendFeature), feature
return feature in self.get_backend_features(get_device_type(device))
def get_dep_size_hint(self, dep: Dep) -> int:
"""
Get the size hint for a dependency with caching to avoid expensive recomputation.
"""
if dep not in self.dep_size_hint_cache:
res = 0
try:
if not dep.has_unbacked_symbols():
res = dep.numbytes_hint()
except KeyError:
# In at least one test (test/inductor/test_torchbind.py) we
# create a StarDep that doesn't exist in the graph and calling
# `has_unbacked_symbols()` throws an error.
pass
self.dep_size_hint_cache[dep] = res
return self.dep_size_hint_cache[dep]
def get_current_device_or_throw(self) -> torch.device:
if device := self.current_device:
return device

View File

@ -7829,6 +7829,10 @@ class TensorBox(MutableBox):
class StorageBox(MutableBox):
"""
StorageBox allow in-place mutation of Tensors
"""
def is_input_buffer(self) -> bool:
if isinstance(self.data, (InputBuffer, ReinterpretView)):
return self.data.get_name() in V.graph.graph_inputs
@ -7878,10 +7882,17 @@ class StorageBox(MutableBox):
):
self.realize()
def has_accumulated_enough_reads_by_size(self) -> bool:
return (
sum(V.graph.get_dep_size_hint(dep) for dep in self.get_reads())
> config.realize_acc_reads_size_threshold
)
def has_exceeded_max_reads(self) -> bool:
return isinstance(self.data, Pointwise) and (
self.num_reads() > config.realize_acc_reads_threshold
or self.has_large_inner_fn()
or self.has_accumulated_enough_reads_by_size()
)
def should_realize_on_reuse(self, users: int) -> bool:

View File

@ -78,19 +78,8 @@ def get_freeable_input_buf(
A dictionary containing all freeble input buffers, keyed by their names.
"""
# this function is copied from torch/_inductor/scheduler.py
# TODO: would be nice to remove the try/except block for both places
def _dep_size_hint(dep: Dep) -> int:
res = 0
try:
if not dep.has_unbacked_symbols():
res = dep.numbytes_hint()
except KeyError:
# In at least one test (test/inductor/test_torchbind.py) we
# create a StarDep that doesn't exist in the graph and calling
# `has_unbacked_symbols()` throws an error.
pass
return res
return V.graph.get_dep_size_hint(dep)
# get freeable input buffers' successor nodes and their sizes
# note that different deps can have the same name, so we use name as keys

View File

@ -2051,15 +2051,12 @@ class Scheduler:
optimizations such as fusion, reorder, and graph partition.
"""
__dep_size_hint_cache: dict[Dep, int]
def __init__(self, nodes: list[ir.Operation]) -> None:
with dynamo_timed("Scheduler.__init__"):
self._init(nodes)
def _init(self, nodes: list[ir.Operation]) -> None:
super().__init__()
self.__dep_size_hint_cache = {}
V.graph.scheduler = self
self.backends: dict[torch.device, BaseScheduling] = {}
self.post_grad_graph_id = next(_post_grad_graph_counter)
@ -3505,6 +3502,17 @@ class Scheduler:
return True
return False
def fusion_accumulate_large_reads(
self, node1: BaseSchedulerNode, node2: BaseSchedulerNode
) -> bool:
all_reads = (node1.read_writes.reads | node2.read_writes.reads) - (
node1.read_writes.writes | node2.read_writes.writes
)
return (
sum(self.dep_size_hint(dep) for dep in all_reads)
> config.realize_acc_reads_size_threshold
)
def are_long_distant_nodes(
self, node1: BaseSchedulerNode, node2: BaseSchedulerNode
) -> bool:
@ -4010,20 +4018,7 @@ class Scheduler:
return False
def dep_size_hint(self, dep: Dep) -> int:
res = 0
if dep not in self.__dep_size_hint_cache:
try:
if not dep.has_unbacked_symbols():
res = dep.numbytes_hint()
except KeyError:
# In at least one test (test/inductor/test_torchbind.py) we
# create a StarDep that doesn't exist in the graph and calling
# `has_unbacked_symbols()` throws an error.
pass
self.__dep_size_hint_cache[dep] = res
else:
res = self.__dep_size_hint_cache[dep]
return res
return V.graph.get_dep_size_hint(dep)
def score_fusion_memory(
self, node1: BaseSchedulerNode, node2: BaseSchedulerNode