mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
This PR fixes more missing-prototypes violations in the torch_cpu source following PRs #100053, #100147 and #100245 Pull Request resolved: https://github.com/pytorch/pytorch/pull/100849 Approved by: https://github.com/albanD
157 lines
4.6 KiB
C++
157 lines
4.6 KiB
C++
#include <ATen/ThreadLocalState.h>
|
|
#include <torch/csrc/distributed/c10d/ProcessGroup.hpp>
|
|
|
|
#include <c10/util/Logging.h>
|
|
#include <fmt/format.h>
|
|
|
|
#include <torch/csrc/distributed/c10d/PrefixStore.hpp>
|
|
#include <torch/csrc/distributed/c10d/ProcessGroupGloo.hpp>
|
|
#include <torch/csrc/distributed/c10d/ProcessGroupMPI.hpp>
|
|
#include <torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp>
|
|
#include <torch/csrc/distributed/c10d/ProcessGroupUCC.hpp>
|
|
#include <torch/csrc/distributed/c10d/ProcessGroupWrapper.hpp>
|
|
|
|
namespace c10d {
|
|
|
|
static ProcessGroup::BackendType strToBackendType(std::string backend) {
|
|
if (backend == "undefined") {
|
|
return ProcessGroup::BackendType::UNDEFINED;
|
|
} else if (backend == "gloo") {
|
|
return ProcessGroup::BackendType::GLOO;
|
|
} else if (backend == "nccl") {
|
|
return ProcessGroup::BackendType::NCCL;
|
|
} else if (backend == "ucc") {
|
|
return ProcessGroup::BackendType::UCC;
|
|
} else if (backend == "mpi") {
|
|
return ProcessGroup::BackendType::MPI;
|
|
} else {
|
|
return ProcessGroup::BackendType::CUSTOM;
|
|
}
|
|
}
|
|
|
|
static std::string backendTypeToStr(ProcessGroup::BackendType backendType) {
|
|
switch (backendType) {
|
|
case ProcessGroup::BackendType::UNDEFINED:
|
|
return "undefined";
|
|
case ProcessGroup::BackendType::GLOO:
|
|
return "gloo";
|
|
case ProcessGroup::BackendType::NCCL:
|
|
return "nccl";
|
|
case ProcessGroup::BackendType::UCC:
|
|
return "ucc";
|
|
case ProcessGroup::BackendType::MPI:
|
|
return "mpi";
|
|
case ProcessGroup::BackendType::CUSTOM:
|
|
return "custom";
|
|
default:
|
|
TORCH_INTERNAL_ASSERT(false, "Unknown backend type");
|
|
}
|
|
}
|
|
|
|
std::string opTypeToString(OpType opType) {
|
|
switch (opType) {
|
|
case OpType::BROADCAST:
|
|
return "BROADCAST";
|
|
case OpType::ALLREDUCE:
|
|
return "ALLREDUCE";
|
|
case OpType::ALLREDUCE_COALESCED:
|
|
return "ALLREDUCE_COALESCED";
|
|
case OpType::REDUCE:
|
|
return "REDUCE";
|
|
case OpType::ALLGATHER:
|
|
return "ALLGATHER";
|
|
case OpType::_ALLGATHER_BASE:
|
|
return "_ALLGATHER_BASE";
|
|
case OpType::ALLGATHER_COALESCED:
|
|
return "ALLGATHER_COALESCED";
|
|
case OpType::GATHER:
|
|
return "GATHER";
|
|
case OpType::SCATTER:
|
|
return "SCATTER";
|
|
case OpType::REDUCE_SCATTER:
|
|
return "REDUCE_SCATTER";
|
|
case OpType::ALLTOALL_BASE:
|
|
return "ALLTOALL_BASE";
|
|
case OpType::ALLTOALL:
|
|
return "ALLTOALL";
|
|
case OpType::SEND:
|
|
return "SEND";
|
|
case OpType::RECV:
|
|
return "RECV";
|
|
case OpType::RECVANYSOURCE:
|
|
return "RECVANYSOURCE";
|
|
case OpType::BARRIER:
|
|
return "BARRIER";
|
|
case OpType::UNKNOWN:
|
|
return "UNKNOWN";
|
|
case OpType::_REDUCE_SCATTER_BASE:
|
|
return "_REDUCE_SCATTER_BASE";
|
|
default:
|
|
TORCH_INTERNAL_ASSERT(false, "Unknown op type!");
|
|
}
|
|
return "UNKNOWN";
|
|
}
|
|
|
|
bool isP2POp(OpType opType, bool batchP2P /*= false*/) {
|
|
if (batchP2P)
|
|
return false;
|
|
return opType == OpType::SEND || opType == OpType::RECV ||
|
|
opType == OpType::RECVANYSOURCE;
|
|
}
|
|
|
|
c10::intrusive_ptr<Backend> ProcessGroup::getBackend(
|
|
c10::DeviceType deviceType) {
|
|
// If there is a backend associated with this device type then return it
|
|
if (deviceTypeToBackend_.find(deviceType) != deviceTypeToBackend_.end()) {
|
|
return deviceTypeToBackend_.at(deviceType);
|
|
}
|
|
|
|
// Get the backend type associated with the device
|
|
ProcessGroup::BackendType backendType;
|
|
try {
|
|
backendType = deviceTypeToBackendType_.at(deviceType);
|
|
} catch (const std::out_of_range& e) {
|
|
TORCH_CHECK(
|
|
false, "No backend type associated with device type ", deviceType);
|
|
}
|
|
|
|
// Check if the backend has already been initialized
|
|
if (backendTypeToBackend_.find(backendType) != backendTypeToBackend_.end()) {
|
|
auto backend = backendTypeToBackend_.at(backendType);
|
|
deviceTypeToBackend_[deviceType] = backend;
|
|
return backend;
|
|
}
|
|
|
|
TORCH_CHECK(
|
|
false,
|
|
"Could not retrieve or create the backend ",
|
|
backendType,
|
|
" for device type ",
|
|
deviceType);
|
|
}
|
|
|
|
ProcessGroup::ProcessGroup(
|
|
const c10::intrusive_ptr<::c10d::Store>& store,
|
|
int rank,
|
|
int size,
|
|
c10::intrusive_ptr<Options> options)
|
|
: store_(store),
|
|
rank_(rank),
|
|
size_(size),
|
|
options_(options),
|
|
backendType_(strToBackendType(options->backend)),
|
|
dist_debug_level_(debug_level()) {
|
|
C10_LOG_API_USAGE_ONCE("c10d.process_group");
|
|
}
|
|
|
|
ProcessGroup::ProcessGroup(int rank, int size)
|
|
: rank_(rank), size_(size), backendType_(BackendType::UNDEFINED) {}
|
|
|
|
ProcessGroup::~ProcessGroup() = default;
|
|
|
|
void ProcessGroup::init() {
|
|
C10_LOG_API_USAGE_ONCE(
|
|
fmt::format("c10d.process_group_{}", getBackendName()));
|
|
}
|
|
} // namespace c10d
|