mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
refactor bucketing (#163754)
Preparatory refactory Pull Request resolved: https://github.com/pytorch/pytorch/pull/163754 Approved by: https://github.com/IvanKobzarev ghstack dependencies: #163215
This commit is contained in:
committed by
PyTorch MergeBot
parent
c9b5af9a38
commit
e1bd5b60cf
@ -18,6 +18,22 @@ logger: logging.Logger = logging.getLogger(__name__)
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
|
||||
# Helper functions moved to top for better organization
|
||||
def _ag_group_key(node: torch.fx.Node) -> tuple[str, torch.dtype]:
|
||||
_, group_size, group_name = node.args
|
||||
dtype = node.meta["val"].dtype
|
||||
assert isinstance(group_name, str)
|
||||
return (group_name, dtype)
|
||||
|
||||
|
||||
def _rs_group_key(node: torch.fx.Node) -> tuple[str, str, torch.dtype]:
|
||||
_, reduce_op, group_size, group_name = node.args
|
||||
dtype = node.meta["val"].dtype
|
||||
assert isinstance(group_name, str)
|
||||
assert isinstance(reduce_op, str)
|
||||
return (group_name, reduce_op, dtype)
|
||||
|
||||
|
||||
def bucket_cap_mb_by_bucket_idx_default(bucket_id: int) -> float:
|
||||
"""
|
||||
Determine the size of a bucket based on its ID.
|
||||
@ -229,12 +245,6 @@ def bucket_all_gather_by_mb(
|
||||
list[list[torch.fx.Node]]: List of buckets, where each bucket is a list of all_gather nodes.
|
||||
"""
|
||||
|
||||
def _ag_group_key(node: torch.fx.Node) -> tuple[str, torch.dtype]:
|
||||
_, group_size, group_name = node.args
|
||||
dtype = node.meta["val"].dtype
|
||||
assert isinstance(group_name, str)
|
||||
return (group_name, dtype)
|
||||
|
||||
return greedy_bucket_collective_by_mb(
|
||||
gm,
|
||||
bucket_cap_mb_by_bucket_idx,
|
||||
@ -265,13 +275,6 @@ def bucket_reduce_scatter_by_mb(
|
||||
list[list[torch.fx.Node]]: List of buckets, where each bucket is a list of reduce_scatter nodes.
|
||||
"""
|
||||
|
||||
def _rs_group_key(node: torch.fx.Node) -> tuple[str, str, torch.dtype]:
|
||||
_, reduce_op, group_size, group_name = node.args
|
||||
dtype = node.meta["val"].dtype
|
||||
assert isinstance(group_name, str)
|
||||
assert isinstance(reduce_op, str)
|
||||
return (group_name, reduce_op, dtype)
|
||||
|
||||
return greedy_bucket_collective_by_mb(
|
||||
gm,
|
||||
bucket_cap_mb_by_bucket_idx,
|
||||
@ -515,13 +518,16 @@ def _insert_fn_trace_before_node( # type: ignore[no-untyped-def]
|
||||
insert_before_node: torch.fx.Node,
|
||||
g_fn_inps: list[torch.fx.Node],
|
||||
g_fn_outs: list[torch.fx.Node],
|
||||
) -> dict[torch.fx.Node, torch.fx.Node]: # type: ignore[no-untyped-def]
|
||||
) -> tuple[dict[torch.fx.Node, torch.fx.Node], list[torch.fx.Node]]: # type: ignore[no-untyped-def]
|
||||
"""
|
||||
Helper function that traces :attr:`fn_to_trace` with inputs
|
||||
:attr:`inps`.
|
||||
The result function graph will be inserted before :attr:`insert_before_node`,
|
||||
using :attr:`g_fn_inps` nodes of original graphas inputs of function graph,
|
||||
using :attr:`g_fn_inps` nodes of original graph as inputs of function graph,
|
||||
function graph outputs will replace :attr:`g_fn_outs` in original graph.
|
||||
|
||||
Returns:
|
||||
(replacements, new_nodes): Dictionary mapping old to new nodes, and list of all newly inserted nodes
|
||||
"""
|
||||
with dynamo_timed(
|
||||
"fx.bucketing._insert_fn_trace_before_node", log_pt2_compile_event=True
|
||||
@ -534,6 +540,8 @@ def _insert_fn_trace_before_node( # type: ignore[no-untyped-def]
|
||||
fn_g_ins = fn_g.find_nodes(op="placeholder")
|
||||
env = {fn_g_ins[idx]: g_fn_inps[idx] for idx in range(len(g_fn_inps))}
|
||||
g_fn_new_outs: list[torch.fx.Node] = []
|
||||
new_nodes: list[torch.fx.Node] = [] # Track all newly inserted nodes
|
||||
|
||||
with g.inserting_before(insert_before_node):
|
||||
for _n in fn_g.nodes:
|
||||
if _n.op == "placeholder":
|
||||
@ -543,12 +551,201 @@ def _insert_fn_trace_before_node( # type: ignore[no-untyped-def]
|
||||
if _n.op == "output":
|
||||
g_fn_new_outs = _new_n.args[0] # type: ignore[assignment]
|
||||
g.erase_node(_new_n)
|
||||
else:
|
||||
new_nodes.append(_new_n) # Track non-output nodes
|
||||
|
||||
replacements = { # noqa: C416
|
||||
orig_out: new_out for orig_out, new_out in zip(g_fn_outs, g_fn_new_outs)
|
||||
}
|
||||
for orig_out, new_out in zip(g_fn_outs, g_fn_new_outs):
|
||||
orig_out.replace_all_uses_with(new_out)
|
||||
return replacements
|
||||
|
||||
return replacements, new_nodes
|
||||
|
||||
|
||||
def process_collective_bucket(
|
||||
g: torch.fx.Graph,
|
||||
bucket_nodes: list[torch.fx.Node],
|
||||
fn_to_trace: Callable[..., list[torch.Tensor]],
|
||||
trace_args_fn: Callable[[list[torch.fx.Node]], tuple[Any, ...]],
|
||||
insert_before: Optional[torch.fx.Node] = None,
|
||||
wait_insertion_point: Optional[torch.fx.Node] = None,
|
||||
) -> dict[torch.fx.Node, torch.fx.Node]:
|
||||
"""
|
||||
Process a single bucket of collective operation nodes with flexible insertion control.
|
||||
|
||||
Args:
|
||||
g: The graph to modify
|
||||
bucket_nodes: Nodes in the current bucket to process
|
||||
fn_to_trace: Function to trace and insert
|
||||
trace_args_fn: Function to create trace arguments from inputs
|
||||
insert_before: Where to insert the traced function (default: after last bucket node)
|
||||
wait_insertion_point: If provided, move all nodes from wait() onwards to before this node
|
||||
|
||||
Returns:
|
||||
replacements: Dictionary mapping old wait nodes to new output nodes
|
||||
"""
|
||||
# Collect inputs and waits from current bucket
|
||||
bucket_ins: list[torch.fx.Node] = []
|
||||
bucket_waits: list[torch.fx.Node] = []
|
||||
ag_node_to_pre_nodes: dict[torch.fx.Node, list[torch.fx.Node]] = defaultdict(list)
|
||||
|
||||
for n in bucket_nodes:
|
||||
assert len(n.users) == 1, f"Expected single user for {n}, got {n.users}"
|
||||
wait_n = next(iter(n.users))
|
||||
|
||||
# Handle convert_element_type operations (for all_gather)
|
||||
node_in = n.args[0]
|
||||
if (
|
||||
is_all_gather_into_tensor(n)
|
||||
and isinstance(node_in, torch.fx.Node) # Add type check
|
||||
and node_in.op == "call_function"
|
||||
and node_in.target == torch.ops.prims.convert_element_type.default
|
||||
and len(node_in.users) == 1
|
||||
):
|
||||
ag_node_to_pre_nodes[n].append(node_in)
|
||||
node_in = node_in.args[0]
|
||||
|
||||
assert isinstance(node_in, torch.fx.Node) # Ensure node_in is a Node
|
||||
bucket_ins.append(node_in)
|
||||
bucket_waits.append(wait_n)
|
||||
|
||||
# Create trace arguments
|
||||
trace_args = trace_args_fn(bucket_ins)
|
||||
|
||||
# Determine insertion point
|
||||
if insert_before is None:
|
||||
insert_before = bucket_nodes[-1].next
|
||||
|
||||
# Insert traced function and get replacements + new nodes
|
||||
replacements, new_nodes = _insert_fn_trace_before_node(
|
||||
g,
|
||||
fn_to_trace,
|
||||
trace_args,
|
||||
insert_before,
|
||||
bucket_ins,
|
||||
bucket_waits,
|
||||
)
|
||||
|
||||
# If requested, move wait nodes and everything after to specified location
|
||||
if wait_insertion_point is not None:
|
||||
# Find the first wait node in new_nodes
|
||||
wait_start_idx = None
|
||||
for i, node in enumerate(new_nodes):
|
||||
if is_wait_tensor(node):
|
||||
wait_start_idx = i
|
||||
break
|
||||
|
||||
# Move all nodes from wait onwards (including the wait)
|
||||
if wait_start_idx is not None:
|
||||
nodes_to_move = new_nodes[wait_start_idx:]
|
||||
for node in nodes_to_move:
|
||||
wait_insertion_point.prepend(node)
|
||||
|
||||
# Erase old nodes
|
||||
for node, wait_n in zip(bucket_nodes, bucket_waits):
|
||||
g.erase_node(wait_n)
|
||||
g.erase_node(node)
|
||||
# Erase any convert_element_type nodes we tracked
|
||||
for pre_node in reversed(ag_node_to_pre_nodes[node]):
|
||||
g.erase_node(pre_node)
|
||||
|
||||
return replacements
|
||||
|
||||
|
||||
def merge_reduce_scatter_bucket(
|
||||
g: torch.fx.Graph,
|
||||
rs_nodes: list[torch.fx.Node],
|
||||
mode: Optional[str] = None,
|
||||
wait_insertion_point: Optional[torch.fx.Node] = None,
|
||||
) -> None:
|
||||
# Validate bucket consistency
|
||||
rs0 = rs_nodes[0]
|
||||
rs0_val = rs0.meta["val"]
|
||||
_, reduce_op, group_size, group_name = rs0.args
|
||||
reduce_dtype = rs0_val.dtype
|
||||
device = rs0_val.device
|
||||
|
||||
for n in rs_nodes:
|
||||
rs_val = n.meta["val"]
|
||||
assert (
|
||||
n.args[1] == reduce_op
|
||||
and n.args[2] == group_size
|
||||
and n.args[3] == group_name
|
||||
and rs_val.device == device
|
||||
and rs_val.dtype == reduce_dtype
|
||||
)
|
||||
|
||||
# Choose merge function based on mode
|
||||
rs_merge_fn = reduce_scatter_merge_fn_to_trace
|
||||
if mode and "custom_ops" in mode:
|
||||
rs_merge_fn = reduce_scatter_merge_fn_to_trace_custom_ops
|
||||
|
||||
# Process bucket with lazy input collection
|
||||
def create_trace_args(bucket_ins: list[torch.fx.Node]) -> tuple[Any, ...]:
|
||||
return (
|
||||
pytree.tree_map(lambda node: node.meta["val"], bucket_ins),
|
||||
group_size,
|
||||
group_name,
|
||||
reduce_op,
|
||||
reduce_dtype,
|
||||
device,
|
||||
)
|
||||
|
||||
process_collective_bucket(
|
||||
g,
|
||||
rs_nodes,
|
||||
rs_merge_fn,
|
||||
create_trace_args,
|
||||
wait_insertion_point=wait_insertion_point,
|
||||
)
|
||||
|
||||
|
||||
def merge_all_gather_bucket(
|
||||
g: torch.fx.Graph,
|
||||
ag_nodes: list[torch.fx.Node],
|
||||
mode: Optional[str] = None,
|
||||
wait_insertion_point: Optional[torch.fx.Node] = None,
|
||||
) -> None:
|
||||
from torch.distributed.distributed_c10d import _resolve_process_group
|
||||
|
||||
ag0 = ag_nodes[0]
|
||||
ag0_val = ag0.meta["val"]
|
||||
_, group_size, group_name = ag0.args
|
||||
dtype = ag0_val.dtype
|
||||
assert isinstance(group_name, str)
|
||||
|
||||
for n in ag_nodes:
|
||||
assert (
|
||||
n.args[1] == group_size
|
||||
and n.args[2] == group_name
|
||||
and n.meta["val"].dtype == dtype
|
||||
)
|
||||
|
||||
# Choose merge function based on mode
|
||||
ag_merge_fn = all_gather_merge_fn_to_trace
|
||||
if mode and "custom_ops" in mode:
|
||||
ag_merge_fn = all_gather_merge_fn_to_trace_custom_ops
|
||||
|
||||
# Process bucket with lazy input collection
|
||||
rank: int = dist.get_rank(_resolve_process_group(group_name))
|
||||
|
||||
def create_trace_args(bucket_ins: list[torch.fx.Node]) -> tuple[Any, ...]:
|
||||
return (
|
||||
pytree.tree_map(lambda node: node.meta["val"], bucket_ins),
|
||||
group_size,
|
||||
group_name,
|
||||
dtype,
|
||||
rank,
|
||||
)
|
||||
|
||||
process_collective_bucket(
|
||||
g,
|
||||
ag_nodes,
|
||||
ag_merge_fn,
|
||||
create_trace_args,
|
||||
wait_insertion_point=wait_insertion_point,
|
||||
)
|
||||
|
||||
|
||||
def merge_reduce_scatter(
|
||||
@ -560,9 +757,6 @@ def merge_reduce_scatter(
|
||||
Merges specified buckets of reduce_scatter to joint reduce_scatter.
|
||||
"""
|
||||
with dynamo_timed("fx.bucketing.merge_reduce_scatter", log_pt2_compile_event=True):
|
||||
rs_merge_fn = reduce_scatter_merge_fn_to_trace
|
||||
if mode and "custom_ops" in mode:
|
||||
rs_merge_fn = reduce_scatter_merge_fn_to_trace_custom_ops
|
||||
trace_structured(
|
||||
"artifact",
|
||||
metadata_fn=lambda: {
|
||||
@ -571,90 +765,22 @@ def merge_reduce_scatter(
|
||||
},
|
||||
payload_fn=lambda: str(rs_buckets),
|
||||
)
|
||||
n_buckets = len(rs_buckets)
|
||||
|
||||
g = gm.graph
|
||||
rs_ins: list[list[torch.fx.Node]] = [[] for _ in range(n_buckets)]
|
||||
rs_waits: list[list[torch.fx.Node]] = [[] for _ in range(n_buckets)]
|
||||
|
||||
for bucket_idx, rs_nodes in enumerate(rs_buckets):
|
||||
rs0 = rs_nodes[0]
|
||||
rs0_val = rs0.meta["val"]
|
||||
_, reduce_op, group_size, group_name = rs0.args
|
||||
reduce_dtype = rs0_val.dtype
|
||||
device = rs0_val.device
|
||||
for n in rs_nodes:
|
||||
rs_val = n.meta["val"]
|
||||
assert (
|
||||
n.args[1] == reduce_op
|
||||
and n.args[2] == group_size
|
||||
and n.args[3] == group_name
|
||||
and rs_val.device == device
|
||||
and rs_val.dtype == reduce_dtype
|
||||
)
|
||||
assert len(n.users) == 1
|
||||
wait_n = next(iter(n.users))
|
||||
rs_ins[bucket_idx].append(n.args[0]) # type: ignore[arg-type]
|
||||
rs_waits[bucket_idx].append(wait_n)
|
||||
|
||||
for bucket_idx in range(n_buckets):
|
||||
_rs_ins = rs_ins[bucket_idx]
|
||||
_rs_waits = rs_waits[bucket_idx]
|
||||
_rs_ns = rs_buckets[bucket_idx]
|
||||
|
||||
rs0 = _rs_ns[0]
|
||||
rs0_val = rs0.meta["val"]
|
||||
_, reduce_op, group_size, group_name = rs0.args
|
||||
reduce_dtype = rs0_val.dtype
|
||||
device = rs0_val.device
|
||||
|
||||
replacements = _insert_fn_trace_before_node(
|
||||
g,
|
||||
rs_merge_fn,
|
||||
(
|
||||
pytree.tree_map(lambda node: node.meta["val"], _rs_ins),
|
||||
group_size,
|
||||
group_name,
|
||||
reduce_op,
|
||||
reduce_dtype,
|
||||
device,
|
||||
),
|
||||
_rs_ns[-1].next,
|
||||
_rs_ins,
|
||||
_rs_waits,
|
||||
)
|
||||
# [Note: Replacement in bucketing passes]
|
||||
# After bucketing _rs_waits will be replaced with output nodes of
|
||||
# fn_to_trace graph that will be inserted in the graph g.
|
||||
# By this time we already prepared rs_ins, rs_waits.
|
||||
# rs_ins for following buckets can be replaced _rs_waits with new nodes.
|
||||
# We apply replacements to rs_ins.
|
||||
|
||||
def _replace(x: torch.fx.Node) -> torch.fx.Node:
|
||||
return replacements.get(x, x)
|
||||
|
||||
for j in range(bucket_idx + 1, n_buckets):
|
||||
rs_ins[j] = pytree.tree_map(_replace, rs_ins[j])
|
||||
|
||||
for rs_n, wait_n in zip(_rs_ns, _rs_waits):
|
||||
g.erase_node(wait_n)
|
||||
g.erase_node(rs_n)
|
||||
for rs_nodes in rs_buckets:
|
||||
merge_reduce_scatter_bucket(g, rs_nodes, mode)
|
||||
|
||||
|
||||
def merge_all_gather(
|
||||
gm: torch.fx.GraphModule,
|
||||
ag_buckets: list[list[torch.fx.Node]],
|
||||
mode: Optional[str] = None,
|
||||
) -> None: # type: ignore[union-attr]
|
||||
) -> None:
|
||||
"""
|
||||
Merges specified buckets of all_gather to joint all_gather.
|
||||
"""
|
||||
with dynamo_timed("fx.bucketing.merge_all_gather", log_pt2_compile_event=True):
|
||||
from torch.distributed.distributed_c10d import _resolve_process_group
|
||||
|
||||
ag_merge_fn = all_gather_merge_fn_to_trace
|
||||
if mode and "custom_ops" in mode:
|
||||
ag_merge_fn = all_gather_merge_fn_to_trace_custom_ops
|
||||
|
||||
trace_structured(
|
||||
"artifact",
|
||||
metadata_fn=lambda: {
|
||||
@ -663,80 +789,8 @@ def merge_all_gather(
|
||||
},
|
||||
payload_fn=lambda: str(ag_buckets),
|
||||
)
|
||||
n_buckets = len(ag_buckets)
|
||||
|
||||
ag_node_to_pre_nodes = defaultdict(list)
|
||||
|
||||
ag_ins: list[list[torch.fx.Node]] = [[] for _ in range(n_buckets)]
|
||||
ag_waits: list[list[torch.fx.Node]] = [[] for _ in range(n_buckets)]
|
||||
for bucket_idx, ag_bucket in enumerate(ag_buckets):
|
||||
_, group_size, group_name = ag_bucket[0].args
|
||||
assert isinstance(group_name, str)
|
||||
dtype = ag_bucket[0].meta["val"].dtype
|
||||
|
||||
for ag_node in ag_bucket:
|
||||
assert len(ag_node.users) == 1, (
|
||||
f"Expect only one user for {ag_node}, but got {ag_node.users}"
|
||||
)
|
||||
wait_node = next(iter(ag_node.users))
|
||||
assert (
|
||||
ag_node.args[1] == group_size
|
||||
and ag_node.args[2] == group_name
|
||||
and ag_node.meta["val"].dtype == dtype
|
||||
)
|
||||
ag_node_in = ag_node.args[0]
|
||||
if (
|
||||
ag_node_in.op == "call_function" # type: ignore[union-attr]
|
||||
and ag_node_in.target # type: ignore[union-attr]
|
||||
== torch.ops.prims.convert_element_type.default # type: ignore[union-attr]
|
||||
and len(ag_node_in.users) == 1 # type: ignore[union-attr]
|
||||
):
|
||||
ag_node_to_pre_nodes[ag_node].append(ag_node_in)
|
||||
ag_node_in = ag_node_in.args[0] # type: ignore[union-attr]
|
||||
|
||||
ag_ins[bucket_idx].append(ag_node_in) # type: ignore[union-attr, arg-type]
|
||||
ag_waits[bucket_idx].append(wait_node)
|
||||
|
||||
g = gm.graph
|
||||
|
||||
for bucket_idx in range(n_buckets):
|
||||
_ag_ins = ag_ins[bucket_idx]
|
||||
_ag_waits = ag_waits[bucket_idx]
|
||||
_ag_ns = ag_buckets[bucket_idx]
|
||||
|
||||
ag0 = _ag_ns[0]
|
||||
ag0_val = ag0.meta["val"]
|
||||
_, group_size, group_name = ag0.args
|
||||
dtype = ag0_val.dtype
|
||||
assert isinstance(group_name, str)
|
||||
|
||||
rank: int = dist.get_rank(_resolve_process_group(group_name))
|
||||
|
||||
replacements = _insert_fn_trace_before_node(
|
||||
g,
|
||||
ag_merge_fn,
|
||||
(
|
||||
pytree.tree_map(lambda node: node.meta["val"], _ag_ins),
|
||||
group_size,
|
||||
group_name,
|
||||
dtype,
|
||||
rank,
|
||||
),
|
||||
ag0.next,
|
||||
_ag_ins,
|
||||
_ag_waits,
|
||||
)
|
||||
|
||||
# See Note: [Replacement in bucketing passes]
|
||||
def _replace(x: torch.fx.Node) -> torch.fx.Node:
|
||||
return replacements.get(x, x)
|
||||
|
||||
for j in range(bucket_idx + 1, n_buckets):
|
||||
ag_ins[j] = pytree.tree_map(_replace, ag_ins[j])
|
||||
|
||||
# Erasing old nodes in reverse order
|
||||
for ag_n, wait_n in zip(ag_buckets[bucket_idx], _ag_waits):
|
||||
g.erase_node(wait_n)
|
||||
g.erase_node(ag_n)
|
||||
for n in reversed(ag_node_to_pre_nodes[ag_n]):
|
||||
g.erase_node(n) # type: ignore[arg-type]
|
||||
for ag_nodes in ag_buckets:
|
||||
merge_all_gather_bucket(g, ag_nodes, mode)
|
||||
|
Reference in New Issue
Block a user