mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[bucketing] allow convert_element_type after fsdp reduce_scatter (#161159)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/161159 Approved by: https://github.com/eellison
This commit is contained in:
committed by
PyTorch MergeBot
parent
c4670e40c9
commit
595987d28d
@ -1642,7 +1642,7 @@ class TestCollectivesInductor(DynamoDistributedSingleProcTestCase):
|
|||||||
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
||||||
@unittest.skipIf(not SM80OrLater, "bfloat16")
|
@unittest.skipIf(not SM80OrLater, "bfloat16")
|
||||||
def test_reduce_scatter_bucket(self):
|
def test_reduce_scatter_bucket(self):
|
||||||
def func(x, w, rs_0, rs_1, *, tag, ranks, group_size):
|
def func(x, w, rs_0, rs_1, tag, ranks, group_size):
|
||||||
# do some unrelated matmuls
|
# do some unrelated matmuls
|
||||||
y = torch.mm(x, w)
|
y = torch.mm(x, w)
|
||||||
|
|
||||||
@ -1667,35 +1667,44 @@ class TestCollectivesInductor(DynamoDistributedSingleProcTestCase):
|
|||||||
|
|
||||||
return y, rs_0_out, rs_1_out
|
return y, rs_0_out, rs_1_out
|
||||||
|
|
||||||
x = torch.ones(4, 384, device="cuda", dtype=torch.float32)
|
# test "fsdp" mode to allow convert_element_type after wait
|
||||||
w = torch.ones(384, 512, device="cuda", dtype=torch.float32)
|
def func2(x, w, rs_0, rs_1, tag, ranks, group_size):
|
||||||
rs_0 = torch.ones(384, 512, device="cuda", dtype=torch.float32)
|
y, rs_0_out, rs_1_out = func(x, w, rs_0, rs_1, tag, ranks, group_size)
|
||||||
rs_1 = torch.ones(384, 256, device="cuda", dtype=torch.float32)
|
return y, rs_0_out.to(torch.float32), rs_1_out.to(torch.float32)
|
||||||
inputs = [x, w, rs_0, rs_1]
|
|
||||||
func(*inputs, **self.get_world_trs())
|
|
||||||
|
|
||||||
with torch._inductor.config.patch(
|
for f in [func, func2]:
|
||||||
{
|
x = torch.ones(4, 384, device="cuda", dtype=torch.float32)
|
||||||
"bucket_reduce_scatters_fx": "all",
|
w = torch.ones(384, 512, device="cuda", dtype=torch.float32)
|
||||||
"reorder_for_compute_comm_overlap": False,
|
rs_0 = torch.ones(384, 512, device="cuda", dtype=torch.float32)
|
||||||
}
|
rs_1 = torch.ones(384, 256, device="cuda", dtype=torch.float32)
|
||||||
):
|
inputs = [x, w, rs_0, rs_1]
|
||||||
compiled = torch.compile(func)
|
f(*inputs, **self.get_world_trs())
|
||||||
code = run_and_get_triton_code(compiled, *inputs, **self.get_world_trs())
|
|
||||||
# NOTE: The first return value should be the output of the first wait_tensor.
|
with torch._inductor.config.patch(
|
||||||
# We want to make sure no unnecessary copy is made.
|
{
|
||||||
(
|
"bucket_reduce_scatters_fx": "fsdp",
|
||||||
FileCheck()
|
"reorder_for_compute_comm_overlap": False,
|
||||||
.check_count(
|
}
|
||||||
"torch.ops._c10d_functional.reduce_scatter_tensor.default(",
|
):
|
||||||
count=1,
|
compiled = torch.compile(f)
|
||||||
exactly=True,
|
compiled(*inputs, **self.get_world_trs())
|
||||||
|
code = run_and_get_triton_code(
|
||||||
|
compiled, *inputs, **self.get_world_trs()
|
||||||
|
)
|
||||||
|
# NOTE: The first return value should be the output of the first wait_tensor.
|
||||||
|
# We want to make sure no unnecessary copy is made.
|
||||||
|
(
|
||||||
|
FileCheck()
|
||||||
|
.check_count(
|
||||||
|
"torch.ops._c10d_functional.reduce_scatter_tensor.default(",
|
||||||
|
count=1,
|
||||||
|
exactly=True,
|
||||||
|
)
|
||||||
|
.run(code)
|
||||||
)
|
)
|
||||||
.run(code)
|
out = compiled(*inputs, **self.get_world_trs())
|
||||||
)
|
correct = f(*inputs, **self.get_world_trs())
|
||||||
out = compiled(*inputs, **self.get_world_trs())
|
assert same(out, correct), f"{out} va {correct}"
|
||||||
correct = func(*inputs, **self.get_world_trs())
|
|
||||||
assert same(out, correct), f"{out} va {correct}"
|
|
||||||
|
|
||||||
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
||||||
@unittest.skipIf(not SM80OrLater, "bfloat16")
|
@unittest.skipIf(not SM80OrLater, "bfloat16")
|
||||||
|
|||||||
@ -7,6 +7,7 @@ import torch.distributed as dist
|
|||||||
import torch.utils._pytree as pytree
|
import torch.utils._pytree as pytree
|
||||||
from torch._dispatch.python import enable_python_dispatcher
|
from torch._dispatch.python import enable_python_dispatcher
|
||||||
from torch._dynamo.utils import detect_fake_mode
|
from torch._dynamo.utils import detect_fake_mode
|
||||||
|
from torch._inductor.runtime.runtime_utils import dynamo_timed
|
||||||
from torch._logging import trace_structured
|
from torch._logging import trace_structured
|
||||||
from torch.fx.experimental.proxy_tensor import make_fx
|
from torch.fx.experimental.proxy_tensor import make_fx
|
||||||
from torch.utils._ordered_set import OrderedSet
|
from torch.utils._ordered_set import OrderedSet
|
||||||
@ -362,16 +363,17 @@ def all_gather_merge_fn_to_trace_functional(
|
|||||||
|
|
||||||
|
|
||||||
def _trace(fn, inps) -> torch.fx.GraphModule: # type: ignore[no-untyped-def]
|
def _trace(fn, inps) -> torch.fx.GraphModule: # type: ignore[no-untyped-def]
|
||||||
fake_mode = detect_fake_mode(inps)
|
with dynamo_timed("fx.bucketing._trace", log_pt2_compile_event=True):
|
||||||
assert fake_mode is not None
|
fake_mode = detect_fake_mode(inps)
|
||||||
with fake_mode, enable_python_dispatcher():
|
assert fake_mode is not None
|
||||||
out = make_fx(fn)(*inps)
|
with fake_mode, enable_python_dispatcher():
|
||||||
for node in out.graph.find_nodes(
|
out = make_fx(fn)(*inps)
|
||||||
op="call_function", target=torch.ops.aten.detach.default
|
for node in out.graph.find_nodes(
|
||||||
):
|
op="call_function", target=torch.ops.aten.detach.default
|
||||||
node.replace_all_uses_with(node.args[0])
|
):
|
||||||
out.graph.erase_node(node)
|
node.replace_all_uses_with(node.args[0])
|
||||||
return out
|
out.graph.erase_node(node)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
def _insert_fn_trace_before_node( # type: ignore[no-untyped-def]
|
def _insert_fn_trace_before_node( # type: ignore[no-untyped-def]
|
||||||
@ -389,109 +391,113 @@ def _insert_fn_trace_before_node( # type: ignore[no-untyped-def]
|
|||||||
using :attr:`g_fn_inps` nodes of original graphas inputs of function graph,
|
using :attr:`g_fn_inps` nodes of original graphas inputs of function graph,
|
||||||
function graph outputs will replace :attr:`g_fn_outs` in original graph.
|
function graph outputs will replace :attr:`g_fn_outs` in original graph.
|
||||||
"""
|
"""
|
||||||
fn_gm = _trace(
|
with dynamo_timed(
|
||||||
fn_to_trace,
|
"fx.bucketing._insert_fn_trace_before_node", log_pt2_compile_event=True
|
||||||
inps,
|
):
|
||||||
)
|
fn_gm = _trace(
|
||||||
fn_g = fn_gm.graph
|
fn_to_trace,
|
||||||
fn_g_ins = fn_g.find_nodes(op="placeholder")
|
inps,
|
||||||
env = {fn_g_ins[idx]: g_fn_inps[idx] for idx in range(len(g_fn_inps))}
|
)
|
||||||
g_fn_new_outs: list[torch.fx.Node] = []
|
fn_g = fn_gm.graph
|
||||||
with g.inserting_before(insert_before_node):
|
fn_g_ins = fn_g.find_nodes(op="placeholder")
|
||||||
for _n in fn_g.nodes:
|
env = {fn_g_ins[idx]: g_fn_inps[idx] for idx in range(len(g_fn_inps))}
|
||||||
if _n.op == "placeholder":
|
g_fn_new_outs: list[torch.fx.Node] = []
|
||||||
continue
|
with g.inserting_before(insert_before_node):
|
||||||
_new_n = g.node_copy(_n, lambda x: env[x])
|
for _n in fn_g.nodes:
|
||||||
env[_n] = _new_n
|
if _n.op == "placeholder":
|
||||||
if _n.op == "output":
|
continue
|
||||||
g_fn_new_outs = _new_n.args[0] # type: ignore[assignment]
|
_new_n = g.node_copy(_n, lambda x: env[x])
|
||||||
g.erase_node(_new_n)
|
env[_n] = _new_n
|
||||||
replacements = { # noqa: C416
|
if _n.op == "output":
|
||||||
orig_out: new_out for orig_out, new_out in zip(g_fn_outs, g_fn_new_outs)
|
g_fn_new_outs = _new_n.args[0] # type: ignore[assignment]
|
||||||
}
|
g.erase_node(_new_n)
|
||||||
for orig_out, new_out in zip(g_fn_outs, g_fn_new_outs):
|
replacements = { # noqa: C416
|
||||||
orig_out.replace_all_uses_with(new_out)
|
orig_out: new_out for orig_out, new_out in zip(g_fn_outs, g_fn_new_outs)
|
||||||
return replacements
|
}
|
||||||
|
for orig_out, new_out in zip(g_fn_outs, g_fn_new_outs):
|
||||||
|
orig_out.replace_all_uses_with(new_out)
|
||||||
|
return replacements
|
||||||
|
|
||||||
|
|
||||||
def merge_reduce_scatter(
|
def merge_reduce_scatter(
|
||||||
gm: torch.fx.GraphModule, rs_buckets: list[list[torch.fx.Node]]
|
gm: torch.fx.GraphModule, rs_buckets: list[list[torch.fx.Node]]
|
||||||
) -> None:
|
) -> None:
|
||||||
trace_structured(
|
with dynamo_timed("fx.bucketing.merge_reduce_scatter", log_pt2_compile_event=True):
|
||||||
"artifact",
|
trace_structured(
|
||||||
metadata_fn=lambda: {
|
"artifact",
|
||||||
"name": "fx_bucketing_passes_reduce_scatter_buckets",
|
metadata_fn=lambda: {
|
||||||
"encoding": "string",
|
"name": "fx_bucketing_passes_reduce_scatter_buckets",
|
||||||
},
|
"encoding": "string",
|
||||||
payload_fn=lambda: str(rs_buckets),
|
},
|
||||||
)
|
payload_fn=lambda: str(rs_buckets),
|
||||||
n_buckets = len(rs_buckets)
|
|
||||||
g = gm.graph
|
|
||||||
rs_ins: list[list[torch.fx.Node]] = [[] for _ in range(n_buckets)]
|
|
||||||
rs_waits: list[list[torch.fx.Node]] = [[] for _ in range(n_buckets)]
|
|
||||||
|
|
||||||
for bucket_idx, rs_nodes in enumerate(rs_buckets):
|
|
||||||
rs0 = rs_nodes[0]
|
|
||||||
rs0_val = rs0.meta["val"]
|
|
||||||
_, reduce_op, group_size, group_name = rs0.args
|
|
||||||
reduce_dtype = rs0_val.dtype
|
|
||||||
device = rs0_val.device
|
|
||||||
for n in rs_nodes:
|
|
||||||
rs_val = n.meta["val"]
|
|
||||||
assert (
|
|
||||||
n.args[1] == reduce_op
|
|
||||||
and n.args[2] == group_size
|
|
||||||
and n.args[3] == group_name
|
|
||||||
and rs_val.device == device
|
|
||||||
and rs_val.dtype == reduce_dtype
|
|
||||||
)
|
|
||||||
assert len(n.users) == 1
|
|
||||||
wait_n = next(iter(n.users))
|
|
||||||
rs_ins[bucket_idx].append(n.args[0]) # type: ignore[arg-type]
|
|
||||||
rs_waits[bucket_idx].append(wait_n)
|
|
||||||
|
|
||||||
for bucket_idx in range(n_buckets):
|
|
||||||
_rs_ins = rs_ins[bucket_idx]
|
|
||||||
_rs_waits = rs_waits[bucket_idx]
|
|
||||||
_rs_ns = rs_buckets[bucket_idx]
|
|
||||||
|
|
||||||
rs0 = _rs_ns[0]
|
|
||||||
rs0_val = rs0.meta["val"]
|
|
||||||
_, reduce_op, group_size, group_name = rs0.args
|
|
||||||
reduce_dtype = rs0_val.dtype
|
|
||||||
device = rs0_val.device
|
|
||||||
|
|
||||||
replacements = _insert_fn_trace_before_node(
|
|
||||||
g,
|
|
||||||
reduce_scatter_merge_fn_to_trace,
|
|
||||||
(
|
|
||||||
pytree.tree_map(lambda node: node.meta["val"], _rs_ins),
|
|
||||||
group_size,
|
|
||||||
group_name,
|
|
||||||
reduce_op,
|
|
||||||
reduce_dtype,
|
|
||||||
device,
|
|
||||||
),
|
|
||||||
_rs_ns[-1].next,
|
|
||||||
_rs_ins,
|
|
||||||
_rs_waits,
|
|
||||||
)
|
)
|
||||||
# [Note: Replacement in bucketing passes]
|
n_buckets = len(rs_buckets)
|
||||||
# After bucketing _rs_waits will be replaced with output nodes of
|
g = gm.graph
|
||||||
# fn_to_trace graph that will be inserted in the graph g.
|
rs_ins: list[list[torch.fx.Node]] = [[] for _ in range(n_buckets)]
|
||||||
# By this time we already prepared rs_ins, rs_waits.
|
rs_waits: list[list[torch.fx.Node]] = [[] for _ in range(n_buckets)]
|
||||||
# rs_ins for following buckets can be replaced _rs_waits with new nodes.
|
|
||||||
# We apply replacements to rs_ins.
|
|
||||||
|
|
||||||
def _replace(x: torch.fx.Node) -> torch.fx.Node:
|
for bucket_idx, rs_nodes in enumerate(rs_buckets):
|
||||||
return replacements.get(x, x)
|
rs0 = rs_nodes[0]
|
||||||
|
rs0_val = rs0.meta["val"]
|
||||||
|
_, reduce_op, group_size, group_name = rs0.args
|
||||||
|
reduce_dtype = rs0_val.dtype
|
||||||
|
device = rs0_val.device
|
||||||
|
for n in rs_nodes:
|
||||||
|
rs_val = n.meta["val"]
|
||||||
|
assert (
|
||||||
|
n.args[1] == reduce_op
|
||||||
|
and n.args[2] == group_size
|
||||||
|
and n.args[3] == group_name
|
||||||
|
and rs_val.device == device
|
||||||
|
and rs_val.dtype == reduce_dtype
|
||||||
|
)
|
||||||
|
assert len(n.users) == 1
|
||||||
|
wait_n = next(iter(n.users))
|
||||||
|
rs_ins[bucket_idx].append(n.args[0]) # type: ignore[arg-type]
|
||||||
|
rs_waits[bucket_idx].append(wait_n)
|
||||||
|
|
||||||
for j in range(bucket_idx + 1, n_buckets):
|
for bucket_idx in range(n_buckets):
|
||||||
rs_ins[j] = pytree.tree_map(_replace, rs_ins[j])
|
_rs_ins = rs_ins[bucket_idx]
|
||||||
|
_rs_waits = rs_waits[bucket_idx]
|
||||||
|
_rs_ns = rs_buckets[bucket_idx]
|
||||||
|
|
||||||
for rs_n, wait_n in zip(_rs_ns, _rs_waits):
|
rs0 = _rs_ns[0]
|
||||||
g.erase_node(wait_n)
|
rs0_val = rs0.meta["val"]
|
||||||
g.erase_node(rs_n)
|
_, reduce_op, group_size, group_name = rs0.args
|
||||||
|
reduce_dtype = rs0_val.dtype
|
||||||
|
device = rs0_val.device
|
||||||
|
|
||||||
|
replacements = _insert_fn_trace_before_node(
|
||||||
|
g,
|
||||||
|
reduce_scatter_merge_fn_to_trace,
|
||||||
|
(
|
||||||
|
pytree.tree_map(lambda node: node.meta["val"], _rs_ins),
|
||||||
|
group_size,
|
||||||
|
group_name,
|
||||||
|
reduce_op,
|
||||||
|
reduce_dtype,
|
||||||
|
device,
|
||||||
|
),
|
||||||
|
_rs_ns[-1].next,
|
||||||
|
_rs_ins,
|
||||||
|
_rs_waits,
|
||||||
|
)
|
||||||
|
# [Note: Replacement in bucketing passes]
|
||||||
|
# After bucketing _rs_waits will be replaced with output nodes of
|
||||||
|
# fn_to_trace graph that will be inserted in the graph g.
|
||||||
|
# By this time we already prepared rs_ins, rs_waits.
|
||||||
|
# rs_ins for following buckets can be replaced _rs_waits with new nodes.
|
||||||
|
# We apply replacements to rs_ins.
|
||||||
|
|
||||||
|
def _replace(x: torch.fx.Node) -> torch.fx.Node:
|
||||||
|
return replacements.get(x, x)
|
||||||
|
|
||||||
|
for j in range(bucket_idx + 1, n_buckets):
|
||||||
|
rs_ins[j] = pytree.tree_map(_replace, rs_ins[j])
|
||||||
|
|
||||||
|
for rs_n, wait_n in zip(_rs_ns, _rs_waits):
|
||||||
|
g.erase_node(wait_n)
|
||||||
|
g.erase_node(rs_n)
|
||||||
|
|
||||||
|
|
||||||
def merge_all_gather(
|
def merge_all_gather(
|
||||||
@ -500,78 +506,79 @@ def merge_all_gather(
|
|||||||
"""
|
"""
|
||||||
Merges specified buckets of all_gather to joint all_gather.
|
Merges specified buckets of all_gather to joint all_gather.
|
||||||
"""
|
"""
|
||||||
from torch.distributed.distributed_c10d import _resolve_process_group
|
with dynamo_timed("fx.bucketing.merge_all_gather", log_pt2_compile_event=True):
|
||||||
|
from torch.distributed.distributed_c10d import _resolve_process_group
|
||||||
|
|
||||||
trace_structured(
|
trace_structured(
|
||||||
"artifact",
|
"artifact",
|
||||||
metadata_fn=lambda: {
|
metadata_fn=lambda: {
|
||||||
"name": "fx_bucketing_passes_all_gather_buckets",
|
"name": "fx_bucketing_passes_all_gather_buckets",
|
||||||
"encoding": "string",
|
"encoding": "string",
|
||||||
},
|
},
|
||||||
payload_fn=lambda: str(ag_buckets),
|
payload_fn=lambda: str(ag_buckets),
|
||||||
)
|
|
||||||
n_buckets = len(ag_buckets)
|
|
||||||
|
|
||||||
ag_ins: list[list[torch.fx.Node]] = [[] for _ in range(n_buckets)]
|
|
||||||
ag_waits: list[list[torch.fx.Node]] = [[] for _ in range(n_buckets)]
|
|
||||||
for bucket_idx, ag_bucket in enumerate(ag_buckets):
|
|
||||||
_, group_size, group_name = ag_bucket[0].args
|
|
||||||
assert isinstance(group_name, str)
|
|
||||||
dtype = ag_bucket[0].meta["val"].dtype
|
|
||||||
|
|
||||||
for ag_node in ag_bucket:
|
|
||||||
assert len(ag_node.users) == 1, (
|
|
||||||
f"Expect only one user for {ag_node}, but got {ag_node.users}"
|
|
||||||
)
|
|
||||||
wait_node = next(iter(ag_node.users))
|
|
||||||
assert (
|
|
||||||
ag_node.args[1] == group_size
|
|
||||||
and ag_node.args[2] == group_name
|
|
||||||
and ag_node.meta["val"].dtype == dtype
|
|
||||||
)
|
|
||||||
ag_node_in = ag_node.args[0]
|
|
||||||
|
|
||||||
ag_ins[bucket_idx].append(ag_node_in) # type: ignore[union-attr, arg-type]
|
|
||||||
ag_waits[bucket_idx].append(wait_node)
|
|
||||||
|
|
||||||
g = gm.graph
|
|
||||||
|
|
||||||
for bucket_idx in range(n_buckets):
|
|
||||||
_ag_ins = ag_ins[bucket_idx]
|
|
||||||
_ag_waits = ag_waits[bucket_idx]
|
|
||||||
_ag_ns = ag_buckets[bucket_idx]
|
|
||||||
|
|
||||||
ag0 = _ag_ns[0]
|
|
||||||
ag0_val = ag0.meta["val"]
|
|
||||||
_, group_size, group_name = ag0.args
|
|
||||||
dtype = ag0_val.dtype
|
|
||||||
assert isinstance(group_name, str)
|
|
||||||
|
|
||||||
rank: int = dist.get_rank(_resolve_process_group(group_name))
|
|
||||||
|
|
||||||
replacements = _insert_fn_trace_before_node(
|
|
||||||
g,
|
|
||||||
all_gather_merge_fn_to_trace,
|
|
||||||
(
|
|
||||||
pytree.tree_map(lambda node: node.meta["val"], _ag_ins),
|
|
||||||
group_size,
|
|
||||||
group_name,
|
|
||||||
dtype,
|
|
||||||
rank,
|
|
||||||
),
|
|
||||||
ag0.next,
|
|
||||||
_ag_ins,
|
|
||||||
_ag_waits,
|
|
||||||
)
|
)
|
||||||
|
n_buckets = len(ag_buckets)
|
||||||
|
|
||||||
# See Note: [Replacement in bucketing passes]
|
ag_ins: list[list[torch.fx.Node]] = [[] for _ in range(n_buckets)]
|
||||||
def _replace(x: torch.fx.Node) -> torch.fx.Node:
|
ag_waits: list[list[torch.fx.Node]] = [[] for _ in range(n_buckets)]
|
||||||
return replacements.get(x, x)
|
for bucket_idx, ag_bucket in enumerate(ag_buckets):
|
||||||
|
_, group_size, group_name = ag_bucket[0].args
|
||||||
|
assert isinstance(group_name, str)
|
||||||
|
dtype = ag_bucket[0].meta["val"].dtype
|
||||||
|
|
||||||
for j in range(bucket_idx + 1, n_buckets):
|
for ag_node in ag_bucket:
|
||||||
ag_ins[j] = pytree.tree_map(_replace, ag_ins[j])
|
assert len(ag_node.users) == 1, (
|
||||||
|
f"Expect only one user for {ag_node}, but got {ag_node.users}"
|
||||||
|
)
|
||||||
|
wait_node = next(iter(ag_node.users))
|
||||||
|
assert (
|
||||||
|
ag_node.args[1] == group_size
|
||||||
|
and ag_node.args[2] == group_name
|
||||||
|
and ag_node.meta["val"].dtype == dtype
|
||||||
|
)
|
||||||
|
ag_node_in = ag_node.args[0]
|
||||||
|
|
||||||
# Erasing old nodes in reverse order
|
ag_ins[bucket_idx].append(ag_node_in) # type: ignore[union-attr, arg-type]
|
||||||
for ag_n, wait_n in zip(ag_buckets[bucket_idx], _ag_waits):
|
ag_waits[bucket_idx].append(wait_node)
|
||||||
g.erase_node(wait_n)
|
|
||||||
g.erase_node(ag_n)
|
g = gm.graph
|
||||||
|
|
||||||
|
for bucket_idx in range(n_buckets):
|
||||||
|
_ag_ins = ag_ins[bucket_idx]
|
||||||
|
_ag_waits = ag_waits[bucket_idx]
|
||||||
|
_ag_ns = ag_buckets[bucket_idx]
|
||||||
|
|
||||||
|
ag0 = _ag_ns[0]
|
||||||
|
ag0_val = ag0.meta["val"]
|
||||||
|
_, group_size, group_name = ag0.args
|
||||||
|
dtype = ag0_val.dtype
|
||||||
|
assert isinstance(group_name, str)
|
||||||
|
|
||||||
|
rank: int = dist.get_rank(_resolve_process_group(group_name))
|
||||||
|
|
||||||
|
replacements = _insert_fn_trace_before_node(
|
||||||
|
g,
|
||||||
|
all_gather_merge_fn_to_trace,
|
||||||
|
(
|
||||||
|
pytree.tree_map(lambda node: node.meta["val"], _ag_ins),
|
||||||
|
group_size,
|
||||||
|
group_name,
|
||||||
|
dtype,
|
||||||
|
rank,
|
||||||
|
),
|
||||||
|
ag0.next,
|
||||||
|
_ag_ins,
|
||||||
|
_ag_waits,
|
||||||
|
)
|
||||||
|
|
||||||
|
# See Note: [Replacement in bucketing passes]
|
||||||
|
def _replace(x: torch.fx.Node) -> torch.fx.Node:
|
||||||
|
return replacements.get(x, x)
|
||||||
|
|
||||||
|
for j in range(bucket_idx + 1, n_buckets):
|
||||||
|
ag_ins[j] = pytree.tree_map(_replace, ag_ins[j])
|
||||||
|
|
||||||
|
# Erasing old nodes in reverse order
|
||||||
|
for ag_n, wait_n in zip(ag_buckets[bucket_idx], _ag_waits):
|
||||||
|
g.erase_node(wait_n)
|
||||||
|
g.erase_node(ag_n)
|
||||||
|
|||||||
@ -38,7 +38,19 @@ def is_graph_output(node: torch.fx.Node) -> bool:
|
|||||||
|
|
||||||
|
|
||||||
def is_fsdp_reduce_scatter_wait(wait: torch.fx.Node) -> bool:
|
def is_fsdp_reduce_scatter_wait(wait: torch.fx.Node) -> bool:
|
||||||
return is_graph_output(wait)
|
if is_graph_output(wait):
|
||||||
|
return True
|
||||||
|
|
||||||
|
if len(wait.users) == 1:
|
||||||
|
user = next(iter(wait.users))
|
||||||
|
assert user is not None
|
||||||
|
return (
|
||||||
|
is_graph_output(user)
|
||||||
|
and user.op == "call_function"
|
||||||
|
and user.target == torch.ops.prims.convert_element_type.default
|
||||||
|
)
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
def bucket_fsdp_all_gather(
|
def bucket_fsdp_all_gather(
|
||||||
|
|||||||
Reference in New Issue
Block a user