mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
This will add pyrefly to lint runner as a warning only - and allow us to collect feedback about the tool before switching to pyrefly as the main type checker. References the steps outlined here: : https://github.com/pytorch/pytorch/issues/163283: test plan: `lintrunner init` `lintrunner` confirm when pyrefly errors are present results look like: https://gist.github.com/maggiemoss/e6cb2d015dd1ded560ae1329098cf33f Pull Request resolved: https://github.com/pytorch/pytorch/pull/165179 Approved by: https://github.com/ezyang
969 lines
33 KiB
Python
969 lines
33 KiB
Python
import collections
|
|
import logging
|
|
import operator
|
|
from collections import defaultdict
|
|
from typing import Any, Callable, Literal, TypeAlias
|
|
|
|
import torch
|
|
import torch.distributed as dist
|
|
import torch.utils._pytree as pytree
|
|
from torch._dispatch.python import enable_python_dispatcher
|
|
from torch._dynamo.utils import detect_fake_mode
|
|
from torch._inductor.runtime.runtime_utils import dynamo_timed
|
|
from torch._logging import trace_structured
|
|
from torch.fx.experimental.proxy_tensor import make_fx
|
|
from torch.utils._ordered_set import OrderedSet
|
|
|
|
|
|
logger: logging.Logger = logging.getLogger(__name__)
|
|
logger.setLevel(logging.INFO)
|
|
|
|
BucketMode: TypeAlias = Literal["default", "custom_ops", "custom_ops_multidtype"]
|
|
|
|
|
|
# Helper functions moved to top for better organization
|
|
def _ag_group_key(node: torch.fx.Node) -> tuple[str, torch.dtype]: # type: ignore[name-defined]
|
|
_, group_size, group_name = node.args
|
|
dtype = node.meta["val"].dtype
|
|
assert isinstance(group_name, str)
|
|
return (group_name, dtype)
|
|
|
|
|
|
def _ag_group_key_multidtype(node: torch.fx.Node) -> tuple[str]:
|
|
_, group_size, group_name = node.args
|
|
assert isinstance(group_name, str)
|
|
return (group_name,)
|
|
|
|
|
|
def _rs_group_key(node: torch.fx.Node) -> tuple[str, str, torch.dtype]: # type: ignore[name-defined]
|
|
_, reduce_op, group_size, group_name = node.args
|
|
dtype = node.meta["val"].dtype
|
|
assert isinstance(group_name, str)
|
|
assert isinstance(reduce_op, str)
|
|
return (group_name, reduce_op, dtype)
|
|
|
|
|
|
def _ar_group_key(node: torch.fx.Node) -> tuple[str, str, torch.dtype]:
|
|
_, reduce_op, group_name = node.args
|
|
dtype = node.meta["val"].dtype
|
|
assert isinstance(group_name, str)
|
|
assert isinstance(reduce_op, str)
|
|
return (group_name, reduce_op, dtype)
|
|
|
|
|
|
def bucket_key(node: torch.fx.Node) -> object | None:
|
|
if is_all_gather_into_tensor(node):
|
|
return _ag_group_key(node)
|
|
elif is_reduce_scatter_tensor(node):
|
|
return _rs_group_key(node)
|
|
elif is_all_reduce_tensor(node):
|
|
return _ar_group_key(node)
|
|
else:
|
|
return None
|
|
|
|
|
|
def pick_bucket_dtype(dtypes: list[torch.dtype]) -> torch.dtype: # type: ignore[name-defined]
|
|
assert len(dtypes) > 0
|
|
return min(dtypes, key=operator.attrgetter("itemsize"))
|
|
|
|
|
|
def bucket_cap_mb_by_bucket_idx_default(bucket_id: int) -> float:
|
|
"""
|
|
Determine the size of a bucket based on its ID.
|
|
|
|
Args:
|
|
bucket_id (int): The ID of the bucket.
|
|
|
|
Returns:
|
|
float: The size of the bucket.
|
|
"""
|
|
return 2000.0
|
|
|
|
|
|
def bucket_all_gather(
|
|
gm: torch.fx.GraphModule,
|
|
bucket_cap_mb_by_bucket_idx: Callable[[int], float] | None = None,
|
|
mode: BucketMode = "default",
|
|
) -> None:
|
|
if bucket_cap_mb_by_bucket_idx is None:
|
|
from torch._inductor.fx_passes.bucketing import ( # pyrefly: ignore # missing-module-attribute
|
|
bucket_cap_mb_by_bucket_idx_default,
|
|
)
|
|
|
|
bucket_cap_mb_by_bucket_idx = bucket_cap_mb_by_bucket_idx_default
|
|
ag_buckets = bucket_all_gather_by_mb(gm, bucket_cap_mb_by_bucket_idx, None, mode)
|
|
if len(ag_buckets) == 0:
|
|
return
|
|
merge_all_gather(gm, ag_buckets, mode)
|
|
|
|
|
|
def bucket_reduce_scatter(
|
|
gm: torch.fx.GraphModule,
|
|
bucket_cap_mb_by_bucket_idx: Callable[[int], float] | None = None,
|
|
mode: BucketMode = "default",
|
|
) -> None:
|
|
if bucket_cap_mb_by_bucket_idx is None:
|
|
from torch._inductor.fx_passes.bucketing import ( # pyrefly: ignore # missing-module-attribute
|
|
bucket_cap_mb_by_bucket_idx_default,
|
|
)
|
|
|
|
bucket_cap_mb_by_bucket_idx = bucket_cap_mb_by_bucket_idx_default
|
|
rs_buckets = bucket_reduce_scatter_by_mb(
|
|
gm, bucket_cap_mb_by_bucket_idx, None, mode
|
|
)
|
|
if len(rs_buckets) == 0:
|
|
return
|
|
merge_reduce_scatter(gm, rs_buckets, mode)
|
|
|
|
|
|
def is_all_gather_into_tensor(node: torch.fx.Node) -> bool: # type: ignore[arg-type]
|
|
return (
|
|
node.op == "call_function"
|
|
and node.target == torch.ops._c10d_functional.all_gather_into_tensor.default
|
|
)
|
|
|
|
|
|
def is_reduce_scatter_tensor(node: torch.fx.Node) -> bool:
|
|
return (
|
|
node.op == "call_function"
|
|
and node.target == torch.ops._c10d_functional.reduce_scatter_tensor.default
|
|
)
|
|
|
|
|
|
def is_wait_tensor(node: torch.fx.Node) -> bool:
|
|
return (
|
|
node.op == "call_function"
|
|
and node.target == torch.ops._c10d_functional.wait_tensor.default
|
|
)
|
|
|
|
|
|
def is_all_reduce_tensor(node: torch.fx.Node) -> bool:
|
|
return (
|
|
node.op == "call_function"
|
|
and node.target == torch.ops._c10d_functional.all_reduce.default
|
|
)
|
|
|
|
|
|
def is_wait_tensor_from_all_gather_into_tensor(node: torch.fx.Node) -> bool:
|
|
return is_wait_tensor(node) and is_all_gather_into_tensor(node.args[0]) # type: ignore[arg-type]
|
|
|
|
|
|
def collect_node_descendants(
|
|
graph: torch.fx.Graph,
|
|
) -> dict[torch.fx.Node, OrderedSet[torch.fx.Node]]:
|
|
"""
|
|
Collects the descendants of each node in the graph.
|
|
Args:
|
|
graph (torch.fx.Graph): The graph to collect descendants from.
|
|
Returns:
|
|
dict[torch.fx.Node, OrderedSet[torch.fx.Node]]: A dictionary mapping each node to its descendants.
|
|
"""
|
|
node_descendants: dict[torch.fx.Node, OrderedSet[torch.fx.Node]] = (
|
|
collections.defaultdict(OrderedSet)
|
|
)
|
|
outdegree = collections.defaultdict(int)
|
|
queue = []
|
|
|
|
for node in graph.nodes:
|
|
n_outdegree = len(node.users)
|
|
if n_outdegree == 0:
|
|
queue.append(node)
|
|
else:
|
|
outdegree[node] = len(node.users)
|
|
|
|
while queue:
|
|
node = queue.pop()
|
|
for input_node in node.all_input_nodes:
|
|
node_descendants[input_node] |= node_descendants[node]
|
|
node_descendants[input_node].add(node)
|
|
outdegree[input_node] -= 1
|
|
|
|
if outdegree[input_node] == 0:
|
|
queue.append(input_node)
|
|
|
|
return node_descendants
|
|
|
|
|
|
def greedy_bucket_collective_by_mb(
|
|
gm: torch.fx.GraphModule,
|
|
bucket_cap_mb_by_bucket_idx: Callable[[int], float],
|
|
filter_node: Callable[[torch.fx.Node], bool],
|
|
node_group_key: Callable[[torch.fx.Node], Any],
|
|
filter_wait_node: Callable[[torch.fx.Node], bool] | None = None,
|
|
) -> list[list[torch.fx.Node]]:
|
|
"""
|
|
Bucketing adjacent collectives with equal node_group_key.
|
|
We can not bucket non adjacent collectives,
|
|
as this will effectively change the order of collectives.
|
|
Reordering can lead to different order on different ranks.
|
|
"""
|
|
g = gm.graph
|
|
found_candidates = False
|
|
for node in g.nodes:
|
|
if filter_node(node):
|
|
found_candidates = True
|
|
break
|
|
if not found_candidates:
|
|
return []
|
|
|
|
# TODO: pearce kelly algorithm for detecting cycles
|
|
node_descendents = collect_node_descendants(gm.graph)
|
|
|
|
nodes_groups: list[list[torch.fx.Node]] = []
|
|
cur_group: list[torch.fx.Node] = []
|
|
cur_group_key = None
|
|
|
|
for node in g.nodes:
|
|
if is_wait_tensor(node) and filter_node(node.args[0]):
|
|
if (filter_wait_node is None) or filter_wait_node(node):
|
|
coll_node = node.args[0]
|
|
group_key = node_group_key(coll_node)
|
|
if group_key == cur_group_key:
|
|
cur_group.append(coll_node)
|
|
else:
|
|
if len(cur_group) > 1:
|
|
nodes_groups.append(cur_group)
|
|
cur_group = [coll_node]
|
|
cur_group_key = group_key
|
|
|
|
if len(cur_group) > 1:
|
|
nodes_groups.append(cur_group)
|
|
|
|
buckets: list[list[torch.fx.Node]] = []
|
|
for nodes in nodes_groups:
|
|
cur_bucket: list[torch.fx.Node] = []
|
|
cur_bucket_descendents: OrderedSet[torch.fx.Node] = OrderedSet()
|
|
cur_bucket_size_bytes: int = 0
|
|
cur_bucket_id: int = 0
|
|
bucket_size_bytes = int(
|
|
bucket_cap_mb_by_bucket_idx(cur_bucket_id) * 1024 * 1024
|
|
)
|
|
for node in nodes:
|
|
if node in cur_bucket_descendents:
|
|
# if there is a path from node to the current bucket, we cannot horizontally fuse (bucket)
|
|
continue
|
|
assert "val" in node.meta
|
|
n_val = node.meta["val"]
|
|
out_size_bytes = n_val.numel() * n_val.element_size()
|
|
n_input_val = node.all_input_nodes[0].meta["val"]
|
|
in_size_bytes = n_input_val.numel() * n_input_val.element_size()
|
|
size_bytes = max(out_size_bytes, in_size_bytes)
|
|
if cur_bucket_size_bytes + size_bytes > bucket_size_bytes and cur_bucket:
|
|
# Current bucket is full, create new bucket
|
|
if len(cur_bucket) > 1:
|
|
buckets.append(cur_bucket)
|
|
cur_bucket = []
|
|
cur_bucket_size_bytes = 0
|
|
cur_bucket_id += 1
|
|
cur_bucket_descendents = OrderedSet()
|
|
cur_bucket_size_bytes += size_bytes
|
|
cur_bucket.append(node)
|
|
cur_bucket_descendents |= node_descendents[node]
|
|
if len(cur_bucket) > 1:
|
|
buckets.append(cur_bucket)
|
|
return buckets
|
|
|
|
|
|
def bucket_all_gather_by_mb(
|
|
gm: torch.fx.GraphModule,
|
|
bucket_cap_mb_by_bucket_idx: Callable[[int], float],
|
|
filter_wait_node: Callable[[torch.fx.Node], bool] | None = None,
|
|
mode: BucketMode = "default",
|
|
) -> list[list[torch.fx.Node]]:
|
|
"""
|
|
Identifies all all_gather nodes and groups them into buckets,
|
|
based on size limit `bucket_cap_mb_by_bucket_idx`.
|
|
|
|
Args:
|
|
gm (torch.fx.GraphModule): GraphModule where to bucket all_gathers.
|
|
bucket_cap_mb_by_bucket_idx (Callable[[int], float]): Callable to specify cap of the bucket
|
|
in megabytes by bucket idx. The idea of `bucket_cap_mb_by_bucket_idx` is to allow
|
|
to specify different sizes of the buckets at the start,
|
|
as first all_gather is usually exposed. Interface of bucket_cap_mb_by_bucket_idx
|
|
is `bucket_cap_mb_by_bucket_idx_default` function that is default value for `bucket_cap_mb_by_bucket_idx`.
|
|
filter_wait_node (Callable[[torch.fx.Node], bool] | None): If specified,
|
|
only all_gather nodes with wait_node that satisfy `filter_wait_node` will be bucketed.
|
|
|
|
Returns:
|
|
list[list[torch.fx.Node]]: List of buckets, where each bucket is a list of all_gather nodes.
|
|
"""
|
|
|
|
group_key_fn = (
|
|
_ag_group_key_multidtype if mode and "multidtype" in mode else _ag_group_key
|
|
)
|
|
|
|
return greedy_bucket_collective_by_mb(
|
|
gm,
|
|
bucket_cap_mb_by_bucket_idx,
|
|
is_all_gather_into_tensor,
|
|
group_key_fn,
|
|
filter_wait_node,
|
|
)
|
|
|
|
|
|
def bucket_reduce_scatter_by_mb(
|
|
gm: torch.fx.GraphModule,
|
|
bucket_cap_mb_by_bucket_idx: Callable[[int], float],
|
|
filter_wait_node: Callable[[torch.fx.Node], bool] | None = None,
|
|
mode: BucketMode = "default",
|
|
) -> list[list[torch.fx.Node]]:
|
|
"""
|
|
Identifies all reduce_scatter nodes and groups them into buckets,
|
|
based on size limit `bucket_cap_mb_by_bucket_idx`.
|
|
|
|
Args:
|
|
gm (torch.fx.GraphModule): GraphModule where to bucket reduce_scatters.
|
|
bucket_cap_mb_by_bucket_idx (Callable[[int], float]): Callable to specify cap of the bucket
|
|
in megabytes by bucket idx. The idea of `bucket_cap_mb_by_bucket_idx` is to allow
|
|
to specify different sizes of the buckets.
|
|
filter_wait_node (Callable[[torch.fx.Node], bool] | None): If specified,
|
|
only reduce_scatter nodes with wait_node that satisfy `filter_wait_node` will be bucketed.
|
|
|
|
Returns:
|
|
list[list[torch.fx.Node]]: List of buckets, where each bucket is a list of reduce_scatter nodes.
|
|
"""
|
|
|
|
assert "multidtype" not in mode, (
|
|
"reduce scatter bucketing does not support multidtype"
|
|
)
|
|
|
|
return greedy_bucket_collective_by_mb(
|
|
gm,
|
|
bucket_cap_mb_by_bucket_idx,
|
|
is_reduce_scatter_tensor,
|
|
_rs_group_key,
|
|
filter_wait_node,
|
|
)
|
|
|
|
|
|
def bucket_all_reduce_by_mb(
|
|
gm: torch.fx.GraphModule,
|
|
bucket_cap_mb_by_bucket_idx: Callable[[int], float],
|
|
filter_wait_node: Callable[[torch.fx.Node], bool] | None = None,
|
|
) -> list[list[torch.fx.Node]]:
|
|
return greedy_bucket_collective_by_mb(
|
|
gm,
|
|
bucket_cap_mb_by_bucket_idx,
|
|
is_all_reduce_tensor,
|
|
_ar_group_key,
|
|
filter_wait_node,
|
|
)
|
|
|
|
|
|
def bucket_all_reduce(
|
|
gm: torch.fx.GraphModule,
|
|
bucket_cap_mb_by_bucket_idx: Callable[[int], float] | None = None,
|
|
mode: str | None = None,
|
|
) -> None:
|
|
if bucket_cap_mb_by_bucket_idx is None:
|
|
from torch._inductor.fx_passes.bucketing import (
|
|
bucket_cap_mb_by_bucket_idx_default,
|
|
)
|
|
|
|
bucket_cap_mb_by_bucket_idx = bucket_cap_mb_by_bucket_idx_default
|
|
ar_buckets = bucket_all_reduce_by_mb(gm, bucket_cap_mb_by_bucket_idx)
|
|
if len(ar_buckets) == 0:
|
|
return
|
|
for bucket in ar_buckets:
|
|
merge_all_reduce_bucket(gm.graph, bucket, mode)
|
|
|
|
|
|
@torch.library.custom_op("bucketing::_pre_bucket_reduce_scatter", mutates_args={})
|
|
def _pre_bucket_reduce_scatter(
|
|
rs_ins: list[torch.Tensor],
|
|
group_size: int,
|
|
) -> torch.Tensor:
|
|
rs_ins_flattened = [x.view(group_size, -1) for x in rs_ins]
|
|
new_rs_in = torch.cat(rs_ins_flattened, dim=1).flatten()
|
|
return new_rs_in
|
|
|
|
|
|
def _pre_bucket_reduce_scatter_fake(
|
|
rs_ins: list[torch.Tensor],
|
|
group_size: int,
|
|
) -> torch.Tensor:
|
|
out_numel = sum(rs_in.numel() for rs_in in rs_ins)
|
|
return torch.empty((out_numel,), device=rs_ins[0].device, dtype=rs_ins[0].dtype)
|
|
|
|
|
|
_pre_bucket_reduce_scatter.register_fake(_pre_bucket_reduce_scatter_fake)
|
|
|
|
|
|
def reduce_scatter_merge_fn_to_trace_custom_ops(
|
|
rs_ins: list[torch.Tensor],
|
|
group_size: int,
|
|
group_name: str,
|
|
reduce_op: str,
|
|
reduce_dtype: torch.dtype, # type: ignore[name-defined]
|
|
device: torch.device, # type: ignore[name-defined]
|
|
) -> list[torch.Tensor]: # type: ignore[no-untyped-def]
|
|
new_out_sizes = [(x.shape[0] // group_size,) + x.shape[1:] for x in rs_ins]
|
|
new_out_numels = [x.numel() // group_size for x in rs_ins]
|
|
|
|
new_rs_in = torch.ops.bucketing._pre_bucket_reduce_scatter(rs_ins, group_size)
|
|
|
|
# TODO - either use torch.cat or make sure inductor foreach codegen
|
|
# fires more reliably
|
|
new_rs_out = torch.ops.c10d_functional.wait_tensor(
|
|
torch.ops._c10d_functional.reduce_scatter_tensor.default(
|
|
new_rs_in, reduce_op, group_size, group_name
|
|
)
|
|
)
|
|
new_out_flat = new_rs_out.split(new_out_numels, 0)
|
|
new_outs = [x.view(s) for x, s in zip(new_out_flat, new_out_sizes)]
|
|
return new_outs
|
|
|
|
|
|
def reduce_scatter_merge_fn_to_trace(
|
|
rs_ins: list[torch.Tensor],
|
|
group_size: int,
|
|
group_name: str,
|
|
reduce_op: str,
|
|
reduce_dtype: torch.dtype, # type: ignore[name-defined]
|
|
device: torch.device, # type: ignore[name-defined]
|
|
) -> list[torch.Tensor]: # type: ignore[no-untyped-def]
|
|
rs_ins_flattened = [x.view(group_size, -1) for x in rs_ins]
|
|
|
|
new_out_sizes = [(x.shape[0] // group_size,) + x.shape[1:] for x in rs_ins]
|
|
new_out_numels = [x.numel() // group_size for x in rs_ins]
|
|
|
|
new_rs_in = torch.cat(rs_ins_flattened, dim=1).flatten()
|
|
|
|
new_rs_out = torch.ops.c10d_functional.wait_tensor(
|
|
torch.ops._c10d_functional.reduce_scatter_tensor.default(
|
|
new_rs_in, reduce_op, group_size, group_name
|
|
)
|
|
)
|
|
new_out_flat = new_rs_out.split(new_out_numels, 0)
|
|
new_outs = [x.view(s) for x, s in zip(new_out_flat, new_out_sizes)]
|
|
return new_outs
|
|
|
|
|
|
def all_reduce_merge_fn_to_trace(
|
|
ar_ins: list[torch.Tensor],
|
|
group_name: str,
|
|
reduce_op: str,
|
|
reduce_dtype: torch.dtype, # type: ignore[name-defined]
|
|
device: torch.device, # type: ignore[name-defined]
|
|
) -> list[torch.Tensor]: # type: ignore[no-untyped-def]
|
|
ar_ins_flattened = [x.view(-1) for x in ar_ins]
|
|
new_ar_in = torch.cat(ar_ins_flattened)
|
|
new_ar_out = torch.ops.c10d_functional.wait_tensor(
|
|
torch.ops._c10d_functional.all_reduce.default(new_ar_in, reduce_op, group_name)
|
|
)
|
|
split_sizes = [x.numel() for x in ar_ins]
|
|
new_outs_flat = new_ar_out.split(split_sizes)
|
|
new_outs = [x.view(ar_in.shape) for x, ar_in in zip(new_outs_flat, ar_ins)]
|
|
return new_outs
|
|
|
|
|
|
@torch.library.custom_op("bucketing::_pre_bucket_all_gather", mutates_args={})
|
|
def _pre_bucket_all_gather(
|
|
ag_ins: list[torch.Tensor],
|
|
group_size: int,
|
|
group_name: str,
|
|
dtype: torch.dtype, # type: ignore[name-defined]
|
|
rank: int,
|
|
) -> torch.Tensor:
|
|
ins_split_sizes_bytes = [ag_in.numel() * ag_in.element_size() for ag_in in ag_ins]
|
|
bucket_dtype_size_bytes = dtype.itemsize
|
|
ins_split_sizes = [
|
|
_bytes // bucket_dtype_size_bytes for _bytes in ins_split_sizes_bytes
|
|
]
|
|
ag_input_numel = sum(ins_split_sizes)
|
|
device = ag_ins[0].device
|
|
new_ag_out = torch.empty(ag_input_numel * group_size, dtype=dtype, device=device)
|
|
new_ag_in = new_ag_out.narrow(0, ag_input_numel * rank, ag_input_numel)
|
|
foreach_copy_dsts = torch.split(new_ag_in, ins_split_sizes)
|
|
ag_ins_flattened = [ag_in.reshape(-1).view(dtype) for ag_in in ag_ins]
|
|
torch._foreach_copy_(foreach_copy_dsts, ag_ins_flattened)
|
|
return new_ag_out
|
|
|
|
|
|
def _pre_bucket_all_gather_fake(
|
|
ag_ins: list[torch.Tensor],
|
|
group_size: int,
|
|
group_name: str,
|
|
dtype: torch.dtype, # type: ignore[name-defined]
|
|
rank: int,
|
|
) -> torch.Tensor:
|
|
ins_split_sizes_bytes = [ag_in.numel() * ag_in.element_size() for ag_in in ag_ins]
|
|
bucket_dtype_size_bytes = dtype.itemsize
|
|
ins_split_sizes = [
|
|
_bytes // bucket_dtype_size_bytes for _bytes in ins_split_sizes_bytes
|
|
]
|
|
ag_input_numel = sum(ins_split_sizes)
|
|
device = ag_ins[0].device
|
|
new_ag_out = torch.empty(ag_input_numel * group_size, dtype=dtype, device=device)
|
|
return new_ag_out
|
|
|
|
|
|
_pre_bucket_all_gather.register_fake(_pre_bucket_all_gather_fake)
|
|
|
|
|
|
def all_gather_merge_fn_to_trace_custom_ops(
|
|
_ag_ins: list[torch.Tensor],
|
|
group_size: int,
|
|
group_name: str,
|
|
dtype: torch.dtype, # type: ignore[name-defined]
|
|
out_dtypes: list[torch.dtype], # type: ignore[name-defined]
|
|
rank: int,
|
|
) -> list[torch.Tensor]:
|
|
ag_ins = [
|
|
torch._prims.convert_element_type(_ag_in, out_dtype)
|
|
if _ag_in.dtype != out_dtype
|
|
else _ag_in
|
|
for _ag_in, out_dtype in zip(_ag_ins, out_dtypes)
|
|
]
|
|
ins_sizes = [ag_in.shape for ag_in in ag_ins]
|
|
ins_split_sizes_bytes = [
|
|
ag_in.numel() * out_dtype.itemsize
|
|
for ag_in, out_dtype in zip(ag_ins, out_dtypes)
|
|
]
|
|
bucket_dtype_size_bytes = dtype.itemsize
|
|
ins_split_sizes = [
|
|
_bytes // bucket_dtype_size_bytes for _bytes in ins_split_sizes_bytes
|
|
]
|
|
ag_input_numel = sum(ins_split_sizes)
|
|
new_ag_out = torch.ops.bucketing._pre_bucket_all_gather(
|
|
ag_ins, group_size, group_name, dtype, rank
|
|
)
|
|
new_ag_in = new_ag_out.narrow(0, ag_input_numel * rank, ag_input_numel)
|
|
wait_tensor = torch.ops.c10d_functional.wait_tensor(
|
|
torch.ops._c10d_functional.all_gather_into_tensor_out.default(
|
|
new_ag_in, group_size, group_name, out=new_ag_out
|
|
)
|
|
)
|
|
new_ag_out_reshaped = wait_tensor.reshape(group_size, -1)
|
|
outs_bucket_dtype = torch.split_with_sizes(
|
|
new_ag_out_reshaped,
|
|
ins_split_sizes,
|
|
dim=1,
|
|
)
|
|
outs_reshaped = [
|
|
o.view(out_dtype).reshape((shape[0] * group_size,) + shape[1:])
|
|
for o, shape, out_dtype in zip(outs_bucket_dtype, ins_sizes, out_dtypes)
|
|
]
|
|
return outs_reshaped
|
|
|
|
|
|
def all_gather_merge_fn_to_trace(
|
|
ag_ins: list[torch.Tensor],
|
|
group_size: int,
|
|
group_name: str,
|
|
dtype: torch.dtype, # type: ignore[name-defined]
|
|
out_dtypes: list[torch.dtype], # type: ignore[name-defined]
|
|
rank: int,
|
|
) -> list[torch.Tensor]:
|
|
ins_sizes = [ag_in.shape for ag_in in ag_ins]
|
|
ins_split_sizes = [ag_in.numel() for ag_in in ag_ins]
|
|
ag_input_numel = sum(ins_split_sizes)
|
|
device = ag_ins[0].device
|
|
new_ag_out = torch.empty(ag_input_numel * group_size, dtype=dtype, device=device)
|
|
new_ag_in = new_ag_out.narrow(0, ag_input_numel * rank, ag_input_numel)
|
|
foreach_copy_dsts = torch.split(new_ag_in, ins_split_sizes)
|
|
ag_ins_flattened = [ag_in.reshape(-1) for ag_in in ag_ins]
|
|
torch._foreach_copy_(foreach_copy_dsts, ag_ins_flattened)
|
|
wait_tensor = torch.ops.c10d_functional.wait_tensor(
|
|
torch.ops._c10d_functional.all_gather_into_tensor_out.default(
|
|
new_ag_in, group_size, group_name, out=new_ag_out
|
|
)
|
|
)
|
|
new_ag_out_reshaped = wait_tensor.reshape(group_size, -1)
|
|
outs = torch.split_with_sizes(
|
|
new_ag_out_reshaped,
|
|
ins_split_sizes,
|
|
dim=1,
|
|
)
|
|
outs_reshaped = [
|
|
o.reshape((shape[0] * group_size,) + shape[1:])
|
|
for o, shape in zip(outs, ins_sizes)
|
|
]
|
|
return outs_reshaped
|
|
|
|
|
|
def all_gather_merge_fn_to_trace_functional(
|
|
ag_ins: list[torch.Tensor],
|
|
group_size: int,
|
|
group_name: str,
|
|
dtype: torch.dtype, # type: ignore[name-defined]
|
|
out_dtypes: list[torch.dtype], # type: ignore[name-defined]
|
|
rank: int,
|
|
use_fsdp_ag_copy_in: bool = False,
|
|
) -> list[torch.Tensor]:
|
|
# Implementation that is functional in graph,
|
|
# but uses custom op torch.ops.fsdp.all_gather_copy_in.
|
|
ins_sizes = [ag_in.shape for ag_in in ag_ins]
|
|
ins_split_sizes = [ag_in.numel() for ag_in in ag_ins]
|
|
ag_input_numel = sum(ins_split_sizes)
|
|
device = ag_ins[0].device
|
|
new_ag_out = torch.empty(ag_input_numel * group_size, dtype=dtype, device=device)
|
|
ag_ins_flattened = [ag_in.reshape(-1) for ag_in in ag_ins]
|
|
if use_fsdp_ag_copy_in:
|
|
new_ag_in, new_ag_out = torch.ops.fsdp.all_gather_copy_in(
|
|
ag_ins_flattened, new_ag_out, ins_split_sizes, ag_input_numel, rank
|
|
)
|
|
else:
|
|
new_ag_in = torch.cat(ag_ins_flattened, dim=0)
|
|
wait_tensor = torch.ops.c10d_functional.wait_tensor(
|
|
torch.ops._c10d_functional.all_gather_into_tensor_out.default(
|
|
new_ag_in, group_size, group_name, out=new_ag_out
|
|
)
|
|
)
|
|
new_ag_out_reshaped = wait_tensor.reshape(group_size, -1)
|
|
outs = torch.split_with_sizes(
|
|
new_ag_out_reshaped,
|
|
ins_split_sizes,
|
|
dim=1,
|
|
)
|
|
outs_reshaped = [
|
|
o.reshape((shape[0] * group_size,) + shape[1:])
|
|
for o, shape in zip(outs, ins_sizes)
|
|
]
|
|
return outs_reshaped
|
|
|
|
|
|
def _trace(fn, inps) -> torch.fx.GraphModule: # type: ignore[no-untyped-def]
|
|
with dynamo_timed("fx.bucketing._trace", log_pt2_compile_event=True):
|
|
fake_mode = detect_fake_mode(inps)
|
|
assert fake_mode is not None
|
|
with fake_mode, enable_python_dispatcher():
|
|
out = make_fx(fn)(*inps)
|
|
for node in out.graph.find_nodes(
|
|
op="call_function", target=torch.ops.aten.detach.default
|
|
):
|
|
node.replace_all_uses_with(node.args[0])
|
|
out.graph.erase_node(node)
|
|
return out
|
|
|
|
|
|
def _insert_fn_trace_before_node( # type: ignore[no-untyped-def]
|
|
g: torch.fx.Graph,
|
|
fn_to_trace,
|
|
inps,
|
|
insert_before_node: torch.fx.Node,
|
|
g_fn_inps: list[torch.fx.Node],
|
|
g_fn_outs: list[torch.fx.Node],
|
|
) -> tuple[dict[torch.fx.Node, torch.fx.Node], list[torch.fx.Node]]: # type: ignore[no-untyped-def]
|
|
"""
|
|
Helper function that traces :attr:`fn_to_trace` with inputs
|
|
:attr:`inps`.
|
|
The result function graph will be inserted before :attr:`insert_before_node`,
|
|
using :attr:`g_fn_inps` nodes of original graph as inputs of function graph,
|
|
function graph outputs will replace :attr:`g_fn_outs` in original graph.
|
|
|
|
Returns:
|
|
(replacements, new_nodes): Dictionary mapping old to new nodes, and list of all newly inserted nodes
|
|
"""
|
|
with dynamo_timed(
|
|
"fx.bucketing._insert_fn_trace_before_node", log_pt2_compile_event=True
|
|
):
|
|
fn_gm = _trace(
|
|
fn_to_trace,
|
|
inps,
|
|
)
|
|
fn_g = fn_gm.graph
|
|
fn_g_ins = fn_g.find_nodes(op="placeholder")
|
|
env = {fn_g_ins[idx]: g_fn_inps[idx] for idx in range(len(g_fn_inps))}
|
|
g_fn_new_outs: list[torch.fx.Node] = []
|
|
new_nodes: list[torch.fx.Node] = [] # Track all newly inserted nodes
|
|
|
|
with g.inserting_before(insert_before_node):
|
|
for _n in fn_g.nodes:
|
|
if _n.op == "placeholder":
|
|
continue
|
|
_new_n = g.node_copy(_n, lambda x: env[x])
|
|
env[_n] = _new_n
|
|
if _n.op == "output":
|
|
g_fn_new_outs = _new_n.args[0] # type: ignore[assignment]
|
|
g.erase_node(_new_n)
|
|
else:
|
|
new_nodes.append(_new_n) # Track non-output nodes
|
|
|
|
replacements = { # noqa: C416
|
|
orig_out: new_out for orig_out, new_out in zip(g_fn_outs, g_fn_new_outs)
|
|
}
|
|
for orig_out, new_out in zip(g_fn_outs, g_fn_new_outs):
|
|
orig_out.replace_all_uses_with(new_out)
|
|
|
|
return replacements, new_nodes
|
|
|
|
|
|
def process_collective_bucket(
|
|
g: torch.fx.Graph,
|
|
bucket_nodes: list[torch.fx.Node],
|
|
fn_to_trace: Callable[..., list[torch.Tensor]],
|
|
trace_args_fn: Callable[[list[torch.fx.Node]], tuple[Any, ...]],
|
|
insert_before: torch.fx.Node | None = None,
|
|
wait_insertion_point: torch.fx.Node | None = None,
|
|
) -> tuple[list[torch.fx.Node], dict[torch.fx.Node, torch.fx.Node]]:
|
|
"""
|
|
Process a single bucket of collective operation nodes with flexible insertion control.
|
|
|
|
Args:
|
|
g: The graph to modify
|
|
bucket_nodes: Nodes in the current bucket to process
|
|
fn_to_trace: Function to trace and insert
|
|
trace_args_fn: Function to create trace arguments from inputs
|
|
insert_before: Where to insert the traced function (default: after last bucket node)
|
|
wait_insertion_point: If provided, move all nodes from wait() onwards to before this node
|
|
|
|
Returns:
|
|
new_nodes: List of all newly inserted nodes
|
|
replacements: Dictionary mapping old wait nodes to new output nodes
|
|
"""
|
|
# Collect inputs and waits from current bucket
|
|
bucket_ins: list[torch.fx.Node] = []
|
|
bucket_waits: list[torch.fx.Node] = []
|
|
ag_node_to_pre_nodes: dict[torch.fx.Node, list[torch.fx.Node]] = defaultdict(list)
|
|
|
|
for n in bucket_nodes:
|
|
assert len(n.users) == 1, f"Expected single user for {n}, got {n.users}"
|
|
wait_n = next(iter(n.users))
|
|
|
|
# Handle convert_element_type operations (for all_gather)
|
|
node_in = n.args[0]
|
|
if (
|
|
is_all_gather_into_tensor(n)
|
|
and isinstance(node_in, torch.fx.Node) # Add type check
|
|
and node_in.op == "call_function"
|
|
and node_in.target == torch.ops.prims.convert_element_type.default
|
|
and len(node_in.users) == 1
|
|
):
|
|
ag_node_to_pre_nodes[n].append(node_in)
|
|
node_in = node_in.args[0]
|
|
|
|
assert isinstance(node_in, torch.fx.Node) # Ensure node_in is a Node
|
|
bucket_ins.append(node_in)
|
|
bucket_waits.append(wait_n)
|
|
|
|
# Create trace arguments
|
|
trace_args = trace_args_fn(bucket_ins)
|
|
|
|
# Determine insertion point
|
|
if insert_before is None:
|
|
insert_before = bucket_nodes[-1].next
|
|
|
|
# Insert traced function and get replacements + new nodes
|
|
replacements, new_nodes = _insert_fn_trace_before_node(
|
|
g,
|
|
fn_to_trace,
|
|
trace_args,
|
|
insert_before,
|
|
bucket_ins,
|
|
bucket_waits,
|
|
)
|
|
|
|
# If requested, move wait nodes and everything after to specified location
|
|
if wait_insertion_point is not None:
|
|
# Find the first wait node in new_nodes
|
|
wait_start_idx = None
|
|
for i, node in enumerate(new_nodes):
|
|
if is_wait_tensor(node):
|
|
wait_start_idx = i
|
|
break
|
|
|
|
# Move all nodes from wait onwards (including the wait)
|
|
if wait_start_idx is not None:
|
|
nodes_to_move = new_nodes[wait_start_idx:]
|
|
for node in nodes_to_move:
|
|
wait_insertion_point.prepend(node)
|
|
|
|
# Erase old nodes
|
|
for node, wait_n in zip(bucket_nodes, bucket_waits):
|
|
g.erase_node(wait_n)
|
|
g.erase_node(node)
|
|
# Erase any convert_element_type nodes we tracked
|
|
for pre_node in reversed(ag_node_to_pre_nodes[node]):
|
|
g.erase_node(pre_node)
|
|
|
|
return new_nodes, replacements
|
|
|
|
|
|
def merge_reduce_scatter_bucket(
|
|
g: torch.fx.Graph,
|
|
rs_nodes: list[torch.fx.Node],
|
|
mode: BucketMode = "default",
|
|
insert_before: torch.fx.Node | None = None,
|
|
wait_insertion_point: torch.fx.Node | None = None,
|
|
) -> tuple[list[torch.fx.Node], dict[torch.fx.Node, torch.fx.Node]]:
|
|
# Validate bucket consistency
|
|
rs0 = rs_nodes[0]
|
|
rs0_val = rs0.meta["val"]
|
|
_, reduce_op, group_size, group_name = rs0.args
|
|
reduce_dtype = rs0_val.dtype
|
|
device = rs0_val.device
|
|
|
|
for n in rs_nodes:
|
|
rs_val = n.meta["val"]
|
|
assert (
|
|
n.args[1] == reduce_op
|
|
and n.args[2] == group_size
|
|
and n.args[3] == group_name
|
|
and rs_val.device == device
|
|
and rs_val.dtype == reduce_dtype
|
|
)
|
|
|
|
# Choose merge function based on mode
|
|
rs_merge_fn = reduce_scatter_merge_fn_to_trace
|
|
if mode and "custom_ops" in mode:
|
|
rs_merge_fn = reduce_scatter_merge_fn_to_trace_custom_ops
|
|
|
|
# Process bucket with lazy input collection
|
|
def create_trace_args(bucket_ins: list[torch.fx.Node]) -> tuple[Any, ...]:
|
|
return (
|
|
pytree.tree_map(lambda node: node.meta["val"], bucket_ins),
|
|
group_size,
|
|
group_name,
|
|
reduce_op,
|
|
reduce_dtype,
|
|
device,
|
|
)
|
|
|
|
return process_collective_bucket(
|
|
g,
|
|
rs_nodes,
|
|
rs_merge_fn,
|
|
create_trace_args,
|
|
insert_before=insert_before,
|
|
wait_insertion_point=wait_insertion_point,
|
|
)
|
|
|
|
|
|
def merge_all_reduce_bucket(
|
|
g: torch.fx.Graph,
|
|
ar_nodes: list[torch.fx.Node],
|
|
mode: str | None = None,
|
|
insert_before: torch.fx.Node | None = None,
|
|
wait_insertion_point: torch.fx.Node | None = None,
|
|
) -> tuple[list[torch.fx.Node], dict[torch.fx.Node, torch.fx.Node]]:
|
|
ar0 = ar_nodes[0]
|
|
ar0_val = ar0.meta["val"]
|
|
_, reduce_op, group_name = ar0.args
|
|
reduce_dtype = ar0_val.dtype
|
|
device = ar0_val.device
|
|
|
|
for n in ar_nodes:
|
|
ar_val = n.meta["val"]
|
|
assert (
|
|
n.args[1] == reduce_op
|
|
and n.args[2] == group_name
|
|
and ar_val.device == device
|
|
and ar_val.dtype == reduce_dtype
|
|
)
|
|
|
|
ar_merge_fn = all_reduce_merge_fn_to_trace
|
|
|
|
def create_trace_args(bucket_ins: list[torch.fx.Node]) -> tuple[Any, ...]:
|
|
return (
|
|
pytree.tree_map(lambda node: node.meta["val"], bucket_ins),
|
|
group_name,
|
|
reduce_op,
|
|
reduce_dtype,
|
|
device,
|
|
)
|
|
|
|
return process_collective_bucket(
|
|
g,
|
|
ar_nodes,
|
|
ar_merge_fn,
|
|
create_trace_args,
|
|
insert_before=insert_before,
|
|
wait_insertion_point=wait_insertion_point,
|
|
)
|
|
|
|
|
|
def merge_all_gather_bucket(
|
|
g: torch.fx.Graph,
|
|
ag_nodes: list[torch.fx.Node],
|
|
mode: BucketMode = "default",
|
|
insert_before: torch.fx.Node | None = None,
|
|
wait_insertion_point: torch.fx.Node | None = None,
|
|
) -> tuple[list[torch.fx.Node], dict[torch.fx.Node, torch.fx.Node]]:
|
|
from torch.distributed.distributed_c10d import _resolve_process_group
|
|
|
|
ag0 = ag_nodes[0]
|
|
_, group_size, group_name = ag0.args
|
|
assert isinstance(group_name, str)
|
|
_ag_dtypes: list[torch.dtype] = [] # type: ignore[name-defined]
|
|
|
|
for n in ag_nodes:
|
|
assert n.args[1] == group_size and n.args[2] == group_name
|
|
_ag_dtypes.append(n.meta["val"].dtype)
|
|
|
|
bucket_dtype = pick_bucket_dtype(_ag_dtypes)
|
|
|
|
# Choose merge function based on mode
|
|
ag_merge_fn = all_gather_merge_fn_to_trace
|
|
if mode is not None and "custom_ops" in mode:
|
|
ag_merge_fn = all_gather_merge_fn_to_trace_custom_ops # type: ignore[assignment]
|
|
|
|
# Process bucket with lazy input collection
|
|
rank: int = dist.get_rank(_resolve_process_group(group_name))
|
|
|
|
def create_trace_args(bucket_ins: list[torch.fx.Node]) -> tuple[Any, ...]:
|
|
return (
|
|
pytree.tree_map(lambda node: node.meta["val"], bucket_ins),
|
|
group_size,
|
|
group_name,
|
|
bucket_dtype,
|
|
_ag_dtypes,
|
|
rank,
|
|
)
|
|
|
|
return process_collective_bucket(
|
|
g,
|
|
ag_nodes,
|
|
ag_merge_fn,
|
|
create_trace_args,
|
|
wait_insertion_point=wait_insertion_point,
|
|
)
|
|
|
|
|
|
def merge_reduce_scatter(
|
|
gm: torch.fx.GraphModule,
|
|
rs_buckets: list[list[torch.fx.Node]],
|
|
mode: BucketMode = "default",
|
|
) -> None:
|
|
"""
|
|
Merges specified buckets of reduce_scatter to joint reduce_scatter.
|
|
"""
|
|
with dynamo_timed("fx.bucketing.merge_reduce_scatter", log_pt2_compile_event=True):
|
|
trace_structured(
|
|
"artifact",
|
|
metadata_fn=lambda: {
|
|
"name": "fx_bucketing_passes_reduce_scatter_buckets",
|
|
"encoding": "string",
|
|
},
|
|
payload_fn=lambda: str(rs_buckets),
|
|
)
|
|
|
|
g = gm.graph
|
|
|
|
for rs_nodes in rs_buckets:
|
|
merge_reduce_scatter_bucket(g, rs_nodes, mode)
|
|
|
|
|
|
def merge_all_gather(
|
|
gm: torch.fx.GraphModule,
|
|
ag_buckets: list[list[torch.fx.Node]],
|
|
mode: BucketMode = "default",
|
|
) -> None:
|
|
"""
|
|
Merges specified buckets of all_gather to joint all_gather.
|
|
"""
|
|
with dynamo_timed("fx.bucketing.merge_all_gather", log_pt2_compile_event=True):
|
|
trace_structured(
|
|
"artifact",
|
|
metadata_fn=lambda: {
|
|
"name": "fx_bucketing_passes_all_gather_buckets",
|
|
"encoding": "string",
|
|
},
|
|
payload_fn=lambda: str(ag_buckets),
|
|
)
|
|
|
|
g = gm.graph
|
|
|
|
for ag_nodes in ag_buckets:
|
|
merge_all_gather_bucket(g, ag_nodes, mode)
|