mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[bucketing] custom_ops mode to hide inductor copies overhead (#161499)
Adding "_custom_ops" bucketing to temporary fallback to eager execution of for_each, to workaround too many generated kernels on inductor side. This PR also reverts parts of bucketing changes for cycles detection that resulted in accuracy problems. Differential Revision: [D81152293](https://our.internmc.facebook.com/intern/diff/D81152293) Pull Request resolved: https://github.com/pytorch/pytorch/pull/161499 Approved by: https://github.com/eellison
This commit is contained in:
committed by
PyTorch MergeBot
parent
9c991b63ff
commit
8ec01f34e9
@ -1528,7 +1528,8 @@ class TestCollectivesInductor(DynamoDistributedSingleProcTestCase):
|
|||||||
|
|
||||||
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
||||||
@unittest.skipIf(not SM80OrLater, "bfloat16")
|
@unittest.skipIf(not SM80OrLater, "bfloat16")
|
||||||
def test_all_gather_bucket(self):
|
@parametrize("bucket_mode", ["all", "all_custom_ops"])
|
||||||
|
def test_all_gather_bucket(self, bucket_mode):
|
||||||
def func(x, w, ag_0, ag_1, ag_2, ag_3, *, tag, ranks, group_size):
|
def func(x, w, ag_0, ag_1, ag_2, ag_3, *, tag, ranks, group_size):
|
||||||
# do some unrelated matmuls
|
# do some unrelated matmuls
|
||||||
y = torch.mm(x, w)
|
y = torch.mm(x, w)
|
||||||
@ -1576,7 +1577,7 @@ class TestCollectivesInductor(DynamoDistributedSingleProcTestCase):
|
|||||||
with (
|
with (
|
||||||
torch._inductor.config.patch(
|
torch._inductor.config.patch(
|
||||||
{
|
{
|
||||||
"bucket_all_gathers_fx": "all",
|
"bucket_all_gathers_fx": bucket_mode,
|
||||||
"reorder_for_compute_comm_overlap": False,
|
"reorder_for_compute_comm_overlap": False,
|
||||||
"runtime_estimations_mms_benchmark": True,
|
"runtime_estimations_mms_benchmark": True,
|
||||||
}
|
}
|
||||||
@ -1595,7 +1596,9 @@ class TestCollectivesInductor(DynamoDistributedSingleProcTestCase):
|
|||||||
# We want to make sure no unnecessary copy is made.
|
# We want to make sure no unnecessary copy is made.
|
||||||
(
|
(
|
||||||
FileCheck()
|
FileCheck()
|
||||||
.check_count(".all_gather_into_tensor_out.default(", 2, exactly=True)
|
.check("= torch.ops._c10d_functional.all_gather_into_tensor")
|
||||||
|
.check("torch.ops._c10d_functional.all_gather_into_tensor_out.default(")
|
||||||
|
.check("= torch.ops._c10d_functional.all_gather_into_tensor")
|
||||||
.run(code)
|
.run(code)
|
||||||
)
|
)
|
||||||
out = compiled(*inputs, **self.get_world_trs())
|
out = compiled(*inputs, **self.get_world_trs())
|
||||||
@ -1656,7 +1659,8 @@ class TestCollectivesInductor(DynamoDistributedSingleProcTestCase):
|
|||||||
|
|
||||||
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
||||||
@unittest.skipIf(not SM80OrLater, "bfloat16")
|
@unittest.skipIf(not SM80OrLater, "bfloat16")
|
||||||
def test_reduce_scatter_bucket(self):
|
@parametrize("bucket_mode", ["all", "all_custom_ops"])
|
||||||
|
def test_reduce_scatter_bucket(self, bucket_mode):
|
||||||
def func(x, w, rs_0, rs_1, tag, ranks, group_size):
|
def func(x, w, rs_0, rs_1, tag, ranks, group_size):
|
||||||
# do some unrelated matmuls
|
# do some unrelated matmuls
|
||||||
y = torch.mm(x, w)
|
y = torch.mm(x, w)
|
||||||
@ -1697,7 +1701,7 @@ class TestCollectivesInductor(DynamoDistributedSingleProcTestCase):
|
|||||||
|
|
||||||
with torch._inductor.config.patch(
|
with torch._inductor.config.patch(
|
||||||
{
|
{
|
||||||
"bucket_reduce_scatters_fx": "fsdp",
|
"bucket_reduce_scatters_fx": bucket_mode,
|
||||||
"reorder_for_compute_comm_overlap": False,
|
"reorder_for_compute_comm_overlap": False,
|
||||||
}
|
}
|
||||||
):
|
):
|
||||||
@ -1723,7 +1727,8 @@ class TestCollectivesInductor(DynamoDistributedSingleProcTestCase):
|
|||||||
|
|
||||||
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
||||||
@unittest.skipIf(not SM80OrLater, "bfloat16")
|
@unittest.skipIf(not SM80OrLater, "bfloat16")
|
||||||
def test_reorder_peak_memory_bucketed(self):
|
@parametrize("bucket_mode", ["all", "all_custom_ops"])
|
||||||
|
def test_reorder_peak_memory_bucketed(self, bucket_mode):
|
||||||
"""
|
"""
|
||||||
Simulate the case where a bucketing pass ran and grouped several inputs into one bucketed allgather.
|
Simulate the case where a bucketing pass ran and grouped several inputs into one bucketed allgather.
|
||||||
Ensure the whole bucketed group including copy-ops get moved together rather than the copy ops preventing the
|
Ensure the whole bucketed group including copy-ops get moved together rather than the copy ops preventing the
|
||||||
@ -1837,9 +1842,9 @@ class TestCollectivesInductor(DynamoDistributedSingleProcTestCase):
|
|||||||
with (
|
with (
|
||||||
torch._inductor.config.patch(
|
torch._inductor.config.patch(
|
||||||
{
|
{
|
||||||
"bucket_all_gathers_fx": "all",
|
"bucket_all_gathers_fx": bucket_mode,
|
||||||
"bucket_all_gathers_fx_bucket_size_determinator": lambda _: 2,
|
"bucket_all_gathers_fx_bucket_size_determinator": lambda _: 2,
|
||||||
"bucket_reduce_scatters_fx": "all",
|
"bucket_reduce_scatters_fx": bucket_mode,
|
||||||
"bucket_reduce_scatters_fx_bucket_size_determinator": lambda _: 2,
|
"bucket_reduce_scatters_fx_bucket_size_determinator": lambda _: 2,
|
||||||
"reorder_for_compute_comm_overlap": True,
|
"reorder_for_compute_comm_overlap": True,
|
||||||
"reorder_for_compute_comm_overlap_passes": [
|
"reorder_for_compute_comm_overlap_passes": [
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
import collections
|
import collections
|
||||||
import logging
|
import logging
|
||||||
|
from collections import defaultdict
|
||||||
from typing import Any, Callable, Optional
|
from typing import Any, Callable, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -33,6 +34,7 @@ def bucket_cap_mb_by_bucket_idx_default(bucket_id: int) -> float:
|
|||||||
def bucket_all_gather(
|
def bucket_all_gather(
|
||||||
gm: torch.fx.GraphModule,
|
gm: torch.fx.GraphModule,
|
||||||
bucket_cap_mb_by_bucket_idx: Optional[Callable[[int], float]] = None,
|
bucket_cap_mb_by_bucket_idx: Optional[Callable[[int], float]] = None,
|
||||||
|
mode: Optional[str] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
if bucket_cap_mb_by_bucket_idx is None:
|
if bucket_cap_mb_by_bucket_idx is None:
|
||||||
from torch._inductor.fx_passes.bucketing import (
|
from torch._inductor.fx_passes.bucketing import (
|
||||||
@ -43,13 +45,13 @@ def bucket_all_gather(
|
|||||||
ag_buckets = bucket_all_gather_by_mb(gm, bucket_cap_mb_by_bucket_idx)
|
ag_buckets = bucket_all_gather_by_mb(gm, bucket_cap_mb_by_bucket_idx)
|
||||||
if len(ag_buckets) == 0:
|
if len(ag_buckets) == 0:
|
||||||
return
|
return
|
||||||
|
merge_all_gather(gm, ag_buckets, mode)
|
||||||
merge_all_gather(gm, ag_buckets)
|
|
||||||
|
|
||||||
|
|
||||||
def bucket_reduce_scatter(
|
def bucket_reduce_scatter(
|
||||||
gm: torch.fx.GraphModule,
|
gm: torch.fx.GraphModule,
|
||||||
bucket_cap_mb_by_bucket_idx: Optional[Callable[[int], float]] = None,
|
bucket_cap_mb_by_bucket_idx: Optional[Callable[[int], float]] = None,
|
||||||
|
mode: Optional[str] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
if bucket_cap_mb_by_bucket_idx is None:
|
if bucket_cap_mb_by_bucket_idx is None:
|
||||||
from torch._inductor.fx_passes.bucketing import (
|
from torch._inductor.fx_passes.bucketing import (
|
||||||
@ -60,7 +62,7 @@ def bucket_reduce_scatter(
|
|||||||
rs_buckets = bucket_reduce_scatter_by_mb(gm, bucket_cap_mb_by_bucket_idx)
|
rs_buckets = bucket_reduce_scatter_by_mb(gm, bucket_cap_mb_by_bucket_idx)
|
||||||
if len(rs_buckets) == 0:
|
if len(rs_buckets) == 0:
|
||||||
return
|
return
|
||||||
merge_reduce_scatter(gm, rs_buckets)
|
merge_reduce_scatter(gm, rs_buckets, mode)
|
||||||
|
|
||||||
|
|
||||||
def is_all_gather_into_tensor(node: torch.fx.Node) -> bool: # type: ignore[arg-type]
|
def is_all_gather_into_tensor(node: torch.fx.Node) -> bool: # type: ignore[arg-type]
|
||||||
@ -131,28 +133,46 @@ def greedy_bucket_collective_by_mb(
|
|||||||
node_group_key: Callable[[torch.fx.Node], Any],
|
node_group_key: Callable[[torch.fx.Node], Any],
|
||||||
filter_wait_node: Optional[Callable[[torch.fx.Node], bool]] = None,
|
filter_wait_node: Optional[Callable[[torch.fx.Node], bool]] = None,
|
||||||
) -> list[list[torch.fx.Node]]:
|
) -> list[list[torch.fx.Node]]:
|
||||||
if not gm.graph.find_nodes(
|
"""
|
||||||
op="call_function", target=torch.ops._c10d_functional.wait_tensor.default
|
Bucketing adjacent collectives with equal node_group_key.
|
||||||
):
|
We can not bucket non adjacent collectives,
|
||||||
return []
|
as this will effectively change the order of collectives.
|
||||||
|
Reordering can lead to different order on different ranks.
|
||||||
|
"""
|
||||||
g = gm.graph
|
g = gm.graph
|
||||||
|
found_candidates = False
|
||||||
|
for node in g.nodes:
|
||||||
|
if filter_node(node):
|
||||||
|
found_candidates = True
|
||||||
|
break
|
||||||
|
if not found_candidates:
|
||||||
|
return []
|
||||||
|
|
||||||
# TODO: pearce kelly algorithm for detecting cycles
|
# TODO: pearce kelly algorithm for detecting cycles
|
||||||
node_descendents = collect_node_descendants(gm.graph)
|
node_descendents = collect_node_descendants(gm.graph)
|
||||||
|
|
||||||
node_groups: dict[Any, list[torch.fx.Node]] = collections.defaultdict(list)
|
nodes_groups: list[list[torch.fx.Node]] = []
|
||||||
|
cur_group: list[torch.fx.Node] = []
|
||||||
|
cur_group_key = None
|
||||||
|
|
||||||
for node in g.nodes:
|
for node in g.nodes:
|
||||||
if is_wait_tensor(node) and filter_node(node.args[0]):
|
if is_wait_tensor(node) and filter_node(node.args[0]):
|
||||||
if (filter_wait_node is None) or filter_wait_node(node):
|
if (filter_wait_node is None) or filter_wait_node(node):
|
||||||
coll_node = node.args[0]
|
coll_node = node.args[0]
|
||||||
group_key = node_group_key(coll_node)
|
group_key = node_group_key(coll_node)
|
||||||
node_groups[group_key].append(coll_node)
|
if group_key == cur_group_key:
|
||||||
|
cur_group.append(coll_node)
|
||||||
|
else:
|
||||||
|
if len(cur_group) > 1:
|
||||||
|
nodes_groups.append(cur_group)
|
||||||
|
cur_group = [coll_node]
|
||||||
|
cur_group_key = group_key
|
||||||
|
|
||||||
|
if len(cur_group) > 1:
|
||||||
|
nodes_groups.append(cur_group)
|
||||||
|
|
||||||
buckets: list[list[torch.fx.Node]] = []
|
buckets: list[list[torch.fx.Node]] = []
|
||||||
|
for nodes in nodes_groups:
|
||||||
for nodes in node_groups.values():
|
|
||||||
cur_bucket: list[torch.fx.Node] = []
|
cur_bucket: list[torch.fx.Node] = []
|
||||||
cur_bucket_descendents: OrderedSet[torch.fx.Node] = OrderedSet()
|
cur_bucket_descendents: OrderedSet[torch.fx.Node] = OrderedSet()
|
||||||
cur_bucket_size_bytes: int = 0
|
cur_bucket_size_bytes: int = 0
|
||||||
@ -261,6 +281,52 @@ def bucket_reduce_scatter_by_mb(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.library.custom_op("bucketing::_pre_bucket_reduce_scatter", mutates_args={})
|
||||||
|
def _pre_bucket_reduce_scatter(
|
||||||
|
rs_ins: list[torch.Tensor],
|
||||||
|
group_size: int,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
rs_ins_flattened = [x.view(group_size, -1) for x in rs_ins]
|
||||||
|
new_rs_in = torch.cat(rs_ins_flattened, dim=1).flatten()
|
||||||
|
return new_rs_in
|
||||||
|
|
||||||
|
|
||||||
|
def _pre_bucket_reduce_scatter_fake(
|
||||||
|
rs_ins: list[torch.Tensor],
|
||||||
|
group_size: int,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
out_numel = sum(rs_in.numel() for rs_in in rs_ins)
|
||||||
|
return torch.empty((out_numel,), device=rs_ins[0].device, dtype=rs_ins[0].dtype)
|
||||||
|
|
||||||
|
|
||||||
|
_pre_bucket_reduce_scatter.register_fake(_pre_bucket_reduce_scatter_fake)
|
||||||
|
|
||||||
|
|
||||||
|
def reduce_scatter_merge_fn_to_trace_custom_ops(
|
||||||
|
rs_ins: list[torch.Tensor],
|
||||||
|
group_size: int,
|
||||||
|
group_name: str,
|
||||||
|
reduce_op: str,
|
||||||
|
reduce_dtype: torch.dtype, # type: ignore[name-defined]
|
||||||
|
device: torch.device, # type: ignore[name-defined]
|
||||||
|
) -> list[torch.Tensor]: # type: ignore[no-untyped-def]
|
||||||
|
new_out_sizes = [(x.shape[0] // group_size,) + x.shape[1:] for x in rs_ins]
|
||||||
|
new_out_numels = [x.numel() // group_size for x in rs_ins]
|
||||||
|
|
||||||
|
new_rs_in = torch.ops.bucketing._pre_bucket_reduce_scatter(rs_ins, group_size)
|
||||||
|
|
||||||
|
# TODO - either use torch.cat or make sure inductor foreach codegen
|
||||||
|
# fires more reliably
|
||||||
|
new_rs_out = torch.ops.c10d_functional.wait_tensor(
|
||||||
|
torch.ops._c10d_functional.reduce_scatter_tensor.default(
|
||||||
|
new_rs_in, reduce_op, group_size, group_name
|
||||||
|
)
|
||||||
|
)
|
||||||
|
new_out_flat = new_rs_out.split(new_out_numels, 0)
|
||||||
|
new_outs = [x.view(s) for x, s in zip(new_out_flat, new_out_sizes)]
|
||||||
|
return new_outs
|
||||||
|
|
||||||
|
|
||||||
def reduce_scatter_merge_fn_to_trace(
|
def reduce_scatter_merge_fn_to_trace(
|
||||||
rs_ins: list[torch.Tensor],
|
rs_ins: list[torch.Tensor],
|
||||||
group_size: int,
|
group_size: int,
|
||||||
@ -276,8 +342,6 @@ def reduce_scatter_merge_fn_to_trace(
|
|||||||
|
|
||||||
new_rs_in = torch.cat(rs_ins_flattened, dim=1).flatten()
|
new_rs_in = torch.cat(rs_ins_flattened, dim=1).flatten()
|
||||||
|
|
||||||
# TODO - either use torch.cat or make sure inductor foreach codegen
|
|
||||||
# fires more reliably
|
|
||||||
new_rs_out = torch.ops.c10d_functional.wait_tensor(
|
new_rs_out = torch.ops.c10d_functional.wait_tensor(
|
||||||
torch.ops._c10d_functional.reduce_scatter_tensor.default(
|
torch.ops._c10d_functional.reduce_scatter_tensor.default(
|
||||||
new_rs_in, reduce_op, group_size, group_name
|
new_rs_in, reduce_op, group_size, group_name
|
||||||
@ -288,6 +352,74 @@ def reduce_scatter_merge_fn_to_trace(
|
|||||||
return new_outs
|
return new_outs
|
||||||
|
|
||||||
|
|
||||||
|
@torch.library.custom_op("bucketing::_pre_bucket_all_gather", mutates_args={})
|
||||||
|
def _pre_bucket_all_gather(
|
||||||
|
ag_ins: list[torch.Tensor],
|
||||||
|
group_size: int,
|
||||||
|
group_name: str,
|
||||||
|
dtype: torch.dtype, # type: ignore[name-defined]
|
||||||
|
rank: int,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
ins_split_sizes = [ag_in.numel() for ag_in in ag_ins]
|
||||||
|
ag_input_numel = sum(ins_split_sizes)
|
||||||
|
device = ag_ins[0].device
|
||||||
|
new_ag_out = torch.empty(ag_input_numel * group_size, dtype=dtype, device=device)
|
||||||
|
new_ag_in = new_ag_out.narrow(0, ag_input_numel * rank, ag_input_numel)
|
||||||
|
foreach_copy_dsts = torch.split(new_ag_in, ins_split_sizes)
|
||||||
|
ag_ins_flattened = [ag_in.reshape(-1) for ag_in in ag_ins]
|
||||||
|
torch._foreach_copy_(foreach_copy_dsts, ag_ins_flattened)
|
||||||
|
return new_ag_out
|
||||||
|
|
||||||
|
|
||||||
|
def _pre_bucket_all_gather_fake(
|
||||||
|
ag_ins: list[torch.Tensor],
|
||||||
|
group_size: int,
|
||||||
|
group_name: str,
|
||||||
|
dtype: torch.dtype, # type: ignore[name-defined]
|
||||||
|
rank: int,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
ins_split_sizes = [ag_in.numel() for ag_in in ag_ins]
|
||||||
|
ag_input_numel = sum(ins_split_sizes)
|
||||||
|
device = ag_ins[0].device
|
||||||
|
new_ag_out = torch.empty(ag_input_numel * group_size, dtype=dtype, device=device)
|
||||||
|
return new_ag_out
|
||||||
|
|
||||||
|
|
||||||
|
_pre_bucket_all_gather.register_fake(_pre_bucket_all_gather_fake)
|
||||||
|
|
||||||
|
|
||||||
|
def all_gather_merge_fn_to_trace_custom_ops(
|
||||||
|
ag_ins: list[torch.Tensor],
|
||||||
|
group_size: int,
|
||||||
|
group_name: str,
|
||||||
|
dtype: torch.dtype, # type: ignore[name-defined]
|
||||||
|
rank: int,
|
||||||
|
) -> list[torch.Tensor]:
|
||||||
|
ins_sizes = [ag_in.shape for ag_in in ag_ins]
|
||||||
|
ins_split_sizes = [ag_in.numel() for ag_in in ag_ins]
|
||||||
|
ag_input_numel = sum(ins_split_sizes)
|
||||||
|
new_ag_out = torch.ops.bucketing._pre_bucket_all_gather(
|
||||||
|
ag_ins, group_size, group_name, dtype, rank
|
||||||
|
)
|
||||||
|
new_ag_in = new_ag_out.narrow(0, ag_input_numel * rank, ag_input_numel)
|
||||||
|
wait_tensor = torch.ops.c10d_functional.wait_tensor(
|
||||||
|
torch.ops._c10d_functional.all_gather_into_tensor_out.default(
|
||||||
|
new_ag_in, group_size, group_name, out=new_ag_out
|
||||||
|
)
|
||||||
|
)
|
||||||
|
new_ag_out_reshaped = wait_tensor.reshape(group_size, -1)
|
||||||
|
outs = torch.split_with_sizes(
|
||||||
|
new_ag_out_reshaped,
|
||||||
|
ins_split_sizes,
|
||||||
|
dim=1,
|
||||||
|
)
|
||||||
|
outs_reshaped = [
|
||||||
|
o.reshape((shape[0] * group_size,) + shape[1:])
|
||||||
|
for o, shape in zip(outs, ins_sizes)
|
||||||
|
]
|
||||||
|
return outs_reshaped
|
||||||
|
|
||||||
|
|
||||||
def all_gather_merge_fn_to_trace(
|
def all_gather_merge_fn_to_trace(
|
||||||
ag_ins: list[torch.Tensor],
|
ag_ins: list[torch.Tensor],
|
||||||
group_size: int,
|
group_size: int,
|
||||||
@ -420,9 +552,17 @@ def _insert_fn_trace_before_node( # type: ignore[no-untyped-def]
|
|||||||
|
|
||||||
|
|
||||||
def merge_reduce_scatter(
|
def merge_reduce_scatter(
|
||||||
gm: torch.fx.GraphModule, rs_buckets: list[list[torch.fx.Node]]
|
gm: torch.fx.GraphModule,
|
||||||
|
rs_buckets: list[list[torch.fx.Node]],
|
||||||
|
mode: Optional[str] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
"""
|
||||||
|
Merges specified buckets of reduce_scatter to joint reduce_scatter.
|
||||||
|
"""
|
||||||
with dynamo_timed("fx.bucketing.merge_reduce_scatter", log_pt2_compile_event=True):
|
with dynamo_timed("fx.bucketing.merge_reduce_scatter", log_pt2_compile_event=True):
|
||||||
|
rs_merge_fn = reduce_scatter_merge_fn_to_trace
|
||||||
|
if mode and "custom_ops" in mode:
|
||||||
|
rs_merge_fn = reduce_scatter_merge_fn_to_trace_custom_ops
|
||||||
trace_structured(
|
trace_structured(
|
||||||
"artifact",
|
"artifact",
|
||||||
metadata_fn=lambda: {
|
metadata_fn=lambda: {
|
||||||
@ -469,7 +609,7 @@ def merge_reduce_scatter(
|
|||||||
|
|
||||||
replacements = _insert_fn_trace_before_node(
|
replacements = _insert_fn_trace_before_node(
|
||||||
g,
|
g,
|
||||||
reduce_scatter_merge_fn_to_trace,
|
rs_merge_fn,
|
||||||
(
|
(
|
||||||
pytree.tree_map(lambda node: node.meta["val"], _rs_ins),
|
pytree.tree_map(lambda node: node.meta["val"], _rs_ins),
|
||||||
group_size,
|
group_size,
|
||||||
@ -501,7 +641,9 @@ def merge_reduce_scatter(
|
|||||||
|
|
||||||
|
|
||||||
def merge_all_gather(
|
def merge_all_gather(
|
||||||
gm: torch.fx.GraphModule, ag_buckets: list[list[torch.fx.Node]]
|
gm: torch.fx.GraphModule,
|
||||||
|
ag_buckets: list[list[torch.fx.Node]],
|
||||||
|
mode: Optional[str] = None,
|
||||||
) -> None: # type: ignore[union-attr]
|
) -> None: # type: ignore[union-attr]
|
||||||
"""
|
"""
|
||||||
Merges specified buckets of all_gather to joint all_gather.
|
Merges specified buckets of all_gather to joint all_gather.
|
||||||
@ -509,6 +651,10 @@ def merge_all_gather(
|
|||||||
with dynamo_timed("fx.bucketing.merge_all_gather", log_pt2_compile_event=True):
|
with dynamo_timed("fx.bucketing.merge_all_gather", log_pt2_compile_event=True):
|
||||||
from torch.distributed.distributed_c10d import _resolve_process_group
|
from torch.distributed.distributed_c10d import _resolve_process_group
|
||||||
|
|
||||||
|
ag_merge_fn = all_gather_merge_fn_to_trace
|
||||||
|
if mode and "custom_ops" in mode:
|
||||||
|
ag_merge_fn = all_gather_merge_fn_to_trace_custom_ops
|
||||||
|
|
||||||
trace_structured(
|
trace_structured(
|
||||||
"artifact",
|
"artifact",
|
||||||
metadata_fn=lambda: {
|
metadata_fn=lambda: {
|
||||||
@ -519,6 +665,8 @@ def merge_all_gather(
|
|||||||
)
|
)
|
||||||
n_buckets = len(ag_buckets)
|
n_buckets = len(ag_buckets)
|
||||||
|
|
||||||
|
ag_node_to_pre_nodes = defaultdict(list)
|
||||||
|
|
||||||
ag_ins: list[list[torch.fx.Node]] = [[] for _ in range(n_buckets)]
|
ag_ins: list[list[torch.fx.Node]] = [[] for _ in range(n_buckets)]
|
||||||
ag_waits: list[list[torch.fx.Node]] = [[] for _ in range(n_buckets)]
|
ag_waits: list[list[torch.fx.Node]] = [[] for _ in range(n_buckets)]
|
||||||
for bucket_idx, ag_bucket in enumerate(ag_buckets):
|
for bucket_idx, ag_bucket in enumerate(ag_buckets):
|
||||||
@ -537,6 +685,14 @@ def merge_all_gather(
|
|||||||
and ag_node.meta["val"].dtype == dtype
|
and ag_node.meta["val"].dtype == dtype
|
||||||
)
|
)
|
||||||
ag_node_in = ag_node.args[0]
|
ag_node_in = ag_node.args[0]
|
||||||
|
if (
|
||||||
|
ag_node_in.op == "call_function" # type: ignore[union-attr]
|
||||||
|
and ag_node_in.target # type: ignore[union-attr]
|
||||||
|
== torch.ops.prims.convert_element_type.default # type: ignore[union-attr]
|
||||||
|
and len(ag_node_in.users) == 1 # type: ignore[union-attr]
|
||||||
|
):
|
||||||
|
ag_node_to_pre_nodes[ag_node].append(ag_node_in)
|
||||||
|
ag_node_in = ag_node_in.args[0] # type: ignore[union-attr]
|
||||||
|
|
||||||
ag_ins[bucket_idx].append(ag_node_in) # type: ignore[union-attr, arg-type]
|
ag_ins[bucket_idx].append(ag_node_in) # type: ignore[union-attr, arg-type]
|
||||||
ag_waits[bucket_idx].append(wait_node)
|
ag_waits[bucket_idx].append(wait_node)
|
||||||
@ -558,7 +714,7 @@ def merge_all_gather(
|
|||||||
|
|
||||||
replacements = _insert_fn_trace_before_node(
|
replacements = _insert_fn_trace_before_node(
|
||||||
g,
|
g,
|
||||||
all_gather_merge_fn_to_trace,
|
ag_merge_fn,
|
||||||
(
|
(
|
||||||
pytree.tree_map(lambda node: node.meta["val"], _ag_ins),
|
pytree.tree_map(lambda node: node.meta["val"], _ag_ins),
|
||||||
group_size,
|
group_size,
|
||||||
@ -582,3 +738,5 @@ def merge_all_gather(
|
|||||||
for ag_n, wait_n in zip(ag_buckets[bucket_idx], _ag_waits):
|
for ag_n, wait_n in zip(ag_buckets[bucket_idx], _ag_waits):
|
||||||
g.erase_node(wait_n)
|
g.erase_node(wait_n)
|
||||||
g.erase_node(ag_n)
|
g.erase_node(ag_n)
|
||||||
|
for n in reversed(ag_node_to_pre_nodes[ag_n]):
|
||||||
|
g.erase_node(n) # type: ignore[arg-type]
|
||||||
|
@ -56,6 +56,7 @@ def is_fsdp_reduce_scatter_wait(wait: torch.fx.Node) -> bool:
|
|||||||
def bucket_fsdp_all_gather(
|
def bucket_fsdp_all_gather(
|
||||||
gm: torch.fx.GraphModule,
|
gm: torch.fx.GraphModule,
|
||||||
bucket_cap_mb_by_bucket_idx: Optional[Callable[[int], float]] = None,
|
bucket_cap_mb_by_bucket_idx: Optional[Callable[[int], float]] = None,
|
||||||
|
mode: Optional[str] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Bucketing pass for SimpleFSDP all_gather ops.
|
Bucketing pass for SimpleFSDP all_gather ops.
|
||||||
@ -79,12 +80,13 @@ def bucket_fsdp_all_gather(
|
|||||||
)
|
)
|
||||||
if len(ag_buckets) == 0:
|
if len(ag_buckets) == 0:
|
||||||
return
|
return
|
||||||
merge_all_gather(gm, ag_buckets)
|
merge_all_gather(gm, ag_buckets, mode)
|
||||||
|
|
||||||
|
|
||||||
def bucket_fsdp_reduce_scatter(
|
def bucket_fsdp_reduce_scatter(
|
||||||
gm: torch.fx.GraphModule,
|
gm: torch.fx.GraphModule,
|
||||||
bucket_cap_mb_by_bucket_idx: Optional[Callable[[int], float]] = None,
|
bucket_cap_mb_by_bucket_idx: Optional[Callable[[int], float]] = None,
|
||||||
|
mode: Optional[str] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Bucketing pass for SimpleFSDP reduce_scatter ops.
|
Bucketing pass for SimpleFSDP reduce_scatter ops.
|
||||||
@ -109,4 +111,4 @@ def bucket_fsdp_reduce_scatter(
|
|||||||
)
|
)
|
||||||
if len(rs_buckets) == 0:
|
if len(rs_buckets) == 0:
|
||||||
return
|
return
|
||||||
merge_reduce_scatter(gm, rs_buckets)
|
merge_reduce_scatter(gm, rs_buckets, mode)
|
||||||
|
@ -204,13 +204,14 @@ def post_grad_passes(gm: torch.fx.GraphModule, is_inference: bool):
|
|||||||
|
|
||||||
p = (
|
p = (
|
||||||
bucket_fsdp_reduce_scatter
|
bucket_fsdp_reduce_scatter
|
||||||
if config.bucket_reduce_scatters_fx == "fsdp"
|
if "fsdp" in config.bucket_reduce_scatters_fx
|
||||||
else bucket_reduce_scatter
|
else bucket_reduce_scatter
|
||||||
)
|
)
|
||||||
GraphTransformObserver(gm, "bucket_reduce_scatters").apply_graph_pass(
|
GraphTransformObserver(gm, "bucket_reduce_scatters").apply_graph_pass(
|
||||||
lambda graph: p(
|
lambda graph: p(
|
||||||
graph.owning_module,
|
graph.owning_module,
|
||||||
config.bucket_reduce_scatters_fx_bucket_size_determinator,
|
config.bucket_reduce_scatters_fx_bucket_size_determinator,
|
||||||
|
config.bucket_reduce_scatters_fx,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
collectives_bucketing = True
|
collectives_bucketing = True
|
||||||
@ -223,13 +224,14 @@ def post_grad_passes(gm: torch.fx.GraphModule, is_inference: bool):
|
|||||||
|
|
||||||
p = (
|
p = (
|
||||||
bucket_fsdp_all_gather # type: ignore[assignment]
|
bucket_fsdp_all_gather # type: ignore[assignment]
|
||||||
if config.bucket_all_gathers_fx == "fsdp"
|
if "fsdp" in config.bucket_all_gathers_fx
|
||||||
else bucket_all_gather
|
else bucket_all_gather
|
||||||
)
|
)
|
||||||
GraphTransformObserver(gm, "bucket_all_gathers").apply_graph_pass(
|
GraphTransformObserver(gm, "bucket_all_gathers").apply_graph_pass(
|
||||||
lambda graph: p(
|
lambda graph: p(
|
||||||
graph.owning_module,
|
graph.owning_module,
|
||||||
config.bucket_all_gathers_fx_bucket_size_determinator,
|
config.bucket_all_gathers_fx_bucket_size_determinator,
|
||||||
|
config.bucket_all_gathers_fx,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
collectives_bucketing = True
|
collectives_bucketing = True
|
||||||
|
Reference in New Issue
Block a user