mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Fx collectives bucketing: add bucket all_reduce (#165351)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165351 Approved by: https://github.com/eellison
This commit is contained in:
committed by
PyTorch MergeBot
parent
f06e669f6c
commit
9272437cde
@ -34,11 +34,21 @@ def _rs_group_key(node: torch.fx.Node) -> tuple[str, str, torch.dtype]:
|
||||
return (group_name, reduce_op, dtype)
|
||||
|
||||
|
||||
def _ar_group_key(node: torch.fx.Node) -> tuple[str, str, torch.dtype]:
|
||||
_, reduce_op, group_name = node.args
|
||||
dtype = node.meta["val"].dtype
|
||||
assert isinstance(group_name, str)
|
||||
assert isinstance(reduce_op, str)
|
||||
return (group_name, reduce_op, dtype)
|
||||
|
||||
|
||||
def bucket_key(node: torch.fx.Node) -> object | None:
|
||||
if is_all_gather_into_tensor(node):
|
||||
return _ag_group_key(node)
|
||||
elif is_reduce_scatter_tensor(node):
|
||||
return _rs_group_key(node)
|
||||
elif is_all_reduce_tensor(node):
|
||||
return _ar_group_key(node)
|
||||
else:
|
||||
return None
|
||||
|
||||
@ -111,6 +121,13 @@ def is_wait_tensor(node: torch.fx.Node) -> bool:
|
||||
)
|
||||
|
||||
|
||||
def is_all_reduce_tensor(node: torch.fx.Node) -> bool:
|
||||
return (
|
||||
node.op == "call_function"
|
||||
and node.target == torch.ops._c10d_functional.all_reduce.default
|
||||
)
|
||||
|
||||
|
||||
def is_wait_tensor_from_all_gather_into_tensor(node: torch.fx.Node) -> bool:
|
||||
return is_wait_tensor(node) and is_all_gather_into_tensor(node.args[0]) # type: ignore[arg-type]
|
||||
|
||||
@ -293,6 +310,38 @@ def bucket_reduce_scatter_by_mb(
|
||||
)
|
||||
|
||||
|
||||
def bucket_all_reduce_by_mb(
|
||||
gm: torch.fx.GraphModule,
|
||||
bucket_cap_mb_by_bucket_idx: Callable[[int], float],
|
||||
filter_wait_node: Callable[[torch.fx.Node], bool] | None = None,
|
||||
) -> list[list[torch.fx.Node]]:
|
||||
return greedy_bucket_collective_by_mb(
|
||||
gm,
|
||||
bucket_cap_mb_by_bucket_idx,
|
||||
is_all_reduce_tensor,
|
||||
_ar_group_key,
|
||||
filter_wait_node,
|
||||
)
|
||||
|
||||
|
||||
def bucket_all_reduce(
|
||||
gm: torch.fx.GraphModule,
|
||||
bucket_cap_mb_by_bucket_idx: Callable[[int], float] | None = None,
|
||||
mode: str | None = None,
|
||||
) -> None:
|
||||
if bucket_cap_mb_by_bucket_idx is None:
|
||||
from torch._inductor.fx_passes.bucketing import (
|
||||
bucket_cap_mb_by_bucket_idx_default,
|
||||
)
|
||||
|
||||
bucket_cap_mb_by_bucket_idx = bucket_cap_mb_by_bucket_idx_default
|
||||
ar_buckets = bucket_all_reduce_by_mb(gm, bucket_cap_mb_by_bucket_idx)
|
||||
if len(ar_buckets) == 0:
|
||||
return
|
||||
for bucket in ar_buckets:
|
||||
merge_all_reduce_bucket(gm.graph, bucket, mode)
|
||||
|
||||
|
||||
@torch.library.custom_op("bucketing::_pre_bucket_reduce_scatter", mutates_args={})
|
||||
def _pre_bucket_reduce_scatter(
|
||||
rs_ins: list[torch.Tensor],
|
||||
@ -364,6 +413,24 @@ def reduce_scatter_merge_fn_to_trace(
|
||||
return new_outs
|
||||
|
||||
|
||||
def all_reduce_merge_fn_to_trace(
|
||||
ar_ins: list[torch.Tensor],
|
||||
group_name: str,
|
||||
reduce_op: str,
|
||||
reduce_dtype: torch.dtype, # type: ignore[name-defined]
|
||||
device: torch.device, # type: ignore[name-defined]
|
||||
) -> list[torch.Tensor]: # type: ignore[no-untyped-def]
|
||||
ar_ins_flattened = [x.view(-1) for x in ar_ins]
|
||||
new_ar_in = torch.cat(ar_ins_flattened)
|
||||
new_ar_out = torch.ops.c10d_functional.wait_tensor(
|
||||
torch.ops._c10d_functional.all_reduce.default(new_ar_in, reduce_op, group_name)
|
||||
)
|
||||
split_sizes = [x.numel() for x in ar_ins]
|
||||
new_outs_flat = new_ar_out.split(split_sizes)
|
||||
new_outs = [x.view(ar_in.shape) for x, ar_in in zip(new_outs_flat, ar_ins)]
|
||||
return new_outs
|
||||
|
||||
|
||||
@torch.library.custom_op("bucketing::_pre_bucket_all_gather", mutates_args={})
|
||||
def _pre_bucket_all_gather(
|
||||
ag_ins: list[torch.Tensor],
|
||||
@ -713,6 +780,49 @@ def merge_reduce_scatter_bucket(
|
||||
)
|
||||
|
||||
|
||||
def merge_all_reduce_bucket(
|
||||
g: torch.fx.Graph,
|
||||
ar_nodes: list[torch.fx.Node],
|
||||
mode: str | None = None,
|
||||
insert_before: torch.fx.Node | None = None,
|
||||
wait_insertion_point: torch.fx.Node | None = None,
|
||||
) -> tuple[list[torch.fx.Node], dict[torch.fx.Node, torch.fx.Node]]:
|
||||
ar0 = ar_nodes[0]
|
||||
ar0_val = ar0.meta["val"]
|
||||
_, reduce_op, group_name = ar0.args
|
||||
reduce_dtype = ar0_val.dtype
|
||||
device = ar0_val.device
|
||||
|
||||
for n in ar_nodes:
|
||||
ar_val = n.meta["val"]
|
||||
assert (
|
||||
n.args[1] == reduce_op
|
||||
and n.args[2] == group_name
|
||||
and ar_val.device == device
|
||||
and ar_val.dtype == reduce_dtype
|
||||
)
|
||||
|
||||
ar_merge_fn = all_reduce_merge_fn_to_trace
|
||||
|
||||
def create_trace_args(bucket_ins: list[torch.fx.Node]) -> tuple[Any, ...]:
|
||||
return (
|
||||
pytree.tree_map(lambda node: node.meta["val"], bucket_ins),
|
||||
group_name,
|
||||
reduce_op,
|
||||
reduce_dtype,
|
||||
device,
|
||||
)
|
||||
|
||||
return process_collective_bucket(
|
||||
g,
|
||||
ar_nodes,
|
||||
ar_merge_fn,
|
||||
create_trace_args,
|
||||
insert_before=insert_before,
|
||||
wait_insertion_point=wait_insertion_point,
|
||||
)
|
||||
|
||||
|
||||
def merge_all_gather_bucket(
|
||||
g: torch.fx.Graph,
|
||||
ag_nodes: list[torch.fx.Node],
|
||||
|
Reference in New Issue
Block a user