Native c10d_functional ops (#110570)

This PR introduces a native version of c10d_functional ops. The main goal is to add collective support in AOTInductor and allow collective ops to work in multi-threaded native runtimes.

The native version also incorporated API improvements we wished to implement in Python c10d_functional:

- Removed `ranks` and `group_size` from collective op signatures which were proven to be redundant.
- Use tensor storage as opposed to `void*` to resolve in-flight work.

The native process group registration/resolution mechansim is only used for native c10d_functional in the PR. It will become the single source of truth in upcoming PRs.

The upcoming PRs will implement Inductor/AOTInductor support for c10d_functional, after which native c10d_functional will replace Python c10d_functional.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/110570
Approved by: https://github.com/wanchaol
This commit is contained in:
Yifu Wang
2023-10-25 10:49:20 -07:00
committed by PyTorch MergeBot
parent 7fe51e3e9b
commit ec18ef62f4
11 changed files with 668 additions and 2 deletions

View File

@ -520,7 +520,9 @@ libtorch_core_sources = sorted(
libtorch_distributed_base_sources = [
"torch/csrc/distributed/c10d/Backend.cpp",
"torch/csrc/distributed/c10d/FileStore.cpp",
"torch/csrc/distributed/c10d/Functional.cpp",
"torch/csrc/distributed/c10d/GlooDeviceFactory.cpp",
"torch/csrc/distributed/c10d/GroupRegistry.cpp",
"torch/csrc/distributed/c10d/Ops.cpp",
"torch/csrc/distributed/c10d/ParamCommsUtils.cpp",
"torch/csrc/distributed/c10d/PrefixStore.cpp",

View File

@ -0,0 +1,187 @@
# Owner(s): ["module: c10d"]
from typing import List
import torch
import torch.distributed as dist
from torch.testing._internal.common_distributed import (
MultiProcessTestCase,
requires_nccl,
skip_if_lt_x_gpu,
)
from torch.testing._internal.common_utils import run_tests
if not dist.is_available():
print("distributed package not available, skipping tests", file=sys.stderr)
sys.exit(0)
@requires_nccl()
class C10DFunctionalNativeTest(MultiProcessTestCase):
def setUp(self) -> None:
super().setUp()
self._spawn_processes()
@property
def world_size(self) -> int:
return 2
@property
def ranks(self) -> List[int]:
return list(range(self.world_size))
@property
def device(self) -> torch.device:
return torch.device(f"cuda:{self.rank}")
def _init_process_group(self) -> None:
store = dist.FileStore(self.file_name, self.world_size)
dist.init_process_group(
backend="nccl",
world_size=self.world_size,
rank=self.rank,
store=store,
)
torch._C._distributed_c10d._register_process_group("default", dist.group.WORLD)
@skip_if_lt_x_gpu(2)
def test_all_reduce(self) -> None:
self._init_process_group()
input = torch.full((10, 10), float(self.rank), device=self.device)
output = torch.ops._c10d_functional.all_reduce(
input,
"avg",
"default",
)
output = torch.ops._c10d_functional.wait_tensor(output)
assert id(output) != id(input)
expect = sum(self.ranks) / self.world_size
assert output.eq(expect).all()
@skip_if_lt_x_gpu(2)
def test_all_reduce_(self) -> None:
self._init_process_group()
input = torch.full((10, 10), float(self.rank), device=self.device)
output = torch.ops._c10d_functional.all_reduce_(
input,
"avg",
"default",
)
output = torch.ops._c10d_functional.wait_tensor(output)
assert id(output) == id(input)
expect = sum(self.ranks) / self.world_size
assert output.eq(expect).all()
@skip_if_lt_x_gpu(2)
def test_all_reduce_coalesced(self) -> None:
self._init_process_group()
inputs = [
torch.full((i, i), float(self.rank * i), device=self.device)
for i in range(10)
]
outputs = torch.ops._c10d_functional.all_reduce_coalesced(
inputs,
"avg",
"default",
)
for i, (output, input) in enumerate(zip(outputs, inputs)):
output = torch.ops._c10d_functional.wait_tensor(output)
assert id(output) != id(input)
assert output.eq(sum(self.ranks) / self.world_size * i).all()
@skip_if_lt_x_gpu(2)
def test_all_reduce_coalesced_(self) -> None:
self._init_process_group()
inputs = [
torch.full((i, i), float(self.rank * i), device=self.device)
for i in range(10)
]
outputs = torch.ops._c10d_functional.all_reduce_coalesced_(
inputs,
"avg",
"default",
)
for i, (output, input) in enumerate(zip(outputs, inputs)):
output = torch.ops._c10d_functional.wait_tensor(output)
assert id(output) == id(input)
assert output.eq(sum(self.ranks) / self.world_size * i).all()
@skip_if_lt_x_gpu(2)
def test_all_gather_into_tensor(self) -> None:
self._init_process_group()
input = torch.full((10, 10), float(self.rank), device=self.device)
output = torch.ops._c10d_functional.all_gather_into_tensor(
input,
self.world_size,
"default",
)
output = torch.ops._c10d_functional.wait_tensor(output)
expect = torch.cat(
[
torch.full((10, 10), float(rank), device=self.device)
for rank in self.ranks
]
)
assert torch.allclose(output, expect)
assert output.eq(expect).all()
@skip_if_lt_x_gpu(2)
def test_all_gather_into_tensor_coalesced(self) -> None:
self._init_process_group()
inputs = [
torch.full((10, 10), float(self.rank * i), device=self.device)
for i in range(10)
]
outputs = torch.ops._c10d_functional.all_gather_into_tensor_coalesced(
inputs,
self.world_size,
"default",
)
for i, output in enumerate(outputs):
output = torch.ops._c10d_functional.wait_tensor(output)
expect = torch.cat(
[
torch.full((10, 10), float(rank) * i, device=self.device)
for rank in self.ranks
]
)
assert output.eq(expect).all()
@skip_if_lt_x_gpu(2)
def test_reduce_scatter_tensor(self) -> None:
self._init_process_group()
input = torch.tensor(self.ranks, device=self.device)
output = torch.ops._c10d_functional.reduce_scatter_tensor(
input,
"avg",
self.world_size,
"default",
)
output = torch.ops._c10d_functional.wait_tensor(output)
assert output.eq(self.rank).all()
@skip_if_lt_x_gpu(2)
def test_reduce_scatter_tensor_coalesced(self) -> None:
self._init_process_group()
inputs = [torch.tensor(self.ranks, device=self.device) * i for i in range(10)]
outputs = torch.ops._c10d_functional.reduce_scatter_tensor_coalesced(
inputs,
"avg",
self.world_size,
"default",
)
for i, output in enumerate(outputs):
output = torch.ops._c10d_functional.wait_tensor(output)
assert output.eq(self.rank * i).all()
if __name__ == "__main__":
run_tests()

View File

@ -0,0 +1,258 @@
#include <torch/csrc/distributed/c10d/Functional.hpp>
#include <shared_mutex>
#include <ATen/ATen.h>
#include <ATen/core/op_registration/op_registration.h>
#include <c10/core/DispatchKey.h>
#include <torch/csrc/autograd/function.h>
#include <torch/csrc/distributed/c10d/GroupRegistry.hpp>
#include <torch/csrc/distributed/c10d/ProcessGroup.hpp>
#include <torch/csrc/distributed/c10d/RankLocal.hpp>
namespace {
class WorkRegistry {
public:
void register_work(
const at::Tensor& tensor,
c10::intrusive_ptr<c10d::Work> work) {
const auto storage = tensor.storage().getWeakStorageImpl();
std::unique_lock lock(lock_);
auto [it, inserted] = registry_.emplace(storage, work);
TORCH_CHECK(
inserted || it->second != work,
"The tensor storage is already associated with another work.");
}
c10::intrusive_ptr<c10d::Work> pop_work(const at::Tensor& tensor) {
const auto storage = tensor.storage().getWeakStorageImpl();
std::unique_lock lock(lock_);
auto it = registry_.find(storage);
TORCH_CHECK(
it != registry_.end(),
"No pending collective is associated with the tensor storage. "
"This typically means that the tensor is not a collective output, "
"or the tensor has already been waited on.");
auto work = it->second;
registry_.erase(it);
return work;
}
private:
std::unordered_map<
c10::weak_intrusive_ptr<c10::StorageImpl>,
c10::intrusive_ptr<c10d::Work>>
registry_;
std::mutex lock_;
};
const std::unordered_map<std::string, c10d::ReduceOp> str_to_reduce_op = {
{"sum", c10d::ReduceOp(c10d::ReduceOp::RedOpType::SUM)},
{"avg", c10d::ReduceOp(c10d::ReduceOp::RedOpType::AVG)},
{"product", c10d::ReduceOp(c10d::ReduceOp::RedOpType::PRODUCT)},
{"min", c10d::ReduceOp(c10d::ReduceOp::RedOpType::MIN)},
{"max", c10d::ReduceOp(c10d::ReduceOp::RedOpType::MAX)},
{"band", c10d::ReduceOp(c10d::ReduceOp::RedOpType::BAND)},
{"bor", c10d::ReduceOp(c10d::ReduceOp::RedOpType::BOR)},
{"bxor", c10d::ReduceOp(c10d::ReduceOp::RedOpType::BXOR)},
// TODO: support premul_sum
// {"premul_sum", c10d::ReduceOp(c10d::ReduceOp::RedOpType::PREMUL_SUM)},
{"unused", c10d::ReduceOp(c10d::ReduceOp::RedOpType::UNUSED)}};
c10d::ReduceOp to_reduce_op(const std::string& reduce_op) {
auto it = str_to_reduce_op.find(reduce_op);
TORCH_CHECK(
it != str_to_reduce_op.end(), "Unrecognized reduce_op: ", reduce_op);
return it->second;
}
at::Tensor all_reduce_(
at::Tensor input,
const std::string& reduce_op,
const std::string& group_name) {
c10d::AllreduceOptions opts;
opts.reduceOp = to_reduce_op(reduce_op);
std::vector<at::Tensor> inputs{input};
auto group = c10d::resolve_process_group(group_name);
auto work = group->allreduce(inputs, opts);
c10d::RankLocal<WorkRegistry>::get().register_work(input, work);
return input;
}
at::Tensor all_reduce(
const at::Tensor& input,
const std::string& reduce_op,
const std::string& group_name) {
auto output = input.clone();
return all_reduce_(output, reduce_op, group_name);
}
std::vector<at::Tensor> all_reduce_coalesced_(
std::vector<at::Tensor> inputs,
const std::string& reduce_op,
const std::string& group_name) {
c10d::AllreduceCoalescedOptions opts;
opts.reduceOp = to_reduce_op(reduce_op);
auto group = c10d::resolve_process_group(group_name);
auto work = group->allreduce_coalesced(inputs, opts);
for (const auto& tensor : inputs) {
c10d::RankLocal<WorkRegistry>::get().register_work(tensor, work);
}
return inputs;
}
std::vector<at::Tensor> all_reduce_coalesced(
const std::vector<at::Tensor>& inputs,
const std::string& reduce_op,
const std::string& group_name) {
std::vector<at::Tensor> outputs;
for (const auto& tensor : inputs) {
outputs.push_back(tensor.clone());
}
return all_reduce_coalesced_(outputs, reduce_op, group_name);
}
at::Tensor allocate_all_gather_output(
const at::Tensor& input,
int64_t group_size) {
auto output_size = input.sizes().vec();
output_size[0] *= group_size;
return at::empty(
output_size,
at::TensorOptions().dtype(input.dtype()).device(input.device()));
}
std::vector<at::Tensor> all_gather_into_tensor_coalesced(
const std::vector<at::Tensor>& inputs,
const int64_t group_size,
const std::string& group_name) {
std::vector<at::Tensor> outputs;
for (const auto& tensor : inputs) {
outputs.push_back(allocate_all_gather_output(tensor, group_size));
}
auto group = c10d::resolve_process_group(group_name);
auto work = group->allgather_into_tensor_coalesced(
outputs, const_cast<std::vector<at::Tensor>&>(inputs));
for (const auto& tensor : outputs) {
c10d::RankLocal<WorkRegistry>::get().register_work(tensor, work);
}
return outputs;
}
at::Tensor all_gather_into_tensor(
const at::Tensor& input,
const int64_t group_size,
const std::string& group_name) {
std::vector<at::Tensor> inputs{input};
return all_gather_into_tensor_coalesced(inputs, group_size, group_name)[0];
}
at::Tensor allocate_reduce_scatter_output(
const at::Tensor& input,
const int64_t group_size) {
auto output_size = input.sizes().vec();
if (output_size[0] % group_size != 0) {
LOG(WARNING) << "The first dimension of the reduce_scatter input ("
<< output_size[0] << ") is not divisible by the group size ("
<< group_size << ").";
}
output_size[0] /= group_size;
return at::empty(
output_size,
at::TensorOptions().dtype(input.dtype()).device(input.device()));
}
std::vector<at::Tensor> reduce_scatter_tensor_coalesced(
const std::vector<at::Tensor>& inputs,
const std::string& reduce_op,
const int64_t group_size,
const std::string& group_name) {
c10d::ReduceScatterOptions opts;
opts.reduceOp = to_reduce_op(reduce_op);
std::vector<at::Tensor> outputs;
for (const auto& tensor : inputs) {
outputs.push_back(allocate_reduce_scatter_output(tensor, group_size));
}
auto group = c10d::resolve_process_group(group_name);
auto work = group->reduce_scatter_tensor_coalesced(
outputs, const_cast<std::vector<at::Tensor>&>(inputs), opts);
for (const auto& tensor : outputs) {
c10d::RankLocal<WorkRegistry>::get().register_work(tensor, work);
}
return outputs;
}
at::Tensor reduce_scatter_tensor(
at::Tensor input,
const std::string& reduce_op,
const int64_t group_size,
const std::string& group_name) {
std::vector<at::Tensor> inputs{input};
return reduce_scatter_tensor_coalesced(
inputs, reduce_op, group_size, group_name)[0];
}
at::Tensor wait_tensor(const at::Tensor& tensor) {
auto work = c10d::RankLocal<WorkRegistry>::get().pop_work(tensor);
work->wait();
return tensor;
}
} // namespace
TORCH_LIBRARY(_c10d_functional, m) {
m.def(
"all_reduce(Tensor input, str reduce_op, str group_name) -> Tensor",
torch::dispatch(
c10::DispatchKey::CompositeExplicitAutograd, ::all_reduce));
m.def(
"all_reduce_(Tensor(a!) input, str reduce_op, str group_name) -> Tensor(a!)",
torch::dispatch(
c10::DispatchKey::CompositeExplicitAutograd, ::all_reduce_));
m.def(
"all_reduce_coalesced(Tensor[] inputs, str reduce_op, str group_name) -> Tensor[]",
torch::dispatch(
c10::DispatchKey::CompositeExplicitAutograd, ::all_reduce_coalesced));
m.def(
"all_reduce_coalesced_(Tensor[](a!) inputs, str reduce_op, str group_name) -> Tensor[](a!)",
torch::dispatch(
c10::DispatchKey::CompositeExplicitAutograd,
::all_reduce_coalesced_));
m.def(
"all_gather_into_tensor(Tensor input, int group_size, str group_name) -> Tensor",
torch::dispatch(
c10::DispatchKey::CompositeExplicitAutograd,
::all_gather_into_tensor));
m.def(
"all_gather_into_tensor_coalesced(Tensor[] inputs, int group_size, str group_name) -> Tensor[]",
torch::dispatch(
c10::DispatchKey::CompositeExplicitAutograd,
::all_gather_into_tensor_coalesced));
m.def(
"reduce_scatter_tensor(Tensor input, str reduce_op, int group_size, str group_name) -> Tensor",
torch::dispatch(
c10::DispatchKey::CompositeExplicitAutograd,
::reduce_scatter_tensor));
m.def(
"reduce_scatter_tensor_coalesced(Tensor[] inputs, str reduce_op, int group_size, str group_name) -> Tensor[]",
torch::dispatch(
c10::DispatchKey::CompositeExplicitAutograd,
::reduce_scatter_tensor_coalesced));
m.def(
"wait_tensor(Tensor tensor) -> Tensor",
torch::dispatch(
c10::DispatchKey::CompositeExplicitAutograd, ::wait_tensor));
}

View File

@ -0,0 +1,12 @@
#include <torch/csrc/distributed/c10d/ProcessGroup.hpp>
namespace c10d_functional {
void register_process_group(
const std::string& tag,
c10::intrusive_ptr<c10d::ProcessGroup> pg);
c10::intrusive_ptr<c10d::ProcessGroup> resolve_process_group(
const std::string& tag);
} // namespace c10d_functional

View File

@ -0,0 +1,61 @@
#include <torch/csrc/distributed/c10d/GroupRegistry.hpp>
#include <torch/csrc/distributed/c10d/RankLocal.hpp>
namespace {
// Each rank operates on a different `c10d::ProcessGroup` instance for the same
// logical process group. Use `RankLocal<GroupRegistry>::get()` to ensure each
// rank gets a unique registry.
class GroupRegistry {
public:
void register_group(
const std::string& group_name,
c10::intrusive_ptr<c10d::ProcessGroup> group) {
std::unique_lock write_lock(lock_);
auto [_, inserted] = registry_.emplace(group_name, group);
TORCH_CHECK(
inserted,
"A process group is already registered under the name",
group_name);
}
c10::intrusive_ptr<c10d::ProcessGroup> resolve_group(
const std::string& group_name) {
std::shared_lock read_lock(lock_);
auto it = registry_.find(group_name);
TORCH_CHECK(
it != registry_.end(),
"Could not resolve the process group registered under the name ",
group_name);
auto group = it->second.lock();
TORCH_CHECK(
group != nullptr,
"Process group registered under the name ",
group_name,
" has already been destroyed.");
return group;
}
private:
std::map<std::string, c10::weak_intrusive_ptr<c10d::ProcessGroup>> registry_;
std::shared_mutex lock_;
};
} // namespace
namespace c10d {
void register_process_group(
const std::string& group_name,
c10::intrusive_ptr<c10d::ProcessGroup> group) {
RankLocal<::GroupRegistry>::get().register_group(group_name, group);
}
c10::intrusive_ptr<c10d::ProcessGroup> resolve_process_group(
const std::string& group_name) {
return RankLocal<::GroupRegistry>::get().resolve_group(group_name);
}
} // namespace c10d

View File

@ -0,0 +1,14 @@
#pragma once
#include <torch/csrc/distributed/c10d/ProcessGroup.hpp>
namespace c10d {
C10_EXPORT void register_process_group(
const std::string& group_name,
c10::intrusive_ptr<c10d::ProcessGroup> group);
C10_EXPORT c10::intrusive_ptr<c10d::ProcessGroup> resolve_process_group(
const std::string& group_name);
} // namespace c10d

View File

@ -171,4 +171,10 @@ void ProcessGroup::enableCollectivesTiming() {
}
}
void ProcessGroup::release_resources() {
store_.reset();
deviceTypeToBackend_.clear();
backendTypeToBackend_.clear();
}
} // namespace c10d

View File

@ -688,12 +688,14 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
void setGroupName(const std::string& name);
void enableCollectivesTiming();
void release_resources() override;
protected:
// Implementations of this interface need to call this to setup
// appropriate logging etc.
void init();
const c10::intrusive_ptr<c10d::Store> store_;
c10::intrusive_ptr<c10d::Store> store_;
const int rank_;
const int size_;
const c10::intrusive_ptr<Options> options_;

View File

@ -0,0 +1,73 @@
#pragma once
#include <shared_mutex>
#include <torch/csrc/autograd/function.h>
namespace c10d {
// `RankLocal` maintains a unique instance of T for each non-autograd thread.
// For non-autograd threads, `RankLocal<T>::get()` functions similar to
// thread_local. For autograd threads, `RankLocal<T>::get()` returns the
// instance of T corresponding to the enqueuing non-autograd thread. The
// mechanism allows for rank-specific context shared between forward and
// backward. It works for both the one-rank-per-process and one-rank-per-thread
// scenarios.
//
// NOTE: RankLocal doesn't make the underlying objects thread-safe.
template <typename T>
class RankLocal {
public:
RankLocal(const RankLocal&) = delete;
RankLocal& operator=(const RankLocal&) = delete;
static T& get() {
// Fast path: non-autograd threads can simply return
// the object reference cached in TLS.
if (cached_ != nullptr) {
return *cached_;
}
const auto node = torch::autograd::get_current_node();
auto fwd_thread_id = node == nullptr ? at::RecordFunction::currentThreadId()
: node->thread_id();
// Optimistically aquire the read lock first, since most likely we are in
// an autograd thread and the object has already been constructed.
{
std::shared_lock read_lock(lock_);
auto it = thread_id_to_rank_local_.find(fwd_thread_id);
if (it != thread_id_to_rank_local_.end()) {
// Cache for non-autograd threads
if (node == nullptr) {
cached_ = &it->second;
}
return it->second;
}
}
std::unique_lock write_lock(lock_);
auto [it, _] = thread_id_to_rank_local_.try_emplace(fwd_thread_id);
// Cache for non-autograd threads
if (node == nullptr) {
cached_ = &it->second;
}
return it->second;
}
private:
RankLocal(){};
thread_local static T* cached_;
static std::unordered_map<uint64_t, T> thread_id_to_rank_local_;
static std::shared_mutex lock_;
};
template <typename T>
thread_local T* RankLocal<T>::cached_ = nullptr;
template <typename T>
std::unordered_map<uint64_t, T> RankLocal<T>::thread_id_to_rank_local_;
template <typename T>
std::shared_mutex RankLocal<T>::lock_;
} // namespace c10d

View File

@ -3,6 +3,7 @@
#include <c10/util/intrusive_ptr.h>
#include <c10/util/string_view.h>
#include <torch/csrc/distributed/c10d/FileStore.hpp>
#include <torch/csrc/distributed/c10d/GroupRegistry.hpp>
#include <torch/csrc/distributed/c10d/TCPStore.hpp>
#include <torch/csrc/distributed/c10d/Utils.hpp>
#ifndef _WIN32
@ -800,6 +801,26 @@ This class does not support ``__members__`` property.)");
py::return_value_policy::copy, // seems safest
py::call_guard<py::gil_scoped_release>());
// TODO(yifu): _{register, resolve}_process_group currently only work for
// c10d_functional. Later, we'll unify the name -> group mapping across
// Python and C++, and spanning both functional and non-functional
// collectives.
module.def(
"_register_process_group",
[](const std::string& group_name,
c10::intrusive_ptr<::c10d::ProcessGroup> group) {
::c10d::register_process_group(group_name, group);
},
py::arg("group_name"),
py::arg("group"));
module.def(
"_resolve_process_group",
[](const std::string& group_name) {
return ::c10d::resolve_process_group(group_name);
},
py::arg("group_name"));
py::class_<::c10d::BroadcastOptions>(module, "BroadcastOptions")
.def(py::init<>())
.def_readwrite("rootRank", &::c10d::BroadcastOptions::rootRank)

View File

@ -539,7 +539,7 @@ def _reduce_scatter_tensor_meta(input, reduce_op, tag, rankset, group_size):
out_size[0] //= group_size
return input.new_empty(out_size)
def _all_reduce_coalesced_meta(self, reduceOp, tag, rankset, group_size):
def _all_reduce_coalesced_meta(self, *args):
return [torch.empty_like(t) for t in self]
def _reduce_scatter_tensor_coalesced_meta(inputs, reduceOp, tag, rankset, group_size):
@ -566,6 +566,27 @@ def _all_to_all_single_meta(input, output_split_sizes, input_split_sizes, tag, r
out_size[0] = sum(output_split_sizes)
return input.new_empty(out_size)
def _all_gather_into_tensor_native_meta(input, group_size, group_name):
shape = list(input.size())
shape[0] *= group_size
return input.new_empty(shape)
def _all_gather_into_tensor_coalesced_native_meta(inputs, group_size, group_name):
return [
_all_gather_into_tensor_native_meta(input, group_size, group_name)
for input in inputs
]
def _reduce_scatter_tensor_native_meta(input, group_size, group_name):
shape = list(input.size())
shape[0] //= group_size
return input.new_empty(shape)
def _reduce_scatter_tensor_coalesced_native_meta(inputs, group_size, group_name):
return [
_reduce_scatter_tensor_native_meta(input, group_size, group_name)
for input in inputs
]
def _register_ops():
ops_defs = [
@ -596,6 +617,15 @@ if not torch._running_with_deploy():
c10_lib = torch.library.Library("c10d_functional", "DEF")
c10_lib_impl = torch.library.Library("c10d_functional", "IMPL")
_register_ops()
_c10_lib_impl = torch.library.Library("_c10d_functional", "IMPL")
_c10_lib_impl.impl("all_reduce", _all_reduce_meta, "Meta")
_c10_lib_impl.impl("all_reduce_coalesced", _all_reduce_coalesced_meta, "Meta")
_c10_lib_impl.impl("wait_tensor", _wait_tensor_meta, "Meta")
_c10_lib_impl.impl("all_gather_into_tensor", _all_gather_into_tensor_native_meta, "Meta")
_c10_lib_impl.impl("all_gather_into_tensor_coalesced", _all_gather_into_tensor_coalesced_native_meta, "Meta")
_c10_lib_impl.impl("reduce_scatter_tensor", _reduce_scatter_tensor_native_meta, "Meta")
_c10_lib_impl.impl("reduce_scatter_tensor_coalesced", _reduce_scatter_tensor_coalesced_native_meta, "Meta")
else:
warnings.warn("PyTorch Distributed functional collectives do not work with torch::deploy.")