Revert "refactor bucketing (#163754)"

This reverts commit e1bd5b60cf243d3a026a6c89733488a6d9d4b33d.

Reverted https://github.com/pytorch/pytorch/pull/163754 on behalf of https://github.com/yangw-dev due to seems fails inductor/test_aten_comm_compute_reordering for macos test, see c9b5af9a38 (51526707590-box) ([comment](https://github.com/pytorch/pytorch/pull/163215#issuecomment-3349177940))
This commit is contained in:
PyTorch MergeBot
2025-09-29 21:53:42 +00:00
parent 84dc54ae5e
commit b28e4f1f87

View File

@ -18,22 +18,6 @@ logger: logging.Logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
# Helper functions moved to top for better organization
def _ag_group_key(node: torch.fx.Node) -> tuple[str, torch.dtype]:
_, group_size, group_name = node.args
dtype = node.meta["val"].dtype
assert isinstance(group_name, str)
return (group_name, dtype)
def _rs_group_key(node: torch.fx.Node) -> tuple[str, str, torch.dtype]:
_, reduce_op, group_size, group_name = node.args
dtype = node.meta["val"].dtype
assert isinstance(group_name, str)
assert isinstance(reduce_op, str)
return (group_name, reduce_op, dtype)
def bucket_cap_mb_by_bucket_idx_default(bucket_id: int) -> float:
"""
Determine the size of a bucket based on its ID.
@ -245,6 +229,12 @@ def bucket_all_gather_by_mb(
list[list[torch.fx.Node]]: List of buckets, where each bucket is a list of all_gather nodes.
"""
def _ag_group_key(node: torch.fx.Node) -> tuple[str, torch.dtype]:
_, group_size, group_name = node.args
dtype = node.meta["val"].dtype
assert isinstance(group_name, str)
return (group_name, dtype)
return greedy_bucket_collective_by_mb(
gm,
bucket_cap_mb_by_bucket_idx,
@ -275,6 +265,13 @@ def bucket_reduce_scatter_by_mb(
list[list[torch.fx.Node]]: List of buckets, where each bucket is a list of reduce_scatter nodes.
"""
def _rs_group_key(node: torch.fx.Node) -> tuple[str, str, torch.dtype]:
_, reduce_op, group_size, group_name = node.args
dtype = node.meta["val"].dtype
assert isinstance(group_name, str)
assert isinstance(reduce_op, str)
return (group_name, reduce_op, dtype)
return greedy_bucket_collective_by_mb(
gm,
bucket_cap_mb_by_bucket_idx,
@ -518,16 +515,13 @@ def _insert_fn_trace_before_node( # type: ignore[no-untyped-def]
insert_before_node: torch.fx.Node,
g_fn_inps: list[torch.fx.Node],
g_fn_outs: list[torch.fx.Node],
) -> tuple[dict[torch.fx.Node, torch.fx.Node], list[torch.fx.Node]]: # type: ignore[no-untyped-def]
) -> dict[torch.fx.Node, torch.fx.Node]: # type: ignore[no-untyped-def]
"""
Helper function that traces :attr:`fn_to_trace` with inputs
:attr:`inps`.
The result function graph will be inserted before :attr:`insert_before_node`,
using :attr:`g_fn_inps` nodes of original graph as inputs of function graph,
using :attr:`g_fn_inps` nodes of original graphas inputs of function graph,
function graph outputs will replace :attr:`g_fn_outs` in original graph.
Returns:
(replacements, new_nodes): Dictionary mapping old to new nodes, and list of all newly inserted nodes
"""
with dynamo_timed(
"fx.bucketing._insert_fn_trace_before_node", log_pt2_compile_event=True
@ -540,8 +534,6 @@ def _insert_fn_trace_before_node( # type: ignore[no-untyped-def]
fn_g_ins = fn_g.find_nodes(op="placeholder")
env = {fn_g_ins[idx]: g_fn_inps[idx] for idx in range(len(g_fn_inps))}
g_fn_new_outs: list[torch.fx.Node] = []
new_nodes: list[torch.fx.Node] = [] # Track all newly inserted nodes
with g.inserting_before(insert_before_node):
for _n in fn_g.nodes:
if _n.op == "placeholder":
@ -551,203 +543,14 @@ def _insert_fn_trace_before_node( # type: ignore[no-untyped-def]
if _n.op == "output":
g_fn_new_outs = _new_n.args[0] # type: ignore[assignment]
g.erase_node(_new_n)
else:
new_nodes.append(_new_n) # Track non-output nodes
replacements = { # noqa: C416
orig_out: new_out for orig_out, new_out in zip(g_fn_outs, g_fn_new_outs)
}
for orig_out, new_out in zip(g_fn_outs, g_fn_new_outs):
orig_out.replace_all_uses_with(new_out)
return replacements, new_nodes
def process_collective_bucket(
g: torch.fx.Graph,
bucket_nodes: list[torch.fx.Node],
fn_to_trace: Callable[..., list[torch.Tensor]],
trace_args_fn: Callable[[list[torch.fx.Node]], tuple[Any, ...]],
insert_before: Optional[torch.fx.Node] = None,
wait_insertion_point: Optional[torch.fx.Node] = None,
) -> dict[torch.fx.Node, torch.fx.Node]:
"""
Process a single bucket of collective operation nodes with flexible insertion control.
Args:
g: The graph to modify
bucket_nodes: Nodes in the current bucket to process
fn_to_trace: Function to trace and insert
trace_args_fn: Function to create trace arguments from inputs
insert_before: Where to insert the traced function (default: after last bucket node)
wait_insertion_point: If provided, move all nodes from wait() onwards to before this node
Returns:
replacements: Dictionary mapping old wait nodes to new output nodes
"""
# Collect inputs and waits from current bucket
bucket_ins: list[torch.fx.Node] = []
bucket_waits: list[torch.fx.Node] = []
ag_node_to_pre_nodes: dict[torch.fx.Node, list[torch.fx.Node]] = defaultdict(list)
for n in bucket_nodes:
assert len(n.users) == 1, f"Expected single user for {n}, got {n.users}"
wait_n = next(iter(n.users))
# Handle convert_element_type operations (for all_gather)
node_in = n.args[0]
if (
is_all_gather_into_tensor(n)
and isinstance(node_in, torch.fx.Node) # Add type check
and node_in.op == "call_function"
and node_in.target == torch.ops.prims.convert_element_type.default
and len(node_in.users) == 1
):
ag_node_to_pre_nodes[n].append(node_in)
node_in = node_in.args[0]
assert isinstance(node_in, torch.fx.Node) # Ensure node_in is a Node
bucket_ins.append(node_in)
bucket_waits.append(wait_n)
# Create trace arguments
trace_args = trace_args_fn(bucket_ins)
# Determine insertion point
if insert_before is None:
insert_before = bucket_nodes[-1].next
# Insert traced function and get replacements + new nodes
replacements, new_nodes = _insert_fn_trace_before_node(
g,
fn_to_trace,
trace_args,
insert_before,
bucket_ins,
bucket_waits,
)
# If requested, move wait nodes and everything after to specified location
if wait_insertion_point is not None:
# Find the first wait node in new_nodes
wait_start_idx = None
for i, node in enumerate(new_nodes):
if is_wait_tensor(node):
wait_start_idx = i
break
# Move all nodes from wait onwards (including the wait)
if wait_start_idx is not None:
nodes_to_move = new_nodes[wait_start_idx:]
for node in nodes_to_move:
wait_insertion_point.prepend(node)
# Erase old nodes
for node, wait_n in zip(bucket_nodes, bucket_waits):
g.erase_node(wait_n)
g.erase_node(node)
# Erase any convert_element_type nodes we tracked
for pre_node in reversed(ag_node_to_pre_nodes[node]):
g.erase_node(pre_node)
return replacements
def merge_reduce_scatter_bucket(
g: torch.fx.Graph,
rs_nodes: list[torch.fx.Node],
mode: Optional[str] = None,
wait_insertion_point: Optional[torch.fx.Node] = None,
) -> None:
# Validate bucket consistency
rs0 = rs_nodes[0]
rs0_val = rs0.meta["val"]
_, reduce_op, group_size, group_name = rs0.args
reduce_dtype = rs0_val.dtype
device = rs0_val.device
for n in rs_nodes:
rs_val = n.meta["val"]
assert (
n.args[1] == reduce_op
and n.args[2] == group_size
and n.args[3] == group_name
and rs_val.device == device
and rs_val.dtype == reduce_dtype
)
# Choose merge function based on mode
rs_merge_fn = reduce_scatter_merge_fn_to_trace
if mode and "custom_ops" in mode:
rs_merge_fn = reduce_scatter_merge_fn_to_trace_custom_ops
# Process bucket with lazy input collection
def create_trace_args(bucket_ins: list[torch.fx.Node]) -> tuple[Any, ...]:
return (
pytree.tree_map(lambda node: node.meta["val"], bucket_ins),
group_size,
group_name,
reduce_op,
reduce_dtype,
device,
)
process_collective_bucket(
g,
rs_nodes,
rs_merge_fn,
create_trace_args,
wait_insertion_point=wait_insertion_point,
)
def merge_all_gather_bucket(
g: torch.fx.Graph,
ag_nodes: list[torch.fx.Node],
mode: Optional[str] = None,
wait_insertion_point: Optional[torch.fx.Node] = None,
) -> None:
from torch.distributed.distributed_c10d import _resolve_process_group
ag0 = ag_nodes[0]
ag0_val = ag0.meta["val"]
_, group_size, group_name = ag0.args
dtype = ag0_val.dtype
assert isinstance(group_name, str)
for n in ag_nodes:
assert (
n.args[1] == group_size
and n.args[2] == group_name
and n.meta["val"].dtype == dtype
)
# Choose merge function based on mode
ag_merge_fn = all_gather_merge_fn_to_trace
if mode and "custom_ops" in mode:
ag_merge_fn = all_gather_merge_fn_to_trace_custom_ops
# Process bucket with lazy input collection
rank: int = dist.get_rank(_resolve_process_group(group_name))
def create_trace_args(bucket_ins: list[torch.fx.Node]) -> tuple[Any, ...]:
return (
pytree.tree_map(lambda node: node.meta["val"], bucket_ins),
group_size,
group_name,
dtype,
rank,
)
process_collective_bucket(
g,
ag_nodes,
ag_merge_fn,
create_trace_args,
wait_insertion_point=wait_insertion_point,
)
def merge_reduce_scatter(
gm: torch.fx.GraphModule,
rs_buckets: list[list[torch.fx.Node]],
@ -757,6 +560,9 @@ def merge_reduce_scatter(
Merges specified buckets of reduce_scatter to joint reduce_scatter.
"""
with dynamo_timed("fx.bucketing.merge_reduce_scatter", log_pt2_compile_event=True):
rs_merge_fn = reduce_scatter_merge_fn_to_trace
if mode and "custom_ops" in mode:
rs_merge_fn = reduce_scatter_merge_fn_to_trace_custom_ops
trace_structured(
"artifact",
metadata_fn=lambda: {
@ -765,22 +571,90 @@ def merge_reduce_scatter(
},
payload_fn=lambda: str(rs_buckets),
)
n_buckets = len(rs_buckets)
g = gm.graph
rs_ins: list[list[torch.fx.Node]] = [[] for _ in range(n_buckets)]
rs_waits: list[list[torch.fx.Node]] = [[] for _ in range(n_buckets)]
for rs_nodes in rs_buckets:
merge_reduce_scatter_bucket(g, rs_nodes, mode)
for bucket_idx, rs_nodes in enumerate(rs_buckets):
rs0 = rs_nodes[0]
rs0_val = rs0.meta["val"]
_, reduce_op, group_size, group_name = rs0.args
reduce_dtype = rs0_val.dtype
device = rs0_val.device
for n in rs_nodes:
rs_val = n.meta["val"]
assert (
n.args[1] == reduce_op
and n.args[2] == group_size
and n.args[3] == group_name
and rs_val.device == device
and rs_val.dtype == reduce_dtype
)
assert len(n.users) == 1
wait_n = next(iter(n.users))
rs_ins[bucket_idx].append(n.args[0]) # type: ignore[arg-type]
rs_waits[bucket_idx].append(wait_n)
for bucket_idx in range(n_buckets):
_rs_ins = rs_ins[bucket_idx]
_rs_waits = rs_waits[bucket_idx]
_rs_ns = rs_buckets[bucket_idx]
rs0 = _rs_ns[0]
rs0_val = rs0.meta["val"]
_, reduce_op, group_size, group_name = rs0.args
reduce_dtype = rs0_val.dtype
device = rs0_val.device
replacements = _insert_fn_trace_before_node(
g,
rs_merge_fn,
(
pytree.tree_map(lambda node: node.meta["val"], _rs_ins),
group_size,
group_name,
reduce_op,
reduce_dtype,
device,
),
_rs_ns[-1].next,
_rs_ins,
_rs_waits,
)
# [Note: Replacement in bucketing passes]
# After bucketing _rs_waits will be replaced with output nodes of
# fn_to_trace graph that will be inserted in the graph g.
# By this time we already prepared rs_ins, rs_waits.
# rs_ins for following buckets can be replaced _rs_waits with new nodes.
# We apply replacements to rs_ins.
def _replace(x: torch.fx.Node) -> torch.fx.Node:
return replacements.get(x, x)
for j in range(bucket_idx + 1, n_buckets):
rs_ins[j] = pytree.tree_map(_replace, rs_ins[j])
for rs_n, wait_n in zip(_rs_ns, _rs_waits):
g.erase_node(wait_n)
g.erase_node(rs_n)
def merge_all_gather(
gm: torch.fx.GraphModule,
ag_buckets: list[list[torch.fx.Node]],
mode: Optional[str] = None,
) -> None:
) -> None: # type: ignore[union-attr]
"""
Merges specified buckets of all_gather to joint all_gather.
"""
with dynamo_timed("fx.bucketing.merge_all_gather", log_pt2_compile_event=True):
from torch.distributed.distributed_c10d import _resolve_process_group
ag_merge_fn = all_gather_merge_fn_to_trace
if mode and "custom_ops" in mode:
ag_merge_fn = all_gather_merge_fn_to_trace_custom_ops
trace_structured(
"artifact",
metadata_fn=lambda: {
@ -789,8 +663,80 @@ def merge_all_gather(
},
payload_fn=lambda: str(ag_buckets),
)
n_buckets = len(ag_buckets)
ag_node_to_pre_nodes = defaultdict(list)
ag_ins: list[list[torch.fx.Node]] = [[] for _ in range(n_buckets)]
ag_waits: list[list[torch.fx.Node]] = [[] for _ in range(n_buckets)]
for bucket_idx, ag_bucket in enumerate(ag_buckets):
_, group_size, group_name = ag_bucket[0].args
assert isinstance(group_name, str)
dtype = ag_bucket[0].meta["val"].dtype
for ag_node in ag_bucket:
assert len(ag_node.users) == 1, (
f"Expect only one user for {ag_node}, but got {ag_node.users}"
)
wait_node = next(iter(ag_node.users))
assert (
ag_node.args[1] == group_size
and ag_node.args[2] == group_name
and ag_node.meta["val"].dtype == dtype
)
ag_node_in = ag_node.args[0]
if (
ag_node_in.op == "call_function" # type: ignore[union-attr]
and ag_node_in.target # type: ignore[union-attr]
== torch.ops.prims.convert_element_type.default # type: ignore[union-attr]
and len(ag_node_in.users) == 1 # type: ignore[union-attr]
):
ag_node_to_pre_nodes[ag_node].append(ag_node_in)
ag_node_in = ag_node_in.args[0] # type: ignore[union-attr]
ag_ins[bucket_idx].append(ag_node_in) # type: ignore[union-attr, arg-type]
ag_waits[bucket_idx].append(wait_node)
g = gm.graph
for ag_nodes in ag_buckets:
merge_all_gather_bucket(g, ag_nodes, mode)
for bucket_idx in range(n_buckets):
_ag_ins = ag_ins[bucket_idx]
_ag_waits = ag_waits[bucket_idx]
_ag_ns = ag_buckets[bucket_idx]
ag0 = _ag_ns[0]
ag0_val = ag0.meta["val"]
_, group_size, group_name = ag0.args
dtype = ag0_val.dtype
assert isinstance(group_name, str)
rank: int = dist.get_rank(_resolve_process_group(group_name))
replacements = _insert_fn_trace_before_node(
g,
ag_merge_fn,
(
pytree.tree_map(lambda node: node.meta["val"], _ag_ins),
group_size,
group_name,
dtype,
rank,
),
ag0.next,
_ag_ins,
_ag_waits,
)
# See Note: [Replacement in bucketing passes]
def _replace(x: torch.fx.Node) -> torch.fx.Node:
return replacements.get(x, x)
for j in range(bucket_idx + 1, n_buckets):
ag_ins[j] = pytree.tree_map(_replace, ag_ins[j])
# Erasing old nodes in reverse order
for ag_n, wait_n in zip(ag_buckets[bucket_idx], _ag_waits):
g.erase_node(wait_n)
g.erase_node(ag_n)
for n in reversed(ag_node_to_pre_nodes[ag_n]):
g.erase_node(n) # type: ignore[arg-type]