mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Native c10d_functional ops (#110570)
This PR introduces a native version of c10d_functional ops. The main goal is to add collective support in AOTInductor and allow collective ops to work in multi-threaded native runtimes. The native version also incorporated API improvements we wished to implement in Python c10d_functional: - Removed `ranks` and `group_size` from collective op signatures which were proven to be redundant. - Use tensor storage as opposed to `void*` to resolve in-flight work. The native process group registration/resolution mechansim is only used for native c10d_functional in the PR. It will become the single source of truth in upcoming PRs. The upcoming PRs will implement Inductor/AOTInductor support for c10d_functional, after which native c10d_functional will replace Python c10d_functional. Pull Request resolved: https://github.com/pytorch/pytorch/pull/110570 Approved by: https://github.com/wanchaol
This commit is contained in:
committed by
PyTorch MergeBot
parent
7fe51e3e9b
commit
ec18ef62f4
@ -520,7 +520,9 @@ libtorch_core_sources = sorted(
|
||||
libtorch_distributed_base_sources = [
|
||||
"torch/csrc/distributed/c10d/Backend.cpp",
|
||||
"torch/csrc/distributed/c10d/FileStore.cpp",
|
||||
"torch/csrc/distributed/c10d/Functional.cpp",
|
||||
"torch/csrc/distributed/c10d/GlooDeviceFactory.cpp",
|
||||
"torch/csrc/distributed/c10d/GroupRegistry.cpp",
|
||||
"torch/csrc/distributed/c10d/Ops.cpp",
|
||||
"torch/csrc/distributed/c10d/ParamCommsUtils.cpp",
|
||||
"torch/csrc/distributed/c10d/PrefixStore.cpp",
|
||||
|
187
test/distributed/test_c10d_functional_native.py
Normal file
187
test/distributed/test_c10d_functional_native.py
Normal file
@ -0,0 +1,187 @@
|
||||
# Owner(s): ["module: c10d"]
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.testing._internal.common_distributed import (
|
||||
MultiProcessTestCase,
|
||||
requires_nccl,
|
||||
skip_if_lt_x_gpu,
|
||||
)
|
||||
from torch.testing._internal.common_utils import run_tests
|
||||
|
||||
|
||||
if not dist.is_available():
|
||||
print("distributed package not available, skipping tests", file=sys.stderr)
|
||||
sys.exit(0)
|
||||
|
||||
|
||||
@requires_nccl()
|
||||
class C10DFunctionalNativeTest(MultiProcessTestCase):
|
||||
def setUp(self) -> None:
|
||||
super().setUp()
|
||||
self._spawn_processes()
|
||||
|
||||
@property
|
||||
def world_size(self) -> int:
|
||||
return 2
|
||||
|
||||
@property
|
||||
def ranks(self) -> List[int]:
|
||||
return list(range(self.world_size))
|
||||
|
||||
@property
|
||||
def device(self) -> torch.device:
|
||||
return torch.device(f"cuda:{self.rank}")
|
||||
|
||||
def _init_process_group(self) -> None:
|
||||
store = dist.FileStore(self.file_name, self.world_size)
|
||||
dist.init_process_group(
|
||||
backend="nccl",
|
||||
world_size=self.world_size,
|
||||
rank=self.rank,
|
||||
store=store,
|
||||
)
|
||||
torch._C._distributed_c10d._register_process_group("default", dist.group.WORLD)
|
||||
|
||||
@skip_if_lt_x_gpu(2)
|
||||
def test_all_reduce(self) -> None:
|
||||
self._init_process_group()
|
||||
|
||||
input = torch.full((10, 10), float(self.rank), device=self.device)
|
||||
output = torch.ops._c10d_functional.all_reduce(
|
||||
input,
|
||||
"avg",
|
||||
"default",
|
||||
)
|
||||
output = torch.ops._c10d_functional.wait_tensor(output)
|
||||
assert id(output) != id(input)
|
||||
expect = sum(self.ranks) / self.world_size
|
||||
assert output.eq(expect).all()
|
||||
|
||||
@skip_if_lt_x_gpu(2)
|
||||
def test_all_reduce_(self) -> None:
|
||||
self._init_process_group()
|
||||
|
||||
input = torch.full((10, 10), float(self.rank), device=self.device)
|
||||
output = torch.ops._c10d_functional.all_reduce_(
|
||||
input,
|
||||
"avg",
|
||||
"default",
|
||||
)
|
||||
output = torch.ops._c10d_functional.wait_tensor(output)
|
||||
assert id(output) == id(input)
|
||||
expect = sum(self.ranks) / self.world_size
|
||||
assert output.eq(expect).all()
|
||||
|
||||
@skip_if_lt_x_gpu(2)
|
||||
def test_all_reduce_coalesced(self) -> None:
|
||||
self._init_process_group()
|
||||
|
||||
inputs = [
|
||||
torch.full((i, i), float(self.rank * i), device=self.device)
|
||||
for i in range(10)
|
||||
]
|
||||
outputs = torch.ops._c10d_functional.all_reduce_coalesced(
|
||||
inputs,
|
||||
"avg",
|
||||
"default",
|
||||
)
|
||||
for i, (output, input) in enumerate(zip(outputs, inputs)):
|
||||
output = torch.ops._c10d_functional.wait_tensor(output)
|
||||
assert id(output) != id(input)
|
||||
assert output.eq(sum(self.ranks) / self.world_size * i).all()
|
||||
|
||||
@skip_if_lt_x_gpu(2)
|
||||
def test_all_reduce_coalesced_(self) -> None:
|
||||
self._init_process_group()
|
||||
|
||||
inputs = [
|
||||
torch.full((i, i), float(self.rank * i), device=self.device)
|
||||
for i in range(10)
|
||||
]
|
||||
outputs = torch.ops._c10d_functional.all_reduce_coalesced_(
|
||||
inputs,
|
||||
"avg",
|
||||
"default",
|
||||
)
|
||||
for i, (output, input) in enumerate(zip(outputs, inputs)):
|
||||
output = torch.ops._c10d_functional.wait_tensor(output)
|
||||
assert id(output) == id(input)
|
||||
assert output.eq(sum(self.ranks) / self.world_size * i).all()
|
||||
|
||||
@skip_if_lt_x_gpu(2)
|
||||
def test_all_gather_into_tensor(self) -> None:
|
||||
self._init_process_group()
|
||||
|
||||
input = torch.full((10, 10), float(self.rank), device=self.device)
|
||||
output = torch.ops._c10d_functional.all_gather_into_tensor(
|
||||
input,
|
||||
self.world_size,
|
||||
"default",
|
||||
)
|
||||
output = torch.ops._c10d_functional.wait_tensor(output)
|
||||
expect = torch.cat(
|
||||
[
|
||||
torch.full((10, 10), float(rank), device=self.device)
|
||||
for rank in self.ranks
|
||||
]
|
||||
)
|
||||
assert torch.allclose(output, expect)
|
||||
assert output.eq(expect).all()
|
||||
|
||||
@skip_if_lt_x_gpu(2)
|
||||
def test_all_gather_into_tensor_coalesced(self) -> None:
|
||||
self._init_process_group()
|
||||
|
||||
inputs = [
|
||||
torch.full((10, 10), float(self.rank * i), device=self.device)
|
||||
for i in range(10)
|
||||
]
|
||||
outputs = torch.ops._c10d_functional.all_gather_into_tensor_coalesced(
|
||||
inputs,
|
||||
self.world_size,
|
||||
"default",
|
||||
)
|
||||
for i, output in enumerate(outputs):
|
||||
output = torch.ops._c10d_functional.wait_tensor(output)
|
||||
expect = torch.cat(
|
||||
[
|
||||
torch.full((10, 10), float(rank) * i, device=self.device)
|
||||
for rank in self.ranks
|
||||
]
|
||||
)
|
||||
assert output.eq(expect).all()
|
||||
|
||||
@skip_if_lt_x_gpu(2)
|
||||
def test_reduce_scatter_tensor(self) -> None:
|
||||
self._init_process_group()
|
||||
|
||||
input = torch.tensor(self.ranks, device=self.device)
|
||||
output = torch.ops._c10d_functional.reduce_scatter_tensor(
|
||||
input,
|
||||
"avg",
|
||||
self.world_size,
|
||||
"default",
|
||||
)
|
||||
output = torch.ops._c10d_functional.wait_tensor(output)
|
||||
assert output.eq(self.rank).all()
|
||||
|
||||
@skip_if_lt_x_gpu(2)
|
||||
def test_reduce_scatter_tensor_coalesced(self) -> None:
|
||||
self._init_process_group()
|
||||
|
||||
inputs = [torch.tensor(self.ranks, device=self.device) * i for i in range(10)]
|
||||
outputs = torch.ops._c10d_functional.reduce_scatter_tensor_coalesced(
|
||||
inputs,
|
||||
"avg",
|
||||
self.world_size,
|
||||
"default",
|
||||
)
|
||||
for i, output in enumerate(outputs):
|
||||
output = torch.ops._c10d_functional.wait_tensor(output)
|
||||
assert output.eq(self.rank * i).all()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
258
torch/csrc/distributed/c10d/Functional.cpp
Normal file
258
torch/csrc/distributed/c10d/Functional.cpp
Normal file
@ -0,0 +1,258 @@
|
||||
#include <torch/csrc/distributed/c10d/Functional.hpp>
|
||||
|
||||
#include <shared_mutex>
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/core/op_registration/op_registration.h>
|
||||
#include <c10/core/DispatchKey.h>
|
||||
#include <torch/csrc/autograd/function.h>
|
||||
#include <torch/csrc/distributed/c10d/GroupRegistry.hpp>
|
||||
#include <torch/csrc/distributed/c10d/ProcessGroup.hpp>
|
||||
#include <torch/csrc/distributed/c10d/RankLocal.hpp>
|
||||
|
||||
namespace {
|
||||
|
||||
class WorkRegistry {
|
||||
public:
|
||||
void register_work(
|
||||
const at::Tensor& tensor,
|
||||
c10::intrusive_ptr<c10d::Work> work) {
|
||||
const auto storage = tensor.storage().getWeakStorageImpl();
|
||||
std::unique_lock lock(lock_);
|
||||
auto [it, inserted] = registry_.emplace(storage, work);
|
||||
TORCH_CHECK(
|
||||
inserted || it->second != work,
|
||||
"The tensor storage is already associated with another work.");
|
||||
}
|
||||
|
||||
c10::intrusive_ptr<c10d::Work> pop_work(const at::Tensor& tensor) {
|
||||
const auto storage = tensor.storage().getWeakStorageImpl();
|
||||
std::unique_lock lock(lock_);
|
||||
auto it = registry_.find(storage);
|
||||
TORCH_CHECK(
|
||||
it != registry_.end(),
|
||||
"No pending collective is associated with the tensor storage. "
|
||||
"This typically means that the tensor is not a collective output, "
|
||||
"or the tensor has already been waited on.");
|
||||
auto work = it->second;
|
||||
registry_.erase(it);
|
||||
return work;
|
||||
}
|
||||
|
||||
private:
|
||||
std::unordered_map<
|
||||
c10::weak_intrusive_ptr<c10::StorageImpl>,
|
||||
c10::intrusive_ptr<c10d::Work>>
|
||||
registry_;
|
||||
std::mutex lock_;
|
||||
};
|
||||
|
||||
const std::unordered_map<std::string, c10d::ReduceOp> str_to_reduce_op = {
|
||||
{"sum", c10d::ReduceOp(c10d::ReduceOp::RedOpType::SUM)},
|
||||
{"avg", c10d::ReduceOp(c10d::ReduceOp::RedOpType::AVG)},
|
||||
{"product", c10d::ReduceOp(c10d::ReduceOp::RedOpType::PRODUCT)},
|
||||
{"min", c10d::ReduceOp(c10d::ReduceOp::RedOpType::MIN)},
|
||||
{"max", c10d::ReduceOp(c10d::ReduceOp::RedOpType::MAX)},
|
||||
{"band", c10d::ReduceOp(c10d::ReduceOp::RedOpType::BAND)},
|
||||
{"bor", c10d::ReduceOp(c10d::ReduceOp::RedOpType::BOR)},
|
||||
{"bxor", c10d::ReduceOp(c10d::ReduceOp::RedOpType::BXOR)},
|
||||
// TODO: support premul_sum
|
||||
// {"premul_sum", c10d::ReduceOp(c10d::ReduceOp::RedOpType::PREMUL_SUM)},
|
||||
{"unused", c10d::ReduceOp(c10d::ReduceOp::RedOpType::UNUSED)}};
|
||||
|
||||
c10d::ReduceOp to_reduce_op(const std::string& reduce_op) {
|
||||
auto it = str_to_reduce_op.find(reduce_op);
|
||||
TORCH_CHECK(
|
||||
it != str_to_reduce_op.end(), "Unrecognized reduce_op: ", reduce_op);
|
||||
return it->second;
|
||||
}
|
||||
|
||||
at::Tensor all_reduce_(
|
||||
at::Tensor input,
|
||||
const std::string& reduce_op,
|
||||
const std::string& group_name) {
|
||||
c10d::AllreduceOptions opts;
|
||||
opts.reduceOp = to_reduce_op(reduce_op);
|
||||
|
||||
std::vector<at::Tensor> inputs{input};
|
||||
auto group = c10d::resolve_process_group(group_name);
|
||||
auto work = group->allreduce(inputs, opts);
|
||||
c10d::RankLocal<WorkRegistry>::get().register_work(input, work);
|
||||
return input;
|
||||
}
|
||||
|
||||
at::Tensor all_reduce(
|
||||
const at::Tensor& input,
|
||||
const std::string& reduce_op,
|
||||
const std::string& group_name) {
|
||||
auto output = input.clone();
|
||||
return all_reduce_(output, reduce_op, group_name);
|
||||
}
|
||||
|
||||
std::vector<at::Tensor> all_reduce_coalesced_(
|
||||
std::vector<at::Tensor> inputs,
|
||||
const std::string& reduce_op,
|
||||
const std::string& group_name) {
|
||||
c10d::AllreduceCoalescedOptions opts;
|
||||
opts.reduceOp = to_reduce_op(reduce_op);
|
||||
|
||||
auto group = c10d::resolve_process_group(group_name);
|
||||
auto work = group->allreduce_coalesced(inputs, opts);
|
||||
for (const auto& tensor : inputs) {
|
||||
c10d::RankLocal<WorkRegistry>::get().register_work(tensor, work);
|
||||
}
|
||||
return inputs;
|
||||
}
|
||||
|
||||
std::vector<at::Tensor> all_reduce_coalesced(
|
||||
const std::vector<at::Tensor>& inputs,
|
||||
const std::string& reduce_op,
|
||||
const std::string& group_name) {
|
||||
std::vector<at::Tensor> outputs;
|
||||
for (const auto& tensor : inputs) {
|
||||
outputs.push_back(tensor.clone());
|
||||
}
|
||||
return all_reduce_coalesced_(outputs, reduce_op, group_name);
|
||||
}
|
||||
|
||||
at::Tensor allocate_all_gather_output(
|
||||
const at::Tensor& input,
|
||||
int64_t group_size) {
|
||||
auto output_size = input.sizes().vec();
|
||||
output_size[0] *= group_size;
|
||||
return at::empty(
|
||||
output_size,
|
||||
at::TensorOptions().dtype(input.dtype()).device(input.device()));
|
||||
}
|
||||
|
||||
std::vector<at::Tensor> all_gather_into_tensor_coalesced(
|
||||
const std::vector<at::Tensor>& inputs,
|
||||
const int64_t group_size,
|
||||
const std::string& group_name) {
|
||||
std::vector<at::Tensor> outputs;
|
||||
for (const auto& tensor : inputs) {
|
||||
outputs.push_back(allocate_all_gather_output(tensor, group_size));
|
||||
}
|
||||
|
||||
auto group = c10d::resolve_process_group(group_name);
|
||||
auto work = group->allgather_into_tensor_coalesced(
|
||||
outputs, const_cast<std::vector<at::Tensor>&>(inputs));
|
||||
for (const auto& tensor : outputs) {
|
||||
c10d::RankLocal<WorkRegistry>::get().register_work(tensor, work);
|
||||
}
|
||||
return outputs;
|
||||
}
|
||||
|
||||
at::Tensor all_gather_into_tensor(
|
||||
const at::Tensor& input,
|
||||
const int64_t group_size,
|
||||
const std::string& group_name) {
|
||||
std::vector<at::Tensor> inputs{input};
|
||||
return all_gather_into_tensor_coalesced(inputs, group_size, group_name)[0];
|
||||
}
|
||||
|
||||
at::Tensor allocate_reduce_scatter_output(
|
||||
const at::Tensor& input,
|
||||
const int64_t group_size) {
|
||||
auto output_size = input.sizes().vec();
|
||||
if (output_size[0] % group_size != 0) {
|
||||
LOG(WARNING) << "The first dimension of the reduce_scatter input ("
|
||||
<< output_size[0] << ") is not divisible by the group size ("
|
||||
<< group_size << ").";
|
||||
}
|
||||
output_size[0] /= group_size;
|
||||
return at::empty(
|
||||
output_size,
|
||||
at::TensorOptions().dtype(input.dtype()).device(input.device()));
|
||||
}
|
||||
|
||||
std::vector<at::Tensor> reduce_scatter_tensor_coalesced(
|
||||
const std::vector<at::Tensor>& inputs,
|
||||
const std::string& reduce_op,
|
||||
const int64_t group_size,
|
||||
const std::string& group_name) {
|
||||
c10d::ReduceScatterOptions opts;
|
||||
opts.reduceOp = to_reduce_op(reduce_op);
|
||||
std::vector<at::Tensor> outputs;
|
||||
for (const auto& tensor : inputs) {
|
||||
outputs.push_back(allocate_reduce_scatter_output(tensor, group_size));
|
||||
}
|
||||
|
||||
auto group = c10d::resolve_process_group(group_name);
|
||||
auto work = group->reduce_scatter_tensor_coalesced(
|
||||
outputs, const_cast<std::vector<at::Tensor>&>(inputs), opts);
|
||||
for (const auto& tensor : outputs) {
|
||||
c10d::RankLocal<WorkRegistry>::get().register_work(tensor, work);
|
||||
}
|
||||
return outputs;
|
||||
}
|
||||
|
||||
at::Tensor reduce_scatter_tensor(
|
||||
at::Tensor input,
|
||||
const std::string& reduce_op,
|
||||
const int64_t group_size,
|
||||
const std::string& group_name) {
|
||||
std::vector<at::Tensor> inputs{input};
|
||||
return reduce_scatter_tensor_coalesced(
|
||||
inputs, reduce_op, group_size, group_name)[0];
|
||||
}
|
||||
|
||||
at::Tensor wait_tensor(const at::Tensor& tensor) {
|
||||
auto work = c10d::RankLocal<WorkRegistry>::get().pop_work(tensor);
|
||||
work->wait();
|
||||
return tensor;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
TORCH_LIBRARY(_c10d_functional, m) {
|
||||
m.def(
|
||||
"all_reduce(Tensor input, str reduce_op, str group_name) -> Tensor",
|
||||
torch::dispatch(
|
||||
c10::DispatchKey::CompositeExplicitAutograd, ::all_reduce));
|
||||
|
||||
m.def(
|
||||
"all_reduce_(Tensor(a!) input, str reduce_op, str group_name) -> Tensor(a!)",
|
||||
torch::dispatch(
|
||||
c10::DispatchKey::CompositeExplicitAutograd, ::all_reduce_));
|
||||
|
||||
m.def(
|
||||
"all_reduce_coalesced(Tensor[] inputs, str reduce_op, str group_name) -> Tensor[]",
|
||||
torch::dispatch(
|
||||
c10::DispatchKey::CompositeExplicitAutograd, ::all_reduce_coalesced));
|
||||
|
||||
m.def(
|
||||
"all_reduce_coalesced_(Tensor[](a!) inputs, str reduce_op, str group_name) -> Tensor[](a!)",
|
||||
torch::dispatch(
|
||||
c10::DispatchKey::CompositeExplicitAutograd,
|
||||
::all_reduce_coalesced_));
|
||||
|
||||
m.def(
|
||||
"all_gather_into_tensor(Tensor input, int group_size, str group_name) -> Tensor",
|
||||
torch::dispatch(
|
||||
c10::DispatchKey::CompositeExplicitAutograd,
|
||||
::all_gather_into_tensor));
|
||||
|
||||
m.def(
|
||||
"all_gather_into_tensor_coalesced(Tensor[] inputs, int group_size, str group_name) -> Tensor[]",
|
||||
torch::dispatch(
|
||||
c10::DispatchKey::CompositeExplicitAutograd,
|
||||
::all_gather_into_tensor_coalesced));
|
||||
|
||||
m.def(
|
||||
"reduce_scatter_tensor(Tensor input, str reduce_op, int group_size, str group_name) -> Tensor",
|
||||
torch::dispatch(
|
||||
c10::DispatchKey::CompositeExplicitAutograd,
|
||||
::reduce_scatter_tensor));
|
||||
|
||||
m.def(
|
||||
"reduce_scatter_tensor_coalesced(Tensor[] inputs, str reduce_op, int group_size, str group_name) -> Tensor[]",
|
||||
torch::dispatch(
|
||||
c10::DispatchKey::CompositeExplicitAutograd,
|
||||
::reduce_scatter_tensor_coalesced));
|
||||
|
||||
m.def(
|
||||
"wait_tensor(Tensor tensor) -> Tensor",
|
||||
torch::dispatch(
|
||||
c10::DispatchKey::CompositeExplicitAutograd, ::wait_tensor));
|
||||
}
|
12
torch/csrc/distributed/c10d/Functional.hpp
Normal file
12
torch/csrc/distributed/c10d/Functional.hpp
Normal file
@ -0,0 +1,12 @@
|
||||
#include <torch/csrc/distributed/c10d/ProcessGroup.hpp>
|
||||
|
||||
namespace c10d_functional {
|
||||
|
||||
void register_process_group(
|
||||
const std::string& tag,
|
||||
c10::intrusive_ptr<c10d::ProcessGroup> pg);
|
||||
|
||||
c10::intrusive_ptr<c10d::ProcessGroup> resolve_process_group(
|
||||
const std::string& tag);
|
||||
|
||||
} // namespace c10d_functional
|
61
torch/csrc/distributed/c10d/GroupRegistry.cpp
Normal file
61
torch/csrc/distributed/c10d/GroupRegistry.cpp
Normal file
@ -0,0 +1,61 @@
|
||||
#include <torch/csrc/distributed/c10d/GroupRegistry.hpp>
|
||||
|
||||
#include <torch/csrc/distributed/c10d/RankLocal.hpp>
|
||||
|
||||
namespace {
|
||||
|
||||
// Each rank operates on a different `c10d::ProcessGroup` instance for the same
|
||||
// logical process group. Use `RankLocal<GroupRegistry>::get()` to ensure each
|
||||
// rank gets a unique registry.
|
||||
class GroupRegistry {
|
||||
public:
|
||||
void register_group(
|
||||
const std::string& group_name,
|
||||
c10::intrusive_ptr<c10d::ProcessGroup> group) {
|
||||
std::unique_lock write_lock(lock_);
|
||||
auto [_, inserted] = registry_.emplace(group_name, group);
|
||||
TORCH_CHECK(
|
||||
inserted,
|
||||
"A process group is already registered under the name",
|
||||
group_name);
|
||||
}
|
||||
|
||||
c10::intrusive_ptr<c10d::ProcessGroup> resolve_group(
|
||||
const std::string& group_name) {
|
||||
std::shared_lock read_lock(lock_);
|
||||
auto it = registry_.find(group_name);
|
||||
TORCH_CHECK(
|
||||
it != registry_.end(),
|
||||
"Could not resolve the process group registered under the name ",
|
||||
group_name);
|
||||
|
||||
auto group = it->second.lock();
|
||||
TORCH_CHECK(
|
||||
group != nullptr,
|
||||
"Process group registered under the name ",
|
||||
group_name,
|
||||
" has already been destroyed.");
|
||||
return group;
|
||||
}
|
||||
|
||||
private:
|
||||
std::map<std::string, c10::weak_intrusive_ptr<c10d::ProcessGroup>> registry_;
|
||||
std::shared_mutex lock_;
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
namespace c10d {
|
||||
|
||||
void register_process_group(
|
||||
const std::string& group_name,
|
||||
c10::intrusive_ptr<c10d::ProcessGroup> group) {
|
||||
RankLocal<::GroupRegistry>::get().register_group(group_name, group);
|
||||
}
|
||||
|
||||
c10::intrusive_ptr<c10d::ProcessGroup> resolve_process_group(
|
||||
const std::string& group_name) {
|
||||
return RankLocal<::GroupRegistry>::get().resolve_group(group_name);
|
||||
}
|
||||
|
||||
} // namespace c10d
|
14
torch/csrc/distributed/c10d/GroupRegistry.hpp
Normal file
14
torch/csrc/distributed/c10d/GroupRegistry.hpp
Normal file
@ -0,0 +1,14 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/csrc/distributed/c10d/ProcessGroup.hpp>
|
||||
|
||||
namespace c10d {
|
||||
|
||||
C10_EXPORT void register_process_group(
|
||||
const std::string& group_name,
|
||||
c10::intrusive_ptr<c10d::ProcessGroup> group);
|
||||
|
||||
C10_EXPORT c10::intrusive_ptr<c10d::ProcessGroup> resolve_process_group(
|
||||
const std::string& group_name);
|
||||
|
||||
} // namespace c10d
|
@ -171,4 +171,10 @@ void ProcessGroup::enableCollectivesTiming() {
|
||||
}
|
||||
}
|
||||
|
||||
void ProcessGroup::release_resources() {
|
||||
store_.reset();
|
||||
deviceTypeToBackend_.clear();
|
||||
backendTypeToBackend_.clear();
|
||||
}
|
||||
|
||||
} // namespace c10d
|
||||
|
@ -688,12 +688,14 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
|
||||
void setGroupName(const std::string& name);
|
||||
void enableCollectivesTiming();
|
||||
|
||||
void release_resources() override;
|
||||
|
||||
protected:
|
||||
// Implementations of this interface need to call this to setup
|
||||
// appropriate logging etc.
|
||||
void init();
|
||||
|
||||
const c10::intrusive_ptr<c10d::Store> store_;
|
||||
c10::intrusive_ptr<c10d::Store> store_;
|
||||
const int rank_;
|
||||
const int size_;
|
||||
const c10::intrusive_ptr<Options> options_;
|
||||
|
73
torch/csrc/distributed/c10d/RankLocal.hpp
Normal file
73
torch/csrc/distributed/c10d/RankLocal.hpp
Normal file
@ -0,0 +1,73 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <shared_mutex>
|
||||
|
||||
#include <torch/csrc/autograd/function.h>
|
||||
|
||||
namespace c10d {
|
||||
|
||||
// `RankLocal` maintains a unique instance of T for each non-autograd thread.
|
||||
// For non-autograd threads, `RankLocal<T>::get()` functions similar to
|
||||
// thread_local. For autograd threads, `RankLocal<T>::get()` returns the
|
||||
// instance of T corresponding to the enqueuing non-autograd thread. The
|
||||
// mechanism allows for rank-specific context shared between forward and
|
||||
// backward. It works for both the one-rank-per-process and one-rank-per-thread
|
||||
// scenarios.
|
||||
//
|
||||
// NOTE: RankLocal doesn't make the underlying objects thread-safe.
|
||||
template <typename T>
|
||||
class RankLocal {
|
||||
public:
|
||||
RankLocal(const RankLocal&) = delete;
|
||||
RankLocal& operator=(const RankLocal&) = delete;
|
||||
|
||||
static T& get() {
|
||||
// Fast path: non-autograd threads can simply return
|
||||
// the object reference cached in TLS.
|
||||
if (cached_ != nullptr) {
|
||||
return *cached_;
|
||||
}
|
||||
const auto node = torch::autograd::get_current_node();
|
||||
auto fwd_thread_id = node == nullptr ? at::RecordFunction::currentThreadId()
|
||||
: node->thread_id();
|
||||
// Optimistically aquire the read lock first, since most likely we are in
|
||||
// an autograd thread and the object has already been constructed.
|
||||
{
|
||||
std::shared_lock read_lock(lock_);
|
||||
auto it = thread_id_to_rank_local_.find(fwd_thread_id);
|
||||
if (it != thread_id_to_rank_local_.end()) {
|
||||
// Cache for non-autograd threads
|
||||
if (node == nullptr) {
|
||||
cached_ = &it->second;
|
||||
}
|
||||
return it->second;
|
||||
}
|
||||
}
|
||||
|
||||
std::unique_lock write_lock(lock_);
|
||||
auto [it, _] = thread_id_to_rank_local_.try_emplace(fwd_thread_id);
|
||||
// Cache for non-autograd threads
|
||||
if (node == nullptr) {
|
||||
cached_ = &it->second;
|
||||
}
|
||||
return it->second;
|
||||
}
|
||||
|
||||
private:
|
||||
RankLocal(){};
|
||||
thread_local static T* cached_;
|
||||
static std::unordered_map<uint64_t, T> thread_id_to_rank_local_;
|
||||
static std::shared_mutex lock_;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
thread_local T* RankLocal<T>::cached_ = nullptr;
|
||||
|
||||
template <typename T>
|
||||
std::unordered_map<uint64_t, T> RankLocal<T>::thread_id_to_rank_local_;
|
||||
|
||||
template <typename T>
|
||||
std::shared_mutex RankLocal<T>::lock_;
|
||||
|
||||
} // namespace c10d
|
@ -3,6 +3,7 @@
|
||||
#include <c10/util/intrusive_ptr.h>
|
||||
#include <c10/util/string_view.h>
|
||||
#include <torch/csrc/distributed/c10d/FileStore.hpp>
|
||||
#include <torch/csrc/distributed/c10d/GroupRegistry.hpp>
|
||||
#include <torch/csrc/distributed/c10d/TCPStore.hpp>
|
||||
#include <torch/csrc/distributed/c10d/Utils.hpp>
|
||||
#ifndef _WIN32
|
||||
@ -800,6 +801,26 @@ This class does not support ``__members__`` property.)");
|
||||
py::return_value_policy::copy, // seems safest
|
||||
py::call_guard<py::gil_scoped_release>());
|
||||
|
||||
// TODO(yifu): _{register, resolve}_process_group currently only work for
|
||||
// c10d_functional. Later, we'll unify the name -> group mapping across
|
||||
// Python and C++, and spanning both functional and non-functional
|
||||
// collectives.
|
||||
module.def(
|
||||
"_register_process_group",
|
||||
[](const std::string& group_name,
|
||||
c10::intrusive_ptr<::c10d::ProcessGroup> group) {
|
||||
::c10d::register_process_group(group_name, group);
|
||||
},
|
||||
py::arg("group_name"),
|
||||
py::arg("group"));
|
||||
|
||||
module.def(
|
||||
"_resolve_process_group",
|
||||
[](const std::string& group_name) {
|
||||
return ::c10d::resolve_process_group(group_name);
|
||||
},
|
||||
py::arg("group_name"));
|
||||
|
||||
py::class_<::c10d::BroadcastOptions>(module, "BroadcastOptions")
|
||||
.def(py::init<>())
|
||||
.def_readwrite("rootRank", &::c10d::BroadcastOptions::rootRank)
|
||||
|
@ -539,7 +539,7 @@ def _reduce_scatter_tensor_meta(input, reduce_op, tag, rankset, group_size):
|
||||
out_size[0] //= group_size
|
||||
return input.new_empty(out_size)
|
||||
|
||||
def _all_reduce_coalesced_meta(self, reduceOp, tag, rankset, group_size):
|
||||
def _all_reduce_coalesced_meta(self, *args):
|
||||
return [torch.empty_like(t) for t in self]
|
||||
|
||||
def _reduce_scatter_tensor_coalesced_meta(inputs, reduceOp, tag, rankset, group_size):
|
||||
@ -566,6 +566,27 @@ def _all_to_all_single_meta(input, output_split_sizes, input_split_sizes, tag, r
|
||||
out_size[0] = sum(output_split_sizes)
|
||||
return input.new_empty(out_size)
|
||||
|
||||
def _all_gather_into_tensor_native_meta(input, group_size, group_name):
|
||||
shape = list(input.size())
|
||||
shape[0] *= group_size
|
||||
return input.new_empty(shape)
|
||||
|
||||
def _all_gather_into_tensor_coalesced_native_meta(inputs, group_size, group_name):
|
||||
return [
|
||||
_all_gather_into_tensor_native_meta(input, group_size, group_name)
|
||||
for input in inputs
|
||||
]
|
||||
|
||||
def _reduce_scatter_tensor_native_meta(input, group_size, group_name):
|
||||
shape = list(input.size())
|
||||
shape[0] //= group_size
|
||||
return input.new_empty(shape)
|
||||
|
||||
def _reduce_scatter_tensor_coalesced_native_meta(inputs, group_size, group_name):
|
||||
return [
|
||||
_reduce_scatter_tensor_native_meta(input, group_size, group_name)
|
||||
for input in inputs
|
||||
]
|
||||
|
||||
def _register_ops():
|
||||
ops_defs = [
|
||||
@ -596,6 +617,15 @@ if not torch._running_with_deploy():
|
||||
c10_lib = torch.library.Library("c10d_functional", "DEF")
|
||||
c10_lib_impl = torch.library.Library("c10d_functional", "IMPL")
|
||||
_register_ops()
|
||||
|
||||
_c10_lib_impl = torch.library.Library("_c10d_functional", "IMPL")
|
||||
_c10_lib_impl.impl("all_reduce", _all_reduce_meta, "Meta")
|
||||
_c10_lib_impl.impl("all_reduce_coalesced", _all_reduce_coalesced_meta, "Meta")
|
||||
_c10_lib_impl.impl("wait_tensor", _wait_tensor_meta, "Meta")
|
||||
_c10_lib_impl.impl("all_gather_into_tensor", _all_gather_into_tensor_native_meta, "Meta")
|
||||
_c10_lib_impl.impl("all_gather_into_tensor_coalesced", _all_gather_into_tensor_coalesced_native_meta, "Meta")
|
||||
_c10_lib_impl.impl("reduce_scatter_tensor", _reduce_scatter_tensor_native_meta, "Meta")
|
||||
_c10_lib_impl.impl("reduce_scatter_tensor_coalesced", _reduce_scatter_tensor_coalesced_native_meta, "Meta")
|
||||
else:
|
||||
warnings.warn("PyTorch Distributed functional collectives do not work with torch::deploy.")
|
||||
|
||||
|
Reference in New Issue
Block a user