mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[inductor][bucketing] Fx collectives bucketing of multiple dtypes (#162470)
Bucketing of multiple dtypes to be processed in one bucketed collective. First target is to bucket bf16 and f32, but already can be used with other dtypes. For now multidtype bucketing is only supported with "custom_ops" mode. Non custom_ops needs additional work on inductor side. Pull Request resolved: https://github.com/pytorch/pytorch/pull/162470 Approved by: https://github.com/eellison
This commit is contained in:
committed by
PyTorch MergeBot
parent
1a34ff4e04
commit
7d87d7052e
@ -1,7 +1,8 @@
|
||||
import collections
|
||||
import logging
|
||||
import operator
|
||||
from collections import defaultdict
|
||||
from typing import Any, Callable
|
||||
from typing import Any, Callable, Literal, TypeAlias
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
@ -17,16 +18,24 @@ from torch.utils._ordered_set import OrderedSet
|
||||
logger: logging.Logger = logging.getLogger(__name__)
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
BucketMode: TypeAlias = Literal["default", "custom_ops", "custom_ops_multidtype"]
|
||||
|
||||
|
||||
# Helper functions moved to top for better organization
|
||||
def _ag_group_key(node: torch.fx.Node) -> tuple[str, torch.dtype]:
|
||||
def _ag_group_key(node: torch.fx.Node) -> tuple[str, torch.dtype]: # type: ignore[name-defined]
|
||||
_, group_size, group_name = node.args
|
||||
dtype = node.meta["val"].dtype
|
||||
assert isinstance(group_name, str)
|
||||
return (group_name, dtype)
|
||||
|
||||
|
||||
def _rs_group_key(node: torch.fx.Node) -> tuple[str, str, torch.dtype]:
|
||||
def _ag_group_key_multidtype(node: torch.fx.Node) -> tuple[str]:
|
||||
_, group_size, group_name = node.args
|
||||
assert isinstance(group_name, str)
|
||||
return (group_name,)
|
||||
|
||||
|
||||
def _rs_group_key(node: torch.fx.Node) -> tuple[str, str, torch.dtype]: # type: ignore[name-defined]
|
||||
_, reduce_op, group_size, group_name = node.args
|
||||
dtype = node.meta["val"].dtype
|
||||
assert isinstance(group_name, str)
|
||||
@ -53,6 +62,11 @@ def bucket_key(node: torch.fx.Node) -> object | None:
|
||||
return None
|
||||
|
||||
|
||||
def pick_bucket_dtype(dtypes: list[torch.dtype]) -> torch.dtype: # type: ignore[name-defined]
|
||||
assert len(dtypes) > 0
|
||||
return min(dtypes, key=operator.attrgetter("itemsize"))
|
||||
|
||||
|
||||
def bucket_cap_mb_by_bucket_idx_default(bucket_id: int) -> float:
|
||||
"""
|
||||
Determine the size of a bucket based on its ID.
|
||||
@ -69,7 +83,7 @@ def bucket_cap_mb_by_bucket_idx_default(bucket_id: int) -> float:
|
||||
def bucket_all_gather(
|
||||
gm: torch.fx.GraphModule,
|
||||
bucket_cap_mb_by_bucket_idx: Callable[[int], float] | None = None,
|
||||
mode: str | None = None,
|
||||
mode: BucketMode = "default",
|
||||
) -> None:
|
||||
if bucket_cap_mb_by_bucket_idx is None:
|
||||
from torch._inductor.fx_passes.bucketing import (
|
||||
@ -77,7 +91,7 @@ def bucket_all_gather(
|
||||
)
|
||||
|
||||
bucket_cap_mb_by_bucket_idx = bucket_cap_mb_by_bucket_idx_default
|
||||
ag_buckets = bucket_all_gather_by_mb(gm, bucket_cap_mb_by_bucket_idx)
|
||||
ag_buckets = bucket_all_gather_by_mb(gm, bucket_cap_mb_by_bucket_idx, None, mode)
|
||||
if len(ag_buckets) == 0:
|
||||
return
|
||||
merge_all_gather(gm, ag_buckets, mode)
|
||||
@ -86,7 +100,7 @@ def bucket_all_gather(
|
||||
def bucket_reduce_scatter(
|
||||
gm: torch.fx.GraphModule,
|
||||
bucket_cap_mb_by_bucket_idx: Callable[[int], float] | None = None,
|
||||
mode: str | None = None,
|
||||
mode: BucketMode = "default",
|
||||
) -> None:
|
||||
if bucket_cap_mb_by_bucket_idx is None:
|
||||
from torch._inductor.fx_passes.bucketing import (
|
||||
@ -94,7 +108,9 @@ def bucket_reduce_scatter(
|
||||
)
|
||||
|
||||
bucket_cap_mb_by_bucket_idx = bucket_cap_mb_by_bucket_idx_default
|
||||
rs_buckets = bucket_reduce_scatter_by_mb(gm, bucket_cap_mb_by_bucket_idx)
|
||||
rs_buckets = bucket_reduce_scatter_by_mb(
|
||||
gm, bucket_cap_mb_by_bucket_idx, None, mode
|
||||
)
|
||||
if len(rs_buckets) == 0:
|
||||
return
|
||||
merge_reduce_scatter(gm, rs_buckets, mode)
|
||||
@ -252,6 +268,7 @@ def bucket_all_gather_by_mb(
|
||||
gm: torch.fx.GraphModule,
|
||||
bucket_cap_mb_by_bucket_idx: Callable[[int], float],
|
||||
filter_wait_node: Callable[[torch.fx.Node], bool] | None = None,
|
||||
mode: BucketMode = "default",
|
||||
) -> list[list[torch.fx.Node]]:
|
||||
"""
|
||||
Identifies all all_gather nodes and groups them into buckets,
|
||||
@ -271,11 +288,15 @@ def bucket_all_gather_by_mb(
|
||||
list[list[torch.fx.Node]]: List of buckets, where each bucket is a list of all_gather nodes.
|
||||
"""
|
||||
|
||||
group_key_fn = (
|
||||
_ag_group_key_multidtype if mode and "multidtype" in mode else _ag_group_key
|
||||
)
|
||||
|
||||
return greedy_bucket_collective_by_mb(
|
||||
gm,
|
||||
bucket_cap_mb_by_bucket_idx,
|
||||
is_all_gather_into_tensor,
|
||||
_ag_group_key,
|
||||
group_key_fn,
|
||||
filter_wait_node,
|
||||
)
|
||||
|
||||
@ -284,6 +305,7 @@ def bucket_reduce_scatter_by_mb(
|
||||
gm: torch.fx.GraphModule,
|
||||
bucket_cap_mb_by_bucket_idx: Callable[[int], float],
|
||||
filter_wait_node: Callable[[torch.fx.Node], bool] | None = None,
|
||||
mode: BucketMode = "default",
|
||||
) -> list[list[torch.fx.Node]]:
|
||||
"""
|
||||
Identifies all reduce_scatter nodes and groups them into buckets,
|
||||
@ -301,6 +323,10 @@ def bucket_reduce_scatter_by_mb(
|
||||
list[list[torch.fx.Node]]: List of buckets, where each bucket is a list of reduce_scatter nodes.
|
||||
"""
|
||||
|
||||
assert "multidtype" not in mode, (
|
||||
"reduce scatter bucketing does not support multidtype"
|
||||
)
|
||||
|
||||
return greedy_bucket_collective_by_mb(
|
||||
gm,
|
||||
bucket_cap_mb_by_bucket_idx,
|
||||
@ -439,13 +465,17 @@ def _pre_bucket_all_gather(
|
||||
dtype: torch.dtype, # type: ignore[name-defined]
|
||||
rank: int,
|
||||
) -> torch.Tensor:
|
||||
ins_split_sizes = [ag_in.numel() for ag_in in ag_ins]
|
||||
ins_split_sizes_bytes = [ag_in.numel() * ag_in.element_size() for ag_in in ag_ins]
|
||||
bucket_dtype_size_bytes = dtype.itemsize
|
||||
ins_split_sizes = [
|
||||
_bytes // bucket_dtype_size_bytes for _bytes in ins_split_sizes_bytes
|
||||
]
|
||||
ag_input_numel = sum(ins_split_sizes)
|
||||
device = ag_ins[0].device
|
||||
new_ag_out = torch.empty(ag_input_numel * group_size, dtype=dtype, device=device)
|
||||
new_ag_in = new_ag_out.narrow(0, ag_input_numel * rank, ag_input_numel)
|
||||
foreach_copy_dsts = torch.split(new_ag_in, ins_split_sizes)
|
||||
ag_ins_flattened = [ag_in.reshape(-1) for ag_in in ag_ins]
|
||||
ag_ins_flattened = [ag_in.reshape(-1).view(dtype) for ag_in in ag_ins]
|
||||
torch._foreach_copy_(foreach_copy_dsts, ag_ins_flattened)
|
||||
return new_ag_out
|
||||
|
||||
@ -457,7 +487,11 @@ def _pre_bucket_all_gather_fake(
|
||||
dtype: torch.dtype, # type: ignore[name-defined]
|
||||
rank: int,
|
||||
) -> torch.Tensor:
|
||||
ins_split_sizes = [ag_in.numel() for ag_in in ag_ins]
|
||||
ins_split_sizes_bytes = [ag_in.numel() * ag_in.element_size() for ag_in in ag_ins]
|
||||
bucket_dtype_size_bytes = dtype.itemsize
|
||||
ins_split_sizes = [
|
||||
_bytes // bucket_dtype_size_bytes for _bytes in ins_split_sizes_bytes
|
||||
]
|
||||
ag_input_numel = sum(ins_split_sizes)
|
||||
device = ag_ins[0].device
|
||||
new_ag_out = torch.empty(ag_input_numel * group_size, dtype=dtype, device=device)
|
||||
@ -468,14 +502,28 @@ _pre_bucket_all_gather.register_fake(_pre_bucket_all_gather_fake)
|
||||
|
||||
|
||||
def all_gather_merge_fn_to_trace_custom_ops(
|
||||
ag_ins: list[torch.Tensor],
|
||||
_ag_ins: list[torch.Tensor],
|
||||
group_size: int,
|
||||
group_name: str,
|
||||
dtype: torch.dtype, # type: ignore[name-defined]
|
||||
out_dtypes: list[torch.dtype], # type: ignore[name-defined]
|
||||
rank: int,
|
||||
) -> list[torch.Tensor]:
|
||||
ag_ins = [
|
||||
torch._prims.convert_element_type(_ag_in, out_dtype)
|
||||
if _ag_in.dtype != out_dtype
|
||||
else _ag_in
|
||||
for _ag_in, out_dtype in zip(_ag_ins, out_dtypes)
|
||||
]
|
||||
ins_sizes = [ag_in.shape for ag_in in ag_ins]
|
||||
ins_split_sizes = [ag_in.numel() for ag_in in ag_ins]
|
||||
ins_split_sizes_bytes = [
|
||||
ag_in.numel() * out_dtype.itemsize
|
||||
for ag_in, out_dtype in zip(ag_ins, out_dtypes)
|
||||
]
|
||||
bucket_dtype_size_bytes = dtype.itemsize
|
||||
ins_split_sizes = [
|
||||
_bytes // bucket_dtype_size_bytes for _bytes in ins_split_sizes_bytes
|
||||
]
|
||||
ag_input_numel = sum(ins_split_sizes)
|
||||
new_ag_out = torch.ops.bucketing._pre_bucket_all_gather(
|
||||
ag_ins, group_size, group_name, dtype, rank
|
||||
@ -487,14 +535,14 @@ def all_gather_merge_fn_to_trace_custom_ops(
|
||||
)
|
||||
)
|
||||
new_ag_out_reshaped = wait_tensor.reshape(group_size, -1)
|
||||
outs = torch.split_with_sizes(
|
||||
outs_bucket_dtype = torch.split_with_sizes(
|
||||
new_ag_out_reshaped,
|
||||
ins_split_sizes,
|
||||
dim=1,
|
||||
)
|
||||
outs_reshaped = [
|
||||
o.reshape((shape[0] * group_size,) + shape[1:])
|
||||
for o, shape in zip(outs, ins_sizes)
|
||||
o.view(out_dtype).reshape((shape[0] * group_size,) + shape[1:])
|
||||
for o, shape, out_dtype in zip(outs_bucket_dtype, ins_sizes, out_dtypes)
|
||||
]
|
||||
return outs_reshaped
|
||||
|
||||
@ -504,6 +552,7 @@ def all_gather_merge_fn_to_trace(
|
||||
group_size: int,
|
||||
group_name: str,
|
||||
dtype: torch.dtype, # type: ignore[name-defined]
|
||||
out_dtypes: list[torch.dtype], # type: ignore[name-defined]
|
||||
rank: int,
|
||||
) -> list[torch.Tensor]:
|
||||
ins_sizes = [ag_in.shape for ag_in in ag_ins]
|
||||
@ -538,6 +587,7 @@ def all_gather_merge_fn_to_trace_functional(
|
||||
group_size: int,
|
||||
group_name: str,
|
||||
dtype: torch.dtype, # type: ignore[name-defined]
|
||||
out_dtypes: list[torch.dtype], # type: ignore[name-defined]
|
||||
rank: int,
|
||||
use_fsdp_ag_copy_in: bool = False,
|
||||
) -> list[torch.Tensor]:
|
||||
@ -733,7 +783,7 @@ def process_collective_bucket(
|
||||
def merge_reduce_scatter_bucket(
|
||||
g: torch.fx.Graph,
|
||||
rs_nodes: list[torch.fx.Node],
|
||||
mode: str | None = None,
|
||||
mode: BucketMode = "default",
|
||||
insert_before: torch.fx.Node | None = None,
|
||||
wait_insertion_point: torch.fx.Node | None = None,
|
||||
) -> tuple[list[torch.fx.Node], dict[torch.fx.Node, torch.fx.Node]]:
|
||||
@ -826,29 +876,27 @@ def merge_all_reduce_bucket(
|
||||
def merge_all_gather_bucket(
|
||||
g: torch.fx.Graph,
|
||||
ag_nodes: list[torch.fx.Node],
|
||||
mode: str | None = None,
|
||||
mode: BucketMode = "default",
|
||||
insert_before: torch.fx.Node | None = None,
|
||||
wait_insertion_point: torch.fx.Node | None = None,
|
||||
) -> tuple[list[torch.fx.Node], dict[torch.fx.Node, torch.fx.Node]]:
|
||||
from torch.distributed.distributed_c10d import _resolve_process_group
|
||||
|
||||
ag0 = ag_nodes[0]
|
||||
ag0_val = ag0.meta["val"]
|
||||
_, group_size, group_name = ag0.args
|
||||
dtype = ag0_val.dtype
|
||||
assert isinstance(group_name, str)
|
||||
_ag_dtypes: list[torch.dtype] = [] # type: ignore[name-defined]
|
||||
|
||||
for n in ag_nodes:
|
||||
assert (
|
||||
n.args[1] == group_size
|
||||
and n.args[2] == group_name
|
||||
and n.meta["val"].dtype == dtype
|
||||
)
|
||||
assert n.args[1] == group_size and n.args[2] == group_name
|
||||
_ag_dtypes.append(n.meta["val"].dtype)
|
||||
|
||||
bucket_dtype = pick_bucket_dtype(_ag_dtypes)
|
||||
|
||||
# Choose merge function based on mode
|
||||
ag_merge_fn = all_gather_merge_fn_to_trace
|
||||
if mode and "custom_ops" in mode:
|
||||
ag_merge_fn = all_gather_merge_fn_to_trace_custom_ops
|
||||
if mode is not None and "custom_ops" in mode:
|
||||
ag_merge_fn = all_gather_merge_fn_to_trace_custom_ops # type: ignore[assignment]
|
||||
|
||||
# Process bucket with lazy input collection
|
||||
rank: int = dist.get_rank(_resolve_process_group(group_name))
|
||||
@ -858,7 +906,8 @@ def merge_all_gather_bucket(
|
||||
pytree.tree_map(lambda node: node.meta["val"], bucket_ins),
|
||||
group_size,
|
||||
group_name,
|
||||
dtype,
|
||||
bucket_dtype,
|
||||
_ag_dtypes,
|
||||
rank,
|
||||
)
|
||||
|
||||
@ -874,7 +923,7 @@ def merge_all_gather_bucket(
|
||||
def merge_reduce_scatter(
|
||||
gm: torch.fx.GraphModule,
|
||||
rs_buckets: list[list[torch.fx.Node]],
|
||||
mode: str | None = None,
|
||||
mode: BucketMode = "default",
|
||||
) -> None:
|
||||
"""
|
||||
Merges specified buckets of reduce_scatter to joint reduce_scatter.
|
||||
@ -898,7 +947,7 @@ def merge_reduce_scatter(
|
||||
def merge_all_gather(
|
||||
gm: torch.fx.GraphModule,
|
||||
ag_buckets: list[list[torch.fx.Node]],
|
||||
mode: str | None = None,
|
||||
mode: BucketMode = "default",
|
||||
) -> None:
|
||||
"""
|
||||
Merges specified buckets of all_gather to joint all_gather.
|
||||
|
Reference in New Issue
Block a user