mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Revert "[PT2][fusion] ban fusions with large accumulated reads (#157563)"
This reverts commit c062550a3598d27c2d6572db7c0f4ff90a84cc84.
Reverted https://github.com/pytorch/pytorch/pull/157563 on behalf of https://github.com/clee2000 due to broke test_linear_and_cel on main c062550a35
, caused OOM? Also broken on PR, Dr. CI classification is wrong (claims the test is disabled by an issue but the issue is for a different test). Also I'm pretty sure the expected results json is supposed to have a ton of empty lines, its to prevent merge conflicts, I will add it to the linter ([comment](https://github.com/pytorch/pytorch/pull/157563#issuecomment-3074355331))
This commit is contained in:
@ -1,23 +1,89 @@
|
|||||||
add_loop_eager,compile_time_instruction_count,2996000000,0.015
|
add_loop_eager,compile_time_instruction_count,3017000000,0.015
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
add_loop_eager_dynamic,compile_time_instruction_count,4352000000,0.025
|
add_loop_eager_dynamic,compile_time_instruction_count,4352000000,0.025
|
||||||
add_loop_inductor,compile_time_instruction_count,33090000000,0.015
|
|
||||||
add_loop_inductor_dynamic_gpu,compile_time_instruction_count,42660000000,0.025
|
|
||||||
add_loop_inductor_gpu,compile_time_instruction_count,29690000000,0.015
|
|
||||||
|
add_loop_inductor,compile_time_instruction_count,29490000000,0.015
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
add_loop_inductor_dynamic_gpu,compile_time_instruction_count,38760000000,0.025
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
add_loop_inductor_gpu,compile_time_instruction_count,26000000000,0.015
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
basic_modules_ListOfLinears_eager,compile_time_instruction_count,947600000,0.015
|
basic_modules_ListOfLinears_eager,compile_time_instruction_count,947600000,0.015
|
||||||
basic_modules_ListOfLinears_inductor,compile_time_instruction_count,18830000000,0.015
|
|
||||||
basic_modules_ListOfLinears_inductor_gpu_force_shape_pad,compile_time_instruction_count,17460000000,0.015
|
|
||||||
basic_modules_ListOfLinears_inductor_gpu,compile_time_instruction_count,11020000000,0.2
|
|
||||||
|
basic_modules_ListOfLinears_inductor,compile_time_instruction_count,18490000000,0.015
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
basic_modules_ListOfLinears_inductor_gpu_force_shape_pad,compile_time_instruction_count,17020000000,0.015
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
basic_modules_ListOfLinears_inductor_gpu,compile_time_instruction_count,10297683772,0.2
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
update_hint_regression,compile_time_instruction_count,1673000000,0.02
|
update_hint_regression,compile_time_instruction_count,1673000000,0.02
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
sum_floordiv_regression,compile_time_instruction_count,986800000,0.015
|
sum_floordiv_regression,compile_time_instruction_count,986800000,0.015
|
||||||
symint_sum,compile_time_instruction_count,3184000000,0.015
|
|
||||||
|
|
||||||
|
|
||||||
|
symint_sum,compile_time_instruction_count,3166000000,0.015
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
symint_sum_loop,compile_time_instruction_count,4202000000,0.015
|
symint_sum_loop,compile_time_instruction_count,4202000000,0.015
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
aotdispatcher_inference_nosubclass_cpu,compile_time_instruction_count,2103000000,0.015
|
aotdispatcher_inference_nosubclass_cpu,compile_time_instruction_count,2103000000,0.015
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
aotdispatcher_inference_subclass_cpu,compile_time_instruction_count,6004000000,0.015
|
aotdispatcher_inference_subclass_cpu,compile_time_instruction_count,6004000000,0.015
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
aotdispatcher_partitioner_cpu,compile_time_instruction_count,8783000000,0.015
|
aotdispatcher_partitioner_cpu,compile_time_instruction_count,8783000000,0.015
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
aotdispatcher_partitioner_cpu2,compile_time_instruction_count,1940000000,0.015
|
aotdispatcher_partitioner_cpu2,compile_time_instruction_count,1940000000,0.015
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
aotdispatcher_training_nosubclass_cpu,compile_time_instruction_count,3885000000,0.015
|
aotdispatcher_training_nosubclass_cpu,compile_time_instruction_count,3885000000,0.015
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
aotdispatcher_training_subclass_cpu,compile_time_instruction_count,10470000000,0.015
|
aotdispatcher_training_subclass_cpu,compile_time_instruction_count,10470000000,0.015
|
||||||
mm_loop_inductor_gpu,compile_time_instruction_count,4365000000,0.015
|
|
||||||
mm_loop_inductor_dynamic_gpu,compile_time_instruction_count,8184000000,0.015
|
|
||||||
|
|
||||||
|
mm_loop_inductor_gpu,compile_time_instruction_count,4324000000,0.015
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
mm_loop_inductor_dynamic_gpu,compile_time_instruction_count,8116000000,0.015
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
basic_NestedModule_eager,compile_time_instruction_count,8152524390,0.015
|
basic_NestedModule_eager,compile_time_instruction_count,8152524390,0.015
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
basic_InlineMod_eager,compile_time_instruction_count,7255000000,0.015
|
basic_InlineMod_eager,compile_time_instruction_count,7255000000,0.015
|
||||||
|
|
@ -306,57 +306,6 @@ class TestOperatorReorderForPeakMemory(TestCase):
|
|||||||
expected_bound = a.size(0) * c.size(1) * a.dtype.itemsize * 2
|
expected_bound = a.size(0) * c.size(1) * a.dtype.itemsize * 2
|
||||||
self.assertLess(peak_mem, expected_bound)
|
self.assertLess(peak_mem, expected_bound)
|
||||||
|
|
||||||
def test_fusion_acc_large_reads(self):
|
|
||||||
def f(x, y, z):
|
|
||||||
res = torch.zeros_like(x[0])
|
|
||||||
for i in range(4):
|
|
||||||
temp = torch.matmul(x, y) + z
|
|
||||||
res = res + temp
|
|
||||||
return res
|
|
||||||
|
|
||||||
N = 128
|
|
||||||
x = torch.rand(N, N, dtype=torch.float32, device=GPU_TYPE)
|
|
||||||
y = torch.rand(N, N, dtype=torch.float32, device=GPU_TYPE)
|
|
||||||
z = torch.rand(N, N, dtype=torch.float32, device=GPU_TYPE)
|
|
||||||
|
|
||||||
# CASE 1: no restriction on the amount of accumulation
|
|
||||||
with config.patch({"realize_acc_reads_size_threshold": float("inf")}):
|
|
||||||
f_compiled = torch.compile(f)
|
|
||||||
code = run_and_get_triton_code(f_compiled, x, y, z)
|
|
||||||
(
|
|
||||||
FileCheck()
|
|
||||||
.check("triton_poi_fused_add_0.run(buf4, arg2_1, buf1, buf2, buf3")
|
|
||||||
.run(code)
|
|
||||||
)
|
|
||||||
|
|
||||||
# CASE 2: for tensors with the same size as x (which is 4 * N**2 bytes)
|
|
||||||
# at most 12 / 4 = 3 reads can be accumulated during fusion
|
|
||||||
with config.patch({"realize_acc_reads_size_threshold": 12 * N**2}):
|
|
||||||
f_compiled = torch.compile(f)
|
|
||||||
code = run_and_get_triton_code(f_compiled, x, y, z)
|
|
||||||
(
|
|
||||||
FileCheck()
|
|
||||||
.check("triton_poi_fused_add_0.run(buf3, arg2_1, buf1, buf2,")
|
|
||||||
.check("triton_poi_fused_add_1.run(buf5, buf4, arg2_1,")
|
|
||||||
.run(code)
|
|
||||||
)
|
|
||||||
|
|
||||||
# CASE 3: no such fusion allowed
|
|
||||||
with config.patch({"realize_acc_reads_size_threshold": N**2}):
|
|
||||||
f_compiled = torch.compile(f)
|
|
||||||
code = run_and_get_triton_code(f_compiled, x, y, z)
|
|
||||||
(
|
|
||||||
FileCheck()
|
|
||||||
.check("triton_poi_fused_add_0.run(buf1, arg2_1,")
|
|
||||||
.check("triton_poi_fused_add_0.run(buf3, arg2_1,")
|
|
||||||
.check("triton_poi_fused_add_0.run(buf4, buf3,")
|
|
||||||
.check("triton_poi_fused_add_0.run(buf6, arg2_1,")
|
|
||||||
.check("triton_poi_fused_add_0.run(buf7, buf6,")
|
|
||||||
.check("triton_poi_fused_add_0.run(buf9, arg2_1,")
|
|
||||||
.check("triton_poi_fused_add_0.run(buf10, buf9,")
|
|
||||||
.run(code)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
from torch._inductor.test_case import run_tests
|
from torch._inductor.test_case import run_tests
|
||||||
|
@ -13,7 +13,6 @@ from torch.testing._internal.common_utils import (
|
|||||||
instantiate_parametrized_tests,
|
instantiate_parametrized_tests,
|
||||||
IS_LINUX,
|
IS_LINUX,
|
||||||
parametrize,
|
parametrize,
|
||||||
serialTest,
|
|
||||||
)
|
)
|
||||||
from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_CUDA
|
from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_CUDA
|
||||||
|
|
||||||
@ -78,17 +77,12 @@ class TestOnlineSoftmax(TestCase):
|
|||||||
out, source_codes = run_and_get_code(f, x)
|
out, source_codes = run_and_get_code(f, x)
|
||||||
return source_codes[0]
|
return source_codes[0]
|
||||||
|
|
||||||
@serialTest()
|
|
||||||
def test_codegen_3pass_softmax_due_to_disable(self):
|
def test_codegen_3pass_softmax_due_to_disable(self):
|
||||||
with inductor_config.patch(
|
with inductor_config.patch(online_softmax=False):
|
||||||
online_softmax=False,
|
|
||||||
realize_acc_reads_size_threshold=float("inf"),
|
|
||||||
):
|
|
||||||
wrapper_code = self.get_softmax_wrapper()
|
wrapper_code = self.get_softmax_wrapper()
|
||||||
|
|
||||||
self.assertEqual(wrapper_code.count("for r0_offset in"), 3)
|
self.assertEqual(wrapper_code.count("for r0_offset in"), 3)
|
||||||
|
|
||||||
@serialTest()
|
|
||||||
@parametrize("V", [2048, 50304])
|
@parametrize("V", [2048, 50304])
|
||||||
@parametrize("use_log_softmax", [False, True])
|
@parametrize("use_log_softmax", [False, True])
|
||||||
def test_codegen_online_softmax(self, use_log_softmax, V):
|
def test_codegen_online_softmax(self, use_log_softmax, V):
|
||||||
|
@ -365,10 +365,6 @@ class InductorChoices:
|
|||||||
WhyNoFuse(node1, node2)("Fusion will increase peak memory")
|
WhyNoFuse(node1, node2)("Fusion will increase peak memory")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
if scheduler.fusion_accumulate_large_reads(node1, node2):
|
|
||||||
WhyNoFuse(node1, node2)("Fusion accumulate large amount of reads")
|
|
||||||
return False
|
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@ -574,7 +574,6 @@ realize_opcount_threshold = 30
|
|||||||
|
|
||||||
# Threshold to prevent excessive accumulation of ops in one buffer during lowering
|
# Threshold to prevent excessive accumulation of ops in one buffer during lowering
|
||||||
realize_acc_reads_threshold = 8
|
realize_acc_reads_threshold = 8
|
||||||
realize_acc_reads_size_threshold = 3 * (1024**3)
|
|
||||||
|
|
||||||
# fallback to eager for random/dropout, this is slow but useful for debugging
|
# fallback to eager for random/dropout, this is slow but useful for debugging
|
||||||
fallback_random = False
|
fallback_random = False
|
||||||
|
@ -123,7 +123,6 @@ if TYPE_CHECKING:
|
|||||||
from torch.fx.graph import Graph
|
from torch.fx.graph import Graph
|
||||||
|
|
||||||
from .codegen.wrapper import PythonWrapperCodegen
|
from .codegen.wrapper import PythonWrapperCodegen
|
||||||
from .dependencies import Dep
|
|
||||||
from .scheduler import BaseSchedulerNode
|
from .scheduler import BaseSchedulerNode
|
||||||
|
|
||||||
CompiledModule = Union[ModuleType, FileBackedGraphModule]
|
CompiledModule = Union[ModuleType, FileBackedGraphModule]
|
||||||
@ -486,9 +485,6 @@ class GraphLowering(torch.fx.Interpreter):
|
|||||||
|
|
||||||
self.bw_donated_idxs = get_donated_idxs()
|
self.bw_donated_idxs = get_donated_idxs()
|
||||||
|
|
||||||
# Cache for dep size hints to avoid expensive recomputation
|
|
||||||
self.dep_size_hint_cache: dict[Dep, int] = {}
|
|
||||||
|
|
||||||
def freeze_runtime_asserts(self) -> None:
|
def freeze_runtime_asserts(self) -> None:
|
||||||
self._shape_env.freeze_runtime_asserts()
|
self._shape_env.freeze_runtime_asserts()
|
||||||
|
|
||||||
@ -574,23 +570,6 @@ class GraphLowering(torch.fx.Interpreter):
|
|||||||
assert isinstance(feature, BackendFeature), feature
|
assert isinstance(feature, BackendFeature), feature
|
||||||
return feature in self.get_backend_features(get_device_type(device))
|
return feature in self.get_backend_features(get_device_type(device))
|
||||||
|
|
||||||
def get_dep_size_hint(self, dep: Dep) -> int:
|
|
||||||
"""
|
|
||||||
Get the size hint for a dependency with caching to avoid expensive recomputation.
|
|
||||||
"""
|
|
||||||
if dep not in self.dep_size_hint_cache:
|
|
||||||
res = 0
|
|
||||||
try:
|
|
||||||
if not dep.has_unbacked_symbols():
|
|
||||||
res = dep.numbytes_hint()
|
|
||||||
except KeyError:
|
|
||||||
# In at least one test (test/inductor/test_torchbind.py) we
|
|
||||||
# create a StarDep that doesn't exist in the graph and calling
|
|
||||||
# `has_unbacked_symbols()` throws an error.
|
|
||||||
pass
|
|
||||||
self.dep_size_hint_cache[dep] = res
|
|
||||||
return self.dep_size_hint_cache[dep]
|
|
||||||
|
|
||||||
def get_current_device_or_throw(self) -> torch.device:
|
def get_current_device_or_throw(self) -> torch.device:
|
||||||
if device := self.current_device:
|
if device := self.current_device:
|
||||||
return device
|
return device
|
||||||
|
@ -7829,10 +7829,6 @@ class TensorBox(MutableBox):
|
|||||||
|
|
||||||
|
|
||||||
class StorageBox(MutableBox):
|
class StorageBox(MutableBox):
|
||||||
"""
|
|
||||||
StorageBox allow in-place mutation of Tensors
|
|
||||||
"""
|
|
||||||
|
|
||||||
def is_input_buffer(self) -> bool:
|
def is_input_buffer(self) -> bool:
|
||||||
if isinstance(self.data, (InputBuffer, ReinterpretView)):
|
if isinstance(self.data, (InputBuffer, ReinterpretView)):
|
||||||
return self.data.get_name() in V.graph.graph_inputs
|
return self.data.get_name() in V.graph.graph_inputs
|
||||||
@ -7882,17 +7878,10 @@ class StorageBox(MutableBox):
|
|||||||
):
|
):
|
||||||
self.realize()
|
self.realize()
|
||||||
|
|
||||||
def has_accumulated_enough_reads_by_size(self) -> bool:
|
|
||||||
return (
|
|
||||||
sum(V.graph.get_dep_size_hint(dep) for dep in self.get_reads())
|
|
||||||
> config.realize_acc_reads_size_threshold
|
|
||||||
)
|
|
||||||
|
|
||||||
def has_exceeded_max_reads(self) -> bool:
|
def has_exceeded_max_reads(self) -> bool:
|
||||||
return isinstance(self.data, Pointwise) and (
|
return isinstance(self.data, Pointwise) and (
|
||||||
self.num_reads() > config.realize_acc_reads_threshold
|
self.num_reads() > config.realize_acc_reads_threshold
|
||||||
or self.has_large_inner_fn()
|
or self.has_large_inner_fn()
|
||||||
or self.has_accumulated_enough_reads_by_size()
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def should_realize_on_reuse(self, users: int) -> bool:
|
def should_realize_on_reuse(self, users: int) -> bool:
|
||||||
|
@ -78,8 +78,19 @@ def get_freeable_input_buf(
|
|||||||
A dictionary containing all freeble input buffers, keyed by their names.
|
A dictionary containing all freeble input buffers, keyed by their names.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# this function is copied from torch/_inductor/scheduler.py
|
||||||
|
# TODO: would be nice to remove the try/except block for both places
|
||||||
def _dep_size_hint(dep: Dep) -> int:
|
def _dep_size_hint(dep: Dep) -> int:
|
||||||
return V.graph.get_dep_size_hint(dep)
|
res = 0
|
||||||
|
try:
|
||||||
|
if not dep.has_unbacked_symbols():
|
||||||
|
res = dep.numbytes_hint()
|
||||||
|
except KeyError:
|
||||||
|
# In at least one test (test/inductor/test_torchbind.py) we
|
||||||
|
# create a StarDep that doesn't exist in the graph and calling
|
||||||
|
# `has_unbacked_symbols()` throws an error.
|
||||||
|
pass
|
||||||
|
return res
|
||||||
|
|
||||||
# get freeable input buffers' successor nodes and their sizes
|
# get freeable input buffers' successor nodes and their sizes
|
||||||
# note that different deps can have the same name, so we use name as keys
|
# note that different deps can have the same name, so we use name as keys
|
||||||
|
@ -2051,12 +2051,15 @@ class Scheduler:
|
|||||||
optimizations such as fusion, reorder, and graph partition.
|
optimizations such as fusion, reorder, and graph partition.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
__dep_size_hint_cache: dict[Dep, int]
|
||||||
|
|
||||||
def __init__(self, nodes: list[ir.Operation]) -> None:
|
def __init__(self, nodes: list[ir.Operation]) -> None:
|
||||||
with dynamo_timed("Scheduler.__init__"):
|
with dynamo_timed("Scheduler.__init__"):
|
||||||
self._init(nodes)
|
self._init(nodes)
|
||||||
|
|
||||||
def _init(self, nodes: list[ir.Operation]) -> None:
|
def _init(self, nodes: list[ir.Operation]) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
self.__dep_size_hint_cache = {}
|
||||||
V.graph.scheduler = self
|
V.graph.scheduler = self
|
||||||
self.backends: dict[torch.device, BaseScheduling] = {}
|
self.backends: dict[torch.device, BaseScheduling] = {}
|
||||||
self.post_grad_graph_id = next(_post_grad_graph_counter)
|
self.post_grad_graph_id = next(_post_grad_graph_counter)
|
||||||
@ -3502,17 +3505,6 @@ class Scheduler:
|
|||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def fusion_accumulate_large_reads(
|
|
||||||
self, node1: BaseSchedulerNode, node2: BaseSchedulerNode
|
|
||||||
) -> bool:
|
|
||||||
all_reads = (node1.read_writes.reads | node2.read_writes.reads) - (
|
|
||||||
node1.read_writes.writes | node2.read_writes.writes
|
|
||||||
)
|
|
||||||
return (
|
|
||||||
sum(self.dep_size_hint(dep) for dep in all_reads)
|
|
||||||
> config.realize_acc_reads_size_threshold
|
|
||||||
)
|
|
||||||
|
|
||||||
def are_long_distant_nodes(
|
def are_long_distant_nodes(
|
||||||
self, node1: BaseSchedulerNode, node2: BaseSchedulerNode
|
self, node1: BaseSchedulerNode, node2: BaseSchedulerNode
|
||||||
) -> bool:
|
) -> bool:
|
||||||
@ -4018,7 +4010,20 @@ class Scheduler:
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
def dep_size_hint(self, dep: Dep) -> int:
|
def dep_size_hint(self, dep: Dep) -> int:
|
||||||
return V.graph.get_dep_size_hint(dep)
|
res = 0
|
||||||
|
if dep not in self.__dep_size_hint_cache:
|
||||||
|
try:
|
||||||
|
if not dep.has_unbacked_symbols():
|
||||||
|
res = dep.numbytes_hint()
|
||||||
|
except KeyError:
|
||||||
|
# In at least one test (test/inductor/test_torchbind.py) we
|
||||||
|
# create a StarDep that doesn't exist in the graph and calling
|
||||||
|
# `has_unbacked_symbols()` throws an error.
|
||||||
|
pass
|
||||||
|
self.__dep_size_hint_cache[dep] = res
|
||||||
|
else:
|
||||||
|
res = self.__dep_size_hint_cache[dep]
|
||||||
|
return res
|
||||||
|
|
||||||
def score_fusion_memory(
|
def score_fusion_memory(
|
||||||
self, node1: BaseSchedulerNode, node2: BaseSchedulerNode
|
self, node1: BaseSchedulerNode, node2: BaseSchedulerNode
|
||||||
|
Reference in New Issue
Block a user