diff --git a/test/distributed/test_c10d_functional_native.py b/test/distributed/test_c10d_functional_native.py index 8306e58bda6d..5ae56745fa6a 100644 --- a/test/distributed/test_c10d_functional_native.py +++ b/test/distributed/test_c10d_functional_native.py @@ -1,14 +1,19 @@ # Owner(s): ["module: c10d"] +import unittest from typing import List import torch import torch.distributed as dist +from torch._C import FileCheck +from torch._dynamo.utils import same +from torch._inductor.utils import run_and_get_triton_code from torch.testing._internal.common_distributed import ( MultiProcessTestCase, requires_nccl, skip_if_lt_x_gpu, ) from torch.testing._internal.common_utils import run_tests +from torch.utils._triton import has_triton if not dist.is_available(): @@ -182,6 +187,264 @@ class C10DFunctionalNativeTest(MultiProcessTestCase): output = torch.ops._c10d_functional.wait_tensor(output) assert output.eq(self.rank * i).all() + @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") + def test_inductor_all_reduce_single(self): + torch._inductor.config.debug = True + self._init_process_group() + + def func(arg: torch.Tensor) -> torch.Tensor: + buf0 = arg + 42 + # Expect in-place with inductor allocated buf + ar0 = torch.ops._c10d_functional.all_reduce(buf0, "avg", "default") + ar0 = torch.ops._c10d_functional.wait_tensor(ar0) + # Expect no in-place with graph input + ar1 = torch.ops._c10d_functional.all_reduce(arg, "avg", "default") + ar1 = torch.ops._c10d_functional.wait_tensor(ar1) + return ar0, ar1 + + arg = torch.rand(4, 4, device=self.device) + compiled = torch.compile(func) + code = run_and_get_triton_code(compiled, arg) + ( + FileCheck() + .check("buf0 = empty(") + .check("buf5 = empty(") + # Expect in-place with inductor allocated buf + .check("torch.ops._c10d_functional.all_reduce_.default(buf0") + .check("torch.ops._c10d_functional.wait_tensor.default(buf0") + # Expect no in-place with graph input (buf5 is a clone) + .check("torch.ops._c10d_functional.all_reduce_.default(buf5") + .check("torch.ops._c10d_functional.wait_tensor.default(buf5") + # Expect no extra copy on return + .check("return (buf0, buf5, )") + .run(code) + ) + out = compiled(arg) + correct = func(arg) + assert same(out, correct), f"{out} va {correct}" + + @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") + def test_inductor_all_reduce_coalesced(self): + torch._inductor.config.debug = True + self._init_process_group() + + def func(args: List[torch.Tensor]) -> torch.Tensor: + bufs = [arg + 42 for arg in args] + # Expect in-place with inductor allocated buf + ar0 = torch.ops._c10d_functional.all_reduce_coalesced( + bufs, "avg", "default" + ) + ar0 = [torch.ops._c10d_functional.wait_tensor(out) for out in ar0] + # Expect no in-place with graph input + ar1 = torch.ops._c10d_functional.all_reduce_coalesced( + args, "avg", "default" + ) + ar1 = [torch.ops._c10d_functional.wait_tensor(out) for out in ar1] + return ar0, ar1 + + args = [torch.rand(4, 4, device=self.device) for _ in range(2)] + compiled = torch.compile(func) + code = run_and_get_triton_code(compiled, args) + ( + FileCheck() + .check("buf0 = empty(") + .check("buf5 = empty(") + .check("buf1 = empty(") + .check("buf6 = empty(") + # Expect in-place with inductor allocated buf + .check( + "torch.ops._c10d_functional.all_reduce_coalesced_" + ".default([buf0, buf1]" + ) + # Expect no in-place with graph input (buf5, buf6 are clones) + .check( + "torch.ops._c10d_functional.all_reduce_coalesced_" + ".default([buf5, buf6]" + ) + .check("torch.ops._c10d_functional.wait_tensor.default(buf0") + .check("torch.ops._c10d_functional.wait_tensor.default(buf1") + .check("torch.ops._c10d_functional.wait_tensor.default(buf5") + .check("torch.ops._c10d_functional.wait_tensor.default(buf6") + # Expect no extra copy on return + .check("return (buf0, buf1, buf5, buf6, )") + .run(code) + ) + out = compiled(args) + correct = func(args) + assert same(out, correct), f"{out} va {correct}" + + @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") + def test_inductor_reuse_buffer_after_inplace_collective(self): + torch._inductor.config.debug = True + self._init_process_group() + + def func(arg: torch.Tensor) -> torch.Tensor: + # Expect allocation + buf0 = arg + 42 + ar0 = torch.ops._c10d_functional.all_reduce(buf0, "avg", "default") + ar0 = torch.ops._c10d_functional.wait_tensor(ar0) + # Expect allocation + buf1 = torch.mm(arg, ar0) + # Expect buf0 to be reused + buf2 = torch.mm(arg, buf1) + return buf1, buf2 + + arg = torch.rand(4, 4, device=self.device) + compiled = torch.compile(func) + code = run_and_get_triton_code(compiled, arg) + ( + FileCheck() + # Expect allocation + .check("buf0 = empty(") + .check("torch.ops._c10d_functional.all_reduce_.default(buf0") + .check("torch.ops._c10d_functional.wait_tensor.default(buf0") + # Expect allocation + .check("buf5 = empty(") + .check("extern_kernels.mm(arg0_1, buf0, out=buf5") + # Expect buf0 to be reused + .check("buf6 = buf0; del buf0 # reuse") + .check("extern_kernels.mm(arg0_1, buf5, out=buf6") + # Expect no extra copy on return + .check("return (buf5, buf6, )") + .run(code) + ) + out = compiled(arg) + correct = func(arg) + assert same(out, correct), f"{out} va {correct}" + + @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") + def test_inductor_all_gather_into_tensor_single(self): + torch._inductor.config.debug = True + self._init_process_group() + + def func(arg: torch.Tensor) -> torch.Tensor: + ag0 = torch.ops._c10d_functional.all_gather_into_tensor( + arg, self.world_size, "default" + ) + ag0 = torch.ops._c10d_functional.wait_tensor(ag0) + return ag0 + + arg = torch.rand(4, 4, device=self.device) + compiled = torch.compile(func) + code = run_and_get_triton_code(compiled, arg) + ( + FileCheck() + .check( + "buf0 = torch.ops._c10d_functional.all_gather_into_tensor.default(arg0_1" + ) + .check("torch.ops._c10d_functional.wait_tensor.default(buf0") + # Expect no extra copy on return + .check("return (buf0, )") + .run(code) + ) + out = compiled(arg) + correct = func(arg) + assert same(out, correct), f"{out} va {correct}" + + @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") + def test_inductor_all_gather_into_tensor_coalesced(self): + torch._inductor.config.debug = True + self._init_process_group() + + def func(args: List[torch.Tensor]) -> torch.Tensor: + ag0 = torch.ops._c10d_functional.all_gather_into_tensor_coalesced( + args, self.world_size, "default" + ) + ag0 = [torch.ops._c10d_functional.wait_tensor(out) for out in ag0] + return ag0 + + args = [torch.rand(4, 4, device=self.device) for _ in range(4)] + compiled = torch.compile(func) + code = run_and_get_triton_code(compiled, args) + print(code) + ( + FileCheck() + .check( + "buf0 = torch.ops._c10d_functional.all_gather_into_tensor_coalesced" + ".default([arg0_1, arg1_1, arg2_1, arg3_1]" + ) + .check("buf1 = buf0[0]") + .check("buf2 = buf0[1]") + .check("buf3 = buf0[2]") + .check("buf4 = buf0[3]") + .check("torch.ops._c10d_functional.wait_tensor.default(buf1") + .check("torch.ops._c10d_functional.wait_tensor.default(buf2") + .check("torch.ops._c10d_functional.wait_tensor.default(buf3") + .check("torch.ops._c10d_functional.wait_tensor.default(buf4") + # Expect no extra copy on return + .check("return (buf1, buf2, buf3, buf4, )") + .run(code) + ) + out = compiled(args) + correct = func(args) + assert same(out, correct), f"{out} va {correct}" + + @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") + def test_inductor_reduce_scatter_tensor_single(self): + torch._inductor.config.debug = True + self._init_process_group() + + def func(arg: torch.Tensor) -> torch.Tensor: + rs0 = torch.ops._c10d_functional.reduce_scatter_tensor( + arg, "avg", self.world_size, "default" + ) + rs0 = torch.ops._c10d_functional.wait_tensor(rs0) + return rs0 + + arg = torch.rand(4, 4, device=self.device) + compiled = torch.compile(func) + code = run_and_get_triton_code(compiled, arg) + ( + FileCheck() + .check( + "buf0 = torch.ops._c10d_functional.reduce_scatter_tensor.default(arg0_1" + ) + .check("torch.ops._c10d_functional.wait_tensor.default(buf0") + # Expect no extra copy on return + .check("return (buf0, )") + .run(code) + ) + out = compiled(arg) + correct = func(arg) + assert same(out, correct), f"{out} va {correct}" + + @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") + def test_inductor_reduce_scatter_tensor_coalesced(self): + torch._inductor.config.debug = True + self._init_process_group() + + def func(args: List[torch.Tensor]) -> torch.Tensor: + rs0 = torch.ops._c10d_functional.reduce_scatter_tensor_coalesced( + args, "avg", self.world_size, "default" + ) + rs0 = [torch.ops._c10d_functional.wait_tensor(out) for out in rs0] + return rs0 + + args = [torch.rand(4, 4, device=self.device) for _ in range(4)] + compiled = torch.compile(func) + code = run_and_get_triton_code(compiled, args) + ( + FileCheck() + .check( + "buf0 = torch.ops._c10d_functional.reduce_scatter_tensor_coalesced" + ".default([arg0_1, arg1_1, arg2_1, arg3_1]" + ) + .check("buf1 = buf0[0]") + .check("buf2 = buf0[1]") + .check("buf3 = buf0[2]") + .check("buf4 = buf0[3]") + .check("torch.ops._c10d_functional.wait_tensor.default(buf1") + .check("torch.ops._c10d_functional.wait_tensor.default(buf2") + .check("torch.ops._c10d_functional.wait_tensor.default(buf3") + .check("torch.ops._c10d_functional.wait_tensor.default(buf4") + # Expect no extra copy on return + .check("return (buf1, buf2, buf3, buf4, )") + .run(code) + ) + out = compiled(args) + correct = func(args) + assert same(out, correct), f"{out} va {correct}" + if __name__ == "__main__": run_tests() diff --git a/torch/_inductor/fx_passes/README.md b/torch/_inductor/fx_passes/README.md index 65e06ff2c809..ae59ca46619a 100644 --- a/torch/_inductor/fx_passes/README.md +++ b/torch/_inductor/fx_passes/README.md @@ -35,4 +35,4 @@ inputs and outputs have any aliasing, it suffices to check whether the storages of the input and the storages of the output have any overlap. See `remove_noop_ops` for an example of how to do this. -Additionally, we do have one pass that *does* introduce mutation - `reinplace_scatters`. This pass must run *just before Inductor lowering*, as otherwise this breaks our invariant. +Additionally, we do have one pass that *does* introduce mutation - `reinplace_inplaceable_ops`. This pass must run *just before Inductor lowering*, as otherwise this breaks our invariant. diff --git a/torch/_inductor/fx_passes/post_grad.py b/torch/_inductor/fx_passes/post_grad.py index bf1910678dcb..f2a70e0e5a32 100644 --- a/torch/_inductor/fx_passes/post_grad.py +++ b/torch/_inductor/fx_passes/post_grad.py @@ -93,7 +93,7 @@ def post_grad_passes(gm: torch.fx.GraphModule, is_inference: bool): # Keep this last, since it introduces mutation. Look at # ./fx_passes/README.md for a discussion of mutation invariants. - reinplace_scatters(gm.graph) + reinplace_inplaceable_ops(gm.graph) gm.recompile() gm.graph.lint() @@ -625,9 +625,9 @@ def remove_noop_ops(graph: torch.fx.Graph): InplaceableOp = namedtuple("InplaceableOp", ["inplace_op", "mutated_arg"]) -def reinplace_scatters(graph): +def reinplace_inplaceable_ops(graph): """ - Reinplaces scatter operations. + Reinplaces in-placeable operations. If there are no uses of a view of the mutated arg after the current node, it is possible to inplace the op. This above algorithm could be justified by observing side effects. While @@ -676,6 +676,9 @@ def reinplace_scatters(graph): return False def can_inplace(node, mutated_arg): + if isinstance(mutated_arg, (list, tuple)): + return all(can_inplace(node, arg) for arg in mutated_arg) + if get_node_storage(mutated_arg) is None: return False shared_view_nodes = storage_to_nodes[get_node_storage(mutated_arg)] @@ -690,7 +693,6 @@ def reinplace_scatters(graph): ): return False - graph.erase_node(copy_node) return True elif any(view.op == "placeholder" for view in shared_view_nodes): # If mutated arg is view of any of the inputs of the graph, @@ -708,11 +710,35 @@ def reinplace_scatters(graph): inductor_prims._unsafe_index_put_, 0 ), } + + try: + c10d_functional = torch.ops._c10d_functional + inplaceable_collective_ops = { + c10d_functional.all_reduce.default: InplaceableOp( + c10d_functional.all_reduce_.default, 0 + ), + c10d_functional.all_reduce_coalesced.default: InplaceableOp( + c10d_functional.all_reduce_coalesced_.default, 0 + ), + } + inplaceable_ops.update(inplaceable_collective_ops) + except AttributeError: + # _c10d_functional ops are only available when torch + # is built with USE_DISTRIBUTED=1. + pass + inplaceable_triton_ops = {triton_kernel_wrapper_functional} for node in graph.nodes: if (inplaceable_op := inplaceable_ops.get(node.target, None)) is not None: - if can_inplace(node, node.args[inplaceable_op.mutated_arg]): + mutated_arg = node.args[inplaceable_op.mutated_arg] + if can_inplace(node, mutated_arg): + # TODO(yifu): this doesn't properly remove copy epilogues for + # ops that mutate multiple inputs. Need to revise the copy + # node tracking logic to support the case. + copy_node = copy_args_to_copy_nodes.get((mutated_arg, node)) + if copy_node is not None: + graph.erase_node(copy_node) node.target = inplaceable_op.inplace_op elif node.target in inplaceable_triton_ops: # inplaceable_triton_ops take an additional argument called @@ -722,7 +748,12 @@ def reinplace_scatters(graph): tensors_to_clone = [] for arg in node.kwargs["tensors_to_clone"]: assert arg in node.kwargs["kwargs"] - if not can_inplace(node, node.kwargs["kwargs"][arg]): + mutated_arg = node.kwargs["kwargs"][arg] + if can_inplace(node, mutated_arg): + copy_node = copy_args_to_copy_nodes.get((mutated_arg, node)) + if copy_node is not None: + graph.erase_node(copy_node) + else: tensors_to_clone.append(arg) kwargs = dict(node.kwargs) kwargs["tensors_to_clone"] = tensors_to_clone diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py index d1c368ecbfc6..43a729e68872 100644 --- a/torch/_inductor/ir.py +++ b/torch/_inductor/ir.py @@ -2693,7 +2693,7 @@ class Buffer(IRNode): # dynamic; it has i0 as one of the arguments. We cannot tell this # directly from MultiOutput, we have to look at the input buffer's # uses to work this out. No big deal. - if isinstance(self.layout, MultiOutputLayout): + if isinstance(self.layout, (NoneLayout, MultiOutputLayout)): return set() # This kernel defines all unbacked symbols... that it didn't get in as @@ -4496,6 +4496,15 @@ class FallbackKernel(ExternKernelAlloc): else: super().codegen(wrapper) + @staticmethod + def tensor_to_layout(output: torch.Tensor): + return FixedLayout( + output.device, + output.dtype, + convert_shape_to_inductor(output.size()), + convert_shape_to_inductor(output.stride()), + ) + @classmethod def create(cls, kernel, *args, **kwargs): fake_incorrect_kernels = (aten._fused_moving_avg_obs_fq_helper_functional,) @@ -4511,18 +4520,10 @@ class FallbackKernel(ExternKernelAlloc): schema, ) = cls.process_kernel(kernel, *args, **kwargs) - device = FallbackKernel.find_device(tensor_args, example_output) + device = cls.find_device(tensor_args, example_output) assert device, "Not sure where to find device info" - def tensor_to_layout(output: torch.Tensor): - return FixedLayout( - output.device, - output.dtype, - convert_shape_to_inductor(output.size()), - convert_shape_to_inductor(output.stride()), - ) - - packed = FallbackKernel( + packed = cls( MultiOutputLayout(device), kernel, tensor_args, @@ -4544,7 +4545,7 @@ class FallbackKernel(ExternKernelAlloc): } elif isinstance(output, torch.Tensor): return MultiOutput( - tensor_to_layout(output), + cls.tensor_to_layout(output), packed, indices, ) @@ -6930,6 +6931,176 @@ class ReduceScatterTensorCoalesced(OutOfPlaceCollectiveKernel): ) +# TODO(yifu): replace the CollectiveKernel IR hierarchy with _CollectiveKernel. +class _CollectiveKernel(FallbackKernel): + def should_allocate(self): + return False + + def has_side_effects(self): + return True + + # This is identical to FallbackKernel.set_cpp_kernel(), minus the + # part that checks against input aliasing and mutation. + def set_cpp_kernel(self, kernel): + from .codegen.wrapper import get_cpp_op_schema + + self.kernel = kernel._schema.name + self.cpp_kernel_overlad_name = kernel._schema.overload_name + self.cpp_kernel_key = ( + f"{self.kernel.replace('::', '_')}_{self.cpp_kernel_overlad_name}" + ) + + self.cpp_op_schema = get_cpp_op_schema(kernel) + self.ordered_kwargs_for_cpp_kernel = [ + x.name for x in kernel._schema.arguments if x.kwarg_only + ] + + # NOTE: [In-Place Collective Safety] + # Between the initiation and completion of an in-place collective, the + # input buffers are subject to both volatile reads and volatile writes. + # They must not be read, written to or reused by another kernel. To ensure + # the constraints, we model collective -> wait_tensor as as two-step + # mutation of the input buffers. + @classmethod + def create_inplace( + cls, kernel, inputs: Union[TensorBox, List[TensorBox]], *args, **kwargs + ) -> None: + with V.graph.fake_mode: + ( + example_output, + tensor_args, + non_tensor_args, + unflatten_args, + schema, + ) = cls.process_kernel(kernel, inputs, *args, **kwargs) + for tensor_arg in tensor_args: + tensor_arg.realize() + + packed = cls( + NoneLayout(tensor_args[0].get_device()), + kernel, + tensor_args, + non_tensor_args, + unflatten_args, + schema=schema, + ) + pytree.tree_map(lambda x: MutationOutput(x.layout, x, packed), inputs) + + # NOTE: [Out-of-Place Collective Safety] + # Between the initiation and completion of an out-of-place collective: + # + # Input buffers: + # - Are subject to volatile reads + # - Can be read by another kernel + # - Must not be written to or reused by another kernel + # + # Output buffers: + # - Are subject to volatile writes + # - Must not be read, written to or reused by another kernel + # + # To ensure the safety of input buffers without sacrificing read + # availability, we add input buffers as read deps of wait_tensor kernels. + # + # To ensure the safety of output buffers, we model wait_tensor as a + # mutation to the output buffer. Note we also assumes the user program being + # correct and the output buffer is not consumed by kernels other than + # wait_tensor. + # + # TODO(yifu): add a pre-grad pass to validate the correctness of collective + # usage in the user program. + @classmethod + def create_out_of_place( + cls, kernel, inputs: Union[TensorBox, List[TensorBox]], *args, **kwargs + ): + with V.graph.fake_mode: + ( + example_output, + tensor_args, + non_tensor_args, + unflatten_args, + schema, + ) = cls.process_kernel(kernel, inputs, *args, **kwargs) + for tensor_arg in tensor_args: + tensor_arg.realize() + + if isinstance(example_output, list): + device = cls.find_device(tensor_args, example_output) + packed = cls( + MultiOutputLayout(device), + kernel, + tensor_args, + non_tensor_args, + unflatten_args, + schema=schema, + ) + packed.outputs = [ + MultiOutput( + cls.tensor_to_layout(tensor), + packed, + [(list, i)], + ) + for i, tensor in enumerate(example_output) + ] + return packed.outputs + else: + packed = cls( + cls.tensor_to_layout(example_output), + kernel, + tensor_args, + non_tensor_args, + unflatten_args, + schema=schema, + ) + packed.outputs = [packed] + return packed + + +class _WaitKernel(_CollectiveKernel): + def get_volatile_reads(self): + inp = self.inputs[0] + if isinstance(inp, _CollectiveKernel): + # Out-of-place single-output + return [inp.inputs[0]] + elif isinstance(inp, MultiOutput): + # Out-of-place multi-output + coll = inp.inputs[0] + assert isinstance(coll, _CollectiveKernel) + _, idx = inp.indices[0] + return [coll.inputs[idx]] + else: + # In-place requires no additional deps handling for volatile + # reads since the inputs are mutated. + return [] + + @classmethod + def create_wait(cls, kernel, inp: TensorBox) -> None: + with V.graph.fake_mode: + ( + example_output, + tensor_args, + non_tensor_args, + unflatten_args, + schema, + ) = cls.process_kernel(kernel, inp) + packed = cls( + NoneLayout(inp.get_device()), + kernel, + tensor_args, + non_tensor_args, + unflatten_args, + schema=schema, + ) + MutationOutput(inp.layout, inp, packed) + + def get_read_writes(self): + read_writes = super().get_read_writes() + # See [Out-of-Place Collective Safety]. + volatile_reads = self.get_volatile_reads() + for vr in volatile_reads: + read_writes.reads.add(dependencies.StarDep(vr.get_name())) + return read_writes + + # NB: recursive structure here reflects val_to_arg_str, avoid # calling free_unbacked_symbols on "exotic" types that don't get pexpr # treatment diff --git a/torch/_inductor/lowering.py b/torch/_inductor/lowering.py index fbd9eaed6103..ef6d9718ed9a 100644 --- a/torch/_inductor/lowering.py +++ b/torch/_inductor/lowering.py @@ -5132,6 +5132,97 @@ try: ) ) + _c10d_functional = torch.ops._c10d_functional + + @register_lowering(_c10d_functional.all_reduce) + def _all_reduce(inp, reduce_op, group_name): + inp = clone(inp) + ir._CollectiveKernel.create_inplace( + _c10d_functional.all_reduce_.default, inp, reduce_op, group_name + ) + return inp + + @register_lowering(_c10d_functional.all_reduce_) + def _all_reduce_(inp, reduce_op, group_name): + ir._CollectiveKernel.create_inplace( + _c10d_functional.all_reduce_.default, inp, reduce_op, group_name + ) + return inp + + @register_lowering(_c10d_functional.all_reduce_coalesced) + def _all_reduce_coalesced(inputs, reduce_op, group_name): + inputs = [clone(inp) for inp in inputs] + ir._CollectiveKernel.create_inplace( + _c10d_functional.all_reduce_coalesced_.default, + inputs, + reduce_op, + group_name, + ) + return inputs + + @register_lowering(_c10d_functional.all_reduce_coalesced_) + def _all_reduce_coalesced_(inputs, reduce_op, group_name): + ir._CollectiveKernel.create_inplace( + _c10d_functional.all_reduce_coalesced_.default, + inputs, + reduce_op, + group_name, + ) + return inputs + + @register_lowering(_c10d_functional.all_gather_into_tensor) + def _all_gather_into_tensor(inp, group_size, group_name): + return ir.TensorBox.create( + ir._CollectiveKernel.create_out_of_place( + _c10d_functional.all_gather_into_tensor.default, + inp, + group_size, + group_name, + ) + ) + + @register_lowering(_c10d_functional.all_gather_into_tensor_coalesced) + def _all_gather_into_tensor_coalesced(inputs, group_size, group_name): + return pytree.tree_map( + ir.TensorBox.create, + ir._CollectiveKernel.create_out_of_place( + _c10d_functional.all_gather_into_tensor_coalesced.default, + inputs, + group_size, + group_name, + ), + ) + + @register_lowering(_c10d_functional.reduce_scatter_tensor) + def _reduce_scatter_tensor(inp, reduce_op, group_size, group_name): + return ir.TensorBox.create( + ir._CollectiveKernel.create_out_of_place( + _c10d_functional.reduce_scatter_tensor.default, + inp, + reduce_op, + group_size, + group_name, + ) + ) + + @register_lowering(_c10d_functional.reduce_scatter_tensor_coalesced) + def _reduce_scatter_tensor_coalesced(inputs, reduce_op, group_size, group_name): + return pytree.tree_map( + ir.TensorBox.create, + ir._CollectiveKernel.create_out_of_place( + _c10d_functional.reduce_scatter_tensor_coalesced.default, + inputs, + reduce_op, + group_size, + group_name, + ), + ) + + @register_lowering(_c10d_functional.wait_tensor) + def _wait_tensor(inp): + ir._WaitKernel.create_wait(_c10d_functional.wait_tensor.default, inp) + return inp + except ImportError: log.info( "Inductor support for distributed collectives depends on building torch.distributed" diff --git a/torch/distributed/_functional_collectives.py b/torch/distributed/_functional_collectives.py index f989e468cd86..ab6dce6a7f3e 100644 --- a/torch/distributed/_functional_collectives.py +++ b/torch/distributed/_functional_collectives.py @@ -542,6 +542,12 @@ def _reduce_scatter_tensor_meta(input, reduce_op, tag, rankset, group_size): def _all_reduce_coalesced_meta(self, *args): return [torch.empty_like(t) for t in self] +def _all_reduce__meta(inp, *args): + return inp + +def _all_reduce_coalesced__meta(inputs, *args): + return inputs + def _reduce_scatter_tensor_coalesced_meta(inputs, reduceOp, tag, rankset, group_size): def mk_out_tensor(input): out_size = list(input.size()) @@ -577,15 +583,15 @@ def _all_gather_into_tensor_coalesced_native_meta(inputs, group_size, group_name for input in inputs ] -def _reduce_scatter_tensor_native_meta(input, group_size, group_name): - shape = list(input.size()) +def _reduce_scatter_tensor_native_meta(inp, reduce_op, group_size, group_name): + shape = list(inp.size()) shape[0] //= group_size - return input.new_empty(shape) + return inp.new_empty(shape) -def _reduce_scatter_tensor_coalesced_native_meta(inputs, group_size, group_name): +def _reduce_scatter_tensor_coalesced_native_meta(inputs, reduce_op, group_size, group_name): return [ - _reduce_scatter_tensor_native_meta(input, group_size, group_name) - for input in inputs + _reduce_scatter_tensor_native_meta(inp, reduce_op, group_size, group_name) + for inp in inputs ] def _register_ops(): @@ -620,7 +626,9 @@ if not torch._running_with_deploy(): _c10_lib_impl = torch.library.Library("_c10d_functional", "IMPL") _c10_lib_impl.impl("all_reduce", _all_reduce_meta, "Meta") + _c10_lib_impl.impl("all_reduce_", _all_reduce__meta, "Meta") _c10_lib_impl.impl("all_reduce_coalesced", _all_reduce_coalesced_meta, "Meta") + _c10_lib_impl.impl("all_reduce_coalesced_", _all_reduce_coalesced__meta, "Meta") _c10_lib_impl.impl("wait_tensor", _wait_tensor_meta, "Meta") _c10_lib_impl.impl("all_gather_into_tensor", _all_gather_into_tensor_native_meta, "Meta") _c10_lib_impl.impl("all_gather_into_tensor_coalesced", _all_gather_into_tensor_coalesced_native_meta, "Meta")