[bucketing] allow convert_element_type after fsdp reduce_scatter (#161159)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/161159
Approved by: https://github.com/eellison
This commit is contained in:
IvanKobzarev
2025-08-21 04:22:11 -07:00
committed by PyTorch MergeBot
parent c4670e40c9
commit 595987d28d
3 changed files with 231 additions and 203 deletions

View File

@ -7,6 +7,7 @@ import torch.distributed as dist
import torch.utils._pytree as pytree
from torch._dispatch.python import enable_python_dispatcher
from torch._dynamo.utils import detect_fake_mode
from torch._inductor.runtime.runtime_utils import dynamo_timed
from torch._logging import trace_structured
from torch.fx.experimental.proxy_tensor import make_fx
from torch.utils._ordered_set import OrderedSet
@ -362,16 +363,17 @@ def all_gather_merge_fn_to_trace_functional(
def _trace(fn, inps) -> torch.fx.GraphModule: # type: ignore[no-untyped-def]
fake_mode = detect_fake_mode(inps)
assert fake_mode is not None
with fake_mode, enable_python_dispatcher():
out = make_fx(fn)(*inps)
for node in out.graph.find_nodes(
op="call_function", target=torch.ops.aten.detach.default
):
node.replace_all_uses_with(node.args[0])
out.graph.erase_node(node)
return out
with dynamo_timed("fx.bucketing._trace", log_pt2_compile_event=True):
fake_mode = detect_fake_mode(inps)
assert fake_mode is not None
with fake_mode, enable_python_dispatcher():
out = make_fx(fn)(*inps)
for node in out.graph.find_nodes(
op="call_function", target=torch.ops.aten.detach.default
):
node.replace_all_uses_with(node.args[0])
out.graph.erase_node(node)
return out
def _insert_fn_trace_before_node( # type: ignore[no-untyped-def]
@ -389,109 +391,113 @@ def _insert_fn_trace_before_node( # type: ignore[no-untyped-def]
using :attr:`g_fn_inps` nodes of original graphas inputs of function graph,
function graph outputs will replace :attr:`g_fn_outs` in original graph.
"""
fn_gm = _trace(
fn_to_trace,
inps,
)
fn_g = fn_gm.graph
fn_g_ins = fn_g.find_nodes(op="placeholder")
env = {fn_g_ins[idx]: g_fn_inps[idx] for idx in range(len(g_fn_inps))}
g_fn_new_outs: list[torch.fx.Node] = []
with g.inserting_before(insert_before_node):
for _n in fn_g.nodes:
if _n.op == "placeholder":
continue
_new_n = g.node_copy(_n, lambda x: env[x])
env[_n] = _new_n
if _n.op == "output":
g_fn_new_outs = _new_n.args[0] # type: ignore[assignment]
g.erase_node(_new_n)
replacements = { # noqa: C416
orig_out: new_out for orig_out, new_out in zip(g_fn_outs, g_fn_new_outs)
}
for orig_out, new_out in zip(g_fn_outs, g_fn_new_outs):
orig_out.replace_all_uses_with(new_out)
return replacements
with dynamo_timed(
"fx.bucketing._insert_fn_trace_before_node", log_pt2_compile_event=True
):
fn_gm = _trace(
fn_to_trace,
inps,
)
fn_g = fn_gm.graph
fn_g_ins = fn_g.find_nodes(op="placeholder")
env = {fn_g_ins[idx]: g_fn_inps[idx] for idx in range(len(g_fn_inps))}
g_fn_new_outs: list[torch.fx.Node] = []
with g.inserting_before(insert_before_node):
for _n in fn_g.nodes:
if _n.op == "placeholder":
continue
_new_n = g.node_copy(_n, lambda x: env[x])
env[_n] = _new_n
if _n.op == "output":
g_fn_new_outs = _new_n.args[0] # type: ignore[assignment]
g.erase_node(_new_n)
replacements = { # noqa: C416
orig_out: new_out for orig_out, new_out in zip(g_fn_outs, g_fn_new_outs)
}
for orig_out, new_out in zip(g_fn_outs, g_fn_new_outs):
orig_out.replace_all_uses_with(new_out)
return replacements
def merge_reduce_scatter(
gm: torch.fx.GraphModule, rs_buckets: list[list[torch.fx.Node]]
) -> None:
trace_structured(
"artifact",
metadata_fn=lambda: {
"name": "fx_bucketing_passes_reduce_scatter_buckets",
"encoding": "string",
},
payload_fn=lambda: str(rs_buckets),
)
n_buckets = len(rs_buckets)
g = gm.graph
rs_ins: list[list[torch.fx.Node]] = [[] for _ in range(n_buckets)]
rs_waits: list[list[torch.fx.Node]] = [[] for _ in range(n_buckets)]
for bucket_idx, rs_nodes in enumerate(rs_buckets):
rs0 = rs_nodes[0]
rs0_val = rs0.meta["val"]
_, reduce_op, group_size, group_name = rs0.args
reduce_dtype = rs0_val.dtype
device = rs0_val.device
for n in rs_nodes:
rs_val = n.meta["val"]
assert (
n.args[1] == reduce_op
and n.args[2] == group_size
and n.args[3] == group_name
and rs_val.device == device
and rs_val.dtype == reduce_dtype
)
assert len(n.users) == 1
wait_n = next(iter(n.users))
rs_ins[bucket_idx].append(n.args[0]) # type: ignore[arg-type]
rs_waits[bucket_idx].append(wait_n)
for bucket_idx in range(n_buckets):
_rs_ins = rs_ins[bucket_idx]
_rs_waits = rs_waits[bucket_idx]
_rs_ns = rs_buckets[bucket_idx]
rs0 = _rs_ns[0]
rs0_val = rs0.meta["val"]
_, reduce_op, group_size, group_name = rs0.args
reduce_dtype = rs0_val.dtype
device = rs0_val.device
replacements = _insert_fn_trace_before_node(
g,
reduce_scatter_merge_fn_to_trace,
(
pytree.tree_map(lambda node: node.meta["val"], _rs_ins),
group_size,
group_name,
reduce_op,
reduce_dtype,
device,
),
_rs_ns[-1].next,
_rs_ins,
_rs_waits,
with dynamo_timed("fx.bucketing.merge_reduce_scatter", log_pt2_compile_event=True):
trace_structured(
"artifact",
metadata_fn=lambda: {
"name": "fx_bucketing_passes_reduce_scatter_buckets",
"encoding": "string",
},
payload_fn=lambda: str(rs_buckets),
)
# [Note: Replacement in bucketing passes]
# After bucketing _rs_waits will be replaced with output nodes of
# fn_to_trace graph that will be inserted in the graph g.
# By this time we already prepared rs_ins, rs_waits.
# rs_ins for following buckets can be replaced _rs_waits with new nodes.
# We apply replacements to rs_ins.
n_buckets = len(rs_buckets)
g = gm.graph
rs_ins: list[list[torch.fx.Node]] = [[] for _ in range(n_buckets)]
rs_waits: list[list[torch.fx.Node]] = [[] for _ in range(n_buckets)]
def _replace(x: torch.fx.Node) -> torch.fx.Node:
return replacements.get(x, x)
for bucket_idx, rs_nodes in enumerate(rs_buckets):
rs0 = rs_nodes[0]
rs0_val = rs0.meta["val"]
_, reduce_op, group_size, group_name = rs0.args
reduce_dtype = rs0_val.dtype
device = rs0_val.device
for n in rs_nodes:
rs_val = n.meta["val"]
assert (
n.args[1] == reduce_op
and n.args[2] == group_size
and n.args[3] == group_name
and rs_val.device == device
and rs_val.dtype == reduce_dtype
)
assert len(n.users) == 1
wait_n = next(iter(n.users))
rs_ins[bucket_idx].append(n.args[0]) # type: ignore[arg-type]
rs_waits[bucket_idx].append(wait_n)
for j in range(bucket_idx + 1, n_buckets):
rs_ins[j] = pytree.tree_map(_replace, rs_ins[j])
for bucket_idx in range(n_buckets):
_rs_ins = rs_ins[bucket_idx]
_rs_waits = rs_waits[bucket_idx]
_rs_ns = rs_buckets[bucket_idx]
for rs_n, wait_n in zip(_rs_ns, _rs_waits):
g.erase_node(wait_n)
g.erase_node(rs_n)
rs0 = _rs_ns[0]
rs0_val = rs0.meta["val"]
_, reduce_op, group_size, group_name = rs0.args
reduce_dtype = rs0_val.dtype
device = rs0_val.device
replacements = _insert_fn_trace_before_node(
g,
reduce_scatter_merge_fn_to_trace,
(
pytree.tree_map(lambda node: node.meta["val"], _rs_ins),
group_size,
group_name,
reduce_op,
reduce_dtype,
device,
),
_rs_ns[-1].next,
_rs_ins,
_rs_waits,
)
# [Note: Replacement in bucketing passes]
# After bucketing _rs_waits will be replaced with output nodes of
# fn_to_trace graph that will be inserted in the graph g.
# By this time we already prepared rs_ins, rs_waits.
# rs_ins for following buckets can be replaced _rs_waits with new nodes.
# We apply replacements to rs_ins.
def _replace(x: torch.fx.Node) -> torch.fx.Node:
return replacements.get(x, x)
for j in range(bucket_idx + 1, n_buckets):
rs_ins[j] = pytree.tree_map(_replace, rs_ins[j])
for rs_n, wait_n in zip(_rs_ns, _rs_waits):
g.erase_node(wait_n)
g.erase_node(rs_n)
def merge_all_gather(
@ -500,78 +506,79 @@ def merge_all_gather(
"""
Merges specified buckets of all_gather to joint all_gather.
"""
from torch.distributed.distributed_c10d import _resolve_process_group
with dynamo_timed("fx.bucketing.merge_all_gather", log_pt2_compile_event=True):
from torch.distributed.distributed_c10d import _resolve_process_group
trace_structured(
"artifact",
metadata_fn=lambda: {
"name": "fx_bucketing_passes_all_gather_buckets",
"encoding": "string",
},
payload_fn=lambda: str(ag_buckets),
)
n_buckets = len(ag_buckets)
ag_ins: list[list[torch.fx.Node]] = [[] for _ in range(n_buckets)]
ag_waits: list[list[torch.fx.Node]] = [[] for _ in range(n_buckets)]
for bucket_idx, ag_bucket in enumerate(ag_buckets):
_, group_size, group_name = ag_bucket[0].args
assert isinstance(group_name, str)
dtype = ag_bucket[0].meta["val"].dtype
for ag_node in ag_bucket:
assert len(ag_node.users) == 1, (
f"Expect only one user for {ag_node}, but got {ag_node.users}"
)
wait_node = next(iter(ag_node.users))
assert (
ag_node.args[1] == group_size
and ag_node.args[2] == group_name
and ag_node.meta["val"].dtype == dtype
)
ag_node_in = ag_node.args[0]
ag_ins[bucket_idx].append(ag_node_in) # type: ignore[union-attr, arg-type]
ag_waits[bucket_idx].append(wait_node)
g = gm.graph
for bucket_idx in range(n_buckets):
_ag_ins = ag_ins[bucket_idx]
_ag_waits = ag_waits[bucket_idx]
_ag_ns = ag_buckets[bucket_idx]
ag0 = _ag_ns[0]
ag0_val = ag0.meta["val"]
_, group_size, group_name = ag0.args
dtype = ag0_val.dtype
assert isinstance(group_name, str)
rank: int = dist.get_rank(_resolve_process_group(group_name))
replacements = _insert_fn_trace_before_node(
g,
all_gather_merge_fn_to_trace,
(
pytree.tree_map(lambda node: node.meta["val"], _ag_ins),
group_size,
group_name,
dtype,
rank,
),
ag0.next,
_ag_ins,
_ag_waits,
trace_structured(
"artifact",
metadata_fn=lambda: {
"name": "fx_bucketing_passes_all_gather_buckets",
"encoding": "string",
},
payload_fn=lambda: str(ag_buckets),
)
n_buckets = len(ag_buckets)
# See Note: [Replacement in bucketing passes]
def _replace(x: torch.fx.Node) -> torch.fx.Node:
return replacements.get(x, x)
ag_ins: list[list[torch.fx.Node]] = [[] for _ in range(n_buckets)]
ag_waits: list[list[torch.fx.Node]] = [[] for _ in range(n_buckets)]
for bucket_idx, ag_bucket in enumerate(ag_buckets):
_, group_size, group_name = ag_bucket[0].args
assert isinstance(group_name, str)
dtype = ag_bucket[0].meta["val"].dtype
for j in range(bucket_idx + 1, n_buckets):
ag_ins[j] = pytree.tree_map(_replace, ag_ins[j])
for ag_node in ag_bucket:
assert len(ag_node.users) == 1, (
f"Expect only one user for {ag_node}, but got {ag_node.users}"
)
wait_node = next(iter(ag_node.users))
assert (
ag_node.args[1] == group_size
and ag_node.args[2] == group_name
and ag_node.meta["val"].dtype == dtype
)
ag_node_in = ag_node.args[0]
# Erasing old nodes in reverse order
for ag_n, wait_n in zip(ag_buckets[bucket_idx], _ag_waits):
g.erase_node(wait_n)
g.erase_node(ag_n)
ag_ins[bucket_idx].append(ag_node_in) # type: ignore[union-attr, arg-type]
ag_waits[bucket_idx].append(wait_node)
g = gm.graph
for bucket_idx in range(n_buckets):
_ag_ins = ag_ins[bucket_idx]
_ag_waits = ag_waits[bucket_idx]
_ag_ns = ag_buckets[bucket_idx]
ag0 = _ag_ns[0]
ag0_val = ag0.meta["val"]
_, group_size, group_name = ag0.args
dtype = ag0_val.dtype
assert isinstance(group_name, str)
rank: int = dist.get_rank(_resolve_process_group(group_name))
replacements = _insert_fn_trace_before_node(
g,
all_gather_merge_fn_to_trace,
(
pytree.tree_map(lambda node: node.meta["val"], _ag_ins),
group_size,
group_name,
dtype,
rank,
),
ag0.next,
_ag_ins,
_ag_waits,
)
# See Note: [Replacement in bucketing passes]
def _replace(x: torch.fx.Node) -> torch.fx.Node:
return replacements.get(x, x)
for j in range(bucket_idx + 1, n_buckets):
ag_ins[j] = pytree.tree_map(_replace, ag_ins[j])
# Erasing old nodes in reverse order
for ag_n, wait_n in zip(ag_buckets[bucket_idx], _ag_waits):
g.erase_node(wait_n)
g.erase_node(ag_n)