mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
## Summary This PR added 3 intra-node GPU allreduce algorithms to PyTorch: - One-shot allreduce (inspired by FasterTransformer): all ranks simultaneously read and accumulate data from other ranks. - Two-shot allreduce (inspired by FasterTransformer): all ranks simultanesouly read and accumulate `1 / world_size` data from other ranks. Then all ranks read accumulated data from other ranks. (effectively one-shot reduce-scatter + one-shot all-gather). - Hybrid cube mesh allreduce (original): a one-shot allreduce variant that avoids transmission over PCIe on HCM topology. ## Micro Benchmarks    ## Details The intra-node algos are organized behind `c10d::IntraNodeComm`, which is responsible for: - Managing handshaking and cuda IPC handle exchange among ranks. - Querying NVLink connection and detecting topology. - Performing algo selection based on available info. - Launching the selected allreduce kernel. `c10d::IntraNodeComm` is integrated into `c10d::ProcessGroupNCCL` as follows: - When the `ENABLE_INTRA_NODE_COMM` environment variable is set, `c10d::ProcessGroupNCCL` initialize a `c10d::IntraNodeComm` for its ranks. - If the setup is not suitable for intra-node comm (e.g. not all ranks are from the same node), the rendezvous logic guarantees all participants fall back consistently. - `c10d::ProcessGroupNCCL::allreduce` consults `c10d::IntraNodeComm` whether to use intra-node allreduce and carries out the communication accordingly. We currently detect two types of topoloies from the nNVLink connection mesh: - Fully connected: all GPU pairs has direct NVLink connection (e.g. NVSwitch or fully connected sub-set of hybrid cube mesh) - `msg <= 256KB`: one-shot allreduce. - `256KB < msg <= 10MB`: two-shot allreduce. - `msg > 10MB`: instructs the caller to fallback to NCCL. - Hybrid cube mesh - `msg <= 256KB`: one-shot allreduce. - `msg > 256KB`: instructs the caller to fallback to NCCL. ## Next Steps - Fine tune algo selection based on GPU model, topology, link speed. - Potentially optimize the two-shot allreduce impl. Accroding to FasterTransformer, two-shot allreduce is preferred until 50MB. There might be room for improvement, but PyTorch does impose more constraints: - FasterTransformer uses a single process to drive multiple devices. It can use `cudaDeviceEnablePeerAccess` enable device-level peer access. - PyTorch uses multiple process to drive multiple devices. With cuda IPC, a device can only share a specific region to other devices. This means extra copies may be unavoidable. Pull Request resolved: https://github.com/pytorch/pytorch/pull/114001 Approved by: https://github.com/yf225
52 lines
1.9 KiB
C++
52 lines
1.9 KiB
C++
#pragma once
|
|
#include <cuda.h>
|
|
#define NVML_NO_UNVERSIONED_FUNC_DEFS
|
|
#include <nvml.h>
|
|
|
|
#define C10_CUDA_DRIVER_CHECK(EXPR) \
|
|
do { \
|
|
CUresult __err = EXPR; \
|
|
if (__err != CUDA_SUCCESS) { \
|
|
const char* err_str; \
|
|
CUresult get_error_str_err C10_UNUSED = \
|
|
c10::cuda::DriverAPI::get()->cuGetErrorString_(__err, &err_str); \
|
|
if (get_error_str_err != CUDA_SUCCESS) { \
|
|
AT_ERROR("CUDA driver error: unknown error"); \
|
|
} else { \
|
|
AT_ERROR("CUDA driver error: ", err_str); \
|
|
} \
|
|
} \
|
|
} while (0)
|
|
|
|
#define C10_LIBCUDA_DRIVER_API(_) \
|
|
_(cuMemAddressReserve) \
|
|
_(cuMemRelease) \
|
|
_(cuMemMap) \
|
|
_(cuMemAddressFree) \
|
|
_(cuMemSetAccess) \
|
|
_(cuMemUnmap) \
|
|
_(cuMemCreate) \
|
|
_(cuGetErrorString)
|
|
|
|
#define C10_NVML_DRIVER_API(_) \
|
|
_(nvmlInit_v2) \
|
|
_(nvmlDeviceGetHandleByPciBusId_v2) \
|
|
_(nvmlDeviceGetNvLinkRemoteDeviceType) \
|
|
_(nvmlDeviceGetNvLinkRemotePciInfo_v2) \
|
|
_(nvmlDeviceGetComputeRunningProcesses)
|
|
|
|
namespace c10 {
|
|
namespace cuda {
|
|
|
|
struct DriverAPI {
|
|
#define CREATE_MEMBER(name) decltype(&name) name##_;
|
|
C10_LIBCUDA_DRIVER_API(CREATE_MEMBER)
|
|
C10_NVML_DRIVER_API(CREATE_MEMBER)
|
|
#undef CREATE_MEMBER
|
|
static DriverAPI* get();
|
|
static void* get_nvml_handle();
|
|
};
|
|
|
|
} // namespace cuda
|
|
} // namespace c10
|