mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[cpp wrapper] add AOTI shim for collective ops (#154492)
Implementations: 1. Move collective ops to c10d namespace, so that we can call them externally. 2. Add AOTI shims for collective ops. Testing 1. Add c10d functional UT for cpu. 2. Include the above one in cpp wrapper UT. Pull Request resolved: https://github.com/pytorch/pytorch/pull/154492 Approved by: https://github.com/desertfire
This commit is contained in:
committed by
PyTorch MergeBot
parent
d797038ea9
commit
02c7ab2f9b
@ -9,7 +9,7 @@ import torch
|
||||
import torch.distributed as dist
|
||||
import torch.distributed._functional_collectives as funcol
|
||||
from torch._C import FileCheck
|
||||
from torch._inductor.utils import fresh_cache, run_and_get_triton_code
|
||||
from torch._inductor.utils import fresh_cache, run_and_get_code, run_and_get_triton_code
|
||||
from torch.distributed._functional_collectives import (
|
||||
all_gather_into_tensor_coalesced,
|
||||
all_gather_tensor,
|
||||
@ -713,6 +713,61 @@ class PyWorkTest(TestCase):
|
||||
self.assertEqual(pg.dels, 4)
|
||||
|
||||
|
||||
class CompileTestCPU(TestCase):
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
|
||||
if not dist.is_initialized():
|
||||
self.rank = 0
|
||||
self.world_size = 2
|
||||
|
||||
store = FakeStore()
|
||||
dist.init_process_group(
|
||||
backend="fake",
|
||||
world_size=self.world_size,
|
||||
rank=self.rank,
|
||||
store=store,
|
||||
)
|
||||
|
||||
def tearDown(self):
|
||||
dist.destroy_process_group()
|
||||
|
||||
@fresh_cache()
|
||||
def _test_inductor_all_reduce_cpu(self, cpp_wrapper=False):
|
||||
def func(arg: torch.Tensor) -> torch.Tensor:
|
||||
buf0 = arg + 42
|
||||
ar0 = funcol.all_reduce(buf0, "avg", "0")
|
||||
ar0 = funcol.wait_tensor(ar0)
|
||||
return ar0
|
||||
|
||||
arg = torch.rand(4, 4, device="cpu")
|
||||
torch._inductor.config.cpp_wrapper = cpp_wrapper
|
||||
compiled = torch.compile(func)
|
||||
|
||||
_, (code,) = run_and_get_code(compiled, arg)
|
||||
include_ops = (
|
||||
[
|
||||
"aoti_torch_cpu__c10d_functional_all_reduce_",
|
||||
"aoti_torch_cpu__c10d_functional_wait_tensor",
|
||||
]
|
||||
if cpp_wrapper
|
||||
else [
|
||||
"torch.ops._c10d_functional.all_reduce_.default",
|
||||
"torch.ops._c10d_functional.wait_tensor.default",
|
||||
]
|
||||
)
|
||||
for op in include_ops:
|
||||
self.assertIn(op, code)
|
||||
|
||||
# Test aoti
|
||||
AOTIRunnerUtil.run(func, (arg,))
|
||||
torch.cpu.synchronize()
|
||||
|
||||
def test_inductor_all_reduce_cpu(self):
|
||||
self._test_inductor_all_reduce_cpu(cpp_wrapper=False)
|
||||
self._test_inductor_all_reduce_cpu(cpp_wrapper=True)
|
||||
|
||||
|
||||
class CompileTest(TestCase):
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
|
@ -67,7 +67,7 @@ def get_collective_input_size_bytes(node: ir.IRNode) -> int:
|
||||
|
||||
|
||||
def get_collective_group_size(node: ir.IRNode) -> int:
|
||||
if type(node) == ir._CollectiveKernel:
|
||||
if isinstance(node, ir._CollectiveKernel) and not isinstance(node, ir._WaitKernel):
|
||||
from torch.distributed.distributed_c10d import _get_group_size_by_name
|
||||
|
||||
return _get_group_size_by_name(node.constant_args[-1])
|
||||
|
@ -208,7 +208,7 @@ def register_comm_lowerings():
|
||||
inp.realize()
|
||||
V.graph.no_fuse_buffer_names.add(inp.get_name())
|
||||
inp = ir.ExternKernel.require_contiguous(inp)
|
||||
ir._CollectiveKernel.create_inplace(
|
||||
ir._AllReduceKernel.create_inplace(
|
||||
c10d.all_reduce_.default, inp, reduce_op, group_name
|
||||
)
|
||||
return inp
|
||||
@ -227,7 +227,7 @@ def register_comm_lowerings():
|
||||
|
||||
# Lower as c10d.all_reduce_
|
||||
inp = ir.ExternKernel.require_contiguous(inp)
|
||||
ir._CollectiveKernel.create_inplace(
|
||||
ir._AllReduce_Kernel.create_inplace(
|
||||
c10d.all_reduce_.default, inp, reduce_op, group_name
|
||||
)
|
||||
return inp
|
||||
|
@ -8235,7 +8235,10 @@ class _CollectiveKernel(FallbackKernel):
|
||||
"Setting cpp kernel needs a valid op_overload"
|
||||
)
|
||||
kernel = self.op_overload
|
||||
self.cpp_kernel_name = kernel._schema.name
|
||||
if cpp_kernel_name is not None:
|
||||
self.cpp_kernel_name = cpp_kernel_name
|
||||
else:
|
||||
self.cpp_kernel_name = kernel._schema.name
|
||||
|
||||
self.ordered_kwargs_for_cpp_kernel = [
|
||||
x.name for x in kernel._schema.arguments if x.kwarg_only
|
||||
@ -8363,7 +8366,98 @@ class _CollectiveKernel(FallbackKernel):
|
||||
return packed
|
||||
|
||||
|
||||
class _AllReduce_Kernel(_CollectiveKernel):
|
||||
def __init__( # type: ignore[no-untyped-def]
|
||||
self,
|
||||
layout,
|
||||
kernel,
|
||||
tensor_args,
|
||||
nontensor_args,
|
||||
unflatten_args,
|
||||
kwargs=None,
|
||||
*,
|
||||
unbacked_bindings=None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
layout,
|
||||
kernel,
|
||||
tensor_args,
|
||||
nontensor_args,
|
||||
unflatten_args,
|
||||
kwargs=None,
|
||||
unbacked_bindings=unbacked_bindings,
|
||||
)
|
||||
self.set_cpp_kernel_name("aoti_torch_cpu__c10d_functional_all_reduce_")
|
||||
|
||||
def codegen(self, wrapper) -> None: # type: ignore[no-untyped-def]
|
||||
wrapper.include_extra_header("torch/csrc/inductor/aoti_torch/c/shim_cpu.h")
|
||||
wrapper.generate_extern_kernel_alloc(self)
|
||||
|
||||
if isinstance(self.layout, Layout):
|
||||
self.codegen_size_asserts(wrapper)
|
||||
|
||||
|
||||
class _AllReduceKernel(_CollectiveKernel):
|
||||
def __init__( # type: ignore[no-untyped-def]
|
||||
self,
|
||||
layout,
|
||||
kernel,
|
||||
tensor_args,
|
||||
nontensor_args,
|
||||
unflatten_args,
|
||||
kwargs=None,
|
||||
*,
|
||||
unbacked_bindings=None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
layout,
|
||||
kernel,
|
||||
tensor_args,
|
||||
nontensor_args,
|
||||
unflatten_args,
|
||||
kwargs=None,
|
||||
unbacked_bindings=unbacked_bindings,
|
||||
)
|
||||
self.set_cpp_kernel_name("aoti_torch_cpu__c10d_functional_all_reduce")
|
||||
|
||||
def codegen(self, wrapper) -> None: # type: ignore[no-untyped-def]
|
||||
wrapper.include_extra_header("torch/csrc/inductor/aoti_torch/c/shim_cpu.h")
|
||||
wrapper.generate_extern_kernel_alloc(self)
|
||||
|
||||
if isinstance(self.layout, Layout):
|
||||
self.codegen_size_asserts(wrapper)
|
||||
|
||||
|
||||
class _WaitKernel(_CollectiveKernel):
|
||||
def __init__( # type: ignore[no-untyped-def]
|
||||
self,
|
||||
layout,
|
||||
kernel,
|
||||
tensor_args,
|
||||
nontensor_args,
|
||||
unflatten_args,
|
||||
kwargs=None,
|
||||
*,
|
||||
unbacked_bindings=None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
layout,
|
||||
kernel,
|
||||
tensor_args,
|
||||
nontensor_args,
|
||||
unflatten_args,
|
||||
kwargs=None,
|
||||
unbacked_bindings=unbacked_bindings,
|
||||
)
|
||||
self.set_cpp_kernel_name("aoti_torch_cpu__c10d_functional_wait_tensor")
|
||||
|
||||
def codegen(self, wrapper) -> None: # type: ignore[no-untyped-def]
|
||||
wrapper.include_extra_header("torch/csrc/inductor/aoti_torch/c/shim_cpu.h")
|
||||
wrapper.generate_extern_kernel_alloc(self)
|
||||
|
||||
if isinstance(self.layout, Layout):
|
||||
self.codegen_size_asserts(wrapper)
|
||||
|
||||
def get_volatile_reads(self): # type: ignore[no-untyped-def]
|
||||
inp = self.inputs[0]
|
||||
if isinstance(inp, _CollectiveKernel):
|
||||
|
@ -2323,7 +2323,9 @@ def is_collective(
|
||||
from . import ir
|
||||
|
||||
return (
|
||||
type(node) == ir._CollectiveKernel and (op is None or node.op_overload is op)
|
||||
isinstance(node, ir._CollectiveKernel)
|
||||
and not isinstance(node, ir._WaitKernel)
|
||||
and (op is None or node.op_overload is op)
|
||||
) or (
|
||||
# TODO: this is a temporary solution to ensure that we can identify torchrec's
|
||||
# communication ops. But in order to allow better communication and computation
|
||||
|
@ -30,6 +30,37 @@ c10d::ReduceOp to_reduce_op(const std::string& reduce_op) {
|
||||
return it->second;
|
||||
}
|
||||
|
||||
at::Tensor allocate_all_gather_output(
|
||||
const at::Tensor& input,
|
||||
int64_t group_size) {
|
||||
TORCH_CHECK(input.is_contiguous());
|
||||
auto output_size = input.sizes().vec();
|
||||
output_size[0] *= group_size;
|
||||
return at::empty(
|
||||
output_size,
|
||||
at::TensorOptions().dtype(input.dtype()).device(input.device()));
|
||||
}
|
||||
|
||||
at::Tensor allocate_reduce_scatter_output(
|
||||
const at::Tensor& input,
|
||||
const int64_t group_size) {
|
||||
TORCH_CHECK(input.is_contiguous());
|
||||
auto output_size = input.sizes().vec();
|
||||
if (output_size[0] % group_size != 0) {
|
||||
LOG(WARNING) << "The first dimension of the reduce_scatter input ("
|
||||
<< output_size[0] << ") is not divisible by the group size ("
|
||||
<< group_size << ").";
|
||||
}
|
||||
output_size[0] /= group_size;
|
||||
return at::empty(
|
||||
output_size,
|
||||
at::TensorOptions().dtype(input.dtype()).device(input.device()));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
namespace c10d {
|
||||
|
||||
at::Tensor& all_reduce_(
|
||||
at::Tensor& input,
|
||||
// NOLINTNEXTLINE(performance-unnecessary-value-param)
|
||||
@ -85,17 +116,6 @@ std::vector<at::Tensor> all_reduce_coalesced(
|
||||
outputs, std::move(reduce_op), std::move(group_name));
|
||||
}
|
||||
|
||||
at::Tensor allocate_all_gather_output(
|
||||
const at::Tensor& input,
|
||||
int64_t group_size) {
|
||||
TORCH_CHECK(input.is_contiguous());
|
||||
auto output_size = input.sizes().vec();
|
||||
output_size[0] *= group_size;
|
||||
return at::empty(
|
||||
output_size,
|
||||
at::TensorOptions().dtype(input.dtype()).device(input.device()));
|
||||
}
|
||||
|
||||
std::vector<at::Tensor> all_gather_into_tensor_coalesced(
|
||||
std::vector<at::Tensor> inputs,
|
||||
int64_t group_size,
|
||||
@ -140,22 +160,6 @@ at::Tensor& all_gather_into_tensor_out(
|
||||
return output;
|
||||
}
|
||||
|
||||
at::Tensor allocate_reduce_scatter_output(
|
||||
const at::Tensor& input,
|
||||
const int64_t group_size) {
|
||||
TORCH_CHECK(input.is_contiguous());
|
||||
auto output_size = input.sizes().vec();
|
||||
if (output_size[0] % group_size != 0) {
|
||||
LOG(WARNING) << "The first dimension of the reduce_scatter input ("
|
||||
<< output_size[0] << ") is not divisible by the group size ("
|
||||
<< group_size << ").";
|
||||
}
|
||||
output_size[0] /= group_size;
|
||||
return at::empty(
|
||||
output_size,
|
||||
at::TensorOptions().dtype(input.dtype()).device(input.device()));
|
||||
}
|
||||
|
||||
std::vector<at::Tensor> reduce_scatter_tensor_coalesced(
|
||||
std::vector<at::Tensor> inputs,
|
||||
// NOLINTNEXTLINE(performance-unnecessary-value-param)
|
||||
@ -234,65 +238,68 @@ at::Tensor broadcast(
|
||||
return broadcast_(output, src, std::move(group_name));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace c10d
|
||||
|
||||
TORCH_LIBRARY(_c10d_functional, m) {
|
||||
m.def(
|
||||
"all_reduce(Tensor input, str reduce_op, str group_name) -> Tensor",
|
||||
torch::dispatch(
|
||||
c10::DispatchKey::CompositeExplicitAutograd, ::all_reduce),
|
||||
c10::DispatchKey::CompositeExplicitAutograd, c10d::all_reduce),
|
||||
{at::Tag::pt2_compliant_tag});
|
||||
|
||||
m.def(
|
||||
"all_reduce_(Tensor(a!) input, str reduce_op, str group_name) -> Tensor(a!)",
|
||||
torch::dispatch(
|
||||
c10::DispatchKey::CompositeExplicitAutograd, ::all_reduce_),
|
||||
c10::DispatchKey::CompositeExplicitAutograd, c10d::all_reduce_),
|
||||
{at::Tag::pt2_compliant_tag});
|
||||
|
||||
m.def(
|
||||
"all_reduce_coalesced(Tensor[] inputs, str reduce_op, str group_name) -> Tensor[]",
|
||||
torch::dispatch(
|
||||
c10::DispatchKey::CompositeExplicitAutograd, ::all_reduce_coalesced),
|
||||
c10::DispatchKey::CompositeExplicitAutograd,
|
||||
c10d::all_reduce_coalesced),
|
||||
{at::Tag::pt2_compliant_tag});
|
||||
|
||||
m.def(
|
||||
"all_reduce_coalesced_(Tensor[](a!) inputs, str reduce_op, str group_name) -> Tensor[](a!)",
|
||||
torch::dispatch(
|
||||
c10::DispatchKey::CompositeExplicitAutograd, ::all_reduce_coalesced_),
|
||||
c10::DispatchKey::CompositeExplicitAutograd,
|
||||
c10d::all_reduce_coalesced_),
|
||||
{at::Tag::pt2_compliant_tag});
|
||||
|
||||
m.def(
|
||||
"all_gather_into_tensor_out(Tensor input, int group_size, str group_name, *, Tensor(a!) out) -> Tensor(a!)",
|
||||
torch::dispatch(
|
||||
c10::DispatchKey::CompositeExplicitAutograd,
|
||||
::all_gather_into_tensor_out),
|
||||
c10d::all_gather_into_tensor_out),
|
||||
{at::Tag::pt2_compliant_tag, at::Tag::needs_contiguous_strides});
|
||||
|
||||
m.def(
|
||||
"all_gather_into_tensor(Tensor input, int group_size, str group_name) -> Tensor",
|
||||
torch::dispatch(
|
||||
c10::DispatchKey::CompositeExplicitAutograd,
|
||||
::all_gather_into_tensor),
|
||||
c10d::all_gather_into_tensor),
|
||||
{at::Tag::pt2_compliant_tag, at::Tag::needs_contiguous_strides});
|
||||
|
||||
m.def(
|
||||
"all_gather_into_tensor_coalesced(Tensor[] inputs, int group_size, str group_name) -> Tensor[]",
|
||||
torch::dispatch(
|
||||
c10::DispatchKey::CompositeExplicitAutograd,
|
||||
::all_gather_into_tensor_coalesced),
|
||||
c10d::all_gather_into_tensor_coalesced),
|
||||
{at::Tag::pt2_compliant_tag, at::Tag::needs_contiguous_strides});
|
||||
|
||||
m.def(
|
||||
"reduce_scatter_tensor(Tensor input, str reduce_op, int group_size, str group_name) -> Tensor",
|
||||
torch::dispatch(
|
||||
c10::DispatchKey::CompositeExplicitAutograd, ::reduce_scatter_tensor),
|
||||
c10::DispatchKey::CompositeExplicitAutograd,
|
||||
c10d::reduce_scatter_tensor),
|
||||
{at::Tag::pt2_compliant_tag, at::Tag::needs_contiguous_strides});
|
||||
|
||||
m.def(
|
||||
"reduce_scatter_tensor_coalesced(Tensor[] inputs, str reduce_op, int group_size, str group_name) -> Tensor[]",
|
||||
torch::dispatch(
|
||||
c10::DispatchKey::CompositeExplicitAutograd,
|
||||
::reduce_scatter_tensor_coalesced),
|
||||
c10d::reduce_scatter_tensor_coalesced),
|
||||
{at::Tag::pt2_compliant_tag, at::Tag::needs_contiguous_strides});
|
||||
|
||||
m.def(
|
||||
@ -302,18 +309,19 @@ TORCH_LIBRARY(_c10d_functional, m) {
|
||||
"SymInt[] input_split_sizes, "
|
||||
"str group_name) -> Tensor",
|
||||
torch::dispatch(
|
||||
c10::DispatchKey::CompositeExplicitAutograd, ::all_to_all_single),
|
||||
c10::DispatchKey::CompositeExplicitAutograd, c10d::all_to_all_single),
|
||||
{at::Tag::pt2_compliant_tag, at::Tag::needs_contiguous_strides});
|
||||
|
||||
m.def(
|
||||
"broadcast(Tensor input, int src, str group_name) -> Tensor",
|
||||
torch::dispatch(c10::DispatchKey::CompositeExplicitAutograd, ::broadcast),
|
||||
torch::dispatch(
|
||||
c10::DispatchKey::CompositeExplicitAutograd, c10d::broadcast),
|
||||
{at::Tag::pt2_compliant_tag});
|
||||
|
||||
m.def(
|
||||
"broadcast_(Tensor(a!) input, int src, str group_name) -> Tensor(a!)",
|
||||
torch::dispatch(
|
||||
c10::DispatchKey::CompositeExplicitAutograd, ::broadcast_),
|
||||
c10::DispatchKey::CompositeExplicitAutograd, c10d::broadcast_),
|
||||
{at::Tag::pt2_compliant_tag});
|
||||
|
||||
m.def(
|
||||
@ -342,7 +350,7 @@ class AllToAllSingle : public torch::autograd::Function<AllToAllSingle> {
|
||||
|
||||
return c10::Dispatcher::singleton()
|
||||
.findSchemaOrThrow("_c10d_functional::all_to_all_single", "")
|
||||
.typed<decltype(all_to_all_single)>()
|
||||
.typed<decltype(c10d::all_to_all_single)>()
|
||||
.call(input, output_split_sizes, input_split_sizes, group_name);
|
||||
}
|
||||
|
||||
@ -361,7 +369,7 @@ class AllToAllSingle : public torch::autograd::Function<AllToAllSingle> {
|
||||
auto out =
|
||||
c10::Dispatcher::singleton()
|
||||
.findSchemaOrThrow("_c10d_functional::all_to_all_single", "")
|
||||
.typed<decltype(all_to_all_single)>()
|
||||
.typed<decltype(c10d::all_to_all_single)>()
|
||||
.call(grad_out, output_split_sizes, input_split_sizes, group_name);
|
||||
|
||||
// do an explicit wait to avoid cuda stream issues
|
||||
@ -400,7 +408,7 @@ class ReduceScatterTensor
|
||||
|
||||
return c10::Dispatcher::singleton()
|
||||
.findSchemaOrThrow("_c10d_functional::reduce_scatter_tensor", "")
|
||||
.typed<decltype(reduce_scatter_tensor)>()
|
||||
.typed<decltype(c10d::reduce_scatter_tensor)>()
|
||||
.call(input, reduce_op, group_size, group_name);
|
||||
}
|
||||
|
||||
@ -416,7 +424,7 @@ class ReduceScatterTensor
|
||||
auto out =
|
||||
c10::Dispatcher::singleton()
|
||||
.findSchemaOrThrow("_c10d_functional::all_gather_into_tensor", "")
|
||||
.typed<decltype(all_gather_into_tensor)>()
|
||||
.typed<decltype(c10d::all_gather_into_tensor)>()
|
||||
.call(grad_out, group_size, group_name);
|
||||
|
||||
// do an explicit wait to avoid cuda stream issues
|
||||
@ -456,7 +464,7 @@ class AllGatherIntoTensor
|
||||
|
||||
return c10::Dispatcher::singleton()
|
||||
.findSchemaOrThrow("_c10d_functional::all_gather_into_tensor", "")
|
||||
.typed<decltype(all_gather_into_tensor)>()
|
||||
.typed<decltype(c10d::all_gather_into_tensor)>()
|
||||
.call(input, group_size, group_name);
|
||||
}
|
||||
|
||||
@ -472,7 +480,7 @@ class AllGatherIntoTensor
|
||||
auto out =
|
||||
c10::Dispatcher::singleton()
|
||||
.findSchemaOrThrow("_c10d_functional::reduce_scatter_tensor", "")
|
||||
.typed<decltype(reduce_scatter_tensor)>()
|
||||
.typed<decltype(c10d::reduce_scatter_tensor)>()
|
||||
.call(grad_out, "sum", group_size, group_name);
|
||||
|
||||
// do an explicit wait to avoid cuda stream issues
|
||||
|
@ -1,3 +1,78 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/csrc/distributed/c10d/ProcessGroup.hpp>
|
||||
|
||||
namespace c10d {
|
||||
|
||||
C10_EXPORT at::Tensor& all_reduce_(
|
||||
at::Tensor& input,
|
||||
std::string reduce_op,
|
||||
std::string group_name);
|
||||
|
||||
C10_EXPORT at::Tensor all_reduce(
|
||||
const at::Tensor& input,
|
||||
std::string reduce_op,
|
||||
std::string group_name);
|
||||
|
||||
C10_EXPORT std::vector<at::Tensor> all_reduce_coalesced_(
|
||||
std::vector<at::Tensor> inputs,
|
||||
// NOLINTNEXTLINE(performance-unnecessary-value-param)
|
||||
std::string reduce_op,
|
||||
// NOLINTNEXTLINE(performance-unnecessary-value-param)
|
||||
std::string group_name);
|
||||
|
||||
C10_EXPORT std::vector<at::Tensor> all_reduce_coalesced(
|
||||
// NOLINTNEXTLINE(performance-unnecessary-value-param)
|
||||
std::vector<at::Tensor> inputs,
|
||||
std::string reduce_op,
|
||||
std::string group_name);
|
||||
|
||||
C10_EXPORT std::vector<at::Tensor> all_gather_into_tensor_coalesced(
|
||||
std::vector<at::Tensor> inputs,
|
||||
int64_t group_size,
|
||||
// NOLINTNEXTLINE(performance-unnecessary-value-param)
|
||||
std::string group_name);
|
||||
|
||||
C10_EXPORT at::Tensor all_gather_into_tensor(
|
||||
const at::Tensor& input,
|
||||
int64_t group_size,
|
||||
std::string group_name);
|
||||
|
||||
C10_EXPORT at::Tensor& all_gather_into_tensor_out(
|
||||
at::Tensor& input,
|
||||
int64_t group_size,
|
||||
const std::string& group_name,
|
||||
at::Tensor& output);
|
||||
|
||||
C10_EXPORT std::vector<at::Tensor> reduce_scatter_tensor_coalesced(
|
||||
std::vector<at::Tensor> inputs,
|
||||
// NOLINTNEXTLINE(performance-unnecessary-value-param)
|
||||
std::string reduce_op,
|
||||
int64_t group_size,
|
||||
// NOLINTNEXTLINE(performance-unnecessary-value-param)
|
||||
std::string group_name);
|
||||
|
||||
C10_EXPORT at::Tensor reduce_scatter_tensor(
|
||||
const at::Tensor& input,
|
||||
std::string reduce_op,
|
||||
int64_t group_size,
|
||||
std::string group_name);
|
||||
|
||||
C10_EXPORT at::Tensor all_to_all_single(
|
||||
const at::Tensor& input,
|
||||
std::vector<int64_t> output_split_sizes,
|
||||
std::vector<int64_t> input_split_sizes,
|
||||
// NOLINTNEXTLINE(performance-unnecessary-value-param)
|
||||
std::string group_name);
|
||||
|
||||
C10_EXPORT at::Tensor& broadcast_(
|
||||
at::Tensor& input,
|
||||
int64_t src,
|
||||
std::string group_name);
|
||||
|
||||
C10_EXPORT at::Tensor broadcast(
|
||||
const at::Tensor& input,
|
||||
int64_t src,
|
||||
std::string group_name);
|
||||
|
||||
} // namespace c10d
|
||||
|
@ -245,6 +245,22 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__weight_int4pack_mm_cpu_tensor(
|
||||
AtenTensorHandle qScaleAndZeros,
|
||||
AtenTensorHandle* ret0);
|
||||
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__c10d_functional_all_reduce_(
|
||||
AtenTensorHandle inp,
|
||||
const char* reduce_op,
|
||||
const char* group_name,
|
||||
AtenTensorHandle* ret0);
|
||||
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__c10d_functional_all_reduce(
|
||||
AtenTensorHandle inp,
|
||||
const char* reduce_op,
|
||||
const char* group_name,
|
||||
AtenTensorHandle* ret0);
|
||||
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__c10d_functional_wait_tensor(
|
||||
AtenTensorHandle inp,
|
||||
AtenTensorHandle* ret0);
|
||||
|
||||
#ifdef __cplusplus
|
||||
} // extern "C"
|
||||
#endif
|
||||
|
@ -1,4 +1,7 @@
|
||||
|
||||
#ifdef USE_DISTRIBUTED
|
||||
#include <torch/csrc/distributed/c10d/Functional.hpp>
|
||||
#endif
|
||||
#include <torch/csrc/inductor/aoti_torch/c/shim_cpu.h>
|
||||
#include <torch/csrc/inductor/aoti_torch/utils.h>
|
||||
|
||||
@ -539,3 +542,38 @@ AOTITorchError aoti_torch_cpu__weight_int4pack_mm_cpu_tensor(
|
||||
*ret0 = new_tensor_handle(std::move(tmp_result));
|
||||
});
|
||||
}
|
||||
|
||||
#ifdef USE_DISTRIBUTED
|
||||
AOTITorchError aoti_torch_cpu__c10d_functional_all_reduce_(
|
||||
AtenTensorHandle inp,
|
||||
const char* reduce_op,
|
||||
const char* group_name,
|
||||
AtenTensorHandle* ret0) {
|
||||
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
|
||||
auto tmp_result = c10d::all_reduce_(
|
||||
*tensor_handle_to_tensor_pointer(inp), reduce_op, group_name);
|
||||
*ret0 = new_tensor_handle(std::move(tmp_result));
|
||||
});
|
||||
}
|
||||
|
||||
AOTITorchError aoti_torch_cpu__c10d_functional_all_reduce(
|
||||
AtenTensorHandle inp,
|
||||
const char* reduce_op,
|
||||
const char* group_name,
|
||||
AtenTensorHandle* ret0) {
|
||||
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
|
||||
auto tmp_result = c10d::all_reduce(
|
||||
*tensor_handle_to_tensor_pointer(inp), reduce_op, group_name);
|
||||
*ret0 = new_tensor_handle(std::move(tmp_result));
|
||||
});
|
||||
}
|
||||
|
||||
AOTITorchError aoti_torch_cpu__c10d_functional_wait_tensor(
|
||||
AtenTensorHandle inp,
|
||||
AtenTensorHandle* ret0) {
|
||||
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
|
||||
auto tmp_result = c10d::wait_tensor(*tensor_handle_to_tensor_pointer(inp));
|
||||
*ret0 = new_tensor_handle(std::move(tmp_result));
|
||||
});
|
||||
}
|
||||
#endif
|
||||
|
Reference in New Issue
Block a user