diff --git a/test/distributed/test_inductor_collectives.py b/test/distributed/test_inductor_collectives.py index c9e4cbaa7558..62e5143d0622 100644 --- a/test/distributed/test_inductor_collectives.py +++ b/test/distributed/test_inductor_collectives.py @@ -1804,6 +1804,63 @@ class TestCollectivesInductor(DynamoDistributedSingleProcTestCase): correct = f(*inputs, **self.get_world_trs()) assert same(out, correct), f"{out} va {correct}" + @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") + @unittest.skipIf(not SM80OrLater, "bfloat16") + @parametrize("bucket_mode", ["all_custom_ops_multidtype"]) + def test_all_gather_bucket_multidtype(self, bucket_mode): + def func(x, w, ag_0, ag_1, *, tag, ranks, group_size): + # do some unrelated matmuls + y = torch.mm(x, w) + + group_name = ( + torch.distributed.distributed_c10d._get_default_group().group_name + ) + + ag_0_w = torch.ops._c10d_functional.all_gather_into_tensor( + ag_0, group_size, group_name + ) + ag_0_out = torch.ops.c10d_functional.wait_tensor(ag_0_w) + ag_0_out = ag_0_out * 2 + + ag_1_w = torch.ops._c10d_functional.all_gather_into_tensor( + ag_1, group_size, group_name + ) + + ag_1_out = torch.ops.c10d_functional.wait_tensor(ag_1_w) + + return y, ag_0_out, ag_1_out + + x = torch.ones(4, 384, device="cuda", dtype=torch.float32) + w = torch.ones(384, 512, device="cuda", dtype=torch.float32) + ag_0 = torch.ones(384, 512, device="cuda", dtype=torch.bfloat16) + ag_1 = torch.ones(384, 512, device="cuda", dtype=torch.float32) + inputs = [x, w, ag_0, ag_1] + correct = func(*inputs, **self.get_world_trs()) + + with torch._inductor.config.patch( + { + "bucket_all_gathers_fx": bucket_mode, + "reorder_for_compute_comm_overlap": False, + } + ): + compiled = torch.compile(func) + code = run_and_get_triton_code(compiled, *inputs, **self.get_world_trs()) + ( + FileCheck() + .check_count( + "torch.ops._c10d_functional.all_gather_into_tensor_out.default(", + count=1, + exactly=True, + ) + .run(code) + ) + out = compiled(*inputs, **self.get_world_trs()) + _, y_ag0, y_ag1 = out + assert y_ag0.dtype == ag_0.dtype + assert y_ag1.dtype == ag_1.dtype + + assert same(out, correct), f"{out} va {correct}" + @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") @unittest.skipIf(not SM80OrLater, "bfloat16") @parametrize("bucket_mode", ["all", "all_custom_ops"]) diff --git a/torch/_inductor/fx_passes/bucketing.py b/torch/_inductor/fx_passes/bucketing.py index 7260a6dc203b..84d6bc5a1950 100644 --- a/torch/_inductor/fx_passes/bucketing.py +++ b/torch/_inductor/fx_passes/bucketing.py @@ -1,7 +1,8 @@ import collections import logging +import operator from collections import defaultdict -from typing import Any, Callable +from typing import Any, Callable, Literal, TypeAlias import torch import torch.distributed as dist @@ -17,16 +18,24 @@ from torch.utils._ordered_set import OrderedSet logger: logging.Logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) +BucketMode: TypeAlias = Literal["default", "custom_ops", "custom_ops_multidtype"] + # Helper functions moved to top for better organization -def _ag_group_key(node: torch.fx.Node) -> tuple[str, torch.dtype]: +def _ag_group_key(node: torch.fx.Node) -> tuple[str, torch.dtype]: # type: ignore[name-defined] _, group_size, group_name = node.args dtype = node.meta["val"].dtype assert isinstance(group_name, str) return (group_name, dtype) -def _rs_group_key(node: torch.fx.Node) -> tuple[str, str, torch.dtype]: +def _ag_group_key_multidtype(node: torch.fx.Node) -> tuple[str]: + _, group_size, group_name = node.args + assert isinstance(group_name, str) + return (group_name,) + + +def _rs_group_key(node: torch.fx.Node) -> tuple[str, str, torch.dtype]: # type: ignore[name-defined] _, reduce_op, group_size, group_name = node.args dtype = node.meta["val"].dtype assert isinstance(group_name, str) @@ -53,6 +62,11 @@ def bucket_key(node: torch.fx.Node) -> object | None: return None +def pick_bucket_dtype(dtypes: list[torch.dtype]) -> torch.dtype: # type: ignore[name-defined] + assert len(dtypes) > 0 + return min(dtypes, key=operator.attrgetter("itemsize")) + + def bucket_cap_mb_by_bucket_idx_default(bucket_id: int) -> float: """ Determine the size of a bucket based on its ID. @@ -69,7 +83,7 @@ def bucket_cap_mb_by_bucket_idx_default(bucket_id: int) -> float: def bucket_all_gather( gm: torch.fx.GraphModule, bucket_cap_mb_by_bucket_idx: Callable[[int], float] | None = None, - mode: str | None = None, + mode: BucketMode = "default", ) -> None: if bucket_cap_mb_by_bucket_idx is None: from torch._inductor.fx_passes.bucketing import ( @@ -77,7 +91,7 @@ def bucket_all_gather( ) bucket_cap_mb_by_bucket_idx = bucket_cap_mb_by_bucket_idx_default - ag_buckets = bucket_all_gather_by_mb(gm, bucket_cap_mb_by_bucket_idx) + ag_buckets = bucket_all_gather_by_mb(gm, bucket_cap_mb_by_bucket_idx, None, mode) if len(ag_buckets) == 0: return merge_all_gather(gm, ag_buckets, mode) @@ -86,7 +100,7 @@ def bucket_all_gather( def bucket_reduce_scatter( gm: torch.fx.GraphModule, bucket_cap_mb_by_bucket_idx: Callable[[int], float] | None = None, - mode: str | None = None, + mode: BucketMode = "default", ) -> None: if bucket_cap_mb_by_bucket_idx is None: from torch._inductor.fx_passes.bucketing import ( @@ -94,7 +108,9 @@ def bucket_reduce_scatter( ) bucket_cap_mb_by_bucket_idx = bucket_cap_mb_by_bucket_idx_default - rs_buckets = bucket_reduce_scatter_by_mb(gm, bucket_cap_mb_by_bucket_idx) + rs_buckets = bucket_reduce_scatter_by_mb( + gm, bucket_cap_mb_by_bucket_idx, None, mode + ) if len(rs_buckets) == 0: return merge_reduce_scatter(gm, rs_buckets, mode) @@ -252,6 +268,7 @@ def bucket_all_gather_by_mb( gm: torch.fx.GraphModule, bucket_cap_mb_by_bucket_idx: Callable[[int], float], filter_wait_node: Callable[[torch.fx.Node], bool] | None = None, + mode: BucketMode = "default", ) -> list[list[torch.fx.Node]]: """ Identifies all all_gather nodes and groups them into buckets, @@ -271,11 +288,15 @@ def bucket_all_gather_by_mb( list[list[torch.fx.Node]]: List of buckets, where each bucket is a list of all_gather nodes. """ + group_key_fn = ( + _ag_group_key_multidtype if mode and "multidtype" in mode else _ag_group_key + ) + return greedy_bucket_collective_by_mb( gm, bucket_cap_mb_by_bucket_idx, is_all_gather_into_tensor, - _ag_group_key, + group_key_fn, filter_wait_node, ) @@ -284,6 +305,7 @@ def bucket_reduce_scatter_by_mb( gm: torch.fx.GraphModule, bucket_cap_mb_by_bucket_idx: Callable[[int], float], filter_wait_node: Callable[[torch.fx.Node], bool] | None = None, + mode: BucketMode = "default", ) -> list[list[torch.fx.Node]]: """ Identifies all reduce_scatter nodes and groups them into buckets, @@ -301,6 +323,10 @@ def bucket_reduce_scatter_by_mb( list[list[torch.fx.Node]]: List of buckets, where each bucket is a list of reduce_scatter nodes. """ + assert "multidtype" not in mode, ( + "reduce scatter bucketing does not support multidtype" + ) + return greedy_bucket_collective_by_mb( gm, bucket_cap_mb_by_bucket_idx, @@ -439,13 +465,17 @@ def _pre_bucket_all_gather( dtype: torch.dtype, # type: ignore[name-defined] rank: int, ) -> torch.Tensor: - ins_split_sizes = [ag_in.numel() for ag_in in ag_ins] + ins_split_sizes_bytes = [ag_in.numel() * ag_in.element_size() for ag_in in ag_ins] + bucket_dtype_size_bytes = dtype.itemsize + ins_split_sizes = [ + _bytes // bucket_dtype_size_bytes for _bytes in ins_split_sizes_bytes + ] ag_input_numel = sum(ins_split_sizes) device = ag_ins[0].device new_ag_out = torch.empty(ag_input_numel * group_size, dtype=dtype, device=device) new_ag_in = new_ag_out.narrow(0, ag_input_numel * rank, ag_input_numel) foreach_copy_dsts = torch.split(new_ag_in, ins_split_sizes) - ag_ins_flattened = [ag_in.reshape(-1) for ag_in in ag_ins] + ag_ins_flattened = [ag_in.reshape(-1).view(dtype) for ag_in in ag_ins] torch._foreach_copy_(foreach_copy_dsts, ag_ins_flattened) return new_ag_out @@ -457,7 +487,11 @@ def _pre_bucket_all_gather_fake( dtype: torch.dtype, # type: ignore[name-defined] rank: int, ) -> torch.Tensor: - ins_split_sizes = [ag_in.numel() for ag_in in ag_ins] + ins_split_sizes_bytes = [ag_in.numel() * ag_in.element_size() for ag_in in ag_ins] + bucket_dtype_size_bytes = dtype.itemsize + ins_split_sizes = [ + _bytes // bucket_dtype_size_bytes for _bytes in ins_split_sizes_bytes + ] ag_input_numel = sum(ins_split_sizes) device = ag_ins[0].device new_ag_out = torch.empty(ag_input_numel * group_size, dtype=dtype, device=device) @@ -468,14 +502,28 @@ _pre_bucket_all_gather.register_fake(_pre_bucket_all_gather_fake) def all_gather_merge_fn_to_trace_custom_ops( - ag_ins: list[torch.Tensor], + _ag_ins: list[torch.Tensor], group_size: int, group_name: str, dtype: torch.dtype, # type: ignore[name-defined] + out_dtypes: list[torch.dtype], # type: ignore[name-defined] rank: int, ) -> list[torch.Tensor]: + ag_ins = [ + torch._prims.convert_element_type(_ag_in, out_dtype) + if _ag_in.dtype != out_dtype + else _ag_in + for _ag_in, out_dtype in zip(_ag_ins, out_dtypes) + ] ins_sizes = [ag_in.shape for ag_in in ag_ins] - ins_split_sizes = [ag_in.numel() for ag_in in ag_ins] + ins_split_sizes_bytes = [ + ag_in.numel() * out_dtype.itemsize + for ag_in, out_dtype in zip(ag_ins, out_dtypes) + ] + bucket_dtype_size_bytes = dtype.itemsize + ins_split_sizes = [ + _bytes // bucket_dtype_size_bytes for _bytes in ins_split_sizes_bytes + ] ag_input_numel = sum(ins_split_sizes) new_ag_out = torch.ops.bucketing._pre_bucket_all_gather( ag_ins, group_size, group_name, dtype, rank @@ -487,14 +535,14 @@ def all_gather_merge_fn_to_trace_custom_ops( ) ) new_ag_out_reshaped = wait_tensor.reshape(group_size, -1) - outs = torch.split_with_sizes( + outs_bucket_dtype = torch.split_with_sizes( new_ag_out_reshaped, ins_split_sizes, dim=1, ) outs_reshaped = [ - o.reshape((shape[0] * group_size,) + shape[1:]) - for o, shape in zip(outs, ins_sizes) + o.view(out_dtype).reshape((shape[0] * group_size,) + shape[1:]) + for o, shape, out_dtype in zip(outs_bucket_dtype, ins_sizes, out_dtypes) ] return outs_reshaped @@ -504,6 +552,7 @@ def all_gather_merge_fn_to_trace( group_size: int, group_name: str, dtype: torch.dtype, # type: ignore[name-defined] + out_dtypes: list[torch.dtype], # type: ignore[name-defined] rank: int, ) -> list[torch.Tensor]: ins_sizes = [ag_in.shape for ag_in in ag_ins] @@ -538,6 +587,7 @@ def all_gather_merge_fn_to_trace_functional( group_size: int, group_name: str, dtype: torch.dtype, # type: ignore[name-defined] + out_dtypes: list[torch.dtype], # type: ignore[name-defined] rank: int, use_fsdp_ag_copy_in: bool = False, ) -> list[torch.Tensor]: @@ -733,7 +783,7 @@ def process_collective_bucket( def merge_reduce_scatter_bucket( g: torch.fx.Graph, rs_nodes: list[torch.fx.Node], - mode: str | None = None, + mode: BucketMode = "default", insert_before: torch.fx.Node | None = None, wait_insertion_point: torch.fx.Node | None = None, ) -> tuple[list[torch.fx.Node], dict[torch.fx.Node, torch.fx.Node]]: @@ -826,29 +876,27 @@ def merge_all_reduce_bucket( def merge_all_gather_bucket( g: torch.fx.Graph, ag_nodes: list[torch.fx.Node], - mode: str | None = None, + mode: BucketMode = "default", insert_before: torch.fx.Node | None = None, wait_insertion_point: torch.fx.Node | None = None, ) -> tuple[list[torch.fx.Node], dict[torch.fx.Node, torch.fx.Node]]: from torch.distributed.distributed_c10d import _resolve_process_group ag0 = ag_nodes[0] - ag0_val = ag0.meta["val"] _, group_size, group_name = ag0.args - dtype = ag0_val.dtype assert isinstance(group_name, str) + _ag_dtypes: list[torch.dtype] = [] # type: ignore[name-defined] for n in ag_nodes: - assert ( - n.args[1] == group_size - and n.args[2] == group_name - and n.meta["val"].dtype == dtype - ) + assert n.args[1] == group_size and n.args[2] == group_name + _ag_dtypes.append(n.meta["val"].dtype) + + bucket_dtype = pick_bucket_dtype(_ag_dtypes) # Choose merge function based on mode ag_merge_fn = all_gather_merge_fn_to_trace - if mode and "custom_ops" in mode: - ag_merge_fn = all_gather_merge_fn_to_trace_custom_ops + if mode is not None and "custom_ops" in mode: + ag_merge_fn = all_gather_merge_fn_to_trace_custom_ops # type: ignore[assignment] # Process bucket with lazy input collection rank: int = dist.get_rank(_resolve_process_group(group_name)) @@ -858,7 +906,8 @@ def merge_all_gather_bucket( pytree.tree_map(lambda node: node.meta["val"], bucket_ins), group_size, group_name, - dtype, + bucket_dtype, + _ag_dtypes, rank, ) @@ -874,7 +923,7 @@ def merge_all_gather_bucket( def merge_reduce_scatter( gm: torch.fx.GraphModule, rs_buckets: list[list[torch.fx.Node]], - mode: str | None = None, + mode: BucketMode = "default", ) -> None: """ Merges specified buckets of reduce_scatter to joint reduce_scatter. @@ -898,7 +947,7 @@ def merge_reduce_scatter( def merge_all_gather( gm: torch.fx.GraphModule, ag_buckets: list[list[torch.fx.Node]], - mode: str | None = None, + mode: BucketMode = "default", ) -> None: """ Merges specified buckets of all_gather to joint all_gather. diff --git a/torch/_inductor/fx_passes/fsdp.py b/torch/_inductor/fx_passes/fsdp.py index 73787bd928a5..6a1a2d227de1 100644 --- a/torch/_inductor/fx_passes/fsdp.py +++ b/torch/_inductor/fx_passes/fsdp.py @@ -5,6 +5,7 @@ import torch from torch._inductor.fx_passes.bucketing import ( bucket_all_gather_by_mb, bucket_reduce_scatter_by_mb, + BucketMode, merge_all_gather, merge_reduce_scatter, ) @@ -56,7 +57,7 @@ def is_fsdp_reduce_scatter_wait(wait: torch.fx.Node) -> bool: def bucket_fsdp_all_gather( gm: torch.fx.GraphModule, bucket_cap_mb_by_bucket_idx: Callable[[int], float] | None = None, - mode: str | None = None, + mode: BucketMode = "default", ) -> None: """ Bucketing pass for SimpleFSDP all_gather ops. @@ -86,7 +87,7 @@ def bucket_fsdp_all_gather( def bucket_fsdp_reduce_scatter( gm: torch.fx.GraphModule, bucket_cap_mb_by_bucket_idx: Callable[[int], float] | None = None, - mode: str | None = None, + mode: BucketMode = "default", ) -> None: """ Bucketing pass for SimpleFSDP reduce_scatter ops. diff --git a/torch/_inductor/fx_passes/post_grad.py b/torch/_inductor/fx_passes/post_grad.py index 938e15deedb2..c9a83000d215 100644 --- a/torch/_inductor/fx_passes/post_grad.py +++ b/torch/_inductor/fx_passes/post_grad.py @@ -216,7 +216,7 @@ def post_grad_passes(gm: torch.fx.GraphModule, is_inference: bool): lambda graph: p( graph.owning_module, config.bucket_reduce_scatters_fx_bucket_size_determinator, - config.bucket_reduce_scatters_fx, + config.bucket_reduce_scatters_fx, # type: ignore[arg-type] ) ) collectives_bucketing = True @@ -236,7 +236,7 @@ def post_grad_passes(gm: torch.fx.GraphModule, is_inference: bool): lambda graph: p( graph.owning_module, config.bucket_all_gathers_fx_bucket_size_determinator, - config.bucket_all_gathers_fx, + config.bucket_all_gathers_fx, # type: ignore[arg-type] ) ) collectives_bucketing = True