[PT2][fusion] ban fusions with large accumulated reads (#157563)

**Problem:**
Fusion can accumulate large amount of reads, which leads to significant increase in peak memory utilization. Imagine we have the following code snippet
```
total = torch.rand(N, N)
for _ in range(r):
    x = torch.rand(N, N)
    total = total + x
```
The default execution is memory efficient as only two tensors of size N-by-N is in memory at any given time. However, with fusion, the additions are fused into a single operation and the execution becomes something like:
```
x_1 = torch.rand(N, N)
x_2 =  torch.rand(N, N)
...
x_r = torch.rand(N, N)
total = x_1 + x_2 + ... + x_r
```
Though this is run-time efficient, in the case of large `N` and/or large `r`, this is not memory efficient.

[internal only] see [post](https://fb.workplace.com/groups/1075192433118967/permalink/1703374333634104/) for additional details

**Solution:**
Our proposed solution is to ban fusions in case where a large amount of reads are accumulated. This is in addition to some existing logics during torch compile.
* During lowering (i.e., `ir.py`), the config `realize_acc_reads_threshold`, which is default to be 8, controls _the number of_ buffers can be accumulated for a single operator. However, this is oblivious to the size of the buffers. Hence, we additionally introduce a config `realize_acc_reads_size_threshold` to control _the amount of buffers_ in size that can be accumulated.
* During scheduling (i.e., `scheduler.py`), additional fusion will be performed and thus we also need to capture such pattern there. The decisions are implemented under `choices.py`.

**Results:**
For a small example similar to be one in the test case (but with larger `N` and higher number of loop repeats), the memory snapshot before and after are shown below. Note the snapshot on the right is zoomed out so that the y-axis of the two snapshots match.

<img width="1328" alt="image" src="https://github.com/user-attachments/assets/670b5961-8454-4379-ae0f-62d4e7946c64" />

Pull Request resolved: https://github.com/pytorch/pytorch/pull/157563
Approved by: https://github.com/jansel, https://github.com/mlazos
This commit is contained in:
Xuan Zhang
2025-07-14 10:59:01 -07:00
committed by PyTorch MergeBot
parent 9345279c6e
commit c062550a35
9 changed files with 118 additions and 106 deletions

View File

@ -1,89 +1,23 @@
add_loop_eager,compile_time_instruction_count,3017000000,0.015
add_loop_eager,compile_time_instruction_count,2996000000,0.015
add_loop_eager_dynamic,compile_time_instruction_count,4352000000,0.025
add_loop_inductor,compile_time_instruction_count,29490000000,0.015
add_loop_inductor_dynamic_gpu,compile_time_instruction_count,38760000000,0.025
add_loop_inductor_gpu,compile_time_instruction_count,26000000000,0.015
add_loop_inductor,compile_time_instruction_count,33090000000,0.015
add_loop_inductor_dynamic_gpu,compile_time_instruction_count,42660000000,0.025
add_loop_inductor_gpu,compile_time_instruction_count,29690000000,0.015
basic_modules_ListOfLinears_eager,compile_time_instruction_count,947600000,0.015
basic_modules_ListOfLinears_inductor,compile_time_instruction_count,18490000000,0.015
basic_modules_ListOfLinears_inductor_gpu_force_shape_pad,compile_time_instruction_count,17020000000,0.015
basic_modules_ListOfLinears_inductor_gpu,compile_time_instruction_count,10297683772,0.2
basic_modules_ListOfLinears_inductor,compile_time_instruction_count,18830000000,0.015
basic_modules_ListOfLinears_inductor_gpu_force_shape_pad,compile_time_instruction_count,17460000000,0.015
basic_modules_ListOfLinears_inductor_gpu,compile_time_instruction_count,11020000000,0.2
update_hint_regression,compile_time_instruction_count,1673000000,0.02
sum_floordiv_regression,compile_time_instruction_count,986800000,0.015
symint_sum,compile_time_instruction_count,3166000000,0.015
symint_sum,compile_time_instruction_count,3184000000,0.015
symint_sum_loop,compile_time_instruction_count,4202000000,0.015
aotdispatcher_inference_nosubclass_cpu,compile_time_instruction_count,2103000000,0.015
aotdispatcher_inference_subclass_cpu,compile_time_instruction_count,6004000000,0.015
aotdispatcher_partitioner_cpu,compile_time_instruction_count,8783000000,0.015
aotdispatcher_partitioner_cpu2,compile_time_instruction_count,1940000000,0.015
aotdispatcher_training_nosubclass_cpu,compile_time_instruction_count,3885000000,0.015
aotdispatcher_training_subclass_cpu,compile_time_instruction_count,10470000000,0.015
mm_loop_inductor_gpu,compile_time_instruction_count,4324000000,0.015
mm_loop_inductor_dynamic_gpu,compile_time_instruction_count,8116000000,0.015
mm_loop_inductor_gpu,compile_time_instruction_count,4365000000,0.015
mm_loop_inductor_dynamic_gpu,compile_time_instruction_count,8184000000,0.015
basic_NestedModule_eager,compile_time_instruction_count,8152524390,0.015
basic_InlineMod_eager,compile_time_instruction_count,7255000000,0.015

1 add_loop_eager compile_time_instruction_count 3017000000 2996000000 0.015
add_loop_eager_dynamic compile_time_instruction_count 4352000000 0.025
add_loop_inductor compile_time_instruction_count 29490000000 0.015
add_loop_inductor_dynamic_gpu compile_time_instruction_count 38760000000 0.025
2 add_loop_inductor_gpu add_loop_eager_dynamic compile_time_instruction_count 26000000000 4352000000 0.015 0.025
3 basic_modules_ListOfLinears_eager add_loop_inductor compile_time_instruction_count 947600000 33090000000 0.015
4 basic_modules_ListOfLinears_inductor add_loop_inductor_dynamic_gpu compile_time_instruction_count 18490000000 42660000000 0.015 0.025
5 basic_modules_ListOfLinears_inductor_gpu_force_shape_pad add_loop_inductor_gpu compile_time_instruction_count 17020000000 29690000000 0.015
basic_modules_ListOfLinears_inductor_gpu compile_time_instruction_count 10297683772 0.2
update_hint_regression compile_time_instruction_count 1673000000 0.02
sum_floordiv_regression compile_time_instruction_count 986800000 0.015
symint_sum compile_time_instruction_count 3166000000 0.015
symint_sum_loop compile_time_instruction_count 4202000000 0.015
aotdispatcher_inference_nosubclass_cpu compile_time_instruction_count 2103000000 0.015
aotdispatcher_inference_subclass_cpu compile_time_instruction_count 6004000000 0.015
aotdispatcher_partitioner_cpu compile_time_instruction_count 8783000000 0.015
aotdispatcher_partitioner_cpu2 compile_time_instruction_count 1940000000 0.015
aotdispatcher_training_nosubclass_cpu compile_time_instruction_count 3885000000 0.015
aotdispatcher_training_subclass_cpu compile_time_instruction_count 10470000000 0.015
mm_loop_inductor_gpu compile_time_instruction_count 4324000000 0.015
6 mm_loop_inductor_dynamic_gpu basic_modules_ListOfLinears_eager compile_time_instruction_count 8116000000 947600000 0.015
7 basic_NestedModule_eager basic_modules_ListOfLinears_inductor compile_time_instruction_count 8152524390 18830000000 0.015
8 basic_InlineMod_eager basic_modules_ListOfLinears_inductor_gpu_force_shape_pad compile_time_instruction_count 7255000000 17460000000 0.015
9 basic_modules_ListOfLinears_inductor_gpu compile_time_instruction_count 11020000000 0.2
10 update_hint_regression compile_time_instruction_count 1673000000 0.02
11 sum_floordiv_regression compile_time_instruction_count 986800000 0.015
12 symint_sum compile_time_instruction_count 3184000000 0.015
13 symint_sum_loop compile_time_instruction_count 4202000000 0.015
14 aotdispatcher_inference_nosubclass_cpu compile_time_instruction_count 2103000000 0.015
15 aotdispatcher_inference_subclass_cpu compile_time_instruction_count 6004000000 0.015
16 aotdispatcher_partitioner_cpu compile_time_instruction_count 8783000000 0.015
17 aotdispatcher_partitioner_cpu2 compile_time_instruction_count 1940000000 0.015
18 aotdispatcher_training_nosubclass_cpu compile_time_instruction_count 3885000000 0.015
19 aotdispatcher_training_subclass_cpu compile_time_instruction_count 10470000000 0.015
20 mm_loop_inductor_gpu compile_time_instruction_count 4365000000 0.015
21 mm_loop_inductor_dynamic_gpu compile_time_instruction_count 8184000000 0.015
22 basic_NestedModule_eager compile_time_instruction_count 8152524390 0.015
23 basic_InlineMod_eager compile_time_instruction_count 7255000000 0.015

View File

@ -306,6 +306,57 @@ class TestOperatorReorderForPeakMemory(TestCase):
expected_bound = a.size(0) * c.size(1) * a.dtype.itemsize * 2
self.assertLess(peak_mem, expected_bound)
def test_fusion_acc_large_reads(self):
def f(x, y, z):
res = torch.zeros_like(x[0])
for i in range(4):
temp = torch.matmul(x, y) + z
res = res + temp
return res
N = 128
x = torch.rand(N, N, dtype=torch.float32, device=GPU_TYPE)
y = torch.rand(N, N, dtype=torch.float32, device=GPU_TYPE)
z = torch.rand(N, N, dtype=torch.float32, device=GPU_TYPE)
# CASE 1: no restriction on the amount of accumulation
with config.patch({"realize_acc_reads_size_threshold": float("inf")}):
f_compiled = torch.compile(f)
code = run_and_get_triton_code(f_compiled, x, y, z)
(
FileCheck()
.check("triton_poi_fused_add_0.run(buf4, arg2_1, buf1, buf2, buf3")
.run(code)
)
# CASE 2: for tensors with the same size as x (which is 4 * N**2 bytes)
# at most 12 / 4 = 3 reads can be accumulated during fusion
with config.patch({"realize_acc_reads_size_threshold": 12 * N**2}):
f_compiled = torch.compile(f)
code = run_and_get_triton_code(f_compiled, x, y, z)
(
FileCheck()
.check("triton_poi_fused_add_0.run(buf3, arg2_1, buf1, buf2,")
.check("triton_poi_fused_add_1.run(buf5, buf4, arg2_1,")
.run(code)
)
# CASE 3: no such fusion allowed
with config.patch({"realize_acc_reads_size_threshold": N**2}):
f_compiled = torch.compile(f)
code = run_and_get_triton_code(f_compiled, x, y, z)
(
FileCheck()
.check("triton_poi_fused_add_0.run(buf1, arg2_1,")
.check("triton_poi_fused_add_0.run(buf3, arg2_1,")
.check("triton_poi_fused_add_0.run(buf4, buf3,")
.check("triton_poi_fused_add_0.run(buf6, arg2_1,")
.check("triton_poi_fused_add_0.run(buf7, buf6,")
.check("triton_poi_fused_add_0.run(buf9, arg2_1,")
.check("triton_poi_fused_add_0.run(buf10, buf9,")
.run(code)
)
if __name__ == "__main__":
from torch._inductor.test_case import run_tests

View File

@ -13,6 +13,7 @@ from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
IS_LINUX,
parametrize,
serialTest,
)
from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_CUDA
@ -77,12 +78,17 @@ class TestOnlineSoftmax(TestCase):
out, source_codes = run_and_get_code(f, x)
return source_codes[0]
@serialTest()
def test_codegen_3pass_softmax_due_to_disable(self):
with inductor_config.patch(online_softmax=False):
with inductor_config.patch(
online_softmax=False,
realize_acc_reads_size_threshold=float("inf"),
):
wrapper_code = self.get_softmax_wrapper()
self.assertEqual(wrapper_code.count("for r0_offset in"), 3)
@serialTest()
@parametrize("V", [2048, 50304])
@parametrize("use_log_softmax", [False, True])
def test_codegen_online_softmax(self, use_log_softmax, V):

View File

@ -365,6 +365,10 @@ class InductorChoices:
WhyNoFuse(node1, node2)("Fusion will increase peak memory")
return False
if scheduler.fusion_accumulate_large_reads(node1, node2):
WhyNoFuse(node1, node2)("Fusion accumulate large amount of reads")
return False
return True
@staticmethod

View File

@ -574,6 +574,7 @@ realize_opcount_threshold = 30
# Threshold to prevent excessive accumulation of ops in one buffer during lowering
realize_acc_reads_threshold = 8
realize_acc_reads_size_threshold = 3 * (1024**3)
# fallback to eager for random/dropout, this is slow but useful for debugging
fallback_random = False

View File

@ -123,6 +123,7 @@ if TYPE_CHECKING:
from torch.fx.graph import Graph
from .codegen.wrapper import PythonWrapperCodegen
from .dependencies import Dep
from .scheduler import BaseSchedulerNode
CompiledModule = Union[ModuleType, FileBackedGraphModule]
@ -485,6 +486,9 @@ class GraphLowering(torch.fx.Interpreter):
self.bw_donated_idxs = get_donated_idxs()
# Cache for dep size hints to avoid expensive recomputation
self.dep_size_hint_cache: dict[Dep, int] = {}
def freeze_runtime_asserts(self) -> None:
self._shape_env.freeze_runtime_asserts()
@ -570,6 +574,23 @@ class GraphLowering(torch.fx.Interpreter):
assert isinstance(feature, BackendFeature), feature
return feature in self.get_backend_features(get_device_type(device))
def get_dep_size_hint(self, dep: Dep) -> int:
"""
Get the size hint for a dependency with caching to avoid expensive recomputation.
"""
if dep not in self.dep_size_hint_cache:
res = 0
try:
if not dep.has_unbacked_symbols():
res = dep.numbytes_hint()
except KeyError:
# In at least one test (test/inductor/test_torchbind.py) we
# create a StarDep that doesn't exist in the graph and calling
# `has_unbacked_symbols()` throws an error.
pass
self.dep_size_hint_cache[dep] = res
return self.dep_size_hint_cache[dep]
def get_current_device_or_throw(self) -> torch.device:
if device := self.current_device:
return device

View File

@ -7829,6 +7829,10 @@ class TensorBox(MutableBox):
class StorageBox(MutableBox):
"""
StorageBox allow in-place mutation of Tensors
"""
def is_input_buffer(self) -> bool:
if isinstance(self.data, (InputBuffer, ReinterpretView)):
return self.data.get_name() in V.graph.graph_inputs
@ -7878,10 +7882,17 @@ class StorageBox(MutableBox):
):
self.realize()
def has_accumulated_enough_reads_by_size(self) -> bool:
return (
sum(V.graph.get_dep_size_hint(dep) for dep in self.get_reads())
> config.realize_acc_reads_size_threshold
)
def has_exceeded_max_reads(self) -> bool:
return isinstance(self.data, Pointwise) and (
self.num_reads() > config.realize_acc_reads_threshold
or self.has_large_inner_fn()
or self.has_accumulated_enough_reads_by_size()
)
def should_realize_on_reuse(self, users: int) -> bool:

View File

@ -78,19 +78,8 @@ def get_freeable_input_buf(
A dictionary containing all freeble input buffers, keyed by their names.
"""
# this function is copied from torch/_inductor/scheduler.py
# TODO: would be nice to remove the try/except block for both places
def _dep_size_hint(dep: Dep) -> int:
res = 0
try:
if not dep.has_unbacked_symbols():
res = dep.numbytes_hint()
except KeyError:
# In at least one test (test/inductor/test_torchbind.py) we
# create a StarDep that doesn't exist in the graph and calling
# `has_unbacked_symbols()` throws an error.
pass
return res
return V.graph.get_dep_size_hint(dep)
# get freeable input buffers' successor nodes and their sizes
# note that different deps can have the same name, so we use name as keys

View File

@ -2051,15 +2051,12 @@ class Scheduler:
optimizations such as fusion, reorder, and graph partition.
"""
__dep_size_hint_cache: dict[Dep, int]
def __init__(self, nodes: list[ir.Operation]) -> None:
with dynamo_timed("Scheduler.__init__"):
self._init(nodes)
def _init(self, nodes: list[ir.Operation]) -> None:
super().__init__()
self.__dep_size_hint_cache = {}
V.graph.scheduler = self
self.backends: dict[torch.device, BaseScheduling] = {}
self.post_grad_graph_id = next(_post_grad_graph_counter)
@ -3505,6 +3502,17 @@ class Scheduler:
return True
return False
def fusion_accumulate_large_reads(
self, node1: BaseSchedulerNode, node2: BaseSchedulerNode
) -> bool:
all_reads = (node1.read_writes.reads | node2.read_writes.reads) - (
node1.read_writes.writes | node2.read_writes.writes
)
return (
sum(self.dep_size_hint(dep) for dep in all_reads)
> config.realize_acc_reads_size_threshold
)
def are_long_distant_nodes(
self, node1: BaseSchedulerNode, node2: BaseSchedulerNode
) -> bool:
@ -4010,20 +4018,7 @@ class Scheduler:
return False
def dep_size_hint(self, dep: Dep) -> int:
res = 0
if dep not in self.__dep_size_hint_cache:
try:
if not dep.has_unbacked_symbols():
res = dep.numbytes_hint()
except KeyError:
# In at least one test (test/inductor/test_torchbind.py) we
# create a StarDep that doesn't exist in the graph and calling
# `has_unbacked_symbols()` throws an error.
pass
self.__dep_size_hint_cache[dep] = res
else:
res = self.__dep_size_hint_cache[dep]
return res
return V.graph.get_dep_size_hint(dep)
def score_fusion_memory(
self, node1: BaseSchedulerNode, node2: BaseSchedulerNode