[cpp wrapper] add AOTI shim for collective ops (#154492)

Implementations:
1. Move collective ops to c10d namespace, so that we can call them externally.
2. Add AOTI shims for collective ops.

Testing
1. Add c10d functional UT for cpu.
2. Include the above one in cpp wrapper UT.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/154492
Approved by: https://github.com/desertfire
This commit is contained in:
Valentine233
2025-06-25 01:20:02 +00:00
committed by PyTorch MergeBot
parent d797038ea9
commit 02c7ab2f9b
9 changed files with 340 additions and 52 deletions

View File

@ -9,7 +9,7 @@ import torch
import torch.distributed as dist
import torch.distributed._functional_collectives as funcol
from torch._C import FileCheck
from torch._inductor.utils import fresh_cache, run_and_get_triton_code
from torch._inductor.utils import fresh_cache, run_and_get_code, run_and_get_triton_code
from torch.distributed._functional_collectives import (
all_gather_into_tensor_coalesced,
all_gather_tensor,
@ -713,6 +713,61 @@ class PyWorkTest(TestCase):
self.assertEqual(pg.dels, 4)
class CompileTestCPU(TestCase):
def setUp(self):
super().setUp()
if not dist.is_initialized():
self.rank = 0
self.world_size = 2
store = FakeStore()
dist.init_process_group(
backend="fake",
world_size=self.world_size,
rank=self.rank,
store=store,
)
def tearDown(self):
dist.destroy_process_group()
@fresh_cache()
def _test_inductor_all_reduce_cpu(self, cpp_wrapper=False):
def func(arg: torch.Tensor) -> torch.Tensor:
buf0 = arg + 42
ar0 = funcol.all_reduce(buf0, "avg", "0")
ar0 = funcol.wait_tensor(ar0)
return ar0
arg = torch.rand(4, 4, device="cpu")
torch._inductor.config.cpp_wrapper = cpp_wrapper
compiled = torch.compile(func)
_, (code,) = run_and_get_code(compiled, arg)
include_ops = (
[
"aoti_torch_cpu__c10d_functional_all_reduce_",
"aoti_torch_cpu__c10d_functional_wait_tensor",
]
if cpp_wrapper
else [
"torch.ops._c10d_functional.all_reduce_.default",
"torch.ops._c10d_functional.wait_tensor.default",
]
)
for op in include_ops:
self.assertIn(op, code)
# Test aoti
AOTIRunnerUtil.run(func, (arg,))
torch.cpu.synchronize()
def test_inductor_all_reduce_cpu(self):
self._test_inductor_all_reduce_cpu(cpp_wrapper=False)
self._test_inductor_all_reduce_cpu(cpp_wrapper=True)
class CompileTest(TestCase):
def setUp(self):
super().setUp()

View File

@ -67,7 +67,7 @@ def get_collective_input_size_bytes(node: ir.IRNode) -> int:
def get_collective_group_size(node: ir.IRNode) -> int:
if type(node) == ir._CollectiveKernel:
if isinstance(node, ir._CollectiveKernel) and not isinstance(node, ir._WaitKernel):
from torch.distributed.distributed_c10d import _get_group_size_by_name
return _get_group_size_by_name(node.constant_args[-1])

View File

@ -208,7 +208,7 @@ def register_comm_lowerings():
inp.realize()
V.graph.no_fuse_buffer_names.add(inp.get_name())
inp = ir.ExternKernel.require_contiguous(inp)
ir._CollectiveKernel.create_inplace(
ir._AllReduceKernel.create_inplace(
c10d.all_reduce_.default, inp, reduce_op, group_name
)
return inp
@ -227,7 +227,7 @@ def register_comm_lowerings():
# Lower as c10d.all_reduce_
inp = ir.ExternKernel.require_contiguous(inp)
ir._CollectiveKernel.create_inplace(
ir._AllReduce_Kernel.create_inplace(
c10d.all_reduce_.default, inp, reduce_op, group_name
)
return inp

View File

@ -8235,7 +8235,10 @@ class _CollectiveKernel(FallbackKernel):
"Setting cpp kernel needs a valid op_overload"
)
kernel = self.op_overload
self.cpp_kernel_name = kernel._schema.name
if cpp_kernel_name is not None:
self.cpp_kernel_name = cpp_kernel_name
else:
self.cpp_kernel_name = kernel._schema.name
self.ordered_kwargs_for_cpp_kernel = [
x.name for x in kernel._schema.arguments if x.kwarg_only
@ -8363,7 +8366,98 @@ class _CollectiveKernel(FallbackKernel):
return packed
class _AllReduce_Kernel(_CollectiveKernel):
def __init__( # type: ignore[no-untyped-def]
self,
layout,
kernel,
tensor_args,
nontensor_args,
unflatten_args,
kwargs=None,
*,
unbacked_bindings=None,
) -> None:
super().__init__(
layout,
kernel,
tensor_args,
nontensor_args,
unflatten_args,
kwargs=None,
unbacked_bindings=unbacked_bindings,
)
self.set_cpp_kernel_name("aoti_torch_cpu__c10d_functional_all_reduce_")
def codegen(self, wrapper) -> None: # type: ignore[no-untyped-def]
wrapper.include_extra_header("torch/csrc/inductor/aoti_torch/c/shim_cpu.h")
wrapper.generate_extern_kernel_alloc(self)
if isinstance(self.layout, Layout):
self.codegen_size_asserts(wrapper)
class _AllReduceKernel(_CollectiveKernel):
def __init__( # type: ignore[no-untyped-def]
self,
layout,
kernel,
tensor_args,
nontensor_args,
unflatten_args,
kwargs=None,
*,
unbacked_bindings=None,
) -> None:
super().__init__(
layout,
kernel,
tensor_args,
nontensor_args,
unflatten_args,
kwargs=None,
unbacked_bindings=unbacked_bindings,
)
self.set_cpp_kernel_name("aoti_torch_cpu__c10d_functional_all_reduce")
def codegen(self, wrapper) -> None: # type: ignore[no-untyped-def]
wrapper.include_extra_header("torch/csrc/inductor/aoti_torch/c/shim_cpu.h")
wrapper.generate_extern_kernel_alloc(self)
if isinstance(self.layout, Layout):
self.codegen_size_asserts(wrapper)
class _WaitKernel(_CollectiveKernel):
def __init__( # type: ignore[no-untyped-def]
self,
layout,
kernel,
tensor_args,
nontensor_args,
unflatten_args,
kwargs=None,
*,
unbacked_bindings=None,
) -> None:
super().__init__(
layout,
kernel,
tensor_args,
nontensor_args,
unflatten_args,
kwargs=None,
unbacked_bindings=unbacked_bindings,
)
self.set_cpp_kernel_name("aoti_torch_cpu__c10d_functional_wait_tensor")
def codegen(self, wrapper) -> None: # type: ignore[no-untyped-def]
wrapper.include_extra_header("torch/csrc/inductor/aoti_torch/c/shim_cpu.h")
wrapper.generate_extern_kernel_alloc(self)
if isinstance(self.layout, Layout):
self.codegen_size_asserts(wrapper)
def get_volatile_reads(self): # type: ignore[no-untyped-def]
inp = self.inputs[0]
if isinstance(inp, _CollectiveKernel):

View File

@ -2323,7 +2323,9 @@ def is_collective(
from . import ir
return (
type(node) == ir._CollectiveKernel and (op is None or node.op_overload is op)
isinstance(node, ir._CollectiveKernel)
and not isinstance(node, ir._WaitKernel)
and (op is None or node.op_overload is op)
) or (
# TODO: this is a temporary solution to ensure that we can identify torchrec's
# communication ops. But in order to allow better communication and computation

View File

@ -30,6 +30,37 @@ c10d::ReduceOp to_reduce_op(const std::string& reduce_op) {
return it->second;
}
at::Tensor allocate_all_gather_output(
const at::Tensor& input,
int64_t group_size) {
TORCH_CHECK(input.is_contiguous());
auto output_size = input.sizes().vec();
output_size[0] *= group_size;
return at::empty(
output_size,
at::TensorOptions().dtype(input.dtype()).device(input.device()));
}
at::Tensor allocate_reduce_scatter_output(
const at::Tensor& input,
const int64_t group_size) {
TORCH_CHECK(input.is_contiguous());
auto output_size = input.sizes().vec();
if (output_size[0] % group_size != 0) {
LOG(WARNING) << "The first dimension of the reduce_scatter input ("
<< output_size[0] << ") is not divisible by the group size ("
<< group_size << ").";
}
output_size[0] /= group_size;
return at::empty(
output_size,
at::TensorOptions().dtype(input.dtype()).device(input.device()));
}
} // namespace
namespace c10d {
at::Tensor& all_reduce_(
at::Tensor& input,
// NOLINTNEXTLINE(performance-unnecessary-value-param)
@ -85,17 +116,6 @@ std::vector<at::Tensor> all_reduce_coalesced(
outputs, std::move(reduce_op), std::move(group_name));
}
at::Tensor allocate_all_gather_output(
const at::Tensor& input,
int64_t group_size) {
TORCH_CHECK(input.is_contiguous());
auto output_size = input.sizes().vec();
output_size[0] *= group_size;
return at::empty(
output_size,
at::TensorOptions().dtype(input.dtype()).device(input.device()));
}
std::vector<at::Tensor> all_gather_into_tensor_coalesced(
std::vector<at::Tensor> inputs,
int64_t group_size,
@ -140,22 +160,6 @@ at::Tensor& all_gather_into_tensor_out(
return output;
}
at::Tensor allocate_reduce_scatter_output(
const at::Tensor& input,
const int64_t group_size) {
TORCH_CHECK(input.is_contiguous());
auto output_size = input.sizes().vec();
if (output_size[0] % group_size != 0) {
LOG(WARNING) << "The first dimension of the reduce_scatter input ("
<< output_size[0] << ") is not divisible by the group size ("
<< group_size << ").";
}
output_size[0] /= group_size;
return at::empty(
output_size,
at::TensorOptions().dtype(input.dtype()).device(input.device()));
}
std::vector<at::Tensor> reduce_scatter_tensor_coalesced(
std::vector<at::Tensor> inputs,
// NOLINTNEXTLINE(performance-unnecessary-value-param)
@ -234,65 +238,68 @@ at::Tensor broadcast(
return broadcast_(output, src, std::move(group_name));
}
} // namespace
} // namespace c10d
TORCH_LIBRARY(_c10d_functional, m) {
m.def(
"all_reduce(Tensor input, str reduce_op, str group_name) -> Tensor",
torch::dispatch(
c10::DispatchKey::CompositeExplicitAutograd, ::all_reduce),
c10::DispatchKey::CompositeExplicitAutograd, c10d::all_reduce),
{at::Tag::pt2_compliant_tag});
m.def(
"all_reduce_(Tensor(a!) input, str reduce_op, str group_name) -> Tensor(a!)",
torch::dispatch(
c10::DispatchKey::CompositeExplicitAutograd, ::all_reduce_),
c10::DispatchKey::CompositeExplicitAutograd, c10d::all_reduce_),
{at::Tag::pt2_compliant_tag});
m.def(
"all_reduce_coalesced(Tensor[] inputs, str reduce_op, str group_name) -> Tensor[]",
torch::dispatch(
c10::DispatchKey::CompositeExplicitAutograd, ::all_reduce_coalesced),
c10::DispatchKey::CompositeExplicitAutograd,
c10d::all_reduce_coalesced),
{at::Tag::pt2_compliant_tag});
m.def(
"all_reduce_coalesced_(Tensor[](a!) inputs, str reduce_op, str group_name) -> Tensor[](a!)",
torch::dispatch(
c10::DispatchKey::CompositeExplicitAutograd, ::all_reduce_coalesced_),
c10::DispatchKey::CompositeExplicitAutograd,
c10d::all_reduce_coalesced_),
{at::Tag::pt2_compliant_tag});
m.def(
"all_gather_into_tensor_out(Tensor input, int group_size, str group_name, *, Tensor(a!) out) -> Tensor(a!)",
torch::dispatch(
c10::DispatchKey::CompositeExplicitAutograd,
::all_gather_into_tensor_out),
c10d::all_gather_into_tensor_out),
{at::Tag::pt2_compliant_tag, at::Tag::needs_contiguous_strides});
m.def(
"all_gather_into_tensor(Tensor input, int group_size, str group_name) -> Tensor",
torch::dispatch(
c10::DispatchKey::CompositeExplicitAutograd,
::all_gather_into_tensor),
c10d::all_gather_into_tensor),
{at::Tag::pt2_compliant_tag, at::Tag::needs_contiguous_strides});
m.def(
"all_gather_into_tensor_coalesced(Tensor[] inputs, int group_size, str group_name) -> Tensor[]",
torch::dispatch(
c10::DispatchKey::CompositeExplicitAutograd,
::all_gather_into_tensor_coalesced),
c10d::all_gather_into_tensor_coalesced),
{at::Tag::pt2_compliant_tag, at::Tag::needs_contiguous_strides});
m.def(
"reduce_scatter_tensor(Tensor input, str reduce_op, int group_size, str group_name) -> Tensor",
torch::dispatch(
c10::DispatchKey::CompositeExplicitAutograd, ::reduce_scatter_tensor),
c10::DispatchKey::CompositeExplicitAutograd,
c10d::reduce_scatter_tensor),
{at::Tag::pt2_compliant_tag, at::Tag::needs_contiguous_strides});
m.def(
"reduce_scatter_tensor_coalesced(Tensor[] inputs, str reduce_op, int group_size, str group_name) -> Tensor[]",
torch::dispatch(
c10::DispatchKey::CompositeExplicitAutograd,
::reduce_scatter_tensor_coalesced),
c10d::reduce_scatter_tensor_coalesced),
{at::Tag::pt2_compliant_tag, at::Tag::needs_contiguous_strides});
m.def(
@ -302,18 +309,19 @@ TORCH_LIBRARY(_c10d_functional, m) {
"SymInt[] input_split_sizes, "
"str group_name) -> Tensor",
torch::dispatch(
c10::DispatchKey::CompositeExplicitAutograd, ::all_to_all_single),
c10::DispatchKey::CompositeExplicitAutograd, c10d::all_to_all_single),
{at::Tag::pt2_compliant_tag, at::Tag::needs_contiguous_strides});
m.def(
"broadcast(Tensor input, int src, str group_name) -> Tensor",
torch::dispatch(c10::DispatchKey::CompositeExplicitAutograd, ::broadcast),
torch::dispatch(
c10::DispatchKey::CompositeExplicitAutograd, c10d::broadcast),
{at::Tag::pt2_compliant_tag});
m.def(
"broadcast_(Tensor(a!) input, int src, str group_name) -> Tensor(a!)",
torch::dispatch(
c10::DispatchKey::CompositeExplicitAutograd, ::broadcast_),
c10::DispatchKey::CompositeExplicitAutograd, c10d::broadcast_),
{at::Tag::pt2_compliant_tag});
m.def(
@ -342,7 +350,7 @@ class AllToAllSingle : public torch::autograd::Function<AllToAllSingle> {
return c10::Dispatcher::singleton()
.findSchemaOrThrow("_c10d_functional::all_to_all_single", "")
.typed<decltype(all_to_all_single)>()
.typed<decltype(c10d::all_to_all_single)>()
.call(input, output_split_sizes, input_split_sizes, group_name);
}
@ -361,7 +369,7 @@ class AllToAllSingle : public torch::autograd::Function<AllToAllSingle> {
auto out =
c10::Dispatcher::singleton()
.findSchemaOrThrow("_c10d_functional::all_to_all_single", "")
.typed<decltype(all_to_all_single)>()
.typed<decltype(c10d::all_to_all_single)>()
.call(grad_out, output_split_sizes, input_split_sizes, group_name);
// do an explicit wait to avoid cuda stream issues
@ -400,7 +408,7 @@ class ReduceScatterTensor
return c10::Dispatcher::singleton()
.findSchemaOrThrow("_c10d_functional::reduce_scatter_tensor", "")
.typed<decltype(reduce_scatter_tensor)>()
.typed<decltype(c10d::reduce_scatter_tensor)>()
.call(input, reduce_op, group_size, group_name);
}
@ -416,7 +424,7 @@ class ReduceScatterTensor
auto out =
c10::Dispatcher::singleton()
.findSchemaOrThrow("_c10d_functional::all_gather_into_tensor", "")
.typed<decltype(all_gather_into_tensor)>()
.typed<decltype(c10d::all_gather_into_tensor)>()
.call(grad_out, group_size, group_name);
// do an explicit wait to avoid cuda stream issues
@ -456,7 +464,7 @@ class AllGatherIntoTensor
return c10::Dispatcher::singleton()
.findSchemaOrThrow("_c10d_functional::all_gather_into_tensor", "")
.typed<decltype(all_gather_into_tensor)>()
.typed<decltype(c10d::all_gather_into_tensor)>()
.call(input, group_size, group_name);
}
@ -472,7 +480,7 @@ class AllGatherIntoTensor
auto out =
c10::Dispatcher::singleton()
.findSchemaOrThrow("_c10d_functional::reduce_scatter_tensor", "")
.typed<decltype(reduce_scatter_tensor)>()
.typed<decltype(c10d::reduce_scatter_tensor)>()
.call(grad_out, "sum", group_size, group_name);
// do an explicit wait to avoid cuda stream issues

View File

@ -1,3 +1,78 @@
#pragma once
#include <torch/csrc/distributed/c10d/ProcessGroup.hpp>
namespace c10d {
C10_EXPORT at::Tensor& all_reduce_(
at::Tensor& input,
std::string reduce_op,
std::string group_name);
C10_EXPORT at::Tensor all_reduce(
const at::Tensor& input,
std::string reduce_op,
std::string group_name);
C10_EXPORT std::vector<at::Tensor> all_reduce_coalesced_(
std::vector<at::Tensor> inputs,
// NOLINTNEXTLINE(performance-unnecessary-value-param)
std::string reduce_op,
// NOLINTNEXTLINE(performance-unnecessary-value-param)
std::string group_name);
C10_EXPORT std::vector<at::Tensor> all_reduce_coalesced(
// NOLINTNEXTLINE(performance-unnecessary-value-param)
std::vector<at::Tensor> inputs,
std::string reduce_op,
std::string group_name);
C10_EXPORT std::vector<at::Tensor> all_gather_into_tensor_coalesced(
std::vector<at::Tensor> inputs,
int64_t group_size,
// NOLINTNEXTLINE(performance-unnecessary-value-param)
std::string group_name);
C10_EXPORT at::Tensor all_gather_into_tensor(
const at::Tensor& input,
int64_t group_size,
std::string group_name);
C10_EXPORT at::Tensor& all_gather_into_tensor_out(
at::Tensor& input,
int64_t group_size,
const std::string& group_name,
at::Tensor& output);
C10_EXPORT std::vector<at::Tensor> reduce_scatter_tensor_coalesced(
std::vector<at::Tensor> inputs,
// NOLINTNEXTLINE(performance-unnecessary-value-param)
std::string reduce_op,
int64_t group_size,
// NOLINTNEXTLINE(performance-unnecessary-value-param)
std::string group_name);
C10_EXPORT at::Tensor reduce_scatter_tensor(
const at::Tensor& input,
std::string reduce_op,
int64_t group_size,
std::string group_name);
C10_EXPORT at::Tensor all_to_all_single(
const at::Tensor& input,
std::vector<int64_t> output_split_sizes,
std::vector<int64_t> input_split_sizes,
// NOLINTNEXTLINE(performance-unnecessary-value-param)
std::string group_name);
C10_EXPORT at::Tensor& broadcast_(
at::Tensor& input,
int64_t src,
std::string group_name);
C10_EXPORT at::Tensor broadcast(
const at::Tensor& input,
int64_t src,
std::string group_name);
} // namespace c10d

View File

@ -245,6 +245,22 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__weight_int4pack_mm_cpu_tensor(
AtenTensorHandle qScaleAndZeros,
AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__c10d_functional_all_reduce_(
AtenTensorHandle inp,
const char* reduce_op,
const char* group_name,
AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__c10d_functional_all_reduce(
AtenTensorHandle inp,
const char* reduce_op,
const char* group_name,
AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__c10d_functional_wait_tensor(
AtenTensorHandle inp,
AtenTensorHandle* ret0);
#ifdef __cplusplus
} // extern "C"
#endif

View File

@ -1,4 +1,7 @@
#ifdef USE_DISTRIBUTED
#include <torch/csrc/distributed/c10d/Functional.hpp>
#endif
#include <torch/csrc/inductor/aoti_torch/c/shim_cpu.h>
#include <torch/csrc/inductor/aoti_torch/utils.h>
@ -539,3 +542,38 @@ AOTITorchError aoti_torch_cpu__weight_int4pack_mm_cpu_tensor(
*ret0 = new_tensor_handle(std::move(tmp_result));
});
}
#ifdef USE_DISTRIBUTED
AOTITorchError aoti_torch_cpu__c10d_functional_all_reduce_(
AtenTensorHandle inp,
const char* reduce_op,
const char* group_name,
AtenTensorHandle* ret0) {
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
auto tmp_result = c10d::all_reduce_(
*tensor_handle_to_tensor_pointer(inp), reduce_op, group_name);
*ret0 = new_tensor_handle(std::move(tmp_result));
});
}
AOTITorchError aoti_torch_cpu__c10d_functional_all_reduce(
AtenTensorHandle inp,
const char* reduce_op,
const char* group_name,
AtenTensorHandle* ret0) {
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
auto tmp_result = c10d::all_reduce(
*tensor_handle_to_tensor_pointer(inp), reduce_op, group_name);
*ret0 = new_tensor_handle(std::move(tmp_result));
});
}
AOTITorchError aoti_torch_cpu__c10d_functional_wait_tensor(
AtenTensorHandle inp,
AtenTensorHandle* ret0) {
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
auto tmp_result = c10d::wait_tensor(*tensor_handle_to_tensor_pointer(inp));
*ret0 = new_tensor_handle(std::move(tmp_result));
});
}
#endif