mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 21:49:24 +08:00
We introduced the dispatchable backend for a ProcessGroup and collective in https://github.com/pytorch/pytorch/issues/86225. This PR is a follow-up cleanup to clean up the option of a ProcessGroup and ask users to either set timeout or backend later on or directly create backend after creating a PG. Also PGNCCL is using option class from ProcessGroup but we actually should use Option from backend class. So this PR is to make the type or name to be aligned with what we are doing in cpp side. I don't change the signature for the public API, so they still use args named "pg_options" We need to make changes to the test to make it aligned with the change. This is try to reland D62008954 by fixing internal errors. Differential Revision: [D62483294](https://our.internmc.facebook.com/intern/diff/D62483294/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/135653 Approved by: https://github.com/wz337, https://github.com/H-Huang
161 lines
4.5 KiB
C++
161 lines
4.5 KiB
C++
#include <ATen/ThreadLocalState.h>
|
|
#include <torch/csrc/distributed/c10d/ProcessGroup.hpp>
|
|
|
|
#include <c10/util/Logging.h>
|
|
#include <fmt/format.h>
|
|
#include <string_view>
|
|
|
|
#include <torch/csrc/distributed/c10d/PrefixStore.hpp>
|
|
#include <torch/csrc/distributed/c10d/ProcessGroupGloo.hpp>
|
|
#include <torch/csrc/distributed/c10d/ProcessGroupMPI.hpp>
|
|
#include <torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp>
|
|
#include <torch/csrc/distributed/c10d/ProcessGroupUCC.hpp>
|
|
#include <torch/csrc/distributed/c10d/ProcessGroupWrapper.hpp>
|
|
|
|
namespace c10d {
|
|
|
|
std::string opTypeToString(OpType opType) {
|
|
switch (opType) {
|
|
case OpType::BROADCAST:
|
|
return "BROADCAST";
|
|
case OpType::ALLREDUCE:
|
|
return "ALLREDUCE";
|
|
case OpType::ALLREDUCE_COALESCED:
|
|
return "ALLREDUCE_COALESCED";
|
|
case OpType::REDUCE:
|
|
return "REDUCE";
|
|
case OpType::ALLGATHER:
|
|
return "ALLGATHER";
|
|
case OpType::_ALLGATHER_BASE:
|
|
return "_ALLGATHER_BASE";
|
|
case OpType::ALLGATHER_COALESCED:
|
|
return "ALLGATHER_COALESCED";
|
|
case OpType::GATHER:
|
|
return "GATHER";
|
|
case OpType::SCATTER:
|
|
return "SCATTER";
|
|
case OpType::REDUCE_SCATTER:
|
|
return "REDUCE_SCATTER";
|
|
case OpType::ALLTOALL_BASE:
|
|
return "ALLTOALL_BASE";
|
|
case OpType::ALLTOALL:
|
|
return "ALLTOALL";
|
|
case OpType::SEND:
|
|
return "SEND";
|
|
case OpType::RECV:
|
|
return "RECV";
|
|
case OpType::RECVANYSOURCE:
|
|
return "RECVANYSOURCE";
|
|
case OpType::BARRIER:
|
|
return "BARRIER";
|
|
case OpType::UNKNOWN:
|
|
return "UNKNOWN";
|
|
case OpType::_REDUCE_SCATTER_BASE:
|
|
return "_REDUCE_SCATTER_BASE";
|
|
case OpType::COALESCED:
|
|
return "COALESCED";
|
|
case OpType::_ALLREDUCE_SPARSE:
|
|
return "_ALLREDUCE_SPARSE";
|
|
default:
|
|
TORCH_INTERNAL_ASSERT(false, "Unknown op type!");
|
|
}
|
|
return "UNKNOWN";
|
|
}
|
|
|
|
bool isP2POp(OpType opType, bool batchP2P /*= false*/) {
|
|
if (batchP2P)
|
|
return false;
|
|
return opType == OpType::SEND || opType == OpType::RECV ||
|
|
opType == OpType::RECVANYSOURCE;
|
|
}
|
|
|
|
c10::intrusive_ptr<Backend> ProcessGroup::getBackend(
|
|
c10::DeviceType deviceType) {
|
|
// If there is a backend associated with this device type then return it
|
|
if (deviceTypeToBackend_.find(deviceType) != deviceTypeToBackend_.end()) {
|
|
return deviceTypeToBackend_.at(deviceType);
|
|
}
|
|
|
|
// Get the backend type associated with the device
|
|
ProcessGroup::BackendType backendType{ProcessGroup::BackendType::UNDEFINED};
|
|
try {
|
|
backendType = deviceTypeToBackendType_.at(deviceType);
|
|
} catch (const std::out_of_range& e) {
|
|
TORCH_CHECK(
|
|
false, "No backend type associated with device type ", deviceType);
|
|
}
|
|
|
|
// Check if the backend has already been initialized
|
|
if (backendTypeToBackend_.find(backendType) != backendTypeToBackend_.end()) {
|
|
auto backend = backendTypeToBackend_.at(backendType);
|
|
deviceTypeToBackend_[deviceType] = backend;
|
|
return backend;
|
|
}
|
|
|
|
TORCH_CHECK(
|
|
false,
|
|
"Could not retrieve or create the backend ",
|
|
backendType,
|
|
" for device type ",
|
|
deviceType);
|
|
}
|
|
|
|
ProcessGroup::ProcessGroup(
|
|
const c10::intrusive_ptr<::c10d::Store>& store,
|
|
int rank,
|
|
int size)
|
|
: store_(store),
|
|
rank_(rank),
|
|
size_(size),
|
|
backendType_(BackendType::UNDEFINED),
|
|
dist_debug_level_(debug_level()) {
|
|
C10_LOG_API_USAGE_ONCE("c10d.process_group");
|
|
}
|
|
|
|
ProcessGroup::ProcessGroup(int rank, int size)
|
|
: rank_(rank), size_(size), backendType_(BackendType::UNDEFINED) {}
|
|
|
|
ProcessGroup::~ProcessGroup() = default;
|
|
|
|
void ProcessGroup::init() {
|
|
C10_LOG_API_USAGE_ONCE(
|
|
fmt::format("c10d.process_group_{}", getBackendName()));
|
|
}
|
|
|
|
const std::string& ProcessGroup::getGroupName() const {
|
|
TORCH_CHECK(!deviceTypeToBackend_.empty(), "ProcessGroup name not set");
|
|
return deviceTypeToBackend_.begin()->second->getGroupUid();
|
|
}
|
|
|
|
void ProcessGroup::setGroupName(const std::string& name) {
|
|
for (auto& kv : deviceTypeToBackend_) {
|
|
kv.second->setGroupUid(name);
|
|
}
|
|
}
|
|
|
|
const std::string& ProcessGroup::getGroupDesc() const {
|
|
return pg_desc_;
|
|
}
|
|
|
|
void ProcessGroup::setGroupDesc(const std::string& name) {
|
|
pg_desc_ = name;
|
|
// Also set the group desc for all backends
|
|
for (auto& kv : deviceTypeToBackend_) {
|
|
kv.second->setGroupDesc(name);
|
|
}
|
|
}
|
|
|
|
void ProcessGroup::enableCollectivesTiming() {
|
|
for (auto& kv : deviceTypeToBackend_) {
|
|
kv.second->enableCollectivesTiming();
|
|
}
|
|
}
|
|
|
|
void ProcessGroup::release_resources() {
|
|
store_.reset();
|
|
deviceTypeToBackend_.clear();
|
|
backendTypeToBackend_.clear();
|
|
}
|
|
|
|
} // namespace c10d
|