mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[c10d] Prototype of group_split
for dist2 work (#157716)
This is to implement group_split as proposed in [docs.google.com/document/d/13R-1t_yESTvmAjcCN-wQjQQadIEu0JNIdS65uZawZzY/edit?tab=t.0#heading=h.3ctbqqopzc89](https://docs.google.com/document/d/13R-1t_yESTvmAjcCN-wQjQQadIEu0JNIdS65uZawZzY/edit?tab=t.0#heading=h.3ctbqqopzc89) Pull Request resolved: https://github.com/pytorch/pytorch/pull/157716 Approved by: https://github.com/d4l3k
This commit is contained in:
@ -28,7 +28,7 @@ class NCCLTestBase {
|
||||
|
||||
NCCLTestBase(NCCLTestBase&& other) noexcept = default;
|
||||
|
||||
std::shared_ptr<::c10d::ProcessGroupNCCL> getProcessGroup() {
|
||||
::c10::intrusive_ptr<::c10d::ProcessGroupNCCL> getProcessGroup() {
|
||||
return pg_;
|
||||
}
|
||||
|
||||
@ -39,7 +39,7 @@ class NCCLTestBase {
|
||||
void initialize(
|
||||
int rank,
|
||||
size_t size,
|
||||
std::optional<::std::shared_ptr<::c10d::ProcessGroupNCCL>> split_from =
|
||||
std::optional<::c10::intrusive_ptr<::c10d::ProcessGroupNCCL>> split_from =
|
||||
std::nullopt) {
|
||||
store_ = c10::make_intrusive<::c10d::FileStore>(path_, size);
|
||||
|
||||
@ -52,13 +52,13 @@ class NCCLTestBase {
|
||||
opts->split_color = ++color_;
|
||||
}
|
||||
#endif
|
||||
pg_ = std::make_unique<::c10d::ProcessGroupNCCL>(
|
||||
pg_ = c10::make_intrusive<::c10d::ProcessGroupNCCL>(
|
||||
store_, rank, size, std::move(opts));
|
||||
}
|
||||
|
||||
protected:
|
||||
std::string path_;
|
||||
std::shared_ptr<::c10d::ProcessGroupNCCL> pg_;
|
||||
::c10::intrusive_ptr<::c10d::ProcessGroupNCCL> pg_;
|
||||
std::chrono::milliseconds pgTimeout_;
|
||||
::c10::intrusive_ptr<::c10d::Store> store_;
|
||||
int color_{1};
|
||||
|
@ -201,6 +201,17 @@ class Dist2MultiProcessTestCase(MultiProcessTestCase):
|
||||
out_range = out[i * 10 : (i + 1) * 10]
|
||||
self.assertEqual(out_range, torch.full_like(out_range, i + 1))
|
||||
|
||||
def test_group_split(self) -> None:
|
||||
group = self.new_group()
|
||||
subgroup = group.split_group([0], timeout=timedelta(seconds=30))
|
||||
if self.rank == 0:
|
||||
assert subgroup is not None
|
||||
self.assertEqual(subgroup.size(), 1)
|
||||
backend = subgroup._get_backend(self.device)
|
||||
self.assertEqual(backend.options._timeout, timedelta(seconds=30))
|
||||
else:
|
||||
self.assertEqual(subgroup, None)
|
||||
|
||||
|
||||
class ProcessGroupGlooTest(Dist2MultiProcessTestCase):
|
||||
device = torch.device("cpu")
|
||||
|
@ -350,6 +350,13 @@ class ProcessGroup:
|
||||
) -> None: ...
|
||||
def rank(self) -> int: ...
|
||||
def size(self) -> int: ...
|
||||
def split_group(
|
||||
self,
|
||||
new_ranks: list[int],
|
||||
timeout: Optional[timedelta] = None,
|
||||
pg_options: Optional[Backend.Options] = None,
|
||||
group_desc: Optional[str] = None,
|
||||
) -> Optional[ProcessGroup]: ...
|
||||
def abort(self) -> None: ...
|
||||
def set_timeout(self, timeout: timedelta) -> None: ...
|
||||
def shutdown(self) -> None: ...
|
||||
|
@ -46,6 +46,7 @@ class TORCH_API Backend : public torch::CustomClassHolder {
|
||||
// backend name
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
|
||||
const std::string backend;
|
||||
std::string group_name;
|
||||
};
|
||||
|
||||
explicit Backend(int rank, int size);
|
||||
@ -105,6 +106,14 @@ class TORCH_API Backend : public torch::CustomClassHolder {
|
||||
TORCH_INTERNAL_ASSERT(false, "getBackendName is not implemented.");
|
||||
}
|
||||
|
||||
// Subclasses must override this method to return the backend name
|
||||
virtual c10::intrusive_ptr<Options> getBackendOptions() {
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
c10::str(
|
||||
"Backend ", getBackendName(), " does not implement endCoalescing"));
|
||||
}
|
||||
|
||||
virtual c10::intrusive_ptr<Work> broadcast(
|
||||
std::vector<at::Tensor>& /* tensors */,
|
||||
const BroadcastOptions& /* opts */ = BroadcastOptions()) {
|
||||
@ -379,6 +388,16 @@ class TORCH_API Backend : public torch::CustomClassHolder {
|
||||
" is missing implementation of enableCollectivesTiming.");
|
||||
}
|
||||
|
||||
virtual c10::intrusive_ptr<Backend> splitBackend(
|
||||
const std::vector<int>& ranks,
|
||||
const c10::intrusive_ptr<Options> opts) {
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"Backend ",
|
||||
getBackendName(),
|
||||
" is missing implementation of splitBackend.");
|
||||
}
|
||||
|
||||
bool hasHooks() const {
|
||||
return onCompletionHook_ != nullptr;
|
||||
}
|
||||
|
@ -573,6 +573,27 @@ size_t hashTensors(const std::vector<at::Tensor>& tensors) {
|
||||
return hash;
|
||||
}
|
||||
|
||||
// NCCL uses Non-negative int to represent in-group according to API
|
||||
// requirement. We take a list of ranks and generate a hash value based on the
|
||||
// list and ensure its range of 32-bit int.
|
||||
int genNcclSplitColor(const std::vector<int>& ranks) {
|
||||
// Combine the hash values using a simple reducer (std::hash + fold)
|
||||
std::size_t combined_hash = std::accumulate(
|
||||
ranks.begin(),
|
||||
ranks.end(),
|
||||
std::size_t(0),
|
||||
[](std::size_t acc, int rank) {
|
||||
return acc ^
|
||||
(std::hash<int>{}(rank) + 0x9e3779b9 + (acc << 6) + (acc >> 2));
|
||||
});
|
||||
|
||||
// max positive value of int32_t
|
||||
constexpr int32_t max_c_int = std::numeric_limits<int32_t>::max();
|
||||
int color = static_cast<int>(
|
||||
std::abs(static_cast<int64_t>(combined_hash)) % max_c_int);
|
||||
return color;
|
||||
}
|
||||
|
||||
// Default value: 30 minutes
|
||||
int nccl_nonblocking_timeout() {
|
||||
static int timeout = -2; // -2 means not initialized
|
||||
|
@ -231,6 +231,7 @@ static std::map<at::ScalarType, ncclDataType_t> ncclDataType = {
|
||||
};
|
||||
|
||||
TORCH_API size_t hashTensors(const std::vector<at::Tensor>& tensors);
|
||||
TORCH_API int genNcclSplitColor(const std::vector<int>& ranks);
|
||||
TORCH_API std::string getNcclVersion();
|
||||
TORCH_API std::tuple<int, int, int> getNcclVersionTuple();
|
||||
TORCH_API int getNcclVersionNumber();
|
||||
|
@ -4,6 +4,7 @@
|
||||
|
||||
#include <c10/util/Logging.h>
|
||||
#include <fmt/format.h>
|
||||
#include <fmt/ranges.h>
|
||||
#include <string_view>
|
||||
|
||||
#include <torch/csrc/distributed/c10d/PrefixStore.hpp>
|
||||
@ -158,6 +159,63 @@ void ProcessGroup::release_resources() {
|
||||
backendTypeToBackend_.clear();
|
||||
}
|
||||
|
||||
c10::intrusive_ptr<ProcessGroup> ProcessGroup::splitGroup(
|
||||
const std::vector<int>& ranks,
|
||||
const std::optional<std::chrono::milliseconds> timeout,
|
||||
const std::optional<c10::intrusive_ptr<Backend::Options>> opts,
|
||||
const std::optional<std::string>& desc) {
|
||||
TORCH_CHECK(
|
||||
ranks.size() > 0,
|
||||
"Split ranks cannot be empty. Please provide a non-empty list of ranks to split the group.");
|
||||
TORCH_CHECK(
|
||||
ranks.size() < static_cast<size_t>(size_),
|
||||
"the split group's size should be less than the world_size set by init_process_group");
|
||||
std::set<int> ranks_set(ranks.begin(), ranks.end());
|
||||
TORCH_CHECK(
|
||||
ranks_set.size() == ranks.size(),
|
||||
"Split ranks should not have duplicates. Please provide a list of unique ranks to split the group.");
|
||||
std::vector<int> sorted_ranks = ranks;
|
||||
std::sort(sorted_ranks.begin(), sorted_ranks.end());
|
||||
c10::intrusive_ptr<ProcessGroup> newGroup;
|
||||
// TODO: Figure out a better way for split group name.
|
||||
std::string groupName =
|
||||
c10::str(getGroupName(), ":split:", fmt::format("{}", sorted_ranks));
|
||||
for (const auto& pair : deviceTypeToBackendType_) {
|
||||
c10::DeviceType deviceType = pair.first;
|
||||
BackendType backendType = pair.second;
|
||||
|
||||
auto parentBackend = getBackend(deviceType);
|
||||
auto backendOpts =
|
||||
opts.has_value() ? opts.value() : parentBackend->getBackendOptions();
|
||||
backendOpts->group_name = groupName;
|
||||
backendOpts->timeout =
|
||||
timeout.has_value() ? timeout.value() : backendOpts->timeout;
|
||||
auto splitBackend = parentBackend->splitBackend(sorted_ranks, backendOpts);
|
||||
if (splitBackend == nullptr) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// TODO: Figure out a better way for split group desc.
|
||||
// TODO: We can add a new field in Backend::Options to specify the group
|
||||
// desc
|
||||
std::string groupDesc = desc.has_value()
|
||||
? desc.value()
|
||||
: c10::str(getGroupDesc(), ":split:", incrementSplitCount());
|
||||
splitBackend->setGroupDesc(groupDesc);
|
||||
|
||||
if (!newGroup) {
|
||||
newGroup = c10::make_intrusive<ProcessGroup>(
|
||||
store_->clone(), splitBackend->getRank(), splitBackend->getSize());
|
||||
newGroup->setDefaultBackend(backendType_);
|
||||
newGroup->setGroupName(groupName);
|
||||
newGroup->setGroupDesc(groupDesc);
|
||||
}
|
||||
newGroup->setBackend(deviceType, backendType, splitBackend);
|
||||
}
|
||||
|
||||
return newGroup;
|
||||
}
|
||||
|
||||
} // namespace c10d
|
||||
|
||||
namespace {
|
||||
|
@ -170,6 +170,10 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
|
||||
}
|
||||
}
|
||||
|
||||
int64_t incrementSplitCount() {
|
||||
return splitCounter_++;
|
||||
}
|
||||
|
||||
virtual void startCoalescing(c10::DeviceType deviceType) {
|
||||
// only nccl has implemented startCoalescing so only execute for nccl
|
||||
// backends
|
||||
@ -955,6 +959,14 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
|
||||
bound_device_id_ = device;
|
||||
}
|
||||
|
||||
// This creates a new subgroup using the specified ranks.
|
||||
// The current rank must be included in the list of new_ranks.
|
||||
virtual c10::intrusive_ptr<ProcessGroup> splitGroup(
|
||||
const std::vector<int>& ranks,
|
||||
const std::optional<std::chrono::milliseconds> timeout,
|
||||
const std::optional<c10::intrusive_ptr<Backend::Options>> opts,
|
||||
const std::optional<std::string>& groupDesc);
|
||||
|
||||
protected:
|
||||
// Implementations of this interface need to call this to setup
|
||||
// appropriate logging etc.
|
||||
@ -968,6 +980,7 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
|
||||
BackendType backendType_;
|
||||
std::string pg_desc_;
|
||||
int64_t splitCounter_;
|
||||
|
||||
// Debug level setting. It is parsed once when ProcessGroup is constructed and
|
||||
// remains the same across use of this process group.
|
||||
|
@ -697,6 +697,35 @@ const std::vector<uint64_t>& ProcessGroupGloo::groupRanks() const {
|
||||
return options_->global_ranks_in_group;
|
||||
}
|
||||
|
||||
c10::intrusive_ptr<Backend> ProcessGroupGloo::splitBackend(
|
||||
const std::vector<int>& ranks,
|
||||
const c10::intrusive_ptr<Backend::Options> opts) {
|
||||
auto it = std::find(ranks.begin(), ranks.end(), rank_);
|
||||
int groupRank;
|
||||
if (it == ranks.end()) {
|
||||
return nullptr;
|
||||
} else {
|
||||
groupRank = std::distance(ranks.begin(), it);
|
||||
}
|
||||
|
||||
auto glooOpts = c10::dynamic_intrusive_pointer_cast<Options>(opts);
|
||||
TORCH_CHECK(glooOpts != nullptr, "opts not a ProcessGroupGloo::Options.");
|
||||
|
||||
// TODO: we need to get rid of globalRanksInGroup eventually.
|
||||
std::vector<uint64_t> globalRanksInGroup;
|
||||
for (auto rank : ranks) {
|
||||
globalRanksInGroup.emplace_back(groupRanks()[rank]);
|
||||
}
|
||||
glooOpts->global_ranks_in_group = std::move(globalRanksInGroup);
|
||||
auto store = std::dynamic_pointer_cast<GlooStore>(store_);
|
||||
TORCH_CHECK(
|
||||
store != nullptr,
|
||||
"store inside ProcessGroupGloo not a ProcessGroupGloo::GlooStore.");
|
||||
auto pg = c10::make_intrusive<ProcessGroupGloo>(
|
||||
store->_getStore()->clone(), groupRank, ranks.size(), glooOpts);
|
||||
return c10::static_intrusive_pointer_cast<Backend>(pg);
|
||||
}
|
||||
|
||||
void ProcessGroupGloo::enqueue(c10::intrusive_ptr<AsyncWork> work) {
|
||||
std::unique_lock<std::mutex> lock(workMutex_);
|
||||
pgStatus_->lastEnqueuedSeq = static_cast<int64_t>(work->seq_);
|
||||
|
@ -188,6 +188,10 @@ class TORCH_API ProcessGroupGloo : public Backend {
|
||||
}
|
||||
#endif
|
||||
|
||||
const c10::intrusive_ptr<::c10d::Store>& _getStore() const {
|
||||
return store_;
|
||||
}
|
||||
|
||||
protected:
|
||||
c10::intrusive_ptr<::c10d::Store> store_;
|
||||
};
|
||||
@ -252,7 +256,6 @@ class TORCH_API ProcessGroupGloo : public Backend {
|
||||
}
|
||||
|
||||
std::vector<uint64_t> global_ranks_in_group;
|
||||
std::string group_name;
|
||||
std::vector<std::shared_ptr<::gloo::transport::Device>> devices;
|
||||
int threads;
|
||||
};
|
||||
@ -301,6 +304,14 @@ class TORCH_API ProcessGroupGloo : public Backend {
|
||||
}
|
||||
}
|
||||
|
||||
c10::intrusive_ptr<Backend::Options> getBackendOptions() override {
|
||||
return c10::static_intrusive_pointer_cast<Backend::Options>(options_);
|
||||
}
|
||||
|
||||
c10::intrusive_ptr<Backend> splitBackend(
|
||||
const std::vector<int>& ranks,
|
||||
const c10::intrusive_ptr<Backend::Options> opts) override;
|
||||
|
||||
const std::vector<uint64_t>& groupRanks() const;
|
||||
|
||||
c10::intrusive_ptr<Work> broadcast(
|
||||
|
@ -1311,6 +1311,45 @@ void ProcessGroupNCCL::enableCollectivesTiming() {
|
||||
enableTiming_.store(true);
|
||||
}
|
||||
|
||||
c10::intrusive_ptr<Backend> ProcessGroupNCCL::splitBackend(
|
||||
const std::vector<int>& ranks,
|
||||
const c10::intrusive_ptr<Backend::Options> opts) {
|
||||
auto deviceIdx = guessDeviceId();
|
||||
TORCH_CHECK(
|
||||
deviceIdx >= 0,
|
||||
"ProcessGroupNCCL::splitBackend: rank ",
|
||||
rank_,
|
||||
" has no device is bound to this rank.");
|
||||
auto device = at::Device(at::DeviceType::CUDA, deviceIdx);
|
||||
auto it = std::find(ranks.begin(), ranks.end(), rank_);
|
||||
int groupRank;
|
||||
if (it == ranks.end()) {
|
||||
// This rank is not in the new group, so no_color split should be called
|
||||
performNocolorSplit(device);
|
||||
return nullptr;
|
||||
} else {
|
||||
groupRank = std::distance(ranks.begin(), it);
|
||||
}
|
||||
|
||||
auto ncclOpts = c10::dynamic_intrusive_pointer_cast<Options>(opts);
|
||||
TORCH_CHECK(ncclOpts != nullptr, "opts not a ProcessGroupNCCL::Options.");
|
||||
|
||||
// TODO: we need to get rid of globalRanksInGroup eventually.
|
||||
std::vector<uint64_t> globalRanksInGroup;
|
||||
for (auto rank : ranks) {
|
||||
globalRanksInGroup.emplace_back(groupRanks()[rank]);
|
||||
}
|
||||
ncclOpts->split_from =
|
||||
c10::intrusive_ptr<ProcessGroupNCCL>::unsafe_reclaim_from_nonowning(this);
|
||||
ncclOpts->global_ranks_in_group = std::move(globalRanksInGroup);
|
||||
auto color = genNcclSplitColor(ranks);
|
||||
ncclOpts->split_color = color;
|
||||
auto pg = c10::make_intrusive<ProcessGroupNCCL>(
|
||||
store_->clone(), groupRank, ranks.size(), ncclOpts);
|
||||
pg->eagerConnectSingleDevice(device);
|
||||
return c10::static_intrusive_pointer_cast<Backend>(pg);
|
||||
}
|
||||
|
||||
bool ProcessGroupNCCL::waitForFutureOrTimeout(
|
||||
std::future<bool>& fut,
|
||||
const std::chrono::milliseconds& timeOutMilSec,
|
||||
|
@ -541,7 +541,7 @@ class TORCH_API ProcessGroupNCCL : public Backend {
|
||||
|
||||
// Optional "parent" backend and color to create communicators from
|
||||
// via `ncclCommSplit`
|
||||
std::shared_ptr<ProcessGroupNCCL> split_from;
|
||||
c10::intrusive_ptr<ProcessGroupNCCL> split_from;
|
||||
// Color to use for `ncclCommSplit`, values:
|
||||
// * Non-negative value: in group;
|
||||
// * NCCL_SPLIT_NOCOLOR (-1): not in group;
|
||||
@ -562,7 +562,6 @@ class TORCH_API ProcessGroupNCCL : public Backend {
|
||||
int split_color{-2};
|
||||
#endif
|
||||
std::vector<uint64_t> global_ranks_in_group;
|
||||
std::string group_name;
|
||||
};
|
||||
|
||||
// Helper class related to TORCH_NCCL_DESYNC_DEBUG
|
||||
@ -804,6 +803,10 @@ class TORCH_API ProcessGroupNCCL : public Backend {
|
||||
return options_;
|
||||
}
|
||||
|
||||
c10::intrusive_ptr<Backend::Options> getBackendOptions() override {
|
||||
return c10::static_intrusive_pointer_cast<Backend::Options>(options_);
|
||||
}
|
||||
|
||||
const std::string getBackendName() const override {
|
||||
return std::string(NCCL_BACKEND_NAME);
|
||||
}
|
||||
@ -972,6 +975,10 @@ class TORCH_API ProcessGroupNCCL : public Backend {
|
||||
|
||||
void enableCollectivesTiming() override;
|
||||
|
||||
c10::intrusive_ptr<Backend> splitBackend(
|
||||
const std::vector<int>& ranks,
|
||||
const c10::intrusive_ptr<Backend::Options> opts) override;
|
||||
|
||||
// Helper function for iteratively aborting communicators in the provided map
|
||||
void abortCommsFromMap(
|
||||
std::unordered_map<std::string, std::shared_ptr<NCCLComm>>& ncclCommsMap,
|
||||
|
@ -151,6 +151,21 @@ class PyProcessGroup : public ProcessGroup {
|
||||
group_desc);
|
||||
}
|
||||
|
||||
c10::intrusive_ptr<ProcessGroup> splitGroup(
|
||||
const std::vector<int>& ranks,
|
||||
const std::optional<std::chrono::milliseconds> timeout,
|
||||
const std::optional<c10::intrusive_ptr<Backend::Options>> opts,
|
||||
const std::optional<std::string>& group_desc) override {
|
||||
PYBIND11_OVERRIDE(
|
||||
c10::intrusive_ptr<ProcessGroup>, /* Return type */
|
||||
ProcessGroup, /* Parent class */
|
||||
splitGroup, /* Name of function in C++ */
|
||||
ranks,
|
||||
timeout,
|
||||
opts,
|
||||
group_desc);
|
||||
}
|
||||
|
||||
c10::intrusive_ptr<Work> allgather(
|
||||
std::vector<std::vector<at::Tensor>>& outputTensors,
|
||||
std::vector<at::Tensor>& inputTensors,
|
||||
|
@ -2063,6 +2063,14 @@ communication mechanism.
|
||||
.def("rank", &::c10d::ProcessGroup::getRank, R"(Get the rank of this process group.)")
|
||||
.def("size", &::c10d::ProcessGroup::getSize, R"(Get the size of this process group.)")
|
||||
.def("name", &::c10d::ProcessGroup::getBackendName, R"(Get the name of this process group.)")
|
||||
.def(
|
||||
"split_group",
|
||||
&::c10d::ProcessGroup::splitGroup,
|
||||
py::arg("ranks"),
|
||||
py::arg("timeout") = std::nullopt,
|
||||
py::arg("opts") = std::nullopt,
|
||||
py::arg("groupDesc") = std::nullopt,
|
||||
py::call_guard<py::gil_scoped_release>())
|
||||
.def(
|
||||
"abort",
|
||||
&::c10d::ProcessGroup::abort,
|
||||
|
Reference in New Issue
Block a user