mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Partitioner: Fix to align partition node order with original graph (#157892)
Fixes #157891 Pull Request resolved: https://github.com/pytorch/pytorch/pull/157892 Approved by: https://github.com/ezyang
This commit is contained in:
committed by
PyTorch MergeBot
parent
40c4d61f9a
commit
2507ae63f2
@ -24,6 +24,7 @@ class DummyPartitioner(CapabilityBasedPartitioner):
|
||||
)
|
||||
|
||||
|
||||
# original graph node order is: ['x', 'add', 'add_1', 'output']
|
||||
class AddModule(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
y = torch.add(x, x)
|
||||
@ -32,8 +33,18 @@ class AddModule(torch.nn.Module):
|
||||
|
||||
|
||||
class TestPartitionerOrder(TestCase):
|
||||
# partitoner test to check graph node order
|
||||
def test_partitioner_order(self):
|
||||
# partitoner test to check graph node order remains the same with the original graph after partitioning
|
||||
def test_partitioner_graph_node_order(self):
|
||||
m = AddModule()
|
||||
traced_m = torch.fx.symbolic_trace(m)
|
||||
origin_node_order = [n.name for n in traced_m.graph.nodes]
|
||||
partions = DummyPartitioner(traced_m).propose_partitions()
|
||||
partion_nodes = [list(partition.nodes) for partition in partions]
|
||||
partition_node_order = [n.name for n in partion_nodes[0]]
|
||||
self.assertTrue(partition_node_order == origin_node_order)
|
||||
|
||||
# partitoner test to check graph node order remains the same during multiple runs
|
||||
def test_partitioner_multiple_runs_order(self):
|
||||
m = AddModule()
|
||||
traced_m = torch.fx.symbolic_trace(m)
|
||||
partitions = DummyPartitioner(traced_m).propose_partitions()
|
||||
|
@ -18,16 +18,29 @@ logger.setLevel(logging.WARNING)
|
||||
|
||||
class Partition:
|
||||
def __init__(
|
||||
self, id: Optional[int] = None, nodes: Optional[Iterable[Node]] = None
|
||||
self,
|
||||
id: Optional[int] = None,
|
||||
nodes: Optional[Iterable[Node]] = None,
|
||||
node_orders: Optional[Iterable[int]] = None,
|
||||
):
|
||||
self.id = id
|
||||
self.nodes = dict.fromkeys(nodes) if nodes is not None else {}
|
||||
self.nodes: dict[Node, Optional[int]] = {}
|
||||
if nodes is not None:
|
||||
if node_orders is None:
|
||||
self.nodes = dict.fromkeys(nodes, None)
|
||||
else:
|
||||
nodes_list = list(nodes)
|
||||
node_orders_list = list(node_orders)
|
||||
assert len(nodes_list) == len(node_orders_list), (
|
||||
"nodes and node_orders must have the same length"
|
||||
)
|
||||
self.nodes = dict(zip(nodes_list, node_orders_list))
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return str(self.nodes)
|
||||
|
||||
def add_node(self, node: Node):
|
||||
self.nodes.update({node: None})
|
||||
def add_node(self, node: Node, node_order: Optional[int] = None):
|
||||
self.nodes.update({node: node_order})
|
||||
|
||||
def remove_node(self, node: Node):
|
||||
del self.nodes[node]
|
||||
@ -172,7 +185,7 @@ class CapabilityBasedPartitioner:
|
||||
|
||||
return merge_id, True
|
||||
|
||||
def merge_single_node(node: Node, id: Optional[int]):
|
||||
def merge_single_node(node: Node, node_order: Optional[int], id: Optional[int]):
|
||||
def _update_partition_map(node: Node, id: int):
|
||||
# Iterate through all the users of this node and update the partition map to indicate
|
||||
# that there is a path from the partition id of this node to the target partition id.
|
||||
@ -189,16 +202,19 @@ class CapabilityBasedPartitioner:
|
||||
assignment.pop(node)
|
||||
elif id not in partitions_by_id:
|
||||
assignment[node] = id
|
||||
partitions_by_id[id] = Partition(id=id, nodes=[node])
|
||||
assert node_order is not None
|
||||
partitions_by_id[id] = Partition(
|
||||
id=id, nodes=[node], node_orders=[node_order]
|
||||
)
|
||||
partition_users[id] = set(node.users)
|
||||
_update_partition_map(node, id)
|
||||
else:
|
||||
assignment[node] = id
|
||||
partitions_by_id[id].add_node(node)
|
||||
partitions_by_id[id].add_node(node, node_order)
|
||||
|
||||
logger.debug("Proposing partitions...")
|
||||
|
||||
for node in reversed(self.graph_module.graph.nodes):
|
||||
for node_order, node in enumerate(reversed(self.graph_module.graph.nodes)):
|
||||
# use Dict as an ordered set to ensure deterministic partitioning result, don't care value
|
||||
merge_candidates: dict[int, None] = {}
|
||||
|
||||
@ -211,7 +227,7 @@ class CapabilityBasedPartitioner:
|
||||
partition_id = next(new_partition_id)
|
||||
nodes_order[node] = partition_id
|
||||
partitions_order[partition_id] = partition_id
|
||||
merge_single_node(node, partition_id)
|
||||
merge_single_node(node, node_order, partition_id)
|
||||
merge_candidates[partition_id] = None
|
||||
|
||||
# merge all possible partitions
|
||||
@ -228,6 +244,14 @@ class CapabilityBasedPartitioner:
|
||||
# in the graph, otherwise, this is a no-op
|
||||
self_id, _ = maybe_merge_partition(self_id, other_id)
|
||||
|
||||
# sort partition nodes based on descending node order
|
||||
for partition in partitions_by_id.values():
|
||||
partition.nodes = dict(
|
||||
sorted(
|
||||
partition.nodes.items(), key=operator.itemgetter(1), reverse=True
|
||||
)
|
||||
)
|
||||
|
||||
# post processing to re-assign "getitem" nodes into upstream partition
|
||||
logger.debug("Reassigning getitem nodes to its producer node's partition...")
|
||||
nodes_reassignment: dict[Node, int] = {}
|
||||
@ -248,7 +272,7 @@ class CapabilityBasedPartitioner:
|
||||
if assignment.get(user, None) != id: # type: ignore[arg-type]
|
||||
nodes_reassignment[user] = id # type: ignore[assignment]
|
||||
for node, id in nodes_reassignment.items():
|
||||
merge_single_node(node, id)
|
||||
merge_single_node(node, None, id)
|
||||
|
||||
# filter out single node partitions
|
||||
if not self.allows_single_node_partition:
|
||||
|
@ -96,7 +96,7 @@ def fuse_as_graphmodule(
|
||||
gm: GraphModule,
|
||||
nodes: NodeList,
|
||||
module_name: str,
|
||||
partition_lookup_table: _Optional[dict[Node, None]] = None,
|
||||
partition_lookup_table: _Optional[dict[Node, _Optional[int]]] = None,
|
||||
*,
|
||||
always_return_tuple: bool = False,
|
||||
) -> tuple[GraphModule, tuple[Node, ...], tuple[Node, ...]]:
|
||||
@ -249,7 +249,7 @@ def erase_nodes(gm: GraphModule, nodes: NodeList) -> None:
|
||||
@compatibility(is_backward_compatible=False)
|
||||
def fuse_by_partitions(
|
||||
gm: GraphModule,
|
||||
partitions: list[dict[Node, None]],
|
||||
partitions: list[dict[Node, _Optional[int]]],
|
||||
prefix: str = "fused_",
|
||||
always_return_tuple: bool = False,
|
||||
) -> GraphModule:
|
||||
|
Reference in New Issue
Block a user