Inductor support for native c10d_functional (#112439)

This PR adds Inductor support for [native c10d_functional ops](https://github.com/pytorch/pytorch/pull/110570).

The Inductor IRs introduced in this PR will replace the existing `CollectiveKernel` IR hierarchy. Compared to the existing collective IRs, the new IRs:
- Are target language agnostic and support AOTInductor.
- Express the constraints solely with read/write deps. This maximizes the potential for buffer reuse.
- Address an issue where out-of-place collective's input buffers could be mutated while being volatilely read.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/112439
Approved by: https://github.com/Chillee
This commit is contained in:
Yifu Wang
2023-11-08 10:59:10 -08:00
committed by PyTorch MergeBot
parent 297c26bb8e
commit 625958d8bc
6 changed files with 589 additions and 25 deletions

View File

@ -1,14 +1,19 @@
# Owner(s): ["module: c10d"]
import unittest
from typing import List
import torch
import torch.distributed as dist
from torch._C import FileCheck
from torch._dynamo.utils import same
from torch._inductor.utils import run_and_get_triton_code
from torch.testing._internal.common_distributed import (
MultiProcessTestCase,
requires_nccl,
skip_if_lt_x_gpu,
)
from torch.testing._internal.common_utils import run_tests
from torch.utils._triton import has_triton
if not dist.is_available():
@ -182,6 +187,264 @@ class C10DFunctionalNativeTest(MultiProcessTestCase):
output = torch.ops._c10d_functional.wait_tensor(output)
assert output.eq(self.rank * i).all()
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
def test_inductor_all_reduce_single(self):
torch._inductor.config.debug = True
self._init_process_group()
def func(arg: torch.Tensor) -> torch.Tensor:
buf0 = arg + 42
# Expect in-place with inductor allocated buf
ar0 = torch.ops._c10d_functional.all_reduce(buf0, "avg", "default")
ar0 = torch.ops._c10d_functional.wait_tensor(ar0)
# Expect no in-place with graph input
ar1 = torch.ops._c10d_functional.all_reduce(arg, "avg", "default")
ar1 = torch.ops._c10d_functional.wait_tensor(ar1)
return ar0, ar1
arg = torch.rand(4, 4, device=self.device)
compiled = torch.compile(func)
code = run_and_get_triton_code(compiled, arg)
(
FileCheck()
.check("buf0 = empty(")
.check("buf5 = empty(")
# Expect in-place with inductor allocated buf
.check("torch.ops._c10d_functional.all_reduce_.default(buf0")
.check("torch.ops._c10d_functional.wait_tensor.default(buf0")
# Expect no in-place with graph input (buf5 is a clone)
.check("torch.ops._c10d_functional.all_reduce_.default(buf5")
.check("torch.ops._c10d_functional.wait_tensor.default(buf5")
# Expect no extra copy on return
.check("return (buf0, buf5, )")
.run(code)
)
out = compiled(arg)
correct = func(arg)
assert same(out, correct), f"{out} va {correct}"
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
def test_inductor_all_reduce_coalesced(self):
torch._inductor.config.debug = True
self._init_process_group()
def func(args: List[torch.Tensor]) -> torch.Tensor:
bufs = [arg + 42 for arg in args]
# Expect in-place with inductor allocated buf
ar0 = torch.ops._c10d_functional.all_reduce_coalesced(
bufs, "avg", "default"
)
ar0 = [torch.ops._c10d_functional.wait_tensor(out) for out in ar0]
# Expect no in-place with graph input
ar1 = torch.ops._c10d_functional.all_reduce_coalesced(
args, "avg", "default"
)
ar1 = [torch.ops._c10d_functional.wait_tensor(out) for out in ar1]
return ar0, ar1
args = [torch.rand(4, 4, device=self.device) for _ in range(2)]
compiled = torch.compile(func)
code = run_and_get_triton_code(compiled, args)
(
FileCheck()
.check("buf0 = empty(")
.check("buf5 = empty(")
.check("buf1 = empty(")
.check("buf6 = empty(")
# Expect in-place with inductor allocated buf
.check(
"torch.ops._c10d_functional.all_reduce_coalesced_"
".default([buf0, buf1]"
)
# Expect no in-place with graph input (buf5, buf6 are clones)
.check(
"torch.ops._c10d_functional.all_reduce_coalesced_"
".default([buf5, buf6]"
)
.check("torch.ops._c10d_functional.wait_tensor.default(buf0")
.check("torch.ops._c10d_functional.wait_tensor.default(buf1")
.check("torch.ops._c10d_functional.wait_tensor.default(buf5")
.check("torch.ops._c10d_functional.wait_tensor.default(buf6")
# Expect no extra copy on return
.check("return (buf0, buf1, buf5, buf6, )")
.run(code)
)
out = compiled(args)
correct = func(args)
assert same(out, correct), f"{out} va {correct}"
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
def test_inductor_reuse_buffer_after_inplace_collective(self):
torch._inductor.config.debug = True
self._init_process_group()
def func(arg: torch.Tensor) -> torch.Tensor:
# Expect allocation
buf0 = arg + 42
ar0 = torch.ops._c10d_functional.all_reduce(buf0, "avg", "default")
ar0 = torch.ops._c10d_functional.wait_tensor(ar0)
# Expect allocation
buf1 = torch.mm(arg, ar0)
# Expect buf0 to be reused
buf2 = torch.mm(arg, buf1)
return buf1, buf2
arg = torch.rand(4, 4, device=self.device)
compiled = torch.compile(func)
code = run_and_get_triton_code(compiled, arg)
(
FileCheck()
# Expect allocation
.check("buf0 = empty(")
.check("torch.ops._c10d_functional.all_reduce_.default(buf0")
.check("torch.ops._c10d_functional.wait_tensor.default(buf0")
# Expect allocation
.check("buf5 = empty(")
.check("extern_kernels.mm(arg0_1, buf0, out=buf5")
# Expect buf0 to be reused
.check("buf6 = buf0; del buf0 # reuse")
.check("extern_kernels.mm(arg0_1, buf5, out=buf6")
# Expect no extra copy on return
.check("return (buf5, buf6, )")
.run(code)
)
out = compiled(arg)
correct = func(arg)
assert same(out, correct), f"{out} va {correct}"
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
def test_inductor_all_gather_into_tensor_single(self):
torch._inductor.config.debug = True
self._init_process_group()
def func(arg: torch.Tensor) -> torch.Tensor:
ag0 = torch.ops._c10d_functional.all_gather_into_tensor(
arg, self.world_size, "default"
)
ag0 = torch.ops._c10d_functional.wait_tensor(ag0)
return ag0
arg = torch.rand(4, 4, device=self.device)
compiled = torch.compile(func)
code = run_and_get_triton_code(compiled, arg)
(
FileCheck()
.check(
"buf0 = torch.ops._c10d_functional.all_gather_into_tensor.default(arg0_1"
)
.check("torch.ops._c10d_functional.wait_tensor.default(buf0")
# Expect no extra copy on return
.check("return (buf0, )")
.run(code)
)
out = compiled(arg)
correct = func(arg)
assert same(out, correct), f"{out} va {correct}"
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
def test_inductor_all_gather_into_tensor_coalesced(self):
torch._inductor.config.debug = True
self._init_process_group()
def func(args: List[torch.Tensor]) -> torch.Tensor:
ag0 = torch.ops._c10d_functional.all_gather_into_tensor_coalesced(
args, self.world_size, "default"
)
ag0 = [torch.ops._c10d_functional.wait_tensor(out) for out in ag0]
return ag0
args = [torch.rand(4, 4, device=self.device) for _ in range(4)]
compiled = torch.compile(func)
code = run_and_get_triton_code(compiled, args)
print(code)
(
FileCheck()
.check(
"buf0 = torch.ops._c10d_functional.all_gather_into_tensor_coalesced"
".default([arg0_1, arg1_1, arg2_1, arg3_1]"
)
.check("buf1 = buf0[0]")
.check("buf2 = buf0[1]")
.check("buf3 = buf0[2]")
.check("buf4 = buf0[3]")
.check("torch.ops._c10d_functional.wait_tensor.default(buf1")
.check("torch.ops._c10d_functional.wait_tensor.default(buf2")
.check("torch.ops._c10d_functional.wait_tensor.default(buf3")
.check("torch.ops._c10d_functional.wait_tensor.default(buf4")
# Expect no extra copy on return
.check("return (buf1, buf2, buf3, buf4, )")
.run(code)
)
out = compiled(args)
correct = func(args)
assert same(out, correct), f"{out} va {correct}"
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
def test_inductor_reduce_scatter_tensor_single(self):
torch._inductor.config.debug = True
self._init_process_group()
def func(arg: torch.Tensor) -> torch.Tensor:
rs0 = torch.ops._c10d_functional.reduce_scatter_tensor(
arg, "avg", self.world_size, "default"
)
rs0 = torch.ops._c10d_functional.wait_tensor(rs0)
return rs0
arg = torch.rand(4, 4, device=self.device)
compiled = torch.compile(func)
code = run_and_get_triton_code(compiled, arg)
(
FileCheck()
.check(
"buf0 = torch.ops._c10d_functional.reduce_scatter_tensor.default(arg0_1"
)
.check("torch.ops._c10d_functional.wait_tensor.default(buf0")
# Expect no extra copy on return
.check("return (buf0, )")
.run(code)
)
out = compiled(arg)
correct = func(arg)
assert same(out, correct), f"{out} va {correct}"
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
def test_inductor_reduce_scatter_tensor_coalesced(self):
torch._inductor.config.debug = True
self._init_process_group()
def func(args: List[torch.Tensor]) -> torch.Tensor:
rs0 = torch.ops._c10d_functional.reduce_scatter_tensor_coalesced(
args, "avg", self.world_size, "default"
)
rs0 = [torch.ops._c10d_functional.wait_tensor(out) for out in rs0]
return rs0
args = [torch.rand(4, 4, device=self.device) for _ in range(4)]
compiled = torch.compile(func)
code = run_and_get_triton_code(compiled, args)
(
FileCheck()
.check(
"buf0 = torch.ops._c10d_functional.reduce_scatter_tensor_coalesced"
".default([arg0_1, arg1_1, arg2_1, arg3_1]"
)
.check("buf1 = buf0[0]")
.check("buf2 = buf0[1]")
.check("buf3 = buf0[2]")
.check("buf4 = buf0[3]")
.check("torch.ops._c10d_functional.wait_tensor.default(buf1")
.check("torch.ops._c10d_functional.wait_tensor.default(buf2")
.check("torch.ops._c10d_functional.wait_tensor.default(buf3")
.check("torch.ops._c10d_functional.wait_tensor.default(buf4")
# Expect no extra copy on return
.check("return (buf1, buf2, buf3, buf4, )")
.run(code)
)
out = compiled(args)
correct = func(args)
assert same(out, correct), f"{out} va {correct}"
if __name__ == "__main__":
run_tests()

View File

@ -35,4 +35,4 @@ inputs and outputs have any aliasing, it suffices to check whether the
storages of the input and the storages of the output have any overlap. See
`remove_noop_ops` for an example of how to do this.
Additionally, we do have one pass that *does* introduce mutation - `reinplace_scatters`. This pass must run *just before Inductor lowering*, as otherwise this breaks our invariant.
Additionally, we do have one pass that *does* introduce mutation - `reinplace_inplaceable_ops`. This pass must run *just before Inductor lowering*, as otherwise this breaks our invariant.

View File

@ -93,7 +93,7 @@ def post_grad_passes(gm: torch.fx.GraphModule, is_inference: bool):
# Keep this last, since it introduces mutation. Look at
# ./fx_passes/README.md for a discussion of mutation invariants.
reinplace_scatters(gm.graph)
reinplace_inplaceable_ops(gm.graph)
gm.recompile()
gm.graph.lint()
@ -625,9 +625,9 @@ def remove_noop_ops(graph: torch.fx.Graph):
InplaceableOp = namedtuple("InplaceableOp", ["inplace_op", "mutated_arg"])
def reinplace_scatters(graph):
def reinplace_inplaceable_ops(graph):
"""
Reinplaces scatter operations.
Reinplaces in-placeable operations.
If there are no uses of a view of the mutated arg after the current node,
it is possible to inplace the op.
This above algorithm could be justified by observing side effects. While
@ -676,6 +676,9 @@ def reinplace_scatters(graph):
return False
def can_inplace(node, mutated_arg):
if isinstance(mutated_arg, (list, tuple)):
return all(can_inplace(node, arg) for arg in mutated_arg)
if get_node_storage(mutated_arg) is None:
return False
shared_view_nodes = storage_to_nodes[get_node_storage(mutated_arg)]
@ -690,7 +693,6 @@ def reinplace_scatters(graph):
):
return False
graph.erase_node(copy_node)
return True
elif any(view.op == "placeholder" for view in shared_view_nodes):
# If mutated arg is view of any of the inputs of the graph,
@ -708,11 +710,35 @@ def reinplace_scatters(graph):
inductor_prims._unsafe_index_put_, 0
),
}
try:
c10d_functional = torch.ops._c10d_functional
inplaceable_collective_ops = {
c10d_functional.all_reduce.default: InplaceableOp(
c10d_functional.all_reduce_.default, 0
),
c10d_functional.all_reduce_coalesced.default: InplaceableOp(
c10d_functional.all_reduce_coalesced_.default, 0
),
}
inplaceable_ops.update(inplaceable_collective_ops)
except AttributeError:
# _c10d_functional ops are only available when torch
# is built with USE_DISTRIBUTED=1.
pass
inplaceable_triton_ops = {triton_kernel_wrapper_functional}
for node in graph.nodes:
if (inplaceable_op := inplaceable_ops.get(node.target, None)) is not None:
if can_inplace(node, node.args[inplaceable_op.mutated_arg]):
mutated_arg = node.args[inplaceable_op.mutated_arg]
if can_inplace(node, mutated_arg):
# TODO(yifu): this doesn't properly remove copy epilogues for
# ops that mutate multiple inputs. Need to revise the copy
# node tracking logic to support the case.
copy_node = copy_args_to_copy_nodes.get((mutated_arg, node))
if copy_node is not None:
graph.erase_node(copy_node)
node.target = inplaceable_op.inplace_op
elif node.target in inplaceable_triton_ops:
# inplaceable_triton_ops take an additional argument called
@ -722,7 +748,12 @@ def reinplace_scatters(graph):
tensors_to_clone = []
for arg in node.kwargs["tensors_to_clone"]:
assert arg in node.kwargs["kwargs"]
if not can_inplace(node, node.kwargs["kwargs"][arg]):
mutated_arg = node.kwargs["kwargs"][arg]
if can_inplace(node, mutated_arg):
copy_node = copy_args_to_copy_nodes.get((mutated_arg, node))
if copy_node is not None:
graph.erase_node(copy_node)
else:
tensors_to_clone.append(arg)
kwargs = dict(node.kwargs)
kwargs["tensors_to_clone"] = tensors_to_clone

View File

@ -2693,7 +2693,7 @@ class Buffer(IRNode):
# dynamic; it has i0 as one of the arguments. We cannot tell this
# directly from MultiOutput, we have to look at the input buffer's
# uses to work this out. No big deal.
if isinstance(self.layout, MultiOutputLayout):
if isinstance(self.layout, (NoneLayout, MultiOutputLayout)):
return set()
# This kernel defines all unbacked symbols... that it didn't get in as
@ -4496,6 +4496,15 @@ class FallbackKernel(ExternKernelAlloc):
else:
super().codegen(wrapper)
@staticmethod
def tensor_to_layout(output: torch.Tensor):
return FixedLayout(
output.device,
output.dtype,
convert_shape_to_inductor(output.size()),
convert_shape_to_inductor(output.stride()),
)
@classmethod
def create(cls, kernel, *args, **kwargs):
fake_incorrect_kernels = (aten._fused_moving_avg_obs_fq_helper_functional,)
@ -4511,18 +4520,10 @@ class FallbackKernel(ExternKernelAlloc):
schema,
) = cls.process_kernel(kernel, *args, **kwargs)
device = FallbackKernel.find_device(tensor_args, example_output)
device = cls.find_device(tensor_args, example_output)
assert device, "Not sure where to find device info"
def tensor_to_layout(output: torch.Tensor):
return FixedLayout(
output.device,
output.dtype,
convert_shape_to_inductor(output.size()),
convert_shape_to_inductor(output.stride()),
)
packed = FallbackKernel(
packed = cls(
MultiOutputLayout(device),
kernel,
tensor_args,
@ -4544,7 +4545,7 @@ class FallbackKernel(ExternKernelAlloc):
}
elif isinstance(output, torch.Tensor):
return MultiOutput(
tensor_to_layout(output),
cls.tensor_to_layout(output),
packed,
indices,
)
@ -6930,6 +6931,176 @@ class ReduceScatterTensorCoalesced(OutOfPlaceCollectiveKernel):
)
# TODO(yifu): replace the CollectiveKernel IR hierarchy with _CollectiveKernel.
class _CollectiveKernel(FallbackKernel):
def should_allocate(self):
return False
def has_side_effects(self):
return True
# This is identical to FallbackKernel.set_cpp_kernel(), minus the
# part that checks against input aliasing and mutation.
def set_cpp_kernel(self, kernel):
from .codegen.wrapper import get_cpp_op_schema
self.kernel = kernel._schema.name
self.cpp_kernel_overlad_name = kernel._schema.overload_name
self.cpp_kernel_key = (
f"{self.kernel.replace('::', '_')}_{self.cpp_kernel_overlad_name}"
)
self.cpp_op_schema = get_cpp_op_schema(kernel)
self.ordered_kwargs_for_cpp_kernel = [
x.name for x in kernel._schema.arguments if x.kwarg_only
]
# NOTE: [In-Place Collective Safety]
# Between the initiation and completion of an in-place collective, the
# input buffers are subject to both volatile reads and volatile writes.
# They must not be read, written to or reused by another kernel. To ensure
# the constraints, we model collective -> wait_tensor as as two-step
# mutation of the input buffers.
@classmethod
def create_inplace(
cls, kernel, inputs: Union[TensorBox, List[TensorBox]], *args, **kwargs
) -> None:
with V.graph.fake_mode:
(
example_output,
tensor_args,
non_tensor_args,
unflatten_args,
schema,
) = cls.process_kernel(kernel, inputs, *args, **kwargs)
for tensor_arg in tensor_args:
tensor_arg.realize()
packed = cls(
NoneLayout(tensor_args[0].get_device()),
kernel,
tensor_args,
non_tensor_args,
unflatten_args,
schema=schema,
)
pytree.tree_map(lambda x: MutationOutput(x.layout, x, packed), inputs)
# NOTE: [Out-of-Place Collective Safety]
# Between the initiation and completion of an out-of-place collective:
#
# Input buffers:
# - Are subject to volatile reads
# - Can be read by another kernel
# - Must not be written to or reused by another kernel
#
# Output buffers:
# - Are subject to volatile writes
# - Must not be read, written to or reused by another kernel
#
# To ensure the safety of input buffers without sacrificing read
# availability, we add input buffers as read deps of wait_tensor kernels.
#
# To ensure the safety of output buffers, we model wait_tensor as a
# mutation to the output buffer. Note we also assumes the user program being
# correct and the output buffer is not consumed by kernels other than
# wait_tensor.
#
# TODO(yifu): add a pre-grad pass to validate the correctness of collective
# usage in the user program.
@classmethod
def create_out_of_place(
cls, kernel, inputs: Union[TensorBox, List[TensorBox]], *args, **kwargs
):
with V.graph.fake_mode:
(
example_output,
tensor_args,
non_tensor_args,
unflatten_args,
schema,
) = cls.process_kernel(kernel, inputs, *args, **kwargs)
for tensor_arg in tensor_args:
tensor_arg.realize()
if isinstance(example_output, list):
device = cls.find_device(tensor_args, example_output)
packed = cls(
MultiOutputLayout(device),
kernel,
tensor_args,
non_tensor_args,
unflatten_args,
schema=schema,
)
packed.outputs = [
MultiOutput(
cls.tensor_to_layout(tensor),
packed,
[(list, i)],
)
for i, tensor in enumerate(example_output)
]
return packed.outputs
else:
packed = cls(
cls.tensor_to_layout(example_output),
kernel,
tensor_args,
non_tensor_args,
unflatten_args,
schema=schema,
)
packed.outputs = [packed]
return packed
class _WaitKernel(_CollectiveKernel):
def get_volatile_reads(self):
inp = self.inputs[0]
if isinstance(inp, _CollectiveKernel):
# Out-of-place single-output
return [inp.inputs[0]]
elif isinstance(inp, MultiOutput):
# Out-of-place multi-output
coll = inp.inputs[0]
assert isinstance(coll, _CollectiveKernel)
_, idx = inp.indices[0]
return [coll.inputs[idx]]
else:
# In-place requires no additional deps handling for volatile
# reads since the inputs are mutated.
return []
@classmethod
def create_wait(cls, kernel, inp: TensorBox) -> None:
with V.graph.fake_mode:
(
example_output,
tensor_args,
non_tensor_args,
unflatten_args,
schema,
) = cls.process_kernel(kernel, inp)
packed = cls(
NoneLayout(inp.get_device()),
kernel,
tensor_args,
non_tensor_args,
unflatten_args,
schema=schema,
)
MutationOutput(inp.layout, inp, packed)
def get_read_writes(self):
read_writes = super().get_read_writes()
# See [Out-of-Place Collective Safety].
volatile_reads = self.get_volatile_reads()
for vr in volatile_reads:
read_writes.reads.add(dependencies.StarDep(vr.get_name()))
return read_writes
# NB: recursive structure here reflects val_to_arg_str, avoid
# calling free_unbacked_symbols on "exotic" types that don't get pexpr
# treatment

View File

@ -5132,6 +5132,97 @@ try:
)
)
_c10d_functional = torch.ops._c10d_functional
@register_lowering(_c10d_functional.all_reduce)
def _all_reduce(inp, reduce_op, group_name):
inp = clone(inp)
ir._CollectiveKernel.create_inplace(
_c10d_functional.all_reduce_.default, inp, reduce_op, group_name
)
return inp
@register_lowering(_c10d_functional.all_reduce_)
def _all_reduce_(inp, reduce_op, group_name):
ir._CollectiveKernel.create_inplace(
_c10d_functional.all_reduce_.default, inp, reduce_op, group_name
)
return inp
@register_lowering(_c10d_functional.all_reduce_coalesced)
def _all_reduce_coalesced(inputs, reduce_op, group_name):
inputs = [clone(inp) for inp in inputs]
ir._CollectiveKernel.create_inplace(
_c10d_functional.all_reduce_coalesced_.default,
inputs,
reduce_op,
group_name,
)
return inputs
@register_lowering(_c10d_functional.all_reduce_coalesced_)
def _all_reduce_coalesced_(inputs, reduce_op, group_name):
ir._CollectiveKernel.create_inplace(
_c10d_functional.all_reduce_coalesced_.default,
inputs,
reduce_op,
group_name,
)
return inputs
@register_lowering(_c10d_functional.all_gather_into_tensor)
def _all_gather_into_tensor(inp, group_size, group_name):
return ir.TensorBox.create(
ir._CollectiveKernel.create_out_of_place(
_c10d_functional.all_gather_into_tensor.default,
inp,
group_size,
group_name,
)
)
@register_lowering(_c10d_functional.all_gather_into_tensor_coalesced)
def _all_gather_into_tensor_coalesced(inputs, group_size, group_name):
return pytree.tree_map(
ir.TensorBox.create,
ir._CollectiveKernel.create_out_of_place(
_c10d_functional.all_gather_into_tensor_coalesced.default,
inputs,
group_size,
group_name,
),
)
@register_lowering(_c10d_functional.reduce_scatter_tensor)
def _reduce_scatter_tensor(inp, reduce_op, group_size, group_name):
return ir.TensorBox.create(
ir._CollectiveKernel.create_out_of_place(
_c10d_functional.reduce_scatter_tensor.default,
inp,
reduce_op,
group_size,
group_name,
)
)
@register_lowering(_c10d_functional.reduce_scatter_tensor_coalesced)
def _reduce_scatter_tensor_coalesced(inputs, reduce_op, group_size, group_name):
return pytree.tree_map(
ir.TensorBox.create,
ir._CollectiveKernel.create_out_of_place(
_c10d_functional.reduce_scatter_tensor_coalesced.default,
inputs,
reduce_op,
group_size,
group_name,
),
)
@register_lowering(_c10d_functional.wait_tensor)
def _wait_tensor(inp):
ir._WaitKernel.create_wait(_c10d_functional.wait_tensor.default, inp)
return inp
except ImportError:
log.info(
"Inductor support for distributed collectives depends on building torch.distributed"

View File

@ -542,6 +542,12 @@ def _reduce_scatter_tensor_meta(input, reduce_op, tag, rankset, group_size):
def _all_reduce_coalesced_meta(self, *args):
return [torch.empty_like(t) for t in self]
def _all_reduce__meta(inp, *args):
return inp
def _all_reduce_coalesced__meta(inputs, *args):
return inputs
def _reduce_scatter_tensor_coalesced_meta(inputs, reduceOp, tag, rankset, group_size):
def mk_out_tensor(input):
out_size = list(input.size())
@ -577,15 +583,15 @@ def _all_gather_into_tensor_coalesced_native_meta(inputs, group_size, group_name
for input in inputs
]
def _reduce_scatter_tensor_native_meta(input, group_size, group_name):
shape = list(input.size())
def _reduce_scatter_tensor_native_meta(inp, reduce_op, group_size, group_name):
shape = list(inp.size())
shape[0] //= group_size
return input.new_empty(shape)
return inp.new_empty(shape)
def _reduce_scatter_tensor_coalesced_native_meta(inputs, group_size, group_name):
def _reduce_scatter_tensor_coalesced_native_meta(inputs, reduce_op, group_size, group_name):
return [
_reduce_scatter_tensor_native_meta(input, group_size, group_name)
for input in inputs
_reduce_scatter_tensor_native_meta(inp, reduce_op, group_size, group_name)
for inp in inputs
]
def _register_ops():
@ -620,7 +626,9 @@ if not torch._running_with_deploy():
_c10_lib_impl = torch.library.Library("_c10d_functional", "IMPL")
_c10_lib_impl.impl("all_reduce", _all_reduce_meta, "Meta")
_c10_lib_impl.impl("all_reduce_", _all_reduce__meta, "Meta")
_c10_lib_impl.impl("all_reduce_coalesced", _all_reduce_coalesced_meta, "Meta")
_c10_lib_impl.impl("all_reduce_coalesced_", _all_reduce_coalesced__meta, "Meta")
_c10_lib_impl.impl("wait_tensor", _wait_tensor_meta, "Meta")
_c10_lib_impl.impl("all_gather_into_tensor", _all_gather_into_tensor_native_meta, "Meta")
_c10_lib_impl.impl("all_gather_into_tensor_coalesced", _all_gather_into_tensor_coalesced_native_meta, "Meta")