mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
ban fusion of large amount of reads (#158667)
This is an reland attempt of https://github.com/pytorch/pytorch/pull/157563, but insteading of introducing the `realize_acc_reads_size_threshold` config and setting to a default value, we set it to `None` for now to unblock an internal use case. Will deep dive into the issue and harden the logic in later PRs. Pull Request resolved: https://github.com/pytorch/pytorch/pull/158667 Approved by: https://github.com/yf225
This commit is contained in:
committed by
PyTorch MergeBot
parent
bc379aebe2
commit
6b0526a2c4
@ -8,6 +8,7 @@ from torch._dynamo.utils import same
|
||||
from torch._inductor import config, memory
|
||||
from torch._inductor.test_case import TestCase
|
||||
from torch._inductor.utils import run_and_get_triton_code
|
||||
from torch.testing._internal.common_utils import serialTest
|
||||
from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU
|
||||
|
||||
|
||||
@ -306,6 +307,58 @@ class TestOperatorReorderForPeakMemory(TestCase):
|
||||
expected_bound = a.size(0) * c.size(1) * a.dtype.itemsize * 2
|
||||
self.assertLess(peak_mem, expected_bound)
|
||||
|
||||
@serialTest()
|
||||
def test_fusion_acc_large_reads(self):
|
||||
def f(x, y, z):
|
||||
res = torch.zeros_like(x[0])
|
||||
for i in range(4):
|
||||
temp = torch.matmul(x, y) + z
|
||||
res = res + temp
|
||||
return res
|
||||
|
||||
N = 128
|
||||
x = torch.rand(N, N, dtype=torch.float32, device=GPU_TYPE)
|
||||
y = torch.rand(N, N, dtype=torch.float32, device=GPU_TYPE)
|
||||
z = torch.rand(N, N, dtype=torch.float32, device=GPU_TYPE)
|
||||
|
||||
# CASE 1: no restriction on the amount of accumulation
|
||||
with config.patch({"realize_acc_reads_size_threshold": float("inf")}):
|
||||
f_compiled = torch.compile(f)
|
||||
code = run_and_get_triton_code(f_compiled, x, y, z)
|
||||
(
|
||||
FileCheck()
|
||||
.check("triton_poi_fused_add_0.run(buf4, arg2_1, buf1, buf2, buf3")
|
||||
.run(code)
|
||||
)
|
||||
|
||||
# CASE 2: for tensors with the same size as x (which is 4 * N**2 bytes)
|
||||
# at most 12 / 4 = 3 reads can be accumulated during fusion
|
||||
with config.patch({"realize_acc_reads_size_threshold": 12 * N**2}):
|
||||
f_compiled = torch.compile(f)
|
||||
code = run_and_get_triton_code(f_compiled, x, y, z)
|
||||
(
|
||||
FileCheck()
|
||||
.check("triton_poi_fused_add_0.run(buf3, arg2_1, buf1, buf2,")
|
||||
.check("triton_poi_fused_add_1.run(buf5, buf4, arg2_1,")
|
||||
.run(code)
|
||||
)
|
||||
|
||||
# CASE 3: no such fusion allowed
|
||||
with config.patch({"realize_acc_reads_size_threshold": N**2}):
|
||||
f_compiled = torch.compile(f)
|
||||
code = run_and_get_triton_code(f_compiled, x, y, z)
|
||||
(
|
||||
FileCheck()
|
||||
.check("triton_poi_fused_add_0.run(buf1, arg2_1,")
|
||||
.check("triton_poi_fused_add_0.run(buf3, arg2_1,")
|
||||
.check("triton_poi_fused_add_0.run(buf4, buf3,")
|
||||
.check("triton_poi_fused_add_0.run(buf6, arg2_1,")
|
||||
.check("triton_poi_fused_add_0.run(buf7, buf6,")
|
||||
.check("triton_poi_fused_add_0.run(buf9, arg2_1,")
|
||||
.check("triton_poi_fused_add_0.run(buf10, buf9,")
|
||||
.run(code)
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._inductor.test_case import run_tests
|
||||
|
@ -365,6 +365,17 @@ class InductorChoices:
|
||||
WhyNoFuse(node1, node2)("Fusion will increase peak memory")
|
||||
return False
|
||||
|
||||
if (
|
||||
config.realize_acc_reads_size_threshold is not None
|
||||
and scheduler.fusion_accumulate_large_reads(
|
||||
node1,
|
||||
node2,
|
||||
config.realize_acc_reads_size_threshold,
|
||||
)
|
||||
):
|
||||
WhyNoFuse(node1, node2)("Fusion accumulate large amount of reads")
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
|
@ -574,6 +574,9 @@ realize_opcount_threshold = 30
|
||||
|
||||
# Threshold to prevent excessive accumulation of ops in one buffer during lowering
|
||||
realize_acc_reads_threshold = 8
|
||||
realize_acc_reads_size_threshold: Optional[int] = (
|
||||
None # TODO(xuanzh): harden this to make it non optional
|
||||
)
|
||||
|
||||
# fallback to eager for random/dropout, this is slow but useful for debugging
|
||||
fallback_random = False
|
||||
|
@ -123,6 +123,7 @@ if TYPE_CHECKING:
|
||||
from torch.fx.graph import Graph
|
||||
|
||||
from .codegen.wrapper import PythonWrapperCodegen
|
||||
from .dependencies import Dep
|
||||
from .scheduler import BaseSchedulerNode
|
||||
|
||||
CompiledModule = Union[ModuleType, FileBackedGraphModule]
|
||||
@ -485,6 +486,9 @@ class GraphLowering(torch.fx.Interpreter):
|
||||
|
||||
self.bw_donated_idxs = get_donated_idxs()
|
||||
|
||||
# Cache for dep size hints to avoid expensive recomputation
|
||||
self.dep_size_hint_cache: dict[Dep, int] = {}
|
||||
|
||||
def freeze_runtime_asserts(self) -> None:
|
||||
self._shape_env.freeze_runtime_asserts()
|
||||
|
||||
@ -570,6 +574,23 @@ class GraphLowering(torch.fx.Interpreter):
|
||||
assert isinstance(feature, BackendFeature), feature
|
||||
return feature in self.get_backend_features(get_device_type(device))
|
||||
|
||||
def get_dep_size_hint(self, dep: Dep) -> int:
|
||||
"""
|
||||
Get the size hint for a dependency with caching to avoid expensive recomputation.
|
||||
"""
|
||||
if dep not in self.dep_size_hint_cache:
|
||||
res = 0
|
||||
try:
|
||||
if not dep.has_unbacked_symbols():
|
||||
res = dep.numbytes_hint()
|
||||
except KeyError:
|
||||
# In at least one test (test/inductor/test_torchbind.py) we
|
||||
# create a StarDep that doesn't exist in the graph and calling
|
||||
# `has_unbacked_symbols()` throws an error.
|
||||
pass
|
||||
self.dep_size_hint_cache[dep] = res
|
||||
return self.dep_size_hint_cache[dep]
|
||||
|
||||
def get_current_device_or_throw(self) -> torch.device:
|
||||
if device := self.current_device:
|
||||
return device
|
||||
|
@ -7877,6 +7877,10 @@ class TensorBox(MutableBox):
|
||||
|
||||
|
||||
class StorageBox(MutableBox):
|
||||
"""
|
||||
StorageBox allow in-place mutation of Tensors
|
||||
"""
|
||||
|
||||
def is_input_buffer(self) -> bool:
|
||||
if isinstance(self.data, (InputBuffer, ReinterpretView)):
|
||||
return self.data.get_name() in V.graph.graph_inputs
|
||||
@ -7926,10 +7930,21 @@ class StorageBox(MutableBox):
|
||||
):
|
||||
self.realize()
|
||||
|
||||
def has_accumulated_enough_reads_by_size(self, threshold: int) -> bool:
|
||||
return (
|
||||
sum(V.graph.get_dep_size_hint(dep) for dep in self.get_reads()) > threshold
|
||||
)
|
||||
|
||||
def has_exceeded_max_reads(self) -> bool:
|
||||
return isinstance(self.data, Pointwise) and (
|
||||
self.num_reads() > config.realize_acc_reads_threshold
|
||||
or self.has_large_inner_fn()
|
||||
or (
|
||||
config.realize_acc_reads_size_threshold is not None
|
||||
and self.has_accumulated_enough_reads_by_size(
|
||||
config.realize_acc_reads_size_threshold
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
def should_realize_on_reuse(self, users: int) -> bool:
|
||||
|
@ -78,19 +78,8 @@ def get_freeable_input_buf(
|
||||
A dictionary containing all freeble input buffers, keyed by their names.
|
||||
"""
|
||||
|
||||
# this function is copied from torch/_inductor/scheduler.py
|
||||
# TODO: would be nice to remove the try/except block for both places
|
||||
def _dep_size_hint(dep: Dep) -> int:
|
||||
res = 0
|
||||
try:
|
||||
if not dep.has_unbacked_symbols():
|
||||
res = dep.numbytes_hint()
|
||||
except KeyError:
|
||||
# In at least one test (test/inductor/test_torchbind.py) we
|
||||
# create a StarDep that doesn't exist in the graph and calling
|
||||
# `has_unbacked_symbols()` throws an error.
|
||||
pass
|
||||
return res
|
||||
return V.graph.get_dep_size_hint(dep)
|
||||
|
||||
# get freeable input buffers' successor nodes and their sizes
|
||||
# note that different deps can have the same name, so we use name as keys
|
||||
|
@ -2051,15 +2051,12 @@ class Scheduler:
|
||||
optimizations such as fusion, reorder, and graph partition.
|
||||
"""
|
||||
|
||||
__dep_size_hint_cache: dict[Dep, int]
|
||||
|
||||
def __init__(self, nodes: list[ir.Operation]) -> None:
|
||||
with dynamo_timed("Scheduler.__init__"):
|
||||
self._init(nodes)
|
||||
|
||||
def _init(self, nodes: list[ir.Operation]) -> None:
|
||||
super().__init__()
|
||||
self.__dep_size_hint_cache = {}
|
||||
V.graph.scheduler = self
|
||||
self.backends: dict[torch.device, BaseScheduling] = {}
|
||||
self.post_grad_graph_id = next(_post_grad_graph_counter)
|
||||
@ -3505,6 +3502,14 @@ class Scheduler:
|
||||
return True
|
||||
return False
|
||||
|
||||
def fusion_accumulate_large_reads(
|
||||
self, node1: BaseSchedulerNode, node2: BaseSchedulerNode, threshold: int
|
||||
) -> bool:
|
||||
all_reads = (node1.read_writes.reads | node2.read_writes.reads) - (
|
||||
node1.read_writes.writes | node2.read_writes.writes
|
||||
)
|
||||
return sum(self.dep_size_hint(dep) for dep in all_reads) > threshold
|
||||
|
||||
def are_long_distant_nodes(
|
||||
self, node1: BaseSchedulerNode, node2: BaseSchedulerNode
|
||||
) -> bool:
|
||||
@ -4010,20 +4015,7 @@ class Scheduler:
|
||||
return False
|
||||
|
||||
def dep_size_hint(self, dep: Dep) -> int:
|
||||
res = 0
|
||||
if dep not in self.__dep_size_hint_cache:
|
||||
try:
|
||||
if not dep.has_unbacked_symbols():
|
||||
res = dep.numbytes_hint()
|
||||
except KeyError:
|
||||
# In at least one test (test/inductor/test_torchbind.py) we
|
||||
# create a StarDep that doesn't exist in the graph and calling
|
||||
# `has_unbacked_symbols()` throws an error.
|
||||
pass
|
||||
self.__dep_size_hint_cache[dep] = res
|
||||
else:
|
||||
res = self.__dep_size_hint_cache[dep]
|
||||
return res
|
||||
return V.graph.get_dep_size_hint(dep)
|
||||
|
||||
def score_fusion_memory(
|
||||
self, node1: BaseSchedulerNode, node2: BaseSchedulerNode
|
||||
|
Reference in New Issue
Block a user