Partitioner: Fix to align partition node order with original graph (#157892)

Fixes #157891

Pull Request resolved: https://github.com/pytorch/pytorch/pull/157892
Approved by: https://github.com/ezyang
This commit is contained in:
Xiaochang Wu
2025-08-06 22:12:47 +00:00
committed by PyTorch MergeBot
parent 40c4d61f9a
commit 2507ae63f2
3 changed files with 49 additions and 14 deletions

View File

@ -24,6 +24,7 @@ class DummyPartitioner(CapabilityBasedPartitioner):
)
# original graph node order is: ['x', 'add', 'add_1', 'output']
class AddModule(torch.nn.Module):
def forward(self, x):
y = torch.add(x, x)
@ -32,8 +33,18 @@ class AddModule(torch.nn.Module):
class TestPartitionerOrder(TestCase):
# partitoner test to check graph node order
def test_partitioner_order(self):
# partitoner test to check graph node order remains the same with the original graph after partitioning
def test_partitioner_graph_node_order(self):
m = AddModule()
traced_m = torch.fx.symbolic_trace(m)
origin_node_order = [n.name for n in traced_m.graph.nodes]
partions = DummyPartitioner(traced_m).propose_partitions()
partion_nodes = [list(partition.nodes) for partition in partions]
partition_node_order = [n.name for n in partion_nodes[0]]
self.assertTrue(partition_node_order == origin_node_order)
# partitoner test to check graph node order remains the same during multiple runs
def test_partitioner_multiple_runs_order(self):
m = AddModule()
traced_m = torch.fx.symbolic_trace(m)
partitions = DummyPartitioner(traced_m).propose_partitions()

View File

@ -18,16 +18,29 @@ logger.setLevel(logging.WARNING)
class Partition:
def __init__(
self, id: Optional[int] = None, nodes: Optional[Iterable[Node]] = None
self,
id: Optional[int] = None,
nodes: Optional[Iterable[Node]] = None,
node_orders: Optional[Iterable[int]] = None,
):
self.id = id
self.nodes = dict.fromkeys(nodes) if nodes is not None else {}
self.nodes: dict[Node, Optional[int]] = {}
if nodes is not None:
if node_orders is None:
self.nodes = dict.fromkeys(nodes, None)
else:
nodes_list = list(nodes)
node_orders_list = list(node_orders)
assert len(nodes_list) == len(node_orders_list), (
"nodes and node_orders must have the same length"
)
self.nodes = dict(zip(nodes_list, node_orders_list))
def __repr__(self) -> str:
return str(self.nodes)
def add_node(self, node: Node):
self.nodes.update({node: None})
def add_node(self, node: Node, node_order: Optional[int] = None):
self.nodes.update({node: node_order})
def remove_node(self, node: Node):
del self.nodes[node]
@ -172,7 +185,7 @@ class CapabilityBasedPartitioner:
return merge_id, True
def merge_single_node(node: Node, id: Optional[int]):
def merge_single_node(node: Node, node_order: Optional[int], id: Optional[int]):
def _update_partition_map(node: Node, id: int):
# Iterate through all the users of this node and update the partition map to indicate
# that there is a path from the partition id of this node to the target partition id.
@ -189,16 +202,19 @@ class CapabilityBasedPartitioner:
assignment.pop(node)
elif id not in partitions_by_id:
assignment[node] = id
partitions_by_id[id] = Partition(id=id, nodes=[node])
assert node_order is not None
partitions_by_id[id] = Partition(
id=id, nodes=[node], node_orders=[node_order]
)
partition_users[id] = set(node.users)
_update_partition_map(node, id)
else:
assignment[node] = id
partitions_by_id[id].add_node(node)
partitions_by_id[id].add_node(node, node_order)
logger.debug("Proposing partitions...")
for node in reversed(self.graph_module.graph.nodes):
for node_order, node in enumerate(reversed(self.graph_module.graph.nodes)):
# use Dict as an ordered set to ensure deterministic partitioning result, don't care value
merge_candidates: dict[int, None] = {}
@ -211,7 +227,7 @@ class CapabilityBasedPartitioner:
partition_id = next(new_partition_id)
nodes_order[node] = partition_id
partitions_order[partition_id] = partition_id
merge_single_node(node, partition_id)
merge_single_node(node, node_order, partition_id)
merge_candidates[partition_id] = None
# merge all possible partitions
@ -228,6 +244,14 @@ class CapabilityBasedPartitioner:
# in the graph, otherwise, this is a no-op
self_id, _ = maybe_merge_partition(self_id, other_id)
# sort partition nodes based on descending node order
for partition in partitions_by_id.values():
partition.nodes = dict(
sorted(
partition.nodes.items(), key=operator.itemgetter(1), reverse=True
)
)
# post processing to re-assign "getitem" nodes into upstream partition
logger.debug("Reassigning getitem nodes to its producer node's partition...")
nodes_reassignment: dict[Node, int] = {}
@ -248,7 +272,7 @@ class CapabilityBasedPartitioner:
if assignment.get(user, None) != id: # type: ignore[arg-type]
nodes_reassignment[user] = id # type: ignore[assignment]
for node, id in nodes_reassignment.items():
merge_single_node(node, id)
merge_single_node(node, None, id)
# filter out single node partitions
if not self.allows_single_node_partition:

View File

@ -96,7 +96,7 @@ def fuse_as_graphmodule(
gm: GraphModule,
nodes: NodeList,
module_name: str,
partition_lookup_table: _Optional[dict[Node, None]] = None,
partition_lookup_table: _Optional[dict[Node, _Optional[int]]] = None,
*,
always_return_tuple: bool = False,
) -> tuple[GraphModule, tuple[Node, ...], tuple[Node, ...]]:
@ -249,7 +249,7 @@ def erase_nodes(gm: GraphModule, nodes: NodeList) -> None:
@compatibility(is_backward_compatible=False)
def fuse_by_partitions(
gm: GraphModule,
partitions: list[dict[Node, None]],
partitions: list[dict[Node, _Optional[int]]],
prefix: str = "fused_",
always_return_tuple: bool = False,
) -> GraphModule: