fusion of large accumulated reads only at ir level (#161978)

This is to revert some of the changes in https://github.com/pytorch/pytorch/pull/158667

In particular, we only disallow fusion of large accumulate read at IR level and not at scheduler level, as users can create their own custom fusion logics for the scheduler level.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/161978
Approved by: https://github.com/yf225
This commit is contained in:
Xuan Zhang
2025-09-12 14:19:24 -07:00
committed by PyTorch MergeBot
parent 783985e9fe
commit da669d51bf
4 changed files with 33 additions and 22 deletions

View File

@ -353,6 +353,33 @@ class TestOperatorReorderForPeakMemory(TestCase):
y = torch.rand(N, N, dtype=torch.float32, device=GPU_TYPE)
z = torch.rand(N, N, dtype=torch.float32, device=GPU_TYPE)
from torch._inductor.choices import InductorChoices
from torch._inductor.scheduler import BaseSchedulerNode, Scheduler
class CustomInductorChoices(InductorChoices):
@staticmethod
def can_fuse(
scheduler: Scheduler,
node1: BaseSchedulerNode,
node2: BaseSchedulerNode,
shared_data_score: int,
) -> bool:
can_fuse_default = InductorChoices.can_fuse(
scheduler, node1, node2, shared_data_score
)
if (not can_fuse_default) or (
not config.realize_acc_reads_size_threshold
):
return can_fuse_default
all_reads = (node1.read_writes.reads | node2.read_writes.reads) - (
node1.read_writes.writes | node2.read_writes.writes
)
size_of_reads = [scheduler.dep_size_hint(dep) for dep in all_reads]
return sum(size_of_reads) < config.realize_acc_reads_size_threshold
torch._inductor.virtualized.V.set_choices_handler(CustomInductorChoices())
# CASE 1: no restriction on the amount of accumulation
with config.patch({"realize_acc_reads_size_threshold": float("inf")}):
f_compiled = torch.compile(f)

View File

@ -496,17 +496,6 @@ class InductorChoices:
WhyNoFuse(node1, node2)("Fusion will increase peak memory")
return False
if (
config.realize_acc_reads_size_threshold is not None
and scheduler.fusion_accumulate_large_reads(
node1,
node2,
config.realize_acc_reads_size_threshold,
)
):
WhyNoFuse(node1, node2)("Fusion accumulate large amount of reads")
return False
return True
@staticmethod

View File

@ -8207,9 +8207,12 @@ class StorageBox(MutableBox):
self.realize()
def has_accumulated_enough_reads_by_size(self, threshold: int) -> bool:
return (
sum(V.graph.get_dep_size_hint(dep) for dep in self.get_reads()) > threshold
)
size_of_reads = [V.graph.get_dep_size_hint(dep) for dep in self.get_reads()]
if not size_of_reads:
return False
total_size = sum(size_of_reads)
max_size = max(size_of_reads)
return total_size > threshold and total_size / max_size >= 2
def has_exceeded_max_reads(self) -> bool:
return isinstance(self.data, Pointwise) and (

View File

@ -3807,14 +3807,6 @@ class Scheduler:
return True
return False
def fusion_accumulate_large_reads(
self, node1: BaseSchedulerNode, node2: BaseSchedulerNode, threshold: int
) -> bool:
all_reads = (node1.read_writes.reads | node2.read_writes.reads) - (
node1.read_writes.writes | node2.read_writes.writes
)
return sum(self.dep_size_hint(dep) for dep in all_reads) > threshold
def are_long_distant_nodes(
self, node1: BaseSchedulerNode, node2: BaseSchedulerNode
) -> bool: