[c10d] NCCL Process Group implementation (#8182)

* [c10d] Process Group NCCL implementation

* Addressed comments

* Added one missing return and clang format again

* Use cmake/Modules for everything and fix gloo build

* Fixed compiler warnings

* Deleted duplicated FindNCCL
This commit is contained in:
Teng Li
2018-06-08 10:33:27 -07:00
committed by Pieter Noordhuis
parent d301d9df7a
commit a994b432ee
15 changed files with 983 additions and 36 deletions

View File

@ -1,45 +1,49 @@
# Try to find the Gloo library and headers.
# Gloo_FOUND - system has Gloo lib
# Gloo_INCLUDE_DIRS - the Gloo include directory
# Gloo_LIBRARIES - libraries needed to use Gloo
# Gloo_LIBRARY/Gloo_NATIVE_LIBRARY - libraries needed to use Gloo
find_path(Gloo_INCLUDE_DIR
NAMES gloo/common/common.h
DOC "The directory where Gloo includes reside"
NAMES gloo/common/common.h
DOC "The directory where Gloo includes reside"
)
find_library(Gloo_NATIVE_LIBRARY
NAMES gloo
DOC "The Gloo library (without CUDA)"
NAMES gloo
DOC "The Gloo library (without CUDA)"
)
find_library(Gloo_CUDA_LIBRARY
NAMES gloo_cuda
DOC "The Gloo library (with CUDA)"
NAMES gloo_cuda
DOC "The Gloo library (with CUDA)"
)
set(Gloo_INCLUDE_DIRS ${Gloo_INCLUDE_DIR})
# use the CUDA library depending on the Gloo_USE_CUDA variable
if (DEFINED Gloo_USE_CUDA)
if (${Gloo_USE_CUDA})
set(Gloo_LIBRARY ${Gloo_CUDA_LIBRARY})
else()
set(Gloo_LIBRARY ${Gloo_NATIVE_LIBRARY})
endif()
if (${Gloo_USE_CUDA})
set(Gloo_LIBRARY ${Gloo_CUDA_LIBRARY})
set(Gloo_NATIVE_LIBRARY ${Gloo_NATIVE_LIBRARY})
else()
set(Gloo_LIBRARY ${Gloo_NATIVE_LIBRARY})
set(Gloo_NATIVE_LIBRARY ${Gloo_NATIVE_LIBRARY})
endif()
else()
# else try to use the CUDA library if found
if (${Gloo_CUDA_LIBRARY} STREQUAL "Gloo_CUDA_LIBRARY-NOTFOUND")
set(Gloo_LIBRARY ${Gloo_NATIVE_LIBRARY})
else()
set(Gloo_LIBRARY ${Gloo_CUDA_LIBRARY})
endif()
# else try to use the CUDA library if found
if (${Gloo_CUDA_LIBRARY} STREQUAL "Gloo_CUDA_LIBRARY-NOTFOUND")
set(Gloo_LIBRARY ${Gloo_NATIVE_LIBRARY})
set(Gloo_NATIVE_LIBRARY ${Gloo_NATIVE_LIBRARY})
else()
set(Gloo_LIBRARY ${Gloo_CUDA_LIBRARY})
set(Gloo_NATIVE_LIBRARY ${Gloo_NATIVE_LIBRARY})
endif()
endif()
include(FindPackageHandleStandardArgs)
find_package_handle_standard_args(Gloo
FOUND_VAR Gloo_FOUND
REQUIRED_VARS Gloo_INCLUDE_DIR Gloo_LIBRARY
FOUND_VAR Gloo_FOUND
REQUIRED_VARS Gloo_INCLUDE_DIR Gloo_LIBRARY
)
mark_as_advanced(Gloo_FOUND)

View File

@ -45,6 +45,15 @@ include(FindPackageHandleStandardArgs)
find_package_handle_standard_args(NCCL DEFAULT_MSG NCCL_INCLUDE_DIRS NCCL_LIBRARIES)
if(NCCL_FOUND)
set (NCCL_HEADER_FILE "${NCCL_INCLUDE_DIRS}/nccl.h")
message (STATUS "Determining NCCL version from the header file: ${NCCL_HEADER_FILE}")
file (STRINGS ${NCCL_HEADER_FILE} NCCL_MAJOR_VERSION_DEFINED
REGEX "^[ \t]*#define[ \t]+NCCL_MAJOR[ \t]+[0-9]+.*$" LIMIT_COUNT 1)
if (NCCL_MAJOR_VERSION_DEFINED)
string (REGEX REPLACE "^[ \t]*#define[ \t]+NCCL_MAJOR[ \t]+" ""
NCCL_MAJOR_VERSION ${NCCL_MAJOR_VERSION_DEFINED})
message (STATUS "NCCL_MAJOR_VERSION: ${NCCL_MAJOR_VERSION}")
endif ()
message(STATUS "Found NCCL (include: ${NCCL_INCLUDE_DIRS}, library: ${NCCL_LIBRARIES})")
mark_as_advanced(NCCL_ROOT_DIR NCCL_INCLUDE_DIRS NCCL_LIBRARIES)
endif()

View File

@ -1,12 +1,10 @@
cmake_minimum_required(VERSION 3.2 FATAL_ERROR)
# Find modules.
# Note: this does NOT include <root>/cmake/Modules, because that woud
# make find_package resolve to <root>cmake/Modules/FindGloo.cmake
# instead of resolving to ../tmp_install/share/cmake/Gloo.
list(APPEND CMAKE_MODULE_PATH
/usr/lib/x86_64-linux-gnu/
${CMAKE_CURRENT_SOURCE_DIR}/../../../cmake/public
${CMAKE_CURRENT_SOURCE_DIR}/../../../cmake/Modules
${CMAKE_CURRENT_SOURCE_DIR}/../../../cmake/Modules_CUDA_fix)
# Polyfill for upstream FindCUDA
@ -21,17 +19,39 @@ if(NOT Caffe2_FOUND)
endif()
find_package(Gloo REQUIRED)
if(NOT Gloo_FOUND)
if(Gloo_FOUND)
message(STATUS "Gloo_LIBRARY: ${Gloo_LIBRARY}")
message(STATUS "Gloo_NATIVE_LIBRARY: ${Gloo_NATIVE_LIBRARY}")
message(STATUS "Gloo_INCLUDE_DIR: ${Gloo_INCLUDE_DIR}")
else()
message(FATAL_ERROR "Gloo not found")
endif()
find_package(MPI)
if(MPI_FOUND)
MESSAGE(STATUS "MPI_INCLUDE_PATH: ${MPI_INCLUDE_PATH}")
MESSAGE(STATUS "MPI_LIBRARIES: ${MPI_LIBRARIES}")
MESSAGE(STATUS "MPIEXEC: ${MPIEXEC}")
message(STATUS "MPI_INCLUDE_PATH: ${MPI_INCLUDE_PATH}")
message(STATUS "MPI_LIBRARIES: ${MPI_LIBRARIES}")
message(STATUS "MPIEXEC: ${MPIEXEC}")
else()
MESSAGE(STATUS "Not able to find MPI, will compile c10d without MPI support")
message(STATUS "Not able to find MPI, will compile c10d without MPI support")
endif()
find_package(NCCL)
IF(NCCL_FOUND)
message(STATUS "NCCL_LIBRARIES: ${NCCL_LIBRARIES}")
message(STATUS "NCCL_INCLUDE_DIRS: ${NCCL_INCLUDE_DIRS}")
IF(NCCL_MAJOR_VERSION AND NOT (NCCL_MAJOR_VERSION LESS 2))
message(STATUS "NCCL Version 2 or higher found, will "
"compile with NCCL distributed backend")
SET(DISTRIBUTED_NCCL_FOUND TRUE)
else()
message(STATUS "Found NCCL, but the NCCL version is either not 2+ or not "
"determinable, will not compile with NCCL distributed "
"backend")
endif()
else()
message(STATUS "Not able to find NCCL, will not "
"compile with NCCL distributed backend")
endif()
find_package(CUDA REQUIRED)
@ -72,7 +92,7 @@ set(C10D_GLOO_SRCS
add_library(c10d_gloo ${C10D_GLOO_SRCS})
target_include_directories(c10d_gloo PUBLIC ${GLOO_INCLUDE_DIR})
target_link_libraries(c10d_gloo PUBLIC c10d gloo gloo_cuda)
target_link_libraries(c10d_gloo PUBLIC c10d ${Gloo_NATIVE_LIBRARY} ${Gloo_LIBRARY})
if(MPI_FOUND)
set(C10D_MPI_SRCS
@ -83,6 +103,15 @@ if(MPI_FOUND)
target_link_libraries(c10d_mpi PUBLIC c10d ${MPI_LIBRARIES})
endif()
if(DISTRIBUTED_NCCL_FOUND)
set(C10D_NCCL_SRCS
ProcessGroupNCCL.cpp
)
add_library(c10d_nccl ${C10D_NCCL_SRCS})
target_include_directories(c10d_nccl PUBLIC ${NCCL_INCLUDE_DIRS})
target_link_libraries(c10d_nccl PUBLIC c10d ${NCCL_LIBRARIES})
endif()
add_subdirectory(example)
enable_testing()

View File

@ -25,9 +25,9 @@ void CUDADevice::setDevice(int device) {
}
}
CUDAEvent CUDAEvent::create() {
CUDAEvent CUDAEvent::create(unsigned int flags) {
CUDAEvent event;
C10D_CUDA_CHECK(cudaEventCreate(&event.event_));
C10D_CUDA_CHECK(cudaEventCreateWithFlags(&event.event_, flags));
return event;
}

View File

@ -34,7 +34,7 @@ class CUDAEvent {
~CUDAEvent();
static CUDAEvent create();
static CUDAEvent create(unsigned int flags = cudaEventDefault);
// Must not be copyable.
CUDAEvent& operator=(const CUDAEvent&) = delete;

View File

@ -0,0 +1,64 @@
#pragma once
#include <nccl.h>
#include <memory>
#define C10D_NCCL_CHECK(cmd) \
do { \
ncclResult_t error = cmd; \
if (error != ncclSuccess) { \
std::string err = "NCCL error in: " + std::string(__FILE__) + ":" + \
std::to_string(__LINE__) + ", " + \
std::string(ncclGetErrorString(error)); \
throw std::runtime_error(err); \
} \
} while (0)
namespace c10d {
// RAII wrapper for NCCL communicator
class NCCLComm {
public:
explicit NCCLComm(ncclComm_t ncclComm) : ncclComm_(ncclComm) {}
NCCLComm() : NCCLComm(nullptr) {}
~NCCLComm() {
if (ncclComm_) {
C10D_NCCL_CHECK(ncclCommDestroy(ncclComm_));
}
}
static std::shared_ptr<NCCLComm> create(
int numRanks,
int rank,
ncclUniqueId commId) {
auto comm = std::make_shared<NCCLComm>();
C10D_NCCL_CHECK(
ncclCommInitRank(&(comm->ncclComm_), numRanks, commId, rank));
return comm;
}
// Must not be copyable
NCCLComm(const NCCLComm&) = delete;
NCCLComm& operator=(const NCCLComm&) = delete;
// Move constructable
NCCLComm(NCCLComm&& other) {
std::swap(ncclComm_, other.ncclComm_);
}
// Move assignable
NCCLComm& operator=(NCCLComm&& other) {
std::swap(ncclComm_, other.ncclComm_);
return *this;
}
ncclComm_t getNcclComm() {
return ncclComm_;
}
protected:
ncclComm_t ncclComm_;
};
} // namespace c10d

View File

@ -86,6 +86,8 @@ bool ProcessGroupMPI::WorkMPI::isSuccess() const {
return !workException_;
}
void ProcessGroupMPI::WorkMPI::synchronize() {}
bool ProcessGroupMPI::WorkMPI::wait() {
std::unique_lock<std::mutex> lock(workMutex_);
while (!completed_) {

View File

@ -75,6 +75,9 @@ class ProcessGroupMPI : public ProcessGroup {
// if false, the exception function can be called to get details.
bool isSuccess() const override;
// No op for the case of MPI
virtual void synchronize() override;
// Waits until request completes. Blocking operation
// Returns false if the work completed with an exception
bool wait() override;

View File

@ -0,0 +1,422 @@
#include "ProcessGroupNCCL.hpp"
#include "private/CUDAUtils.hpp"
#include <THC.h>
#include <map>
#include <unordered_set>
namespace c10d {
namespace {
// NCCL op mapping
std::map<ReduceOp, ncclRedOp_t> ncclOp = {
{ReduceOp::MIN, ncclMin},
{ReduceOp::MAX, ncclMax},
{ReduceOp::SUM, ncclSum},
{ReduceOp::PRODUCT, ncclProd},
};
// NCCL type typing
std::map<at::ScalarType, ncclDataType_t> ncclDataType = {
{at::kChar, ncclInt8},
{at::kByte, ncclUint8},
{at::kFloat, ncclFloat},
{at::kDouble, ncclDouble},
{at::kInt, ncclInt32},
{at::kLong, ncclInt64},
{at::kHalf, ncclHalf},
};
// Helper function that gets the data type and issues error if not supported
ncclDataType_t getNcclDataType(at::ScalarType type) {
try {
return ncclDataType.at(type);
} catch (std::out_of_range& e) {
throw std::runtime_error("Unsupported data type for NCCL process group");
}
}
// Get the deviceList String from the list of devices
std::string getKeyFromDevices(const std::vector<int>& devices) {
std::string deviceList;
for (auto device : devices) {
if (deviceList.empty()) {
deviceList = std::to_string(device);
} else {
deviceList += "," + std::to_string(device);
}
}
return deviceList;
}
// Get the list of devices from list of tensors
std::vector<int> getDevicesOfTensors(const std::vector<at::Tensor>& tensors) {
std::vector<int> res;
for (auto& tensor : tensors) {
res.push_back(tensor.get_device());
}
return res;
}
// Helper that lets the input ncclStreams to wait for the THC stream
void syncStreams(
THCState* thcState,
const std::vector<int>& devices,
std::vector<CUDAEvent>& ncclEvents,
std::vector<CUDAStream>& ncclStreams) {
CUDADevice gpuGuard;
for (size_t i = 0; i < devices.size(); ++i) {
gpuGuard.setDevice(devices[i]);
auto currentThcStream =
THCState_getCurrentStreamOnDevice(thcState, devices[i]);
CUDAStream& ncclStream = ncclStreams[i];
CUDAEvent& ncclEvent = ncclEvents[i];
C10D_CUDA_CHECK(cudaEventRecord(ncclEvent.getEvent(), currentThcStream));
C10D_CUDA_CHECK(
cudaStreamWaitEvent(ncclStream.getStream(), ncclEvent.getEvent(), 0));
}
}
} // namespace
ProcessGroupNCCL::WorkNCCL::WorkNCCL(const std::vector<int>& devices)
: devices_(devices) {
CUDADevice gpuGuard;
cudaEvents_.resize(devices.size());
// Now create the CUDA events
for (size_t i = 0; i < devices.size(); ++i) {
gpuGuard.setDevice(devices[i]);
cudaEvents_[i] = CUDAEvent::create(cudaEventDisableTiming);
}
}
ProcessGroupNCCL::WorkNCCL::~WorkNCCL() {}
// Check if the NCCL kernels are queued on the GPUs
bool ProcessGroupNCCL::WorkNCCL::isCompleted() const {
return true;
}
// Helper that checks if the NCCL kernels are completed on the GPUs
bool ProcessGroupNCCL::WorkNCCL::finishedGPUExecution() const {
CUDADevice gpuGuard;
for (size_t i = 0; i < devices_.size(); ++i) {
gpuGuard.setDevice(devices_[i]);
auto& cudaEvent = cudaEvents_[i];
// Checking the work's corresponding CUDA events' status
auto ret = cudaEventQuery(cudaEvent.getEvent());
if (ret != cudaSuccess && ret != cudaErrorNotReady) {
C10D_CUDA_CHECK(ret);
}
if (ret == cudaErrorNotReady) {
return false;
}
}
return true;
}
// Same as synchronize(), and will always return true
bool ProcessGroupNCCL::WorkNCCL::wait() {
synchronize();
return true;
}
// Waiting on the work's corresponding CUDA events
void ProcessGroupNCCL::WorkNCCL::synchronize() {
auto thcState = ::at::globalContext().lazyInitCUDA();
CUDADevice gpuGuard;
for (size_t i = 0; i < devices_.size(); ++i) {
gpuGuard.setDevice(devices_[i]);
auto thcStream = THCState_getCurrentStreamOnDevice(thcState, devices_[i]);
auto& cudaEvent = cudaEvents_[i];
// Let THC stream wait for the NCCL stream
C10D_CUDA_CHECK(cudaStreamWaitEvent(thcStream, cudaEvent.getEvent(), 0));
}
}
bool ProcessGroupNCCL::WorkNCCL::isSuccess() const {
return true;
}
const std::exception& ProcessGroupNCCL::WorkNCCL::exception() const {
throw std::runtime_error(
"exception() is not supported by NCCL process "
"group's work, since isSuccess() will always return true, and "
"isCompleted() and wait() will either succeed or throw");
}
ProcessGroupNCCL::ProcessGroupNCCL(
const std::shared_ptr<Store>& store,
int rank,
int size)
: ProcessGroup(rank, size), store_(store) {
C10D_CUDA_CHECK(cudaGetDeviceCount(&numGPUs_));
thcState_ = ::at::globalContext().lazyInitCUDA();
}
ProcessGroupNCCL::~ProcessGroupNCCL() {}
void ProcessGroupNCCL::broadcastUniqueNCCLId(
const std::string& devicesKey,
ncclUniqueId* ncclId) {
// Rank 0 writes to the store as bcast
if (rank_ == 0) {
auto ncclIdVal = std::vector<uint8_t>(
reinterpret_cast<uint8_t*>(ncclId),
reinterpret_cast<uint8_t*>(ncclId) + NCCL_UNIQUE_ID_BYTES);
store_->set(devicesKey, ncclIdVal);
// Other ranks get to the store
} else {
auto ncclIdVal = store_->get(devicesKey);
// Just a sanity check
if (ncclIdVal.size() != NCCL_UNIQUE_ID_BYTES) {
throw std::runtime_error(
"Unexpected NCCL unique ID length received "
"from the store");
}
// Now put the data back to the input pointer
memcpy(ncclId, ncclIdVal.data(), NCCL_UNIQUE_ID_BYTES);
}
}
std::vector<std::shared_ptr<NCCLComm>>& ProcessGroupNCCL::getNCCLComm(
const std::string& devicesKey,
const std::vector<int>& devices) {
// Sanity check
if (devicesKey.empty()) {
throw std::runtime_error(
"Not able to create/get the NCCL Communicator since "
"the GPU devices are not known");
}
if (devNCCLCommMap_.find(devicesKey) != devNCCLCommMap_.end()) {
// Reuse the cached communicator if there is one.
return devNCCLCommMap_[devicesKey];
}
// NCCL communicator not cached, create a new entry
std::vector<std::shared_ptr<NCCLComm>> ncclComms;
ncclComms.resize(devices.size());
// Create the unique NCCL ID and broadcast it
ncclUniqueId ncclId;
if (rank_ == 0) {
C10D_NCCL_CHECK(ncclGetUniqueId(&ncclId));
}
// Broadcast so that each process can have a unique NCCL ID
broadcastUniqueNCCLId(devicesKey, &ncclId);
CUDADevice gpuGuard;
std::vector<CUDAEvent> eventVal;
std::vector<CUDAStream> streamVal;
eventVal.resize(devices.size());
streamVal.resize(devices.size());
// Create the NCCL communicators for each GPU
C10D_NCCL_CHECK(ncclGroupStart());
for (size_t i = 0; i < devices.size(); ++i) {
// GPU world size and GPU rank
int numRanks = getSize() * devices.size();
int rank = getRank() * devices.size() + i;
gpuGuard.setDevice(devices[i]);
ncclComms[i] = NCCLComm::create(numRanks, rank, ncclId);
// Also create the NCCL streams and events
streamVal[i] = CUDAStream::create();
// Event created using cudaEventDisableTiming flag and not
// cudaEventBlockingSync flag will provide the best performance when used
// with cudaStreamWaitEvent() and cudaEventQuery(). Since we here don't
// measure the performance using cudaEvent, this should be set.
eventVal[i] = CUDAEvent::create(cudaEventDisableTiming);
}
C10D_NCCL_CHECK(ncclGroupEnd());
// Move the NCCL resource to cache
devNCCLCommMap_.emplace(devicesKey, std::move(ncclComms));
ncclStreams_.emplace(devicesKey, std::move(streamVal));
ncclEvents_.emplace(devicesKey, std::move(eventVal));
return devNCCLCommMap_[devicesKey];
}
// Helper function that checks the input and output tensors for validity
void ProcessGroupNCCL::tensorCheckHelper(
const std::vector<at::Tensor>& input,
const std::vector<at::Tensor>& output,
int outputOverInput) {
if (input.size() != output.size()) {
throw std::runtime_error(
"Input tensor sequence should have the same "
"number of tensors as the output tensor sequence");
}
if (input.size() == 0) {
throw std::runtime_error("The number of input tensors should not be zero");
}
if (input.size() > static_cast<size_t>(numGPUs_)) {
throw std::runtime_error(
"The number of input tensors is larger than "
"the number of available GPUs");
}
// To make sure each tensor is on separate devices
std::unordered_set<int> usedDevices;
usedDevices.reserve(input.size());
auto inputNumElement = input[0].numel();
auto elementType = input[0].type().scalarType();
for (size_t i = 0; i < input.size(); ++i) {
// Check to make sure it's a GPU dense tensor
if (!(input[i].type().is_cuda() && !input[i].type().is_sparse() &&
output[i].type().is_cuda() && !output[i].type().is_sparse())) {
throw std::runtime_error(
"Only CUDA dense tensor is supported for NCCL "
"collective operations");
}
// Check the tensor type is identical
if (input[i].type().scalarType() != elementType ||
output[i].type().scalarType() != elementType) {
throw std::runtime_error(
"Expecting all GPU tensors to have identical "
"type");
}
// Check the input tensor size is identical
if (input[i].numel() != inputNumElement) {
throw std::runtime_error(
"Expecting all input tensors to have identical "
"number of elements");
}
// Check the output tensor size equals to input tensor size
if (output[i].numel() != inputNumElement * outputOverInput) {
throw std::runtime_error(
"The number of elements of output tensor does "
"not match the number of elements of the input "
"tensor");
}
// Contiguous verification
if (!input[i].is_contiguous() || !output[i].is_contiguous()) {
throw std::runtime_error("Expecting all GPU tensors to be contiguous");
}
bool inserted;
std::tie(std::ignore, inserted) = usedDevices.insert(input[i].get_device());
// Device verification, if the insertion didn't take place
if (!inserted) {
throw std::runtime_error("Expecting inputs on different GPU devices");
}
// Now check the output device
if (input[i].get_device() != output[i].get_device()) {
throw std::runtime_error(
"Expecting input and output tensors to be on "
"the same device");
}
}
}
std::shared_ptr<ProcessGroup::Work> ProcessGroupNCCL::allreduce(
std::vector<at::Tensor>& tensors,
const AllreduceOptions& opts) {
tensorCheckHelper(tensors, tensors);
auto devices = getDevicesOfTensors(tensors);
auto key = getKeyFromDevices(devices);
auto& ncclComms = getNCCLComm(key, devices);
// First let NCCL streams wait for THC stream
syncStreams(thcState_, devices, ncclEvents_[key], ncclStreams_[key]);
// Work itself will create the CUDA events on all GPUs of tensors
auto work = std::make_shared<ProcessGroupNCCL::WorkNCCL>(devices);
CUDADevice gpuGuard;
C10D_NCCL_CHECK(ncclGroupStart());
for (size_t i = 0; i < tensors.size(); ++i) {
gpuGuard.setDevice(devices[i]);
CUDAStream& ncclStream = ncclStreams_[key][i];
C10D_NCCL_CHECK(ncclAllReduce(
tensors[i].data_ptr(),
tensors[i].data_ptr(),
tensors[i].numel(),
getNcclDataType(tensors[i].type().scalarType()),
ncclOp[opts.reduceOp],
ncclComms[i]->getNcclComm(),
ncclStream.getStream()));
}
C10D_NCCL_CHECK(ncclGroupEnd());
// Event should only be recorded after the ncclGroupEnd()
for (size_t i = 0; i < tensors.size(); ++i) {
CUDAStream& ncclStream = ncclStreams_[key][i];
CUDAEvent& cudaEvent = work->cudaEvents_[i];
C10D_CUDA_CHECK(
cudaEventRecord(cudaEvent.getEvent(), ncclStream.getStream()));
}
return work;
}
std::shared_ptr<ProcessGroup::Work> ProcessGroupNCCL::broadcast(
std::vector<at::Tensor>& tensors,
const BroadcastOptions& opts) {
tensorCheckHelper(tensors, tensors);
auto devices = getDevicesOfTensors(tensors);
auto key = getKeyFromDevices(devices);
auto& ncclComms = getNCCLComm(key, devices);
// First let NCCL streams wait for THC stream
syncStreams(thcState_, devices, ncclEvents_[key], ncclStreams_[key]);
// Work itself will create the CUDA events on all GPUs of tensors
auto work = std::make_shared<ProcessGroupNCCL::WorkNCCL>(devices);
CUDADevice gpuGuard;
C10D_NCCL_CHECK(ncclGroupStart());
for (size_t i = 0; i < tensors.size(); ++i) {
gpuGuard.setDevice(devices[i]);
CUDAStream& ncclStream = ncclStreams_[key][i];
// root rank of the the GPU
int root = opts.rootRank * tensors.size() + opts.rootTensor;
C10D_NCCL_CHECK(ncclBcast(
tensors[i].data_ptr(),
tensors[i].numel(),
getNcclDataType(tensors[i].type().scalarType()),
root,
ncclComms[i]->getNcclComm(),
ncclStream.getStream()));
}
C10D_NCCL_CHECK(ncclGroupEnd());
// Event should only be recorded after the ncclGroupEnd()
for (size_t i = 0; i < tensors.size(); ++i) {
CUDAStream& ncclStream = ncclStreams_[key][i];
CUDAEvent& cudaEvent = work->cudaEvents_[i];
C10D_CUDA_CHECK(
cudaEventRecord(cudaEvent.getEvent(), ncclStream.getStream()));
}
return work;
}
} // namespace c10d

View File

@ -0,0 +1,165 @@
#pragma once
#include "CUDAUtils.hpp"
#include "NCCLUtils.hpp"
#include "ProcessGroup.hpp"
#include "Store.hpp"
// forward declaration
struct THCState;
namespace c10d {
// ProcessGroupNCCL implements NCCL bindings for c10d.
//
// All functions of the class are expected to be called in the same order
// across all processes in the process group. This is the only way that we
// can guarantee to match up the same calls among all processes.
//
// All NCCL functions provided by this class are asynchronous functions. More
// specifically, each NCCL call is scheduled on a separate CUDA stream that is
// different from the current THC CUDA stream. This is for the purpose of
// achieving potentially concurrency and better performance. As a result,
// it is the callers' responsibilty to make sure that the CUDA stream their
// code works on (the THC stream) needs to wait for the NCCL operation from
// this class.
//
// This can be done by calling:
//
// either WorkNCCL::wait() or WorkNCCL::synchronize(), both achieves the same
// functionality and are synonyms.
//
// Note that WorkNCCL::isSuccess() and WorkNCCL::isCompleted() will always
// return true since ProcessGroupNCCL is single threaded. Every single NCCL
// or CUDA failure will simply raise std::runtime_error.
//
// Therefore, WorkNCCL::exception() is not supported since isSuccess() always
// returns true.
//
// Also note that WorkNCCL::finishedGPUExecution() is a helper function only
// provided by ProcessGroupNCCL to check if the NCCL operation of WorkNCCL has
// finished execution on the GPU (not just scheduled).
//
// Example on using the NCCL process group
//
// ProcessGroupNCCL pg(store, rank, size);
// std::shared_ptr<WorkNCCL> work = pg.allreduce(tensors);
//
// // At this point, NCCL kernel has already by queued successfully
// // Now, let THC stream wait for the NCCL to finish, this function is
// // async operation as well
//
// work->wait()
//
// // Now continue on other work in the THC stream.
class ProcessGroupNCCL : public ProcessGroup {
public:
class WorkNCCL : public ProcessGroup::Work {
public:
// Constructor takes a list of CUDA devices
WorkNCCL(const std::vector<int>& devices);
virtual ~WorkNCCL();
// Checks if request has completed. In this specific case of NCCL, it checks
// if the NCCL operation has completed on the GPU in its own NCCL stream.
// Non-blocking operation.
bool isCompleted() const override;
// Let current THC stream wait on the completing of the NCCL work
// always return true and will throw if there are exceptions
// Non-blocking operation
bool wait() override;
// Will always return true
bool isSuccess() const override;
// Same as wait()
void synchronize() override;
// Not supported by WorkNCCL
const std::exception& exception() const override;
// Helper function that checks if the NCCL kernels have finished
// execution on the GPUs
bool finishedGPUExecution() const;
protected:
// The cached list of CUDA devices to operate on
std::vector<int> devices_;
// The CUDA events that are used to track this workitem on
// multiple CUDA devices
std::vector<CUDAEvent> cudaEvents_;
friend class ProcessGroupNCCL;
};
// Constructor will also check the number of available GPUs in the system
ProcessGroupNCCL(const std::shared_ptr<Store>& store, int rank, int size);
virtual ~ProcessGroupNCCL();
std::shared_ptr<ProcessGroup::Work> broadcast(
std::vector<at::Tensor>& tensors,
const BroadcastOptions& opts = BroadcastOptions()) override;
std::shared_ptr<ProcessGroup::Work> allreduce(
std::vector<at::Tensor>& tensors,
const AllreduceOptions& opts = AllreduceOptions()) override;
protected:
// Helper that broadcasts nccl unique ID to all ranks through the store
void broadcastUniqueNCCLId(
const std::string& devicesKey,
ncclUniqueId* ncclId);
// Helper that either looks up the cached NCCL communicators or creates
// a new set of NCCL communicators as a cache entry
std::vector<std::shared_ptr<NCCLComm>>& getNCCLComm(
const std::string& devicesKey,
const std::vector<int>& devices);
// Tensor checker helper
void tensorCheckHelper(
const std::vector<at::Tensor>& input,
const std::vector<at::Tensor>& output,
int outputOverInput = 1);
// Store that is used to exchange each Ranks's NCCL unique ID
std::shared_ptr<Store> store_;
// The NCCL communicator that the process group has cached.
// The key is a list of GPU devices that an operation is operating on
// The GPU devices are stored in a device sequence and the cache NCCL
// communicator is associated with this GPU device sequence
//
// e.g. If the process group op only uses device 0, then the value of
// the used device string stored (value of the hashmap) would be "0".
//
// If the process group op uses device 0 - 7 and the each tensor of the
// input tensor list is on device, 0, 1, 2, 3, 4, 5, 6, 7 separately,
// then the value of the used device string (key) stored would be
// "0,1,2,3,4,5,6,7"
//
// If the process group op uses device 0 - 7 and the each tensor of the
// input tensor list is on device, 0, 4, 5, 6, 7, 1, 2, 3 separately,
// then the value of the used device string stored would be
// "0,4,5,6,7,1,2,3"
//
// Note that the order of the device for the tensor list matters.
std::unordered_map<std::string, std::vector<std::shared_ptr<NCCLComm>>>
devNCCLCommMap_;
// The CUDA steams used by NCCL kernels
std::unordered_map<std::string, std::vector<CUDAStream>> ncclStreams_;
// The CUDA events used to sync NCCL streams
std::unordered_map<std::string, std::vector<CUDAEvent>> ncclEvents_;
// Caches the number of GPUs available in the current system
int numGPUs_;
// Store copy of pointer to THCState retrieved from ::at::globalContext().
THCState* thcState_;
};
} // namespace c10d

View File

@ -17,3 +17,6 @@ if(MPI_FOUND)
add_definitions(-DMPIEXEC=${MPIEXEC})
c10d_add_test(ProcessGroupMPITest.cpp c10d c10d_mpi)
endif()
if(DISTRIBUTED_NCCL_FOUND)
c10d_add_test(ProcessGroupNCCLTest.cpp c10d c10d_cuda_test c10d_nccl)
endif()

View File

@ -6,8 +6,8 @@ namespace c10d {
namespace test {
namespace {
__global__ void waitClocks(const size_t count) {
clock_t start = clock();
__global__ void waitClocks(const uint64_t count) {
clock_t start = clock64();
clock_t offset = 0;
while (offset < count) {
offset = clock() - start;
@ -16,7 +16,7 @@ __global__ void waitClocks(const size_t count) {
} // namespace
void cudaSleep(CUDAStream& stream, size_t clocks) {
void cudaSleep(CUDAStream& stream, uint64_t clocks) {
waitClocks<<<1, 1, 0, stream.getStream()>>>(clocks);
}

View File

@ -8,7 +8,7 @@
namespace c10d {
namespace test {
void cudaSleep(CUDAStream& stream, size_t clocks);
void cudaSleep(CUDAStream& stream, uint64_t clocks);
int cudaNumDevices();

View File

@ -0,0 +1,239 @@
#include "ProcessGroupNCCL.hpp"
#include "CUDAUtils.hpp"
#include "FileStore.hpp"
#include "private/CUDAUtils.hpp"
#include "test/CUDATest.hpp"
#include "test/TestUtils.hpp"
#include <iostream>
using namespace c10d::test;
using c10d::CUDADevice;
using c10d::CUDAStream;
using c10d::ProcessGroup;
using c10d::THCStreamGuard;
class NCCLTestBase {
public:
NCCLTestBase(const std::string& path) : path_(path) {}
NCCLTestBase(NCCLTestBase&& other) {
path_ = std::move(other.path_);
pg_ = std::move(other.pg_);
}
::c10d::ProcessGroupNCCL& getProcessGroup() {
return *pg_;
}
void initialize(int rank, int size) {
auto store = std::make_shared<::c10d::FileStore>(path_);
pg_ = std::unique_ptr<::c10d::ProcessGroupNCCL>(
new ::c10d::ProcessGroupNCCL(store, rank, size));
}
protected:
std::string path_;
std::unique_ptr<::c10d::ProcessGroupNCCL> pg_;
};
class NCCLTest : public NCCLTestBase {
public:
NCCLTest(const std::string& path)
: NCCLTestBase(path),
numDevices_(cudaNumDevices()),
state_(::at::globalContext().lazyInitCUDA()) {
const auto& type = at::getType(at::kCUDA, at::kFloat);
// Each device has a single tensor to perf the NCCL op
inputs_.resize(numDevices_);
for (auto i = 0; i < numDevices_; i++) {
CUDADevice device(i);
inputs_[i] = type.tensor({3, 3});
}
// Allocate a stream per device.
//
// The "current stream" is set globally per device in THC, so we
// can't make two tensors on the same device use different streams
// and pass this along to the collective (since it uses the THC
// getters to retrieve the current stream).
//
streams_.resize(numDevices_);
for (auto i = 0; i < numDevices_; i++) {
CUDADevice device(i);
streams_[i] = CUDAStream::create();
}
}
std::vector<THCStreamGuard> createStreamGuard() {
std::vector<THCStreamGuard> guards;
for (auto& stream : streams_) {
guards.push_back(std::move(THCStreamGuard(state_, stream)));
}
return guards;
}
void wait(std::shared_ptr<ProcessGroup::Work>& work) {
auto guards = createStreamGuard();
work->wait();
}
std::vector<at::Tensor> getTensors() {
std::vector<at::Tensor> outputs(numDevices_);
// For the duration of this function, make THC use our streams
auto guards = createStreamGuard();
// Copy inputs to outputs
for (auto i = 0; i < numDevices_; i++) {
cudaStreamSynchronize(streams_[i].getStream());
outputs[i] = inputs_[i].toBackend(at::kCPU);
}
return outputs;
}
int numDevices() const {
return numDevices_;
}
protected:
const int numDevices_;
THCState* state_;
std::vector<at::Tensor> inputs_;
std::vector<CUDAStream> streams_;
};
class AllreduceNCCLTest : public NCCLTest {
public:
AllreduceNCCLTest(const std::string& path) : NCCLTest(path) {}
std::shared_ptr<c10d::ProcessGroup::Work> run() {
// For the duration of this function, make THC use our streams
auto guards = createStreamGuard();
// Launch sleep on every device
for (auto i = 0; i < numDevices_; i++) {
CUDADevice device(i);
cudaSleep(streams_[i], 2000 * 1000 * 1000);
}
// Launch value initialization for every tensor
for (auto i = 0; i < numDevices_; i++) {
CUDADevice device(i);
inputs_[i].fill_(pg_->getRank() * numDevices_ + i);
}
return pg_->allreduce(inputs_);
}
};
class BroadcastNCCLTest : public NCCLTest {
public:
BroadcastNCCLTest(const std::string& path) : NCCLTest(path) {}
std::shared_ptr<c10d::ProcessGroup::Work> run(int rootRank, int rootTensor) {
// For the duration of this function, make THC use our streams
auto guards = createStreamGuard();
// Launch sleep on every device
for (auto i = 0; i < numDevices_; i++) {
CUDADevice device(i);
cudaSleep(streams_[i], 2000 * 1000 * 1000);
}
// Launch value initialization for every tensor
for (auto i = 0; i < numDevices_; i++) {
CUDADevice device(i);
inputs_[i].fill_(pg_->getRank() * numDevices_ + i);
}
::c10d::BroadcastOptions options;
options.rootRank = rootRank;
options.rootTensor = rootTensor;
return pg_->broadcast(inputs_, options);
}
};
void testAllreduce(const std::string& path, int rank, int size) {
auto test = AllreduceNCCLTest(path);
test.initialize(rank, size);
auto work = test.run();
// Wait for work to finish
test.wait(work);
// Validation
const int totalNumGPUs = test.numDevices() * size;
const auto expected = (totalNumGPUs * (totalNumGPUs - 1)) / 2;
auto tensors = test.getTensors();
for (size_t j = 0; j < tensors.size(); j++) {
auto& tensor = tensors[j];
auto data = tensor.data<float>();
for (auto k = 0; k < tensor.numel(); k++) {
if (data[k] != expected) {
throw std::runtime_error("BOOM!");
}
}
}
std::cout << "Allreduce test successful" << std::endl;
}
void testBroadcast(const std::string& path, int rank, int size) {
auto test = BroadcastNCCLTest(path);
test.initialize(rank, size);
const int numDevices = test.numDevices();
// Try every permutation of root rank and root tensor
for (auto rootRank = 0; rootRank < size; rootRank++) {
for (auto rootTensor = 0; rootTensor < numDevices; rootTensor++) {
auto work = test.run(rootRank, rootTensor);
// Wait for work to complete
test.wait(work);
// Check results
const auto expected = (rootRank * numDevices + rootTensor);
auto tensors = test.getTensors();
for (size_t j = 0; j < tensors.size(); j++) {
auto& tensor = tensors[j];
auto data = tensor.data<float>();
for (auto k = 0; k < tensor.numel(); k++) {
if (data[k] != expected) {
throw std::runtime_error("BOOM!");
}
}
}
}
}
std::cout << "Broadcast test successful" << std::endl;
}
int main(int argc, char** argv) {
// Use WORLD_SIZE and RANK environmental variables to do multi-node
// distributed testing
auto sizeEnv = std::getenv("WORLD_SIZE");
auto rankEnv = std::getenv("RANK");
int size = 1;
int rank = 0;
if (sizeEnv && rankEnv) {
size = std::stoi(std::string(sizeEnv));
rank = std::stoi(std::string(rankEnv));
std::cout << "Multi-node world size: " << size << " rank: " << rank
<< std::endl;
}
{
TemporaryFile file;
testAllreduce(file.path, rank, size);
}
{
TemporaryFile file;
testBroadcast(file.path, rank, size);
}
return EXIT_SUCCESS;
}

View File

@ -36,6 +36,13 @@ class Semaphore {
};
std::string tmppath() {
// TMPFILE is for manual test execution during which the user will specify
// the full temp file path using the environmental variable TMPFILE
const char* tmpfile = getenv("TMPFILE");
if (tmpfile) {
return std::string(tmpfile);
}
const char* tmpdir = getenv("TMPDIR");
if (tmpdir == nullptr) {
tmpdir = "/tmp";