mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[Graph Partition] reorder for minimal number of partitions (#151968)
This pr adds an optimal reordering for minimizing #partitions. ## Optimal reordering for minimizing #partitions A bfs could minimize #partitions (ignore peak memory for now): 1. For each node, compute node_to_indegree: dict[node, int]. 2. Maintain 2 queues: cudagraphable_nodes, and non_cudagraphable_nodes. Iterate through all nodes and add nodes to one of these 2 queues if node_to_indegree[node] == 0. 3. While non_cudagraphable_nodes is not empty: Pop 1 node, schedule it, update the indegree of all its successors, and add its successor nodes to one of the queues if node_to_indegree[successor] == 0. 4. While cudagraphable_nodes is not empty: Pop 1 node, schedule it, update the indegree of all its successors, and add its successor nodes to one of the queues if node_to_indegree[successor] == 0. 5. Repeat step 3 & 4 until all nodes have been scheduled. We call this strategy `reorder_for_minimizing_partition`. **Q: Why is this optimal?** Suppose this is not optimal, we have a counter example with 2 non_cudagraphable regions: ``` [non_cudagrable1, cudagraphable2, non_cudagraphable3] ``` where we can reorder to only 1 non_cudagraphable region: ``` [non_cudagrable1, non_cudagraphable3, cudagraphable2] ``` This reorder means non_cudagraphable3 does not depend on cudagraphable2. So after we scheduled non_cudagraphable1, both non_cudagraphable3 and cudagraphable2 have in_degree as 0. If this is true, Step 3 should have already scheduled non_cudagraphable3 before cudagraphable2 such that the counter example cannot exist. This shows we cannot find such a counter example and the bfs is optimal on minimizing #partitions. ## Minimize peak memory `reorder_for_peak_memory` currently uses topological_sort_dfs, topological_sort_lpmf, and topological_sort_bfs, where the later 2 are bfs. ILP brings small benefits and it can hardly scale to more than 100 nodes, according to @xuanzhang816. So ILP is not used for peak memory reorder in the inductor. Heuristics strategy: - Conduct reorder_for_peak_memory as the default order - Conduct reorder_for_minimal_partitions and get results as list[tuple[partition, bool]], where partition: list[BaseSchedulerNode] and bool for cudagraphable. - If the reorder increases peak memory too much, we use the default order. Pull Request resolved: https://github.com/pytorch/pytorch/pull/151968 Approved by: https://github.com/eellison
This commit is contained in:
committed by
PyTorch MergeBot
parent
a77a44761b
commit
797768cd90
@ -2686,8 +2686,8 @@ if HAS_CUDA:
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
# 2 graph partitions lead to 2 fwd cudagraphs and 2 bwd cudagraphs
|
||||
self.assertEqual(self.get_manager().new_graph_id().id, 4)
|
||||
# 2 graph partitions lead to 2 fwd cudagraphs and 1 bwd cudagraphs
|
||||
self.assertEqual(self.get_manager().new_graph_id().id, 3)
|
||||
|
||||
@torch._inductor.config.patch("graph_partition", True)
|
||||
def test_graph_partition_cpu_only(self):
|
||||
@ -3088,6 +3088,89 @@ if HAS_CUDA:
|
||||
|
||||
self.assertEqual(self.get_manager().new_graph_id().id, 3)
|
||||
|
||||
@torch._inductor.config.patch("graph_partition", True)
|
||||
def test_graph_partition_reorder_cpu_and_gpu(self):
|
||||
def f(x_cuda, y_cpu, z_cuda, weight_cuda, weight_cpu):
|
||||
x_cuda0 = x_cuda + 1
|
||||
x_cuda1 = x_cuda0 @ weight_cuda
|
||||
x_cuda2 = 2 * (x_cuda1 + x_cuda)
|
||||
|
||||
y_cpu0 = y_cpu + 1
|
||||
y_cpu1 = y_cpu0 @ weight_cpu
|
||||
|
||||
z_cuda0 = z_cuda + 1
|
||||
z_cuda1 = z_cuda0 @ weight_cuda
|
||||
z_cuda2 = 2 * (z_cuda1 + z_cuda)
|
||||
|
||||
return x_cuda2, y_cpu1, z_cuda2
|
||||
|
||||
x_cuda = torch.randn(3, 3, device="cuda")
|
||||
y_cpu = torch.randn(3, 3, device="cpu")
|
||||
z_cuda = torch.randn(3, 3, device="cuda")
|
||||
weight_cuda = torch.randn(3, 3, device="cuda")
|
||||
weight_cpu = torch.randn(3, 3, device="cpu")
|
||||
|
||||
eager_out = f(x_cuda, y_cpu, z_cuda, weight_cuda, weight_cpu)
|
||||
|
||||
compiled_f = torch.compile(f, mode="reduce-overhead")
|
||||
for _ in range(3):
|
||||
compiled_out = compiled_f(
|
||||
x_cuda, y_cpu, z_cuda, weight_cuda, weight_cpu
|
||||
)
|
||||
self.assertEqual(eager_out, compiled_out)
|
||||
|
||||
# reorder merges ops on cuda into 1 graph partition
|
||||
self.assertEqual(self.get_manager().new_graph_id().id, 1)
|
||||
|
||||
@torch._inductor.config.patch("graph_partition", True)
|
||||
def test_graph_partition_reorder_cpu_and_gpu_interleave(self):
|
||||
def f(x_cuda, y_cpu, z_cuda, weight_cuda, weight_cpu):
|
||||
# partition 1 on cuda, no dependency
|
||||
x_cuda0 = x_cuda + 1
|
||||
x_cuda1 = x_cuda0 @ weight_cuda
|
||||
x_cuda2 = 2 * (x_cuda1 + x_cuda)
|
||||
|
||||
# partition 2 on cpu w/ dependency on partition 1
|
||||
y_cpu0 = y_cpu + 1
|
||||
x_cuda2_cpu = x_cuda2.cpu() # adds dependency on gpu computations
|
||||
y_cpu1 = y_cpu0 @ weight_cpu + x_cuda2_cpu
|
||||
|
||||
# partition 3 on cuda w/o dependency
|
||||
z_cuda0 = z_cuda + 1
|
||||
z_cuda1 = z_cuda0 @ weight_cuda
|
||||
z_cuda2 = 2 * (z_cuda1 + z_cuda)
|
||||
|
||||
# partition 4 on cpu w/o dependency
|
||||
y_cpu2 = y_cpu + 5
|
||||
y_cpu3 = y_cpu2 @ weight_cpu
|
||||
|
||||
# partition 5 on cuda w/o dependency
|
||||
u_cuda0 = z_cuda + 3
|
||||
u_cuda1 = u_cuda0 @ weight_cuda
|
||||
u_cuda2 = 2 * (u_cuda0 + u_cuda1)
|
||||
|
||||
return x_cuda2, y_cpu1, z_cuda2, y_cpu3, u_cuda2
|
||||
|
||||
x_cuda = torch.randn(3, 3, device="cuda")
|
||||
y_cpu = torch.randn(3, 3, device="cpu")
|
||||
z_cuda = torch.randn(3, 3, device="cuda")
|
||||
weight_cuda = torch.randn(3, 3, device="cuda")
|
||||
weight_cpu = torch.randn(3, 3, device="cpu")
|
||||
|
||||
eager_out = f(x_cuda, y_cpu, z_cuda, weight_cuda, weight_cpu)
|
||||
|
||||
compiled_f = torch.compile(f, mode="reduce-overhead")
|
||||
for _ in range(3):
|
||||
compiled_out = compiled_f(
|
||||
x_cuda, y_cpu, z_cuda, weight_cuda, weight_cpu
|
||||
)
|
||||
self.assertEqual(eager_out, compiled_out)
|
||||
|
||||
# the optimal order is
|
||||
# [[partition 4 on cpu], [partition 1,3,5 on cuda], [partition 2 on cpu]]
|
||||
# since partition2 depends on partition1. So we have 1 cudagraph in total.
|
||||
self.assertEqual(self.get_manager().new_graph_id().id, 1)
|
||||
|
||||
@config.patch(implicit_fallbacks=True)
|
||||
@torch._inductor.config.patch("graph_partition", True)
|
||||
def test_graph_partition_reorder_custom_op_with_no_dependency(self):
|
||||
|
@ -22,6 +22,13 @@ if TYPE_CHECKING:
|
||||
torch_log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class PeakMemoryResult:
|
||||
order: list[BaseSchedulerNode]
|
||||
peak_memory: int
|
||||
method: str
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class MemoryPlanningInfoForBuffer:
|
||||
size_alloc: int = 0
|
||||
@ -578,6 +585,35 @@ def topological_sort_dfs(nodes: list[BaseSchedulerNode]) -> list[BaseSchedulerNo
|
||||
return result
|
||||
|
||||
|
||||
def prepare_planning_info(
|
||||
nodes: list[BaseSchedulerNode],
|
||||
name_to_buf: dict[str, SchedulerBuffer],
|
||||
name_to_fused_node: dict[str, BaseSchedulerNode],
|
||||
graph_inputs: OrderedSet[str],
|
||||
graph_outputs: OrderedSet[str],
|
||||
) -> tuple[int, dict[str, FreeableInputBuffer]]:
|
||||
"""
|
||||
Prepare planning info. As nodes are scheduled one at a time, these help
|
||||
keep track of when a buffer can be freed, and when a node can be scheduled
|
||||
|
||||
Returns:
|
||||
int: peak memory estimation
|
||||
dict[str, FreeableInputBuffer]: name to freeable input buffer
|
||||
"""
|
||||
name_to_freeable_input_buf = get_freeable_input_buf(nodes, graph_inputs)
|
||||
assign_memory_planning_info_for_scheduler_buffers(nodes, name_to_buf)
|
||||
assign_memory_planning_info_for_scheduler_nodes(
|
||||
nodes, name_to_fused_node, name_to_buf, name_to_freeable_input_buf
|
||||
)
|
||||
|
||||
# the default
|
||||
estimated_peak_memory, _ = estimate_peak_memory(
|
||||
nodes, name_to_freeable_input_buf, graph_outputs
|
||||
)
|
||||
|
||||
return estimated_peak_memory, name_to_freeable_input_buf
|
||||
|
||||
|
||||
def reorder_for_peak_memory(
|
||||
nodes: list[BaseSchedulerNode],
|
||||
name_to_buf: dict[str, SchedulerBuffer],
|
||||
@ -597,29 +633,16 @@ def reorder_for_peak_memory(
|
||||
|
||||
torch_log.info("Reordering for peak memory -- %d nodes", len(nodes))
|
||||
|
||||
@dataclasses.dataclass
|
||||
class PeakMemoryResult:
|
||||
order: list[BaseSchedulerNode]
|
||||
peak_memory: int
|
||||
method: str
|
||||
|
||||
# preparation -- as nodes are scheduled one at a time, these help
|
||||
# keep track of when a buffer can be freed, and when a node can be scheduled
|
||||
name_to_freeable_input_buf: dict[str, FreeableInputBuffer] = get_freeable_input_buf(
|
||||
nodes, graph_inputs
|
||||
)
|
||||
assign_memory_planning_info_for_scheduler_buffers(nodes, name_to_buf)
|
||||
assign_memory_planning_info_for_scheduler_nodes(
|
||||
nodes, name_to_fused_node, name_to_buf, name_to_freeable_input_buf
|
||||
estimated_peak_memory, name_to_freeable_input_buf = prepare_planning_info(
|
||||
nodes,
|
||||
name_to_buf,
|
||||
name_to_fused_node,
|
||||
graph_inputs,
|
||||
graph_outputs,
|
||||
)
|
||||
|
||||
# keep track of the peak memory estimates of different methods
|
||||
peak_memory_diff_methods: list[PeakMemoryResult] = []
|
||||
|
||||
# the default
|
||||
estimated_peak_memory, _ = estimate_peak_memory(
|
||||
nodes, name_to_freeable_input_buf, graph_outputs
|
||||
)
|
||||
peak_memory_diff_methods.append(
|
||||
PeakMemoryResult(nodes, estimated_peak_memory, "baseline")
|
||||
)
|
||||
|
@ -2104,6 +2104,7 @@ class Scheduler:
|
||||
self.process_grouped_nodes()
|
||||
|
||||
if torch._inductor.config.graph_partition:
|
||||
self.nodes = self.maybe_reorder_for_minimizing_partition(self.nodes)
|
||||
self.nodes = self.reorder_for_partition_with_simple_dependency(self.nodes)
|
||||
|
||||
self.compute_last_usage()
|
||||
@ -4282,6 +4283,100 @@ class Scheduler:
|
||||
|
||||
return signatures[::-1]
|
||||
|
||||
def reorder_for_minimizing_partition(
|
||||
self,
|
||||
nodes: list[BaseSchedulerNode],
|
||||
) -> list[BaseSchedulerNode]:
|
||||
"""
|
||||
Reorder nodes to minimize the number of partitions via a bfs
|
||||
topological sort. This is the optimal reodering such that the
|
||||
number of partitions cannot be reduced further. This may be
|
||||
sub-optimal for other metrics such as peak memory. This does not
|
||||
change relative orders of two cudagraphable nodes, nor the
|
||||
relative order of two non_cudagraphable nodes.
|
||||
"""
|
||||
node_to_indegree: dict[BaseSchedulerNode, int] = dict()
|
||||
cudagraphable_nodes: collections.deque[BaseSchedulerNode] = collections.deque()
|
||||
non_cudagraphable_nodes: collections.deque[BaseSchedulerNode] = (
|
||||
collections.deque()
|
||||
)
|
||||
|
||||
def insert_pending_nodes(node: BaseSchedulerNode) -> None:
|
||||
if self.should_partition(node):
|
||||
non_cudagraphable_nodes.append(node)
|
||||
else:
|
||||
cudagraphable_nodes.append(node)
|
||||
|
||||
def update_indegree(node: BaseSchedulerNode) -> None:
|
||||
for succ_node in node.mpi_node.succ_nodes:
|
||||
assert node_to_indegree[succ_node] > 0
|
||||
node_to_indegree[succ_node] -= 1
|
||||
if node_to_indegree[succ_node] == 0:
|
||||
insert_pending_nodes(succ_node)
|
||||
|
||||
for node in nodes:
|
||||
node_to_indegree[node] = len(node.mpi_node.pred_nodes)
|
||||
if node_to_indegree[node] == 0:
|
||||
insert_pending_nodes(node)
|
||||
|
||||
schedule: list[BaseSchedulerNode] = []
|
||||
num_iters: int = 0
|
||||
while num_iters < len(nodes) and (
|
||||
non_cudagraphable_nodes or cudagraphable_nodes
|
||||
):
|
||||
while non_cudagraphable_nodes:
|
||||
node = non_cudagraphable_nodes.popleft()
|
||||
schedule.append(node)
|
||||
update_indegree(node)
|
||||
|
||||
while cudagraphable_nodes:
|
||||
node = cudagraphable_nodes.popleft()
|
||||
schedule.append(node)
|
||||
update_indegree(node)
|
||||
|
||||
num_iters += 1
|
||||
|
||||
if num_iters > len(nodes):
|
||||
raise RuntimeError(
|
||||
"""
|
||||
Failed to schedule, while loop ran too long when
|
||||
reordering for minimizing the num of partitions
|
||||
"""
|
||||
)
|
||||
|
||||
return schedule
|
||||
|
||||
def maybe_reorder_for_minimizing_partition(
|
||||
self,
|
||||
nodes: list[BaseSchedulerNode],
|
||||
) -> list[BaseSchedulerNode]:
|
||||
"""
|
||||
Reorder nodes to minimize the number of partitions if this only slightly
|
||||
increase peak memory.
|
||||
"""
|
||||
from .memory import estimate_peak_memory, prepare_planning_info
|
||||
|
||||
graph_outputs = OrderedSet(V.graph.get_output_names())
|
||||
|
||||
default_peak_memory, name_to_freeable_input_buf = prepare_planning_info(
|
||||
nodes,
|
||||
self.name_to_buf,
|
||||
self.name_to_fused_node,
|
||||
OrderedSet(V.graph.graph_inputs.keys()),
|
||||
graph_outputs,
|
||||
)
|
||||
|
||||
reordered_nodes = self.reorder_for_minimizing_partition(nodes)
|
||||
reorder_peak_memory, _ = estimate_peak_memory(
|
||||
reordered_nodes, name_to_freeable_input_buf, graph_outputs
|
||||
)
|
||||
|
||||
# 1.1 here means 10% extra peak memory budget which is quite arbitrary
|
||||
if reorder_peak_memory < default_peak_memory * 1.1:
|
||||
return reordered_nodes
|
||||
|
||||
return nodes
|
||||
|
||||
def reorder_for_partition_with_simple_dependency(
|
||||
self, nodes: list[BaseSchedulerNode]
|
||||
) -> list[BaseSchedulerNode]:
|
||||
|
Reference in New Issue
Block a user