mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
fusion of large accumulated reads only at ir level (#161978)
This is to revert some of the changes in https://github.com/pytorch/pytorch/pull/158667 In particular, we only disallow fusion of large accumulate read at IR level and not at scheduler level, as users can create their own custom fusion logics for the scheduler level. Pull Request resolved: https://github.com/pytorch/pytorch/pull/161978 Approved by: https://github.com/yf225
This commit is contained in:
committed by
PyTorch MergeBot
parent
783985e9fe
commit
da669d51bf
@ -353,6 +353,33 @@ class TestOperatorReorderForPeakMemory(TestCase):
|
||||
y = torch.rand(N, N, dtype=torch.float32, device=GPU_TYPE)
|
||||
z = torch.rand(N, N, dtype=torch.float32, device=GPU_TYPE)
|
||||
|
||||
from torch._inductor.choices import InductorChoices
|
||||
from torch._inductor.scheduler import BaseSchedulerNode, Scheduler
|
||||
|
||||
class CustomInductorChoices(InductorChoices):
|
||||
@staticmethod
|
||||
def can_fuse(
|
||||
scheduler: Scheduler,
|
||||
node1: BaseSchedulerNode,
|
||||
node2: BaseSchedulerNode,
|
||||
shared_data_score: int,
|
||||
) -> bool:
|
||||
can_fuse_default = InductorChoices.can_fuse(
|
||||
scheduler, node1, node2, shared_data_score
|
||||
)
|
||||
if (not can_fuse_default) or (
|
||||
not config.realize_acc_reads_size_threshold
|
||||
):
|
||||
return can_fuse_default
|
||||
|
||||
all_reads = (node1.read_writes.reads | node2.read_writes.reads) - (
|
||||
node1.read_writes.writes | node2.read_writes.writes
|
||||
)
|
||||
size_of_reads = [scheduler.dep_size_hint(dep) for dep in all_reads]
|
||||
return sum(size_of_reads) < config.realize_acc_reads_size_threshold
|
||||
|
||||
torch._inductor.virtualized.V.set_choices_handler(CustomInductorChoices())
|
||||
|
||||
# CASE 1: no restriction on the amount of accumulation
|
||||
with config.patch({"realize_acc_reads_size_threshold": float("inf")}):
|
||||
f_compiled = torch.compile(f)
|
||||
|
@ -496,17 +496,6 @@ class InductorChoices:
|
||||
WhyNoFuse(node1, node2)("Fusion will increase peak memory")
|
||||
return False
|
||||
|
||||
if (
|
||||
config.realize_acc_reads_size_threshold is not None
|
||||
and scheduler.fusion_accumulate_large_reads(
|
||||
node1,
|
||||
node2,
|
||||
config.realize_acc_reads_size_threshold,
|
||||
)
|
||||
):
|
||||
WhyNoFuse(node1, node2)("Fusion accumulate large amount of reads")
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
|
@ -8207,9 +8207,12 @@ class StorageBox(MutableBox):
|
||||
self.realize()
|
||||
|
||||
def has_accumulated_enough_reads_by_size(self, threshold: int) -> bool:
|
||||
return (
|
||||
sum(V.graph.get_dep_size_hint(dep) for dep in self.get_reads()) > threshold
|
||||
)
|
||||
size_of_reads = [V.graph.get_dep_size_hint(dep) for dep in self.get_reads()]
|
||||
if not size_of_reads:
|
||||
return False
|
||||
total_size = sum(size_of_reads)
|
||||
max_size = max(size_of_reads)
|
||||
return total_size > threshold and total_size / max_size >= 2
|
||||
|
||||
def has_exceeded_max_reads(self) -> bool:
|
||||
return isinstance(self.data, Pointwise) and (
|
||||
|
@ -3807,14 +3807,6 @@ class Scheduler:
|
||||
return True
|
||||
return False
|
||||
|
||||
def fusion_accumulate_large_reads(
|
||||
self, node1: BaseSchedulerNode, node2: BaseSchedulerNode, threshold: int
|
||||
) -> bool:
|
||||
all_reads = (node1.read_writes.reads | node2.read_writes.reads) - (
|
||||
node1.read_writes.writes | node2.read_writes.writes
|
||||
)
|
||||
return sum(self.dep_size_hint(dep) for dep in all_reads) > threshold
|
||||
|
||||
def are_long_distant_nodes(
|
||||
self, node1: BaseSchedulerNode, node2: BaseSchedulerNode
|
||||
) -> bool:
|
||||
|
Reference in New Issue
Block a user