mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Bucketing of multiple dtypes to be processed in one bucketed collective. First target is to bucket bf16 and f32, but already can be used with other dtypes. For now multidtype bucketing is only supported with "custom_ops" mode. Non custom_ops needs additional work on inductor side. Pull Request resolved: https://github.com/pytorch/pytorch/pull/162470 Approved by: https://github.com/eellison
969 lines
33 KiB
Python
969 lines
33 KiB
Python
import collections
|
|
import logging
|
|
import operator
|
|
from collections import defaultdict
|
|
from typing import Any, Callable, Literal, TypeAlias
|
|
|
|
import torch
|
|
import torch.distributed as dist
|
|
import torch.utils._pytree as pytree
|
|
from torch._dispatch.python import enable_python_dispatcher
|
|
from torch._dynamo.utils import detect_fake_mode
|
|
from torch._inductor.runtime.runtime_utils import dynamo_timed
|
|
from torch._logging import trace_structured
|
|
from torch.fx.experimental.proxy_tensor import make_fx
|
|
from torch.utils._ordered_set import OrderedSet
|
|
|
|
|
|
logger: logging.Logger = logging.getLogger(__name__)
|
|
logger.setLevel(logging.INFO)
|
|
|
|
BucketMode: TypeAlias = Literal["default", "custom_ops", "custom_ops_multidtype"]
|
|
|
|
|
|
# Helper functions moved to top for better organization
|
|
def _ag_group_key(node: torch.fx.Node) -> tuple[str, torch.dtype]: # type: ignore[name-defined]
|
|
_, group_size, group_name = node.args
|
|
dtype = node.meta["val"].dtype
|
|
assert isinstance(group_name, str)
|
|
return (group_name, dtype)
|
|
|
|
|
|
def _ag_group_key_multidtype(node: torch.fx.Node) -> tuple[str]:
|
|
_, group_size, group_name = node.args
|
|
assert isinstance(group_name, str)
|
|
return (group_name,)
|
|
|
|
|
|
def _rs_group_key(node: torch.fx.Node) -> tuple[str, str, torch.dtype]: # type: ignore[name-defined]
|
|
_, reduce_op, group_size, group_name = node.args
|
|
dtype = node.meta["val"].dtype
|
|
assert isinstance(group_name, str)
|
|
assert isinstance(reduce_op, str)
|
|
return (group_name, reduce_op, dtype)
|
|
|
|
|
|
def _ar_group_key(node: torch.fx.Node) -> tuple[str, str, torch.dtype]:
|
|
_, reduce_op, group_name = node.args
|
|
dtype = node.meta["val"].dtype
|
|
assert isinstance(group_name, str)
|
|
assert isinstance(reduce_op, str)
|
|
return (group_name, reduce_op, dtype)
|
|
|
|
|
|
def bucket_key(node: torch.fx.Node) -> object | None:
|
|
if is_all_gather_into_tensor(node):
|
|
return _ag_group_key(node)
|
|
elif is_reduce_scatter_tensor(node):
|
|
return _rs_group_key(node)
|
|
elif is_all_reduce_tensor(node):
|
|
return _ar_group_key(node)
|
|
else:
|
|
return None
|
|
|
|
|
|
def pick_bucket_dtype(dtypes: list[torch.dtype]) -> torch.dtype: # type: ignore[name-defined]
|
|
assert len(dtypes) > 0
|
|
return min(dtypes, key=operator.attrgetter("itemsize"))
|
|
|
|
|
|
def bucket_cap_mb_by_bucket_idx_default(bucket_id: int) -> float:
|
|
"""
|
|
Determine the size of a bucket based on its ID.
|
|
|
|
Args:
|
|
bucket_id (int): The ID of the bucket.
|
|
|
|
Returns:
|
|
float: The size of the bucket.
|
|
"""
|
|
return 2000.0
|
|
|
|
|
|
def bucket_all_gather(
|
|
gm: torch.fx.GraphModule,
|
|
bucket_cap_mb_by_bucket_idx: Callable[[int], float] | None = None,
|
|
mode: BucketMode = "default",
|
|
) -> None:
|
|
if bucket_cap_mb_by_bucket_idx is None:
|
|
from torch._inductor.fx_passes.bucketing import (
|
|
bucket_cap_mb_by_bucket_idx_default,
|
|
)
|
|
|
|
bucket_cap_mb_by_bucket_idx = bucket_cap_mb_by_bucket_idx_default
|
|
ag_buckets = bucket_all_gather_by_mb(gm, bucket_cap_mb_by_bucket_idx, None, mode)
|
|
if len(ag_buckets) == 0:
|
|
return
|
|
merge_all_gather(gm, ag_buckets, mode)
|
|
|
|
|
|
def bucket_reduce_scatter(
|
|
gm: torch.fx.GraphModule,
|
|
bucket_cap_mb_by_bucket_idx: Callable[[int], float] | None = None,
|
|
mode: BucketMode = "default",
|
|
) -> None:
|
|
if bucket_cap_mb_by_bucket_idx is None:
|
|
from torch._inductor.fx_passes.bucketing import (
|
|
bucket_cap_mb_by_bucket_idx_default,
|
|
)
|
|
|
|
bucket_cap_mb_by_bucket_idx = bucket_cap_mb_by_bucket_idx_default
|
|
rs_buckets = bucket_reduce_scatter_by_mb(
|
|
gm, bucket_cap_mb_by_bucket_idx, None, mode
|
|
)
|
|
if len(rs_buckets) == 0:
|
|
return
|
|
merge_reduce_scatter(gm, rs_buckets, mode)
|
|
|
|
|
|
def is_all_gather_into_tensor(node: torch.fx.Node) -> bool: # type: ignore[arg-type]
|
|
return (
|
|
node.op == "call_function"
|
|
and node.target == torch.ops._c10d_functional.all_gather_into_tensor.default
|
|
)
|
|
|
|
|
|
def is_reduce_scatter_tensor(node: torch.fx.Node) -> bool:
|
|
return (
|
|
node.op == "call_function"
|
|
and node.target == torch.ops._c10d_functional.reduce_scatter_tensor.default
|
|
)
|
|
|
|
|
|
def is_wait_tensor(node: torch.fx.Node) -> bool:
|
|
return (
|
|
node.op == "call_function"
|
|
and node.target == torch.ops._c10d_functional.wait_tensor.default
|
|
)
|
|
|
|
|
|
def is_all_reduce_tensor(node: torch.fx.Node) -> bool:
|
|
return (
|
|
node.op == "call_function"
|
|
and node.target == torch.ops._c10d_functional.all_reduce.default
|
|
)
|
|
|
|
|
|
def is_wait_tensor_from_all_gather_into_tensor(node: torch.fx.Node) -> bool:
|
|
return is_wait_tensor(node) and is_all_gather_into_tensor(node.args[0]) # type: ignore[arg-type]
|
|
|
|
|
|
def collect_node_descendants(
|
|
graph: torch.fx.Graph,
|
|
) -> dict[torch.fx.Node, OrderedSet[torch.fx.Node]]:
|
|
"""
|
|
Collects the descendants of each node in the graph.
|
|
Args:
|
|
graph (torch.fx.Graph): The graph to collect descendants from.
|
|
Returns:
|
|
dict[torch.fx.Node, OrderedSet[torch.fx.Node]]: A dictionary mapping each node to its descendants.
|
|
"""
|
|
node_descendants: dict[torch.fx.Node, OrderedSet[torch.fx.Node]] = (
|
|
collections.defaultdict(OrderedSet)
|
|
)
|
|
outdegree = collections.defaultdict(int)
|
|
queue = []
|
|
|
|
for node in graph.nodes:
|
|
n_outdegree = len(node.users)
|
|
if n_outdegree == 0:
|
|
queue.append(node)
|
|
else:
|
|
outdegree[node] = len(node.users)
|
|
|
|
while queue:
|
|
node = queue.pop()
|
|
for input_node in node.all_input_nodes:
|
|
node_descendants[input_node] |= node_descendants[node]
|
|
node_descendants[input_node].add(node)
|
|
outdegree[input_node] -= 1
|
|
|
|
if outdegree[input_node] == 0:
|
|
queue.append(input_node)
|
|
|
|
return node_descendants
|
|
|
|
|
|
def greedy_bucket_collective_by_mb(
|
|
gm: torch.fx.GraphModule,
|
|
bucket_cap_mb_by_bucket_idx: Callable[[int], float],
|
|
filter_node: Callable[[torch.fx.Node], bool],
|
|
node_group_key: Callable[[torch.fx.Node], Any],
|
|
filter_wait_node: Callable[[torch.fx.Node], bool] | None = None,
|
|
) -> list[list[torch.fx.Node]]:
|
|
"""
|
|
Bucketing adjacent collectives with equal node_group_key.
|
|
We can not bucket non adjacent collectives,
|
|
as this will effectively change the order of collectives.
|
|
Reordering can lead to different order on different ranks.
|
|
"""
|
|
g = gm.graph
|
|
found_candidates = False
|
|
for node in g.nodes:
|
|
if filter_node(node):
|
|
found_candidates = True
|
|
break
|
|
if not found_candidates:
|
|
return []
|
|
|
|
# TODO: pearce kelly algorithm for detecting cycles
|
|
node_descendents = collect_node_descendants(gm.graph)
|
|
|
|
nodes_groups: list[list[torch.fx.Node]] = []
|
|
cur_group: list[torch.fx.Node] = []
|
|
cur_group_key = None
|
|
|
|
for node in g.nodes:
|
|
if is_wait_tensor(node) and filter_node(node.args[0]):
|
|
if (filter_wait_node is None) or filter_wait_node(node):
|
|
coll_node = node.args[0]
|
|
group_key = node_group_key(coll_node)
|
|
if group_key == cur_group_key:
|
|
cur_group.append(coll_node)
|
|
else:
|
|
if len(cur_group) > 1:
|
|
nodes_groups.append(cur_group)
|
|
cur_group = [coll_node]
|
|
cur_group_key = group_key
|
|
|
|
if len(cur_group) > 1:
|
|
nodes_groups.append(cur_group)
|
|
|
|
buckets: list[list[torch.fx.Node]] = []
|
|
for nodes in nodes_groups:
|
|
cur_bucket: list[torch.fx.Node] = []
|
|
cur_bucket_descendents: OrderedSet[torch.fx.Node] = OrderedSet()
|
|
cur_bucket_size_bytes: int = 0
|
|
cur_bucket_id: int = 0
|
|
bucket_size_bytes = int(
|
|
bucket_cap_mb_by_bucket_idx(cur_bucket_id) * 1024 * 1024
|
|
)
|
|
for node in nodes:
|
|
if node in cur_bucket_descendents:
|
|
# if there is a path from node to the current bucket, we cannot horizontally fuse (bucket)
|
|
continue
|
|
assert "val" in node.meta
|
|
n_val = node.meta["val"]
|
|
out_size_bytes = n_val.numel() * n_val.element_size()
|
|
n_input_val = node.all_input_nodes[0].meta["val"]
|
|
in_size_bytes = n_input_val.numel() * n_input_val.element_size()
|
|
size_bytes = max(out_size_bytes, in_size_bytes)
|
|
if cur_bucket_size_bytes + size_bytes > bucket_size_bytes and cur_bucket:
|
|
# Current bucket is full, create new bucket
|
|
if len(cur_bucket) > 1:
|
|
buckets.append(cur_bucket)
|
|
cur_bucket = []
|
|
cur_bucket_size_bytes = 0
|
|
cur_bucket_id += 1
|
|
cur_bucket_descendents = OrderedSet()
|
|
cur_bucket_size_bytes += size_bytes
|
|
cur_bucket.append(node)
|
|
cur_bucket_descendents |= node_descendents[node]
|
|
if len(cur_bucket) > 1:
|
|
buckets.append(cur_bucket)
|
|
return buckets
|
|
|
|
|
|
def bucket_all_gather_by_mb(
|
|
gm: torch.fx.GraphModule,
|
|
bucket_cap_mb_by_bucket_idx: Callable[[int], float],
|
|
filter_wait_node: Callable[[torch.fx.Node], bool] | None = None,
|
|
mode: BucketMode = "default",
|
|
) -> list[list[torch.fx.Node]]:
|
|
"""
|
|
Identifies all all_gather nodes and groups them into buckets,
|
|
based on size limit `bucket_cap_mb_by_bucket_idx`.
|
|
|
|
Args:
|
|
gm (torch.fx.GraphModule): GraphModule where to bucket all_gathers.
|
|
bucket_cap_mb_by_bucket_idx (Callable[[int], float]): Callable to specify cap of the bucket
|
|
in megabytes by bucket idx. The idea of `bucket_cap_mb_by_bucket_idx` is to allow
|
|
to specify different sizes of the buckets at the start,
|
|
as first all_gather is usually exposed. Interface of bucket_cap_mb_by_bucket_idx
|
|
is `bucket_cap_mb_by_bucket_idx_default` function that is default value for `bucket_cap_mb_by_bucket_idx`.
|
|
filter_wait_node (Callable[[torch.fx.Node], bool] | None): If specified,
|
|
only all_gather nodes with wait_node that satisfy `filter_wait_node` will be bucketed.
|
|
|
|
Returns:
|
|
list[list[torch.fx.Node]]: List of buckets, where each bucket is a list of all_gather nodes.
|
|
"""
|
|
|
|
group_key_fn = (
|
|
_ag_group_key_multidtype if mode and "multidtype" in mode else _ag_group_key
|
|
)
|
|
|
|
return greedy_bucket_collective_by_mb(
|
|
gm,
|
|
bucket_cap_mb_by_bucket_idx,
|
|
is_all_gather_into_tensor,
|
|
group_key_fn,
|
|
filter_wait_node,
|
|
)
|
|
|
|
|
|
def bucket_reduce_scatter_by_mb(
|
|
gm: torch.fx.GraphModule,
|
|
bucket_cap_mb_by_bucket_idx: Callable[[int], float],
|
|
filter_wait_node: Callable[[torch.fx.Node], bool] | None = None,
|
|
mode: BucketMode = "default",
|
|
) -> list[list[torch.fx.Node]]:
|
|
"""
|
|
Identifies all reduce_scatter nodes and groups them into buckets,
|
|
based on size limit `bucket_cap_mb_by_bucket_idx`.
|
|
|
|
Args:
|
|
gm (torch.fx.GraphModule): GraphModule where to bucket reduce_scatters.
|
|
bucket_cap_mb_by_bucket_idx (Callable[[int], float]): Callable to specify cap of the bucket
|
|
in megabytes by bucket idx. The idea of `bucket_cap_mb_by_bucket_idx` is to allow
|
|
to specify different sizes of the buckets.
|
|
filter_wait_node (Callable[[torch.fx.Node], bool] | None): If specified,
|
|
only reduce_scatter nodes with wait_node that satisfy `filter_wait_node` will be bucketed.
|
|
|
|
Returns:
|
|
list[list[torch.fx.Node]]: List of buckets, where each bucket is a list of reduce_scatter nodes.
|
|
"""
|
|
|
|
assert "multidtype" not in mode, (
|
|
"reduce scatter bucketing does not support multidtype"
|
|
)
|
|
|
|
return greedy_bucket_collective_by_mb(
|
|
gm,
|
|
bucket_cap_mb_by_bucket_idx,
|
|
is_reduce_scatter_tensor,
|
|
_rs_group_key,
|
|
filter_wait_node,
|
|
)
|
|
|
|
|
|
def bucket_all_reduce_by_mb(
|
|
gm: torch.fx.GraphModule,
|
|
bucket_cap_mb_by_bucket_idx: Callable[[int], float],
|
|
filter_wait_node: Callable[[torch.fx.Node], bool] | None = None,
|
|
) -> list[list[torch.fx.Node]]:
|
|
return greedy_bucket_collective_by_mb(
|
|
gm,
|
|
bucket_cap_mb_by_bucket_idx,
|
|
is_all_reduce_tensor,
|
|
_ar_group_key,
|
|
filter_wait_node,
|
|
)
|
|
|
|
|
|
def bucket_all_reduce(
|
|
gm: torch.fx.GraphModule,
|
|
bucket_cap_mb_by_bucket_idx: Callable[[int], float] | None = None,
|
|
mode: str | None = None,
|
|
) -> None:
|
|
if bucket_cap_mb_by_bucket_idx is None:
|
|
from torch._inductor.fx_passes.bucketing import (
|
|
bucket_cap_mb_by_bucket_idx_default,
|
|
)
|
|
|
|
bucket_cap_mb_by_bucket_idx = bucket_cap_mb_by_bucket_idx_default
|
|
ar_buckets = bucket_all_reduce_by_mb(gm, bucket_cap_mb_by_bucket_idx)
|
|
if len(ar_buckets) == 0:
|
|
return
|
|
for bucket in ar_buckets:
|
|
merge_all_reduce_bucket(gm.graph, bucket, mode)
|
|
|
|
|
|
@torch.library.custom_op("bucketing::_pre_bucket_reduce_scatter", mutates_args={})
|
|
def _pre_bucket_reduce_scatter(
|
|
rs_ins: list[torch.Tensor],
|
|
group_size: int,
|
|
) -> torch.Tensor:
|
|
rs_ins_flattened = [x.view(group_size, -1) for x in rs_ins]
|
|
new_rs_in = torch.cat(rs_ins_flattened, dim=1).flatten()
|
|
return new_rs_in
|
|
|
|
|
|
def _pre_bucket_reduce_scatter_fake(
|
|
rs_ins: list[torch.Tensor],
|
|
group_size: int,
|
|
) -> torch.Tensor:
|
|
out_numel = sum(rs_in.numel() for rs_in in rs_ins)
|
|
return torch.empty((out_numel,), device=rs_ins[0].device, dtype=rs_ins[0].dtype)
|
|
|
|
|
|
_pre_bucket_reduce_scatter.register_fake(_pre_bucket_reduce_scatter_fake)
|
|
|
|
|
|
def reduce_scatter_merge_fn_to_trace_custom_ops(
|
|
rs_ins: list[torch.Tensor],
|
|
group_size: int,
|
|
group_name: str,
|
|
reduce_op: str,
|
|
reduce_dtype: torch.dtype, # type: ignore[name-defined]
|
|
device: torch.device, # type: ignore[name-defined]
|
|
) -> list[torch.Tensor]: # type: ignore[no-untyped-def]
|
|
new_out_sizes = [(x.shape[0] // group_size,) + x.shape[1:] for x in rs_ins]
|
|
new_out_numels = [x.numel() // group_size for x in rs_ins]
|
|
|
|
new_rs_in = torch.ops.bucketing._pre_bucket_reduce_scatter(rs_ins, group_size)
|
|
|
|
# TODO - either use torch.cat or make sure inductor foreach codegen
|
|
# fires more reliably
|
|
new_rs_out = torch.ops.c10d_functional.wait_tensor(
|
|
torch.ops._c10d_functional.reduce_scatter_tensor.default(
|
|
new_rs_in, reduce_op, group_size, group_name
|
|
)
|
|
)
|
|
new_out_flat = new_rs_out.split(new_out_numels, 0)
|
|
new_outs = [x.view(s) for x, s in zip(new_out_flat, new_out_sizes)]
|
|
return new_outs
|
|
|
|
|
|
def reduce_scatter_merge_fn_to_trace(
|
|
rs_ins: list[torch.Tensor],
|
|
group_size: int,
|
|
group_name: str,
|
|
reduce_op: str,
|
|
reduce_dtype: torch.dtype, # type: ignore[name-defined]
|
|
device: torch.device, # type: ignore[name-defined]
|
|
) -> list[torch.Tensor]: # type: ignore[no-untyped-def]
|
|
rs_ins_flattened = [x.view(group_size, -1) for x in rs_ins]
|
|
|
|
new_out_sizes = [(x.shape[0] // group_size,) + x.shape[1:] for x in rs_ins]
|
|
new_out_numels = [x.numel() // group_size for x in rs_ins]
|
|
|
|
new_rs_in = torch.cat(rs_ins_flattened, dim=1).flatten()
|
|
|
|
new_rs_out = torch.ops.c10d_functional.wait_tensor(
|
|
torch.ops._c10d_functional.reduce_scatter_tensor.default(
|
|
new_rs_in, reduce_op, group_size, group_name
|
|
)
|
|
)
|
|
new_out_flat = new_rs_out.split(new_out_numels, 0)
|
|
new_outs = [x.view(s) for x, s in zip(new_out_flat, new_out_sizes)]
|
|
return new_outs
|
|
|
|
|
|
def all_reduce_merge_fn_to_trace(
|
|
ar_ins: list[torch.Tensor],
|
|
group_name: str,
|
|
reduce_op: str,
|
|
reduce_dtype: torch.dtype, # type: ignore[name-defined]
|
|
device: torch.device, # type: ignore[name-defined]
|
|
) -> list[torch.Tensor]: # type: ignore[no-untyped-def]
|
|
ar_ins_flattened = [x.view(-1) for x in ar_ins]
|
|
new_ar_in = torch.cat(ar_ins_flattened)
|
|
new_ar_out = torch.ops.c10d_functional.wait_tensor(
|
|
torch.ops._c10d_functional.all_reduce.default(new_ar_in, reduce_op, group_name)
|
|
)
|
|
split_sizes = [x.numel() for x in ar_ins]
|
|
new_outs_flat = new_ar_out.split(split_sizes)
|
|
new_outs = [x.view(ar_in.shape) for x, ar_in in zip(new_outs_flat, ar_ins)]
|
|
return new_outs
|
|
|
|
|
|
@torch.library.custom_op("bucketing::_pre_bucket_all_gather", mutates_args={})
|
|
def _pre_bucket_all_gather(
|
|
ag_ins: list[torch.Tensor],
|
|
group_size: int,
|
|
group_name: str,
|
|
dtype: torch.dtype, # type: ignore[name-defined]
|
|
rank: int,
|
|
) -> torch.Tensor:
|
|
ins_split_sizes_bytes = [ag_in.numel() * ag_in.element_size() for ag_in in ag_ins]
|
|
bucket_dtype_size_bytes = dtype.itemsize
|
|
ins_split_sizes = [
|
|
_bytes // bucket_dtype_size_bytes for _bytes in ins_split_sizes_bytes
|
|
]
|
|
ag_input_numel = sum(ins_split_sizes)
|
|
device = ag_ins[0].device
|
|
new_ag_out = torch.empty(ag_input_numel * group_size, dtype=dtype, device=device)
|
|
new_ag_in = new_ag_out.narrow(0, ag_input_numel * rank, ag_input_numel)
|
|
foreach_copy_dsts = torch.split(new_ag_in, ins_split_sizes)
|
|
ag_ins_flattened = [ag_in.reshape(-1).view(dtype) for ag_in in ag_ins]
|
|
torch._foreach_copy_(foreach_copy_dsts, ag_ins_flattened)
|
|
return new_ag_out
|
|
|
|
|
|
def _pre_bucket_all_gather_fake(
|
|
ag_ins: list[torch.Tensor],
|
|
group_size: int,
|
|
group_name: str,
|
|
dtype: torch.dtype, # type: ignore[name-defined]
|
|
rank: int,
|
|
) -> torch.Tensor:
|
|
ins_split_sizes_bytes = [ag_in.numel() * ag_in.element_size() for ag_in in ag_ins]
|
|
bucket_dtype_size_bytes = dtype.itemsize
|
|
ins_split_sizes = [
|
|
_bytes // bucket_dtype_size_bytes for _bytes in ins_split_sizes_bytes
|
|
]
|
|
ag_input_numel = sum(ins_split_sizes)
|
|
device = ag_ins[0].device
|
|
new_ag_out = torch.empty(ag_input_numel * group_size, dtype=dtype, device=device)
|
|
return new_ag_out
|
|
|
|
|
|
_pre_bucket_all_gather.register_fake(_pre_bucket_all_gather_fake)
|
|
|
|
|
|
def all_gather_merge_fn_to_trace_custom_ops(
|
|
_ag_ins: list[torch.Tensor],
|
|
group_size: int,
|
|
group_name: str,
|
|
dtype: torch.dtype, # type: ignore[name-defined]
|
|
out_dtypes: list[torch.dtype], # type: ignore[name-defined]
|
|
rank: int,
|
|
) -> list[torch.Tensor]:
|
|
ag_ins = [
|
|
torch._prims.convert_element_type(_ag_in, out_dtype)
|
|
if _ag_in.dtype != out_dtype
|
|
else _ag_in
|
|
for _ag_in, out_dtype in zip(_ag_ins, out_dtypes)
|
|
]
|
|
ins_sizes = [ag_in.shape for ag_in in ag_ins]
|
|
ins_split_sizes_bytes = [
|
|
ag_in.numel() * out_dtype.itemsize
|
|
for ag_in, out_dtype in zip(ag_ins, out_dtypes)
|
|
]
|
|
bucket_dtype_size_bytes = dtype.itemsize
|
|
ins_split_sizes = [
|
|
_bytes // bucket_dtype_size_bytes for _bytes in ins_split_sizes_bytes
|
|
]
|
|
ag_input_numel = sum(ins_split_sizes)
|
|
new_ag_out = torch.ops.bucketing._pre_bucket_all_gather(
|
|
ag_ins, group_size, group_name, dtype, rank
|
|
)
|
|
new_ag_in = new_ag_out.narrow(0, ag_input_numel * rank, ag_input_numel)
|
|
wait_tensor = torch.ops.c10d_functional.wait_tensor(
|
|
torch.ops._c10d_functional.all_gather_into_tensor_out.default(
|
|
new_ag_in, group_size, group_name, out=new_ag_out
|
|
)
|
|
)
|
|
new_ag_out_reshaped = wait_tensor.reshape(group_size, -1)
|
|
outs_bucket_dtype = torch.split_with_sizes(
|
|
new_ag_out_reshaped,
|
|
ins_split_sizes,
|
|
dim=1,
|
|
)
|
|
outs_reshaped = [
|
|
o.view(out_dtype).reshape((shape[0] * group_size,) + shape[1:])
|
|
for o, shape, out_dtype in zip(outs_bucket_dtype, ins_sizes, out_dtypes)
|
|
]
|
|
return outs_reshaped
|
|
|
|
|
|
def all_gather_merge_fn_to_trace(
|
|
ag_ins: list[torch.Tensor],
|
|
group_size: int,
|
|
group_name: str,
|
|
dtype: torch.dtype, # type: ignore[name-defined]
|
|
out_dtypes: list[torch.dtype], # type: ignore[name-defined]
|
|
rank: int,
|
|
) -> list[torch.Tensor]:
|
|
ins_sizes = [ag_in.shape for ag_in in ag_ins]
|
|
ins_split_sizes = [ag_in.numel() for ag_in in ag_ins]
|
|
ag_input_numel = sum(ins_split_sizes)
|
|
device = ag_ins[0].device
|
|
new_ag_out = torch.empty(ag_input_numel * group_size, dtype=dtype, device=device)
|
|
new_ag_in = new_ag_out.narrow(0, ag_input_numel * rank, ag_input_numel)
|
|
foreach_copy_dsts = torch.split(new_ag_in, ins_split_sizes)
|
|
ag_ins_flattened = [ag_in.reshape(-1) for ag_in in ag_ins]
|
|
torch._foreach_copy_(foreach_copy_dsts, ag_ins_flattened)
|
|
wait_tensor = torch.ops.c10d_functional.wait_tensor(
|
|
torch.ops._c10d_functional.all_gather_into_tensor_out.default(
|
|
new_ag_in, group_size, group_name, out=new_ag_out
|
|
)
|
|
)
|
|
new_ag_out_reshaped = wait_tensor.reshape(group_size, -1)
|
|
outs = torch.split_with_sizes(
|
|
new_ag_out_reshaped,
|
|
ins_split_sizes,
|
|
dim=1,
|
|
)
|
|
outs_reshaped = [
|
|
o.reshape((shape[0] * group_size,) + shape[1:])
|
|
for o, shape in zip(outs, ins_sizes)
|
|
]
|
|
return outs_reshaped
|
|
|
|
|
|
def all_gather_merge_fn_to_trace_functional(
|
|
ag_ins: list[torch.Tensor],
|
|
group_size: int,
|
|
group_name: str,
|
|
dtype: torch.dtype, # type: ignore[name-defined]
|
|
out_dtypes: list[torch.dtype], # type: ignore[name-defined]
|
|
rank: int,
|
|
use_fsdp_ag_copy_in: bool = False,
|
|
) -> list[torch.Tensor]:
|
|
# Implementation that is functional in graph,
|
|
# but uses custom op torch.ops.fsdp.all_gather_copy_in.
|
|
ins_sizes = [ag_in.shape for ag_in in ag_ins]
|
|
ins_split_sizes = [ag_in.numel() for ag_in in ag_ins]
|
|
ag_input_numel = sum(ins_split_sizes)
|
|
device = ag_ins[0].device
|
|
new_ag_out = torch.empty(ag_input_numel * group_size, dtype=dtype, device=device)
|
|
ag_ins_flattened = [ag_in.reshape(-1) for ag_in in ag_ins]
|
|
if use_fsdp_ag_copy_in:
|
|
new_ag_in, new_ag_out = torch.ops.fsdp.all_gather_copy_in(
|
|
ag_ins_flattened, new_ag_out, ins_split_sizes, ag_input_numel, rank
|
|
)
|
|
else:
|
|
new_ag_in = torch.cat(ag_ins_flattened, dim=0)
|
|
wait_tensor = torch.ops.c10d_functional.wait_tensor(
|
|
torch.ops._c10d_functional.all_gather_into_tensor_out.default(
|
|
new_ag_in, group_size, group_name, out=new_ag_out
|
|
)
|
|
)
|
|
new_ag_out_reshaped = wait_tensor.reshape(group_size, -1)
|
|
outs = torch.split_with_sizes(
|
|
new_ag_out_reshaped,
|
|
ins_split_sizes,
|
|
dim=1,
|
|
)
|
|
outs_reshaped = [
|
|
o.reshape((shape[0] * group_size,) + shape[1:])
|
|
for o, shape in zip(outs, ins_sizes)
|
|
]
|
|
return outs_reshaped
|
|
|
|
|
|
def _trace(fn, inps) -> torch.fx.GraphModule: # type: ignore[no-untyped-def]
|
|
with dynamo_timed("fx.bucketing._trace", log_pt2_compile_event=True):
|
|
fake_mode = detect_fake_mode(inps)
|
|
assert fake_mode is not None
|
|
with fake_mode, enable_python_dispatcher():
|
|
out = make_fx(fn)(*inps)
|
|
for node in out.graph.find_nodes(
|
|
op="call_function", target=torch.ops.aten.detach.default
|
|
):
|
|
node.replace_all_uses_with(node.args[0])
|
|
out.graph.erase_node(node)
|
|
return out
|
|
|
|
|
|
def _insert_fn_trace_before_node( # type: ignore[no-untyped-def]
|
|
g: torch.fx.Graph,
|
|
fn_to_trace,
|
|
inps,
|
|
insert_before_node: torch.fx.Node,
|
|
g_fn_inps: list[torch.fx.Node],
|
|
g_fn_outs: list[torch.fx.Node],
|
|
) -> tuple[dict[torch.fx.Node, torch.fx.Node], list[torch.fx.Node]]: # type: ignore[no-untyped-def]
|
|
"""
|
|
Helper function that traces :attr:`fn_to_trace` with inputs
|
|
:attr:`inps`.
|
|
The result function graph will be inserted before :attr:`insert_before_node`,
|
|
using :attr:`g_fn_inps` nodes of original graph as inputs of function graph,
|
|
function graph outputs will replace :attr:`g_fn_outs` in original graph.
|
|
|
|
Returns:
|
|
(replacements, new_nodes): Dictionary mapping old to new nodes, and list of all newly inserted nodes
|
|
"""
|
|
with dynamo_timed(
|
|
"fx.bucketing._insert_fn_trace_before_node", log_pt2_compile_event=True
|
|
):
|
|
fn_gm = _trace(
|
|
fn_to_trace,
|
|
inps,
|
|
)
|
|
fn_g = fn_gm.graph
|
|
fn_g_ins = fn_g.find_nodes(op="placeholder")
|
|
env = {fn_g_ins[idx]: g_fn_inps[idx] for idx in range(len(g_fn_inps))}
|
|
g_fn_new_outs: list[torch.fx.Node] = []
|
|
new_nodes: list[torch.fx.Node] = [] # Track all newly inserted nodes
|
|
|
|
with g.inserting_before(insert_before_node):
|
|
for _n in fn_g.nodes:
|
|
if _n.op == "placeholder":
|
|
continue
|
|
_new_n = g.node_copy(_n, lambda x: env[x])
|
|
env[_n] = _new_n
|
|
if _n.op == "output":
|
|
g_fn_new_outs = _new_n.args[0] # type: ignore[assignment]
|
|
g.erase_node(_new_n)
|
|
else:
|
|
new_nodes.append(_new_n) # Track non-output nodes
|
|
|
|
replacements = { # noqa: C416
|
|
orig_out: new_out for orig_out, new_out in zip(g_fn_outs, g_fn_new_outs)
|
|
}
|
|
for orig_out, new_out in zip(g_fn_outs, g_fn_new_outs):
|
|
orig_out.replace_all_uses_with(new_out)
|
|
|
|
return replacements, new_nodes
|
|
|
|
|
|
def process_collective_bucket(
|
|
g: torch.fx.Graph,
|
|
bucket_nodes: list[torch.fx.Node],
|
|
fn_to_trace: Callable[..., list[torch.Tensor]],
|
|
trace_args_fn: Callable[[list[torch.fx.Node]], tuple[Any, ...]],
|
|
insert_before: torch.fx.Node | None = None,
|
|
wait_insertion_point: torch.fx.Node | None = None,
|
|
) -> tuple[list[torch.fx.Node], dict[torch.fx.Node, torch.fx.Node]]:
|
|
"""
|
|
Process a single bucket of collective operation nodes with flexible insertion control.
|
|
|
|
Args:
|
|
g: The graph to modify
|
|
bucket_nodes: Nodes in the current bucket to process
|
|
fn_to_trace: Function to trace and insert
|
|
trace_args_fn: Function to create trace arguments from inputs
|
|
insert_before: Where to insert the traced function (default: after last bucket node)
|
|
wait_insertion_point: If provided, move all nodes from wait() onwards to before this node
|
|
|
|
Returns:
|
|
new_nodes: List of all newly inserted nodes
|
|
replacements: Dictionary mapping old wait nodes to new output nodes
|
|
"""
|
|
# Collect inputs and waits from current bucket
|
|
bucket_ins: list[torch.fx.Node] = []
|
|
bucket_waits: list[torch.fx.Node] = []
|
|
ag_node_to_pre_nodes: dict[torch.fx.Node, list[torch.fx.Node]] = defaultdict(list)
|
|
|
|
for n in bucket_nodes:
|
|
assert len(n.users) == 1, f"Expected single user for {n}, got {n.users}"
|
|
wait_n = next(iter(n.users))
|
|
|
|
# Handle convert_element_type operations (for all_gather)
|
|
node_in = n.args[0]
|
|
if (
|
|
is_all_gather_into_tensor(n)
|
|
and isinstance(node_in, torch.fx.Node) # Add type check
|
|
and node_in.op == "call_function"
|
|
and node_in.target == torch.ops.prims.convert_element_type.default
|
|
and len(node_in.users) == 1
|
|
):
|
|
ag_node_to_pre_nodes[n].append(node_in)
|
|
node_in = node_in.args[0]
|
|
|
|
assert isinstance(node_in, torch.fx.Node) # Ensure node_in is a Node
|
|
bucket_ins.append(node_in)
|
|
bucket_waits.append(wait_n)
|
|
|
|
# Create trace arguments
|
|
trace_args = trace_args_fn(bucket_ins)
|
|
|
|
# Determine insertion point
|
|
if insert_before is None:
|
|
insert_before = bucket_nodes[-1].next
|
|
|
|
# Insert traced function and get replacements + new nodes
|
|
replacements, new_nodes = _insert_fn_trace_before_node(
|
|
g,
|
|
fn_to_trace,
|
|
trace_args,
|
|
insert_before,
|
|
bucket_ins,
|
|
bucket_waits,
|
|
)
|
|
|
|
# If requested, move wait nodes and everything after to specified location
|
|
if wait_insertion_point is not None:
|
|
# Find the first wait node in new_nodes
|
|
wait_start_idx = None
|
|
for i, node in enumerate(new_nodes):
|
|
if is_wait_tensor(node):
|
|
wait_start_idx = i
|
|
break
|
|
|
|
# Move all nodes from wait onwards (including the wait)
|
|
if wait_start_idx is not None:
|
|
nodes_to_move = new_nodes[wait_start_idx:]
|
|
for node in nodes_to_move:
|
|
wait_insertion_point.prepend(node)
|
|
|
|
# Erase old nodes
|
|
for node, wait_n in zip(bucket_nodes, bucket_waits):
|
|
g.erase_node(wait_n)
|
|
g.erase_node(node)
|
|
# Erase any convert_element_type nodes we tracked
|
|
for pre_node in reversed(ag_node_to_pre_nodes[node]):
|
|
g.erase_node(pre_node)
|
|
|
|
return new_nodes, replacements
|
|
|
|
|
|
def merge_reduce_scatter_bucket(
|
|
g: torch.fx.Graph,
|
|
rs_nodes: list[torch.fx.Node],
|
|
mode: BucketMode = "default",
|
|
insert_before: torch.fx.Node | None = None,
|
|
wait_insertion_point: torch.fx.Node | None = None,
|
|
) -> tuple[list[torch.fx.Node], dict[torch.fx.Node, torch.fx.Node]]:
|
|
# Validate bucket consistency
|
|
rs0 = rs_nodes[0]
|
|
rs0_val = rs0.meta["val"]
|
|
_, reduce_op, group_size, group_name = rs0.args
|
|
reduce_dtype = rs0_val.dtype
|
|
device = rs0_val.device
|
|
|
|
for n in rs_nodes:
|
|
rs_val = n.meta["val"]
|
|
assert (
|
|
n.args[1] == reduce_op
|
|
and n.args[2] == group_size
|
|
and n.args[3] == group_name
|
|
and rs_val.device == device
|
|
and rs_val.dtype == reduce_dtype
|
|
)
|
|
|
|
# Choose merge function based on mode
|
|
rs_merge_fn = reduce_scatter_merge_fn_to_trace
|
|
if mode and "custom_ops" in mode:
|
|
rs_merge_fn = reduce_scatter_merge_fn_to_trace_custom_ops
|
|
|
|
# Process bucket with lazy input collection
|
|
def create_trace_args(bucket_ins: list[torch.fx.Node]) -> tuple[Any, ...]:
|
|
return (
|
|
pytree.tree_map(lambda node: node.meta["val"], bucket_ins),
|
|
group_size,
|
|
group_name,
|
|
reduce_op,
|
|
reduce_dtype,
|
|
device,
|
|
)
|
|
|
|
return process_collective_bucket(
|
|
g,
|
|
rs_nodes,
|
|
rs_merge_fn,
|
|
create_trace_args,
|
|
insert_before=insert_before,
|
|
wait_insertion_point=wait_insertion_point,
|
|
)
|
|
|
|
|
|
def merge_all_reduce_bucket(
|
|
g: torch.fx.Graph,
|
|
ar_nodes: list[torch.fx.Node],
|
|
mode: str | None = None,
|
|
insert_before: torch.fx.Node | None = None,
|
|
wait_insertion_point: torch.fx.Node | None = None,
|
|
) -> tuple[list[torch.fx.Node], dict[torch.fx.Node, torch.fx.Node]]:
|
|
ar0 = ar_nodes[0]
|
|
ar0_val = ar0.meta["val"]
|
|
_, reduce_op, group_name = ar0.args
|
|
reduce_dtype = ar0_val.dtype
|
|
device = ar0_val.device
|
|
|
|
for n in ar_nodes:
|
|
ar_val = n.meta["val"]
|
|
assert (
|
|
n.args[1] == reduce_op
|
|
and n.args[2] == group_name
|
|
and ar_val.device == device
|
|
and ar_val.dtype == reduce_dtype
|
|
)
|
|
|
|
ar_merge_fn = all_reduce_merge_fn_to_trace
|
|
|
|
def create_trace_args(bucket_ins: list[torch.fx.Node]) -> tuple[Any, ...]:
|
|
return (
|
|
pytree.tree_map(lambda node: node.meta["val"], bucket_ins),
|
|
group_name,
|
|
reduce_op,
|
|
reduce_dtype,
|
|
device,
|
|
)
|
|
|
|
return process_collective_bucket(
|
|
g,
|
|
ar_nodes,
|
|
ar_merge_fn,
|
|
create_trace_args,
|
|
insert_before=insert_before,
|
|
wait_insertion_point=wait_insertion_point,
|
|
)
|
|
|
|
|
|
def merge_all_gather_bucket(
|
|
g: torch.fx.Graph,
|
|
ag_nodes: list[torch.fx.Node],
|
|
mode: BucketMode = "default",
|
|
insert_before: torch.fx.Node | None = None,
|
|
wait_insertion_point: torch.fx.Node | None = None,
|
|
) -> tuple[list[torch.fx.Node], dict[torch.fx.Node, torch.fx.Node]]:
|
|
from torch.distributed.distributed_c10d import _resolve_process_group
|
|
|
|
ag0 = ag_nodes[0]
|
|
_, group_size, group_name = ag0.args
|
|
assert isinstance(group_name, str)
|
|
_ag_dtypes: list[torch.dtype] = [] # type: ignore[name-defined]
|
|
|
|
for n in ag_nodes:
|
|
assert n.args[1] == group_size and n.args[2] == group_name
|
|
_ag_dtypes.append(n.meta["val"].dtype)
|
|
|
|
bucket_dtype = pick_bucket_dtype(_ag_dtypes)
|
|
|
|
# Choose merge function based on mode
|
|
ag_merge_fn = all_gather_merge_fn_to_trace
|
|
if mode is not None and "custom_ops" in mode:
|
|
ag_merge_fn = all_gather_merge_fn_to_trace_custom_ops # type: ignore[assignment]
|
|
|
|
# Process bucket with lazy input collection
|
|
rank: int = dist.get_rank(_resolve_process_group(group_name))
|
|
|
|
def create_trace_args(bucket_ins: list[torch.fx.Node]) -> tuple[Any, ...]:
|
|
return (
|
|
pytree.tree_map(lambda node: node.meta["val"], bucket_ins),
|
|
group_size,
|
|
group_name,
|
|
bucket_dtype,
|
|
_ag_dtypes,
|
|
rank,
|
|
)
|
|
|
|
return process_collective_bucket(
|
|
g,
|
|
ag_nodes,
|
|
ag_merge_fn,
|
|
create_trace_args,
|
|
wait_insertion_point=wait_insertion_point,
|
|
)
|
|
|
|
|
|
def merge_reduce_scatter(
|
|
gm: torch.fx.GraphModule,
|
|
rs_buckets: list[list[torch.fx.Node]],
|
|
mode: BucketMode = "default",
|
|
) -> None:
|
|
"""
|
|
Merges specified buckets of reduce_scatter to joint reduce_scatter.
|
|
"""
|
|
with dynamo_timed("fx.bucketing.merge_reduce_scatter", log_pt2_compile_event=True):
|
|
trace_structured(
|
|
"artifact",
|
|
metadata_fn=lambda: {
|
|
"name": "fx_bucketing_passes_reduce_scatter_buckets",
|
|
"encoding": "string",
|
|
},
|
|
payload_fn=lambda: str(rs_buckets),
|
|
)
|
|
|
|
g = gm.graph
|
|
|
|
for rs_nodes in rs_buckets:
|
|
merge_reduce_scatter_bucket(g, rs_nodes, mode)
|
|
|
|
|
|
def merge_all_gather(
|
|
gm: torch.fx.GraphModule,
|
|
ag_buckets: list[list[torch.fx.Node]],
|
|
mode: BucketMode = "default",
|
|
) -> None:
|
|
"""
|
|
Merges specified buckets of all_gather to joint all_gather.
|
|
"""
|
|
with dynamo_timed("fx.bucketing.merge_all_gather", log_pt2_compile_event=True):
|
|
trace_structured(
|
|
"artifact",
|
|
metadata_fn=lambda: {
|
|
"name": "fx_bucketing_passes_all_gather_buckets",
|
|
"encoding": "string",
|
|
},
|
|
payload_fn=lambda: str(ag_buckets),
|
|
)
|
|
|
|
g = gm.graph
|
|
|
|
for ag_nodes in ag_buckets:
|
|
merge_all_gather_bucket(g, ag_nodes, mode)
|