Revert "[PT2][fusion] ban fusions with large accumulated reads (#157563)"

This reverts commit c062550a3598d27c2d6572db7c0f4ff90a84cc84.

Reverted https://github.com/pytorch/pytorch/pull/157563 on behalf of https://github.com/clee2000 due to broke test_linear_and_cel on main c062550a35, caused OOM? Also broken on PR, Dr. CI classification is wrong (claims the test is disabled by an issue but the issue is for a different test).  Also I'm pretty sure the expected results json is supposed to have a ton of empty lines, its to prevent merge conflicts, I will add it to the linter ([comment](https://github.com/pytorch/pytorch/pull/157563#issuecomment-3074355331))
This commit is contained in:
PyTorch MergeBot
2025-07-15 16:35:55 +00:00
parent 4f36743f5e
commit 26807dcf27
9 changed files with 106 additions and 118 deletions

View File

@ -1,23 +1,89 @@
add_loop_eager,compile_time_instruction_count,2996000000,0.015 add_loop_eager,compile_time_instruction_count,3017000000,0.015
add_loop_eager_dynamic,compile_time_instruction_count,4352000000,0.025 add_loop_eager_dynamic,compile_time_instruction_count,4352000000,0.025
add_loop_inductor,compile_time_instruction_count,33090000000,0.015
add_loop_inductor_dynamic_gpu,compile_time_instruction_count,42660000000,0.025
add_loop_inductor_gpu,compile_time_instruction_count,29690000000,0.015
add_loop_inductor,compile_time_instruction_count,29490000000,0.015
add_loop_inductor_dynamic_gpu,compile_time_instruction_count,38760000000,0.025
add_loop_inductor_gpu,compile_time_instruction_count,26000000000,0.015
basic_modules_ListOfLinears_eager,compile_time_instruction_count,947600000,0.015 basic_modules_ListOfLinears_eager,compile_time_instruction_count,947600000,0.015
basic_modules_ListOfLinears_inductor,compile_time_instruction_count,18830000000,0.015
basic_modules_ListOfLinears_inductor_gpu_force_shape_pad,compile_time_instruction_count,17460000000,0.015
basic_modules_ListOfLinears_inductor_gpu,compile_time_instruction_count,11020000000,0.2
basic_modules_ListOfLinears_inductor,compile_time_instruction_count,18490000000,0.015
basic_modules_ListOfLinears_inductor_gpu_force_shape_pad,compile_time_instruction_count,17020000000,0.015
basic_modules_ListOfLinears_inductor_gpu,compile_time_instruction_count,10297683772,0.2
update_hint_regression,compile_time_instruction_count,1673000000,0.02 update_hint_regression,compile_time_instruction_count,1673000000,0.02
sum_floordiv_regression,compile_time_instruction_count,986800000,0.015 sum_floordiv_regression,compile_time_instruction_count,986800000,0.015
symint_sum,compile_time_instruction_count,3184000000,0.015
symint_sum,compile_time_instruction_count,3166000000,0.015
symint_sum_loop,compile_time_instruction_count,4202000000,0.015 symint_sum_loop,compile_time_instruction_count,4202000000,0.015
aotdispatcher_inference_nosubclass_cpu,compile_time_instruction_count,2103000000,0.015 aotdispatcher_inference_nosubclass_cpu,compile_time_instruction_count,2103000000,0.015
aotdispatcher_inference_subclass_cpu,compile_time_instruction_count,6004000000,0.015 aotdispatcher_inference_subclass_cpu,compile_time_instruction_count,6004000000,0.015
aotdispatcher_partitioner_cpu,compile_time_instruction_count,8783000000,0.015 aotdispatcher_partitioner_cpu,compile_time_instruction_count,8783000000,0.015
aotdispatcher_partitioner_cpu2,compile_time_instruction_count,1940000000,0.015 aotdispatcher_partitioner_cpu2,compile_time_instruction_count,1940000000,0.015
aotdispatcher_training_nosubclass_cpu,compile_time_instruction_count,3885000000,0.015 aotdispatcher_training_nosubclass_cpu,compile_time_instruction_count,3885000000,0.015
aotdispatcher_training_subclass_cpu,compile_time_instruction_count,10470000000,0.015 aotdispatcher_training_subclass_cpu,compile_time_instruction_count,10470000000,0.015
mm_loop_inductor_gpu,compile_time_instruction_count,4365000000,0.015
mm_loop_inductor_dynamic_gpu,compile_time_instruction_count,8184000000,0.015
mm_loop_inductor_gpu,compile_time_instruction_count,4324000000,0.015
mm_loop_inductor_dynamic_gpu,compile_time_instruction_count,8116000000,0.015
basic_NestedModule_eager,compile_time_instruction_count,8152524390,0.015 basic_NestedModule_eager,compile_time_instruction_count,8152524390,0.015
basic_InlineMod_eager,compile_time_instruction_count,7255000000,0.015 basic_InlineMod_eager,compile_time_instruction_count,7255000000,0.015

1 add_loop_eager compile_time_instruction_count 2996000000 3017000000 0.015
2 add_loop_eager_dynamic compile_time_instruction_count 4352000000 0.025
3 add_loop_inductor compile_time_instruction_count 29490000000 0.015
4 add_loop_inductor_dynamic_gpu compile_time_instruction_count 38760000000 0.025
5 add_loop_eager_dynamic add_loop_inductor_gpu compile_time_instruction_count 4352000000 26000000000 0.025 0.015
6 add_loop_inductor basic_modules_ListOfLinears_eager compile_time_instruction_count 33090000000 947600000 0.015
7 add_loop_inductor_dynamic_gpu basic_modules_ListOfLinears_inductor compile_time_instruction_count 42660000000 18490000000 0.025 0.015
8 add_loop_inductor_gpu basic_modules_ListOfLinears_inductor_gpu_force_shape_pad compile_time_instruction_count 29690000000 17020000000 0.015
9 basic_modules_ListOfLinears_inductor_gpu compile_time_instruction_count 10297683772 0.2
10 update_hint_regression compile_time_instruction_count 1673000000 0.02
11 sum_floordiv_regression compile_time_instruction_count 986800000 0.015
12 symint_sum compile_time_instruction_count 3166000000 0.015
13 symint_sum_loop compile_time_instruction_count 4202000000 0.015
14 aotdispatcher_inference_nosubclass_cpu compile_time_instruction_count 2103000000 0.015
15 aotdispatcher_inference_subclass_cpu compile_time_instruction_count 6004000000 0.015
16 aotdispatcher_partitioner_cpu compile_time_instruction_count 8783000000 0.015
17 aotdispatcher_partitioner_cpu2 compile_time_instruction_count 1940000000 0.015
18 aotdispatcher_training_nosubclass_cpu compile_time_instruction_count 3885000000 0.015
19 aotdispatcher_training_subclass_cpu compile_time_instruction_count 10470000000 0.015
20 mm_loop_inductor_gpu compile_time_instruction_count 4324000000 0.015
21 basic_modules_ListOfLinears_eager mm_loop_inductor_dynamic_gpu compile_time_instruction_count 947600000 8116000000 0.015
22 basic_modules_ListOfLinears_inductor basic_NestedModule_eager compile_time_instruction_count 18830000000 8152524390 0.015
23 basic_modules_ListOfLinears_inductor_gpu_force_shape_pad basic_InlineMod_eager compile_time_instruction_count 17460000000 7255000000 0.015
24 11020000000
25
26
27
28
29
30
31
32
33
34
35
36
37 1673000000
38
39
40
41 986800000
42 3184000000
43
44
45
46
47
48
49 4202000000
50
51
52
53 2103000000
54
55
56
57 6004000000
58
59
60
61 8783000000
62
63
64
65 1940000000
66
67
68
69 3885000000
70
71
72
73 10470000000
74 4365000000
75 8184000000
76
77
78
79
80
81
82
83
84
85 8152524390
86
87
88
89 7255000000

View File

@ -306,57 +306,6 @@ class TestOperatorReorderForPeakMemory(TestCase):
expected_bound = a.size(0) * c.size(1) * a.dtype.itemsize * 2 expected_bound = a.size(0) * c.size(1) * a.dtype.itemsize * 2
self.assertLess(peak_mem, expected_bound) self.assertLess(peak_mem, expected_bound)
def test_fusion_acc_large_reads(self):
def f(x, y, z):
res = torch.zeros_like(x[0])
for i in range(4):
temp = torch.matmul(x, y) + z
res = res + temp
return res
N = 128
x = torch.rand(N, N, dtype=torch.float32, device=GPU_TYPE)
y = torch.rand(N, N, dtype=torch.float32, device=GPU_TYPE)
z = torch.rand(N, N, dtype=torch.float32, device=GPU_TYPE)
# CASE 1: no restriction on the amount of accumulation
with config.patch({"realize_acc_reads_size_threshold": float("inf")}):
f_compiled = torch.compile(f)
code = run_and_get_triton_code(f_compiled, x, y, z)
(
FileCheck()
.check("triton_poi_fused_add_0.run(buf4, arg2_1, buf1, buf2, buf3")
.run(code)
)
# CASE 2: for tensors with the same size as x (which is 4 * N**2 bytes)
# at most 12 / 4 = 3 reads can be accumulated during fusion
with config.patch({"realize_acc_reads_size_threshold": 12 * N**2}):
f_compiled = torch.compile(f)
code = run_and_get_triton_code(f_compiled, x, y, z)
(
FileCheck()
.check("triton_poi_fused_add_0.run(buf3, arg2_1, buf1, buf2,")
.check("triton_poi_fused_add_1.run(buf5, buf4, arg2_1,")
.run(code)
)
# CASE 3: no such fusion allowed
with config.patch({"realize_acc_reads_size_threshold": N**2}):
f_compiled = torch.compile(f)
code = run_and_get_triton_code(f_compiled, x, y, z)
(
FileCheck()
.check("triton_poi_fused_add_0.run(buf1, arg2_1,")
.check("triton_poi_fused_add_0.run(buf3, arg2_1,")
.check("triton_poi_fused_add_0.run(buf4, buf3,")
.check("triton_poi_fused_add_0.run(buf6, arg2_1,")
.check("triton_poi_fused_add_0.run(buf7, buf6,")
.check("triton_poi_fused_add_0.run(buf9, arg2_1,")
.check("triton_poi_fused_add_0.run(buf10, buf9,")
.run(code)
)
if __name__ == "__main__": if __name__ == "__main__":
from torch._inductor.test_case import run_tests from torch._inductor.test_case import run_tests

View File

@ -13,7 +13,6 @@ from torch.testing._internal.common_utils import (
instantiate_parametrized_tests, instantiate_parametrized_tests,
IS_LINUX, IS_LINUX,
parametrize, parametrize,
serialTest,
) )
from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_CUDA from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_CUDA
@ -78,17 +77,12 @@ class TestOnlineSoftmax(TestCase):
out, source_codes = run_and_get_code(f, x) out, source_codes = run_and_get_code(f, x)
return source_codes[0] return source_codes[0]
@serialTest()
def test_codegen_3pass_softmax_due_to_disable(self): def test_codegen_3pass_softmax_due_to_disable(self):
with inductor_config.patch( with inductor_config.patch(online_softmax=False):
online_softmax=False,
realize_acc_reads_size_threshold=float("inf"),
):
wrapper_code = self.get_softmax_wrapper() wrapper_code = self.get_softmax_wrapper()
self.assertEqual(wrapper_code.count("for r0_offset in"), 3) self.assertEqual(wrapper_code.count("for r0_offset in"), 3)
@serialTest()
@parametrize("V", [2048, 50304]) @parametrize("V", [2048, 50304])
@parametrize("use_log_softmax", [False, True]) @parametrize("use_log_softmax", [False, True])
def test_codegen_online_softmax(self, use_log_softmax, V): def test_codegen_online_softmax(self, use_log_softmax, V):

View File

@ -365,10 +365,6 @@ class InductorChoices:
WhyNoFuse(node1, node2)("Fusion will increase peak memory") WhyNoFuse(node1, node2)("Fusion will increase peak memory")
return False return False
if scheduler.fusion_accumulate_large_reads(node1, node2):
WhyNoFuse(node1, node2)("Fusion accumulate large amount of reads")
return False
return True return True
@staticmethod @staticmethod

View File

@ -574,7 +574,6 @@ realize_opcount_threshold = 30
# Threshold to prevent excessive accumulation of ops in one buffer during lowering # Threshold to prevent excessive accumulation of ops in one buffer during lowering
realize_acc_reads_threshold = 8 realize_acc_reads_threshold = 8
realize_acc_reads_size_threshold = 3 * (1024**3)
# fallback to eager for random/dropout, this is slow but useful for debugging # fallback to eager for random/dropout, this is slow but useful for debugging
fallback_random = False fallback_random = False

View File

@ -123,7 +123,6 @@ if TYPE_CHECKING:
from torch.fx.graph import Graph from torch.fx.graph import Graph
from .codegen.wrapper import PythonWrapperCodegen from .codegen.wrapper import PythonWrapperCodegen
from .dependencies import Dep
from .scheduler import BaseSchedulerNode from .scheduler import BaseSchedulerNode
CompiledModule = Union[ModuleType, FileBackedGraphModule] CompiledModule = Union[ModuleType, FileBackedGraphModule]
@ -486,9 +485,6 @@ class GraphLowering(torch.fx.Interpreter):
self.bw_donated_idxs = get_donated_idxs() self.bw_donated_idxs = get_donated_idxs()
# Cache for dep size hints to avoid expensive recomputation
self.dep_size_hint_cache: dict[Dep, int] = {}
def freeze_runtime_asserts(self) -> None: def freeze_runtime_asserts(self) -> None:
self._shape_env.freeze_runtime_asserts() self._shape_env.freeze_runtime_asserts()
@ -574,23 +570,6 @@ class GraphLowering(torch.fx.Interpreter):
assert isinstance(feature, BackendFeature), feature assert isinstance(feature, BackendFeature), feature
return feature in self.get_backend_features(get_device_type(device)) return feature in self.get_backend_features(get_device_type(device))
def get_dep_size_hint(self, dep: Dep) -> int:
"""
Get the size hint for a dependency with caching to avoid expensive recomputation.
"""
if dep not in self.dep_size_hint_cache:
res = 0
try:
if not dep.has_unbacked_symbols():
res = dep.numbytes_hint()
except KeyError:
# In at least one test (test/inductor/test_torchbind.py) we
# create a StarDep that doesn't exist in the graph and calling
# `has_unbacked_symbols()` throws an error.
pass
self.dep_size_hint_cache[dep] = res
return self.dep_size_hint_cache[dep]
def get_current_device_or_throw(self) -> torch.device: def get_current_device_or_throw(self) -> torch.device:
if device := self.current_device: if device := self.current_device:
return device return device

View File

@ -7829,10 +7829,6 @@ class TensorBox(MutableBox):
class StorageBox(MutableBox): class StorageBox(MutableBox):
"""
StorageBox allow in-place mutation of Tensors
"""
def is_input_buffer(self) -> bool: def is_input_buffer(self) -> bool:
if isinstance(self.data, (InputBuffer, ReinterpretView)): if isinstance(self.data, (InputBuffer, ReinterpretView)):
return self.data.get_name() in V.graph.graph_inputs return self.data.get_name() in V.graph.graph_inputs
@ -7882,17 +7878,10 @@ class StorageBox(MutableBox):
): ):
self.realize() self.realize()
def has_accumulated_enough_reads_by_size(self) -> bool:
return (
sum(V.graph.get_dep_size_hint(dep) for dep in self.get_reads())
> config.realize_acc_reads_size_threshold
)
def has_exceeded_max_reads(self) -> bool: def has_exceeded_max_reads(self) -> bool:
return isinstance(self.data, Pointwise) and ( return isinstance(self.data, Pointwise) and (
self.num_reads() > config.realize_acc_reads_threshold self.num_reads() > config.realize_acc_reads_threshold
or self.has_large_inner_fn() or self.has_large_inner_fn()
or self.has_accumulated_enough_reads_by_size()
) )
def should_realize_on_reuse(self, users: int) -> bool: def should_realize_on_reuse(self, users: int) -> bool:

View File

@ -78,8 +78,19 @@ def get_freeable_input_buf(
A dictionary containing all freeble input buffers, keyed by their names. A dictionary containing all freeble input buffers, keyed by their names.
""" """
# this function is copied from torch/_inductor/scheduler.py
# TODO: would be nice to remove the try/except block for both places
def _dep_size_hint(dep: Dep) -> int: def _dep_size_hint(dep: Dep) -> int:
return V.graph.get_dep_size_hint(dep) res = 0
try:
if not dep.has_unbacked_symbols():
res = dep.numbytes_hint()
except KeyError:
# In at least one test (test/inductor/test_torchbind.py) we
# create a StarDep that doesn't exist in the graph and calling
# `has_unbacked_symbols()` throws an error.
pass
return res
# get freeable input buffers' successor nodes and their sizes # get freeable input buffers' successor nodes and their sizes
# note that different deps can have the same name, so we use name as keys # note that different deps can have the same name, so we use name as keys

View File

@ -2051,12 +2051,15 @@ class Scheduler:
optimizations such as fusion, reorder, and graph partition. optimizations such as fusion, reorder, and graph partition.
""" """
__dep_size_hint_cache: dict[Dep, int]
def __init__(self, nodes: list[ir.Operation]) -> None: def __init__(self, nodes: list[ir.Operation]) -> None:
with dynamo_timed("Scheduler.__init__"): with dynamo_timed("Scheduler.__init__"):
self._init(nodes) self._init(nodes)
def _init(self, nodes: list[ir.Operation]) -> None: def _init(self, nodes: list[ir.Operation]) -> None:
super().__init__() super().__init__()
self.__dep_size_hint_cache = {}
V.graph.scheduler = self V.graph.scheduler = self
self.backends: dict[torch.device, BaseScheduling] = {} self.backends: dict[torch.device, BaseScheduling] = {}
self.post_grad_graph_id = next(_post_grad_graph_counter) self.post_grad_graph_id = next(_post_grad_graph_counter)
@ -3502,17 +3505,6 @@ class Scheduler:
return True return True
return False return False
def fusion_accumulate_large_reads(
self, node1: BaseSchedulerNode, node2: BaseSchedulerNode
) -> bool:
all_reads = (node1.read_writes.reads | node2.read_writes.reads) - (
node1.read_writes.writes | node2.read_writes.writes
)
return (
sum(self.dep_size_hint(dep) for dep in all_reads)
> config.realize_acc_reads_size_threshold
)
def are_long_distant_nodes( def are_long_distant_nodes(
self, node1: BaseSchedulerNode, node2: BaseSchedulerNode self, node1: BaseSchedulerNode, node2: BaseSchedulerNode
) -> bool: ) -> bool:
@ -4018,7 +4010,20 @@ class Scheduler:
return False return False
def dep_size_hint(self, dep: Dep) -> int: def dep_size_hint(self, dep: Dep) -> int:
return V.graph.get_dep_size_hint(dep) res = 0
if dep not in self.__dep_size_hint_cache:
try:
if not dep.has_unbacked_symbols():
res = dep.numbytes_hint()
except KeyError:
# In at least one test (test/inductor/test_torchbind.py) we
# create a StarDep that doesn't exist in the graph and calling
# `has_unbacked_symbols()` throws an error.
pass
self.__dep_size_hint_cache[dep] = res
else:
res = self.__dep_size_hint_cache[dep]
return res
def score_fusion_memory( def score_fusion_memory(
self, node1: BaseSchedulerNode, node2: BaseSchedulerNode self, node1: BaseSchedulerNode, node2: BaseSchedulerNode