mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Introduce 3 low-latency, intra-node allreduce algorithms for small messages to PyTorch (#114001)
## Summary This PR added 3 intra-node GPU allreduce algorithms to PyTorch: - One-shot allreduce (inspired by FasterTransformer): all ranks simultaneously read and accumulate data from other ranks. - Two-shot allreduce (inspired by FasterTransformer): all ranks simultanesouly read and accumulate `1 / world_size` data from other ranks. Then all ranks read accumulated data from other ranks. (effectively one-shot reduce-scatter + one-shot all-gather). - Hybrid cube mesh allreduce (original): a one-shot allreduce variant that avoids transmission over PCIe on HCM topology. ## Micro Benchmarks    ## Details The intra-node algos are organized behind `c10d::IntraNodeComm`, which is responsible for: - Managing handshaking and cuda IPC handle exchange among ranks. - Querying NVLink connection and detecting topology. - Performing algo selection based on available info. - Launching the selected allreduce kernel. `c10d::IntraNodeComm` is integrated into `c10d::ProcessGroupNCCL` as follows: - When the `ENABLE_INTRA_NODE_COMM` environment variable is set, `c10d::ProcessGroupNCCL` initialize a `c10d::IntraNodeComm` for its ranks. - If the setup is not suitable for intra-node comm (e.g. not all ranks are from the same node), the rendezvous logic guarantees all participants fall back consistently. - `c10d::ProcessGroupNCCL::allreduce` consults `c10d::IntraNodeComm` whether to use intra-node allreduce and carries out the communication accordingly. We currently detect two types of topoloies from the nNVLink connection mesh: - Fully connected: all GPU pairs has direct NVLink connection (e.g. NVSwitch or fully connected sub-set of hybrid cube mesh) - `msg <= 256KB`: one-shot allreduce. - `256KB < msg <= 10MB`: two-shot allreduce. - `msg > 10MB`: instructs the caller to fallback to NCCL. - Hybrid cube mesh - `msg <= 256KB`: one-shot allreduce. - `msg > 256KB`: instructs the caller to fallback to NCCL. ## Next Steps - Fine tune algo selection based on GPU model, topology, link speed. - Potentially optimize the two-shot allreduce impl. Accroding to FasterTransformer, two-shot allreduce is preferred until 50MB. There might be room for improvement, but PyTorch does impose more constraints: - FasterTransformer uses a single process to drive multiple devices. It can use `cudaDeviceEnablePeerAccess` enable device-level peer access. - PyTorch uses multiple process to drive multiple devices. With cuda IPC, a device can only share a specific region to other devices. This means extra copies may be unavoidable. Pull Request resolved: https://github.com/pytorch/pytorch/pull/114001 Approved by: https://github.com/yf225
This commit is contained in:
committed by
PyTorch MergeBot
parent
cd47e335d1
commit
4edc921857
@ -1452,7 +1452,10 @@ cu_library(
|
||||
# https://github.com/pytorch/pytorch/issues/79236
|
||||
# To solve it we add it into the `caffe2_cuda`,
|
||||
# this is also aligned with the CMake build.
|
||||
srcs = [":caffe2_cu_srcs"] + ["torch/csrc/distributed/c10d/quantization/quantization_gpu.cu"],
|
||||
srcs = [":caffe2_cu_srcs"] + [
|
||||
"torch/csrc/distributed/c10d/intra_node_comm.cu",
|
||||
"torch/csrc/distributed/c10d/quantization/quantization_gpu.cu",
|
||||
],
|
||||
copts = CAFFE2_COPTS + torch_cuda_half_options,
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
@ -1619,6 +1622,7 @@ cc_library(
|
||||
exclude = [
|
||||
"torch/csrc/cuda/python_nccl.cpp",
|
||||
"torch/csrc/cuda/nccl.cpp",
|
||||
"torch/csrc/distributed/c10d/intra_node_comm.cu",
|
||||
"torch/csrc/distributed/c10d/quantization/quantization_gpu.cu",
|
||||
],
|
||||
)) + torch_sources,
|
||||
|
@ -674,6 +674,8 @@ libtorch_cuda_distributed_extra_sources = [
|
||||
"torch/csrc/distributed/c10d/ProcessGroupUCC.cpp",
|
||||
"torch/csrc/distributed/c10d/UCCTracing.cpp",
|
||||
"torch/csrc/distributed/c10d/UCCUtils.cpp",
|
||||
"torch/csrc/distributed/c10d/intra_node_comm.cpp",
|
||||
"torch/csrc/distributed/c10d/intra_node_comm.cu",
|
||||
"torch/csrc/distributed/rpc/tensorpipe_cuda.cpp",
|
||||
"torch/csrc/distributed/c10d/quantization/quantization_gpu.cu",
|
||||
]
|
||||
|
@ -37,7 +37,7 @@ void* DriverAPI::get_nvml_handle() {
|
||||
return nvml_hanle;
|
||||
}
|
||||
|
||||
DriverAPI* DriverAPI::get() {
|
||||
C10_EXPORT DriverAPI* DriverAPI::get() {
|
||||
static DriverAPI singleton = create_driver_api();
|
||||
return &singleton;
|
||||
}
|
||||
|
@ -28,9 +28,11 @@
|
||||
_(cuMemCreate) \
|
||||
_(cuGetErrorString)
|
||||
|
||||
#define C10_NVML_DRIVER_API(_) \
|
||||
_(nvmlInit_v2) \
|
||||
_(nvmlDeviceGetHandleByPciBusId_v2) \
|
||||
#define C10_NVML_DRIVER_API(_) \
|
||||
_(nvmlInit_v2) \
|
||||
_(nvmlDeviceGetHandleByPciBusId_v2) \
|
||||
_(nvmlDeviceGetNvLinkRemoteDeviceType) \
|
||||
_(nvmlDeviceGetNvLinkRemotePciInfo_v2) \
|
||||
_(nvmlDeviceGetComputeRunningProcesses)
|
||||
|
||||
namespace c10 {
|
||||
|
@ -641,6 +641,10 @@ if(USE_CUDA)
|
||||
append_filelist("libtorch_cuda_distributed_base_sources" Caffe2_GPU_SRCS)
|
||||
if(NOT WIN32)
|
||||
append_filelist("libtorch_cuda_distributed_extra_sources" Caffe2_GPU_SRCS)
|
||||
set_source_files_properties(
|
||||
${TORCH_SRC_DIR}/csrc/distributed/c10d/intra_node_comm.cpp
|
||||
PROPERTIES COMPILE_FLAGS "-DPYTORCH_C10_DRIVER_API_SUPPORTED=1"
|
||||
)
|
||||
endif()
|
||||
endif()
|
||||
set_source_files_properties(
|
||||
|
@ -15,7 +15,7 @@ import warnings
|
||||
from contextlib import contextmanager
|
||||
from datetime import datetime, timedelta
|
||||
from itertools import chain, product
|
||||
from unittest import mock
|
||||
from unittest import SkipTest, mock
|
||||
|
||||
import torch
|
||||
import torch.distributed as c10d
|
||||
@ -3113,6 +3113,65 @@ class CommTest(test_c10d_common.AbstractCommTest, MultiProcessTestCase):
|
||||
for i, t in enumerate(tensors):
|
||||
self.assertEqual(t, torch.full_like(t, self.world_size * (i + (self.world_size + 1.) / 2.)))
|
||||
|
||||
@requires_nccl()
|
||||
@skip_if_lt_x_gpu(2)
|
||||
@skip_if_rocm
|
||||
def test_intra_node_comm_all_reduce(self):
|
||||
from torch._C._distributed_c10d import _get_intra_node_comm_usage_counter
|
||||
from torch.testing._internal.common_cuda import SM80OrLater
|
||||
for peer in range(self.world_size):
|
||||
if peer == self.rank:
|
||||
continue
|
||||
if not torch._C._cuda_canDeviceAccessPeer(self.rank, peer):
|
||||
raise SkipTest("Test requires p2p access")
|
||||
|
||||
if not SM80OrLater:
|
||||
raise SkipTest("Test requires sm>=80")
|
||||
|
||||
store = c10d.FileStore(self.file_name, self.world_size)
|
||||
os.environ["ENABLE_INTRA_NODE_COMM"] = "1"
|
||||
os.environ["TEST_INTRA_NODE_COMM"] = "1"
|
||||
torch.cuda.set_device(self.rank)
|
||||
c10d.init_process_group(
|
||||
backend="nccl", rank=self.rank, world_size=self.world_size, store=store
|
||||
)
|
||||
expect = self.world_size * (self.world_size - 1) // 2
|
||||
|
||||
# IntraNodeComm currently only supports sum and bf16.
|
||||
# Verify that it is not used in the next two configurations.
|
||||
t = torch.full((4 * 1024 // 2,), self.rank).cuda()
|
||||
c10d.all_reduce(t, c10d.ReduceOp.SUM)
|
||||
self.assertTrue(t.eq(expect).all())
|
||||
self.assertEqual(_get_intra_node_comm_usage_counter(), 0)
|
||||
|
||||
t = torch.full((4 * 1024 // 2,), self.rank, dtype=torch.bfloat16).cuda()
|
||||
c10d.all_reduce(t, c10d.ReduceOp.AVG)
|
||||
self.assertEqual(_get_intra_node_comm_usage_counter(), 0)
|
||||
|
||||
# Verify that IntraNodeComm is used up to 10MB
|
||||
t = torch.full((4 * 1024 // 2,), self.rank, dtype=torch.bfloat16).cuda()
|
||||
c10d.all_reduce(t, c10d.ReduceOp.SUM)
|
||||
self.assertTrue(t.eq(expect).all())
|
||||
self.assertEqual(_get_intra_node_comm_usage_counter(), 1)
|
||||
|
||||
t = torch.full((512 * 1024 // 2,), self.rank, dtype=torch.bfloat16).cuda()
|
||||
c10d.all_reduce(t, c10d.ReduceOp.SUM)
|
||||
self.assertTrue(t.eq(expect).all())
|
||||
self.assertEqual(_get_intra_node_comm_usage_counter(), 2)
|
||||
|
||||
t = torch.full((10 * 1024 ** 2 // 2,), self.rank, dtype=torch.bfloat16).cuda()
|
||||
c10d.all_reduce(t, c10d.ReduceOp.SUM)
|
||||
self.assertTrue(t.eq(expect).all())
|
||||
self.assertEqual(_get_intra_node_comm_usage_counter(), 3)
|
||||
|
||||
# Verify that IntraNodeComm is not used beyond 10MB
|
||||
t = torch.full((10 * 1024 ** 2 // 2 + 1,), self.rank, dtype=torch.bfloat16).cuda()
|
||||
c10d.all_reduce(t, c10d.ReduceOp.SUM)
|
||||
self.assertTrue(t.eq(expect).all())
|
||||
self.assertEqual(_get_intra_node_comm_usage_counter(), 3)
|
||||
|
||||
c10d.destroy_process_group()
|
||||
|
||||
@requires_nccl()
|
||||
@skip_if_lt_x_gpu(2)
|
||||
def test_sequence_num_set_default_pg_nccl(self):
|
||||
|
@ -712,7 +712,8 @@ ProcessGroupNCCL::ProcessGroupNCCL(
|
||||
terminateProcessGroup_(false),
|
||||
terminateHeartbeatMonitorThread_(false),
|
||||
collectiveDebugInfoMode_(false),
|
||||
uid_(process_group_id++) {
|
||||
uid_(process_group_id++),
|
||||
intraNodeComm_(initIntraNodeComm()) {
|
||||
TORCH_CHECK_WITH(
|
||||
ValueError,
|
||||
at::cuda::getNumGPUs() != 0,
|
||||
@ -895,6 +896,12 @@ void ProcessGroupNCCL::performNocolorSplit(at::Device device) {
|
||||
#endif
|
||||
}
|
||||
|
||||
c10::intrusive_ptr<intra_node_comm::IntraNodeComm> ProcessGroupNCCL::
|
||||
initIntraNodeComm() {
|
||||
return intra_node_comm::IntraNodeComm::rendezvous(
|
||||
store_, std::to_string(uid_), rank_, size_);
|
||||
}
|
||||
|
||||
void ProcessGroupNCCL::runHealthCheck() {
|
||||
// Run health check in a separate thread and wait on CV to handle timeouts,
|
||||
// since majority of getNCCLComm failures are hangs.
|
||||
@ -2802,6 +2809,16 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::allreduce_impl(
|
||||
c10::intrusive_ptr<Work> ProcessGroupNCCL::allreduce(
|
||||
std::vector<at::Tensor>& tensors,
|
||||
const AllreduceOptions& opts) {
|
||||
if (intraNodeComm_ != nullptr && tensors.size() == 1 &&
|
||||
opts.reduceOp == ReduceOp::SUM) {
|
||||
using namespace intra_node_comm;
|
||||
auto algo = intraNodeComm_->selectAllReduceAlgo(tensors[0]);
|
||||
if (algo != intra_node_comm::AllReduceAlgo::NONE) {
|
||||
intraNodeComm_->allReduce(tensors[0], algo);
|
||||
return c10::make_intrusive<IntraNodeCommWork>();
|
||||
}
|
||||
}
|
||||
|
||||
check_gpu_tensors_different_devices(tensors);
|
||||
|
||||
// @lint-ignore CLANGTIDY
|
||||
|
@ -13,6 +13,7 @@
|
||||
#include <torch/csrc/distributed/c10d/Backend.hpp>
|
||||
#include <torch/csrc/distributed/c10d/NCCLUtils.hpp>
|
||||
#include <torch/csrc/distributed/c10d/Store.hpp>
|
||||
#include <torch/csrc/distributed/c10d/intra_node_comm.hpp>
|
||||
|
||||
#include <ATen/DynamicLibrary.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
@ -546,6 +547,8 @@ class TORCH_API ProcessGroupNCCL : public Backend {
|
||||
// Provide an API for users to define their own ways to store NCCL debug info.
|
||||
void registerDebugInfoWriter(std::unique_ptr<DebugInfoWriter> writer);
|
||||
|
||||
c10::intrusive_ptr<intra_node_comm::IntraNodeComm> initIntraNodeComm();
|
||||
|
||||
// Provides an API to abort the ProcessGroup (similar to ncclCommAbort)
|
||||
// instead of relying on ProcessGroupNCCL destructor.
|
||||
void abort(c10::optional<std::string> abortReason = c10::nullopt);
|
||||
@ -940,6 +943,8 @@ class TORCH_API ProcessGroupNCCL : public Backend {
|
||||
std::unique_ptr<DebugInfoWriter> debugInfoWriter_ = nullptr;
|
||||
|
||||
size_t uid_;
|
||||
|
||||
c10::intrusive_ptr<intra_node_comm::IntraNodeComm> intraNodeComm_;
|
||||
};
|
||||
|
||||
TORCH_API std::string dump_nccl_trace();
|
||||
|
@ -21,6 +21,7 @@
|
||||
#ifdef USE_C10D_NCCL
|
||||
#include <torch/csrc/distributed/c10d/NCCLUtils.hpp>
|
||||
#include <torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp>
|
||||
#include <torch/csrc/distributed/c10d/intra_node_comm.hpp>
|
||||
#endif
|
||||
|
||||
#ifdef USE_C10D_MPI
|
||||
@ -2328,6 +2329,10 @@ options :class:`~torch.distributed.ProcessGroupNCCL.Options`).
|
||||
"perform_nocolor_split",
|
||||
&::c10d::ProcessGroupNCCL::performNocolorSplit);
|
||||
|
||||
module.def(
|
||||
"_get_intra_node_comm_usage_counter",
|
||||
&::c10d::intra_node_comm::getIntraNodeCommUsageCounter);
|
||||
|
||||
#ifdef NCCL_HAS_COMM_CTA_CGA
|
||||
py::class_<ncclConfig_t>(
|
||||
processGroupNCCL,
|
||||
|
448
torch/csrc/distributed/c10d/intra_node_comm.cpp
Normal file
448
torch/csrc/distributed/c10d/intra_node_comm.cpp
Normal file
@ -0,0 +1,448 @@
|
||||
#include <torch/csrc/distributed/c10d/intra_node_comm.hpp>
|
||||
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <c10/util/Logging.h>
|
||||
#include <torch/csrc/distributed/c10d/Utils.hpp>
|
||||
|
||||
#include <iostream>
|
||||
#include <random>
|
||||
|
||||
#include <fcntl.h>
|
||||
#include <pthread.h>
|
||||
#include <semaphore.h>
|
||||
#include <sys/mman.h>
|
||||
#include <sys/stat.h>
|
||||
#include <unistd.h>
|
||||
|
||||
#if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED)
|
||||
#include <c10/cuda/driver_api.h>
|
||||
#include <nvml.h>
|
||||
#endif
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
namespace c10d {
|
||||
namespace intra_node_comm {
|
||||
|
||||
static std::vector<std::string> ENABLE_INTRA_NODE_COMM = {
|
||||
"ENABLE_INTRA_NODE_COMM"};
|
||||
// Forces detectedTopology() to return Topology::FULLY_CONNECTED, so
|
||||
// IntraNodeComm can be used even without NVLink connection. This is only used
|
||||
// for testing purposes.
|
||||
static std::vector<std::string> TEST_INTRA_NODE_COMM = {"TEST_INTRA_NODE_COMM"};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
// CUDA Functions
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
bool isIntraNodeCommSupported();
|
||||
|
||||
std::optional<HybridCubeMesh> getHybridCubeMesh(NvlMesh nvlMesh);
|
||||
|
||||
void* initP2pState();
|
||||
|
||||
void* initTopoInfo(Topology topology, NvlMesh nvlMesh, size_t rank);
|
||||
|
||||
AllReduceAlgo selectAllReduceAlgo(
|
||||
const at::Tensor& input,
|
||||
Topology topology,
|
||||
size_t worldSize);
|
||||
|
||||
at::Tensor allReduce(
|
||||
const at::Tensor& input,
|
||||
std::array<void*, kMaxDevices> p2pStates,
|
||||
std::array<void*, kMaxDevices> buffers,
|
||||
void* topoInfo,
|
||||
size_t rank,
|
||||
size_t worldSize,
|
||||
AllReduceAlgo algo,
|
||||
at::cuda::CUDAStream& stream);
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
// Topology Detection
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// TODO: find a better way to determine this
|
||||
static constexpr size_t kMaxNvLinks = 20;
|
||||
|
||||
static std::ostream& operator<<(std::ostream& os, const NvlMesh& nvlMesh) {
|
||||
std::ostringstream oss;
|
||||
for (size_t i = 0; i < kMaxDevices; ++i) {
|
||||
for (size_t j = 0; j < kMaxDevices; ++j) {
|
||||
oss << nvlMesh[i][j] << " ";
|
||||
}
|
||||
oss << std::endl;
|
||||
}
|
||||
os << oss.str();
|
||||
return os;
|
||||
}
|
||||
|
||||
static bool isSame(NvlMesh lhs, NvlMesh rhs) {
|
||||
for (size_t i = 0; i < kMaxDevices; ++i) {
|
||||
for (size_t j = 0; j < kMaxDevices; ++j) {
|
||||
if (lhs[i][j] != rhs[i][j]) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
/**
|
||||
* Query the nvlink connection among devices.
|
||||
*/
|
||||
static NvlMesh getNvlMesh(std::vector<std::string> rankToBusId) {
|
||||
#if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED)
|
||||
using namespace c10::cuda;
|
||||
|
||||
NvlMesh nvlMesh = {};
|
||||
auto driverApi = DriverAPI::get();
|
||||
if (driverApi == nullptr) {
|
||||
return nvlMesh;
|
||||
}
|
||||
|
||||
const auto worldSize = rankToBusId.size();
|
||||
std::vector<nvmlDevice_t> devices(worldSize, 0);
|
||||
std::unordered_map<std::string, size_t> busIdToRank;
|
||||
std::vector<size_t> switchLinkCount(worldSize, 0);
|
||||
|
||||
for (size_t r = 0; r < worldSize; ++r) {
|
||||
busIdToRank.emplace(std::make_pair(rankToBusId[r], r));
|
||||
TORCH_CHECK(
|
||||
driverApi->nvmlDeviceGetHandleByPciBusId_v2_(
|
||||
rankToBusId[r].c_str(), &devices[r]) == NVML_SUCCESS);
|
||||
}
|
||||
|
||||
// For each device, loop over devices connected to it via NVLink
|
||||
for (size_t idx = 0; idx < worldSize; ++idx) {
|
||||
for (size_t link = 0; link < kMaxNvLinks; ++link) {
|
||||
nvmlReturn_t ret;
|
||||
nvmlIntNvLinkDeviceType_t deviceType;
|
||||
ret = driverApi->nvmlDeviceGetNvLinkRemoteDeviceType_(
|
||||
devices[idx], link, &deviceType);
|
||||
if (ret != NVML_SUCCESS) {
|
||||
// We've exhausted the NVLinks connected to this device.
|
||||
// This error is benign. There doesn't seem to be a reliable
|
||||
// way to obtain the maximum link value that can be passed to
|
||||
// the API, so we simply increment the link value until the
|
||||
// API fails or we hit a predefined maximum value.
|
||||
break;
|
||||
}
|
||||
// Remote device is GPU
|
||||
if (deviceType == NVML_NVLINK_DEVICE_TYPE_GPU) {
|
||||
nvmlPciInfo_t pciInfo;
|
||||
ret = driverApi->nvmlDeviceGetNvLinkRemotePciInfo_v2_(
|
||||
devices[idx], link, &pciInfo);
|
||||
if (ret != NVML_SUCCESS) {
|
||||
// Unexpected error. Return an empty NvlMesh
|
||||
return {};
|
||||
}
|
||||
auto it = busIdToRank.find(pciInfo.busId);
|
||||
if (it != busIdToRank.end()) {
|
||||
if (idx != it->second) {
|
||||
nvlMesh[idx][it->second] += 1;
|
||||
}
|
||||
}
|
||||
// Remote device is NVSwitch
|
||||
} else if (deviceType == NVML_NVLINK_DEVICE_TYPE_SWITCH) {
|
||||
switchLinkCount[idx] += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
// Process NVSwitch connections. For simplicity, we assume
|
||||
// all NVSwitches are interconnected.
|
||||
for (size_t i = 0; i < worldSize; ++i) {
|
||||
for (size_t j = 0; j < worldSize; ++j) {
|
||||
if (i == j) {
|
||||
continue;
|
||||
}
|
||||
nvlMesh[i][j] += std::min(switchLinkCount[i], switchLinkCount[j]);
|
||||
}
|
||||
}
|
||||
return nvlMesh;
|
||||
#else
|
||||
return {};
|
||||
#endif
|
||||
}
|
||||
|
||||
/**
|
||||
* Determine if the devices form a hybrid cube mesh
|
||||
* topology given a NvlMesh.
|
||||
*/
|
||||
static bool isHybridCubeMesh(const NvlMesh nvlMesh) {
|
||||
std::array<size_t, kMaxDevices> numNeighbors = {};
|
||||
for (size_t i = 0; i < kMaxDevices; ++i) {
|
||||
for (size_t j = 0; j < kMaxDevices; ++j) {
|
||||
if (nvlMesh[i][j] > 0) {
|
||||
numNeighbors[i] += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
for (size_t i = 0; i < kMaxDevices; ++i) {
|
||||
// TODO: this is insufficent and needs revisit
|
||||
if (numNeighbors[i] != 4) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
/**
|
||||
* Detech topology given a NvlMesh.
|
||||
*/
|
||||
static Topology detectTopology(const NvlMesh nvlMesh, size_t worldSize) {
|
||||
if (getCvarBool(TEST_INTRA_NODE_COMM, false)) {
|
||||
return Topology::FULLY_CONNECTED;
|
||||
}
|
||||
bool fullyConnected = true;
|
||||
for (size_t i = 0; i < worldSize - 1; ++i) {
|
||||
for (size_t j = i + 1; j < worldSize; ++j) {
|
||||
if (nvlMesh[i][j] == 0 || nvlMesh[j][i] == 0) {
|
||||
fullyConnected = false;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (fullyConnected) {
|
||||
LOG(INFO) << "IntraNodeComm: Topology::FULLY_CONNECTED";
|
||||
return Topology::FULLY_CONNECTED;
|
||||
}
|
||||
if (worldSize == kMaxDevices && getHybridCubeMesh(nvlMesh) != std::nullopt) {
|
||||
LOG(INFO) << "IntraNodeComm: Topology::HYBRID_CUBE_MESH";
|
||||
return Topology::HYBRID_CUBE_MESH;
|
||||
}
|
||||
LOG(INFO) << "IntraNodeComm: Topology::UNKNOWN";
|
||||
return Topology::UNKNOWN;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
// Rendezvous and Initialization
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
IntraNodeComm::IntraNodeComm(
|
||||
Topology topology,
|
||||
std::array<void*, kMaxDevices> p2pStates,
|
||||
std::array<void*, kMaxDevices> buffers,
|
||||
void* topoInfo,
|
||||
size_t rank,
|
||||
size_t worldSize)
|
||||
: topology_(topology),
|
||||
p2pStates_(p2pStates),
|
||||
buffers_(buffers),
|
||||
topoInfo_(topoInfo),
|
||||
rank_(rank),
|
||||
worldSize_(worldSize) {}
|
||||
|
||||
IntraNodeComm::~IntraNodeComm() {
|
||||
// Intentionally releasing resources without synchronizing devices. The
|
||||
// teardown logic is safe for propoerly sync'd user program. We don't want
|
||||
// improperly sync'd user program to hang here.
|
||||
for (size_t r = 0; r < worldSize_; ++r) {
|
||||
if (r == rank_) {
|
||||
continue;
|
||||
}
|
||||
AT_CUDA_CHECK(cudaIpcCloseMemHandle(p2pStates_[r]));
|
||||
AT_CUDA_CHECK(cudaIpcCloseMemHandle(buffers_[r]));
|
||||
}
|
||||
AT_CUDA_CHECK(cudaFree(p2pStates_[rank_]));
|
||||
AT_CUDA_CHECK(cudaFree(buffers_[rank_]));
|
||||
if (topoInfo_ != nullptr) {
|
||||
AT_CUDA_CHECK(cudaFree(topoInfo_));
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Use c10d::Store to perform allgather on a trivially copyable type.
|
||||
*/
|
||||
template <typename T>
|
||||
std::vector<T> storeAllGather(
|
||||
c10::intrusive_ptr<c10d::Store> store,
|
||||
const std::string& prefix,
|
||||
size_t rank,
|
||||
size_t worldSize,
|
||||
T val) {
|
||||
static_assert(std::is_trivially_copyable<T>::value);
|
||||
|
||||
std::vector<std::string> peerKeys;
|
||||
for (size_t r = 0; r < worldSize; ++r) {
|
||||
std::ostringstream oss;
|
||||
oss << prefix << "-" << r;
|
||||
peerKeys.push_back(oss.str());
|
||||
}
|
||||
|
||||
{
|
||||
std::vector<uint8_t> payload(
|
||||
reinterpret_cast<uint8_t*>(&val),
|
||||
reinterpret_cast<uint8_t*>(&val) + sizeof(T));
|
||||
store->set(peerKeys[rank], payload);
|
||||
}
|
||||
|
||||
std::vector<T> peerVals;
|
||||
for (size_t r = 0; r < worldSize; ++r) {
|
||||
if (r == rank) {
|
||||
peerVals.push_back(val);
|
||||
continue;
|
||||
}
|
||||
store->wait({peerKeys[r]});
|
||||
auto payload = store->get(peerKeys[r]);
|
||||
TORCH_CHECK(payload.size() == sizeof(T));
|
||||
T peerVal;
|
||||
std::memcpy(&peerVal, payload.data(), sizeof(T));
|
||||
peerVals.push_back(peerVal);
|
||||
}
|
||||
return peerVals;
|
||||
}
|
||||
|
||||
c10::intrusive_ptr<IntraNodeComm> IntraNodeComm::rendezvous(
|
||||
c10::intrusive_ptr<c10d::Store> store,
|
||||
const std::string& prefix,
|
||||
size_t rank,
|
||||
size_t worldSize) {
|
||||
#if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED)
|
||||
if (!isIntraNodeCommSupported() ||
|
||||
!getCvarBool(ENABLE_INTRA_NODE_COMM, false) || worldSize < 2 ||
|
||||
worldSize > kMaxDevices) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
int deviceIdx = at::cuda::current_device();
|
||||
c10::cuda::CUDAGuard guard(deviceIdx);
|
||||
|
||||
// First hand shake: exchange hostname and device bus ID
|
||||
struct DevInfo {
|
||||
char hostname[HOST_NAME_MAX + 1];
|
||||
char busId[80];
|
||||
};
|
||||
|
||||
DevInfo devInfo{};
|
||||
gethostname(devInfo.hostname, sizeof(devInfo.hostname));
|
||||
cudaDeviceProp prop{};
|
||||
AT_CUDA_CHECK(cudaGetDeviceProperties(&prop, deviceIdx));
|
||||
snprintf(
|
||||
devInfo.busId,
|
||||
sizeof(devInfo.busId),
|
||||
NVML_DEVICE_PCI_BUS_ID_FMT,
|
||||
prop.pciDomainID,
|
||||
prop.pciBusID,
|
||||
prop.pciDeviceID);
|
||||
|
||||
auto peerDevInfos = storeAllGather(
|
||||
store, prefix + "-IntraNodeCommHandShake-0", rank, worldSize, devInfo);
|
||||
|
||||
std::vector<std::string> rankToBusId;
|
||||
for (const auto& info : peerDevInfos) {
|
||||
if (strcmp(info.hostname, peerDevInfos.front().hostname) != 0) {
|
||||
LOG(WARNING) << "Aborting IntraNodeComm::rendezvous because some "
|
||||
"participants are not on the same host ("
|
||||
<< info.hostname << ", " << devInfo.hostname << ")";
|
||||
return nullptr;
|
||||
}
|
||||
rankToBusId.emplace_back(info.busId);
|
||||
}
|
||||
|
||||
// Verify unique devices
|
||||
{
|
||||
std::unordered_set uniqueBusIds(rankToBusId.begin(), rankToBusId.end());
|
||||
TORCH_CHECK(
|
||||
uniqueBusIds.size() == worldSize,
|
||||
"IntraNodeComm::rendezvous: detected overlapping devices across ranks. "
|
||||
"Please properly set device via torch.cuda.set_device() before "
|
||||
"initiating rendezvous.");
|
||||
}
|
||||
|
||||
// Query nvlink connection
|
||||
auto nvlMesh = getNvlMesh(rankToBusId);
|
||||
|
||||
// Detect topology
|
||||
Topology topology = detectTopology(nvlMesh, worldSize);
|
||||
|
||||
// Initialize p2p state
|
||||
auto p2pState = initP2pState();
|
||||
|
||||
// Allocate buffer
|
||||
void* buffer = nullptr;
|
||||
AT_CUDA_CHECK(cudaMalloc(&buffer, kMaxIntraNodeSize * 2));
|
||||
|
||||
// Second handshake: exchange topology and CUDA IPC handles
|
||||
struct IpcInfo {
|
||||
NvlMesh nvlMesh;
|
||||
Topology topology;
|
||||
cudaIpcMemHandle_t p2pStateHandle, bufferHandle;
|
||||
};
|
||||
|
||||
// Make p2p state and buffer available for IPC
|
||||
cudaIpcMemHandle_t p2pStateHandle, bufferHandle;
|
||||
AT_CUDA_CHECK(cudaIpcGetMemHandle(&p2pStateHandle, p2pState));
|
||||
AT_CUDA_CHECK(cudaIpcGetMemHandle(&bufferHandle, buffer));
|
||||
|
||||
IpcInfo ipcInfo{
|
||||
.nvlMesh = nvlMesh,
|
||||
.topology = topology,
|
||||
.p2pStateHandle = p2pStateHandle,
|
||||
.bufferHandle = bufferHandle};
|
||||
|
||||
auto peerIpcInfos = storeAllGather(
|
||||
store, prefix + "-IntraNodeCommHandShake-2", rank, worldSize, ipcInfo);
|
||||
|
||||
for (const auto& info : peerIpcInfos) {
|
||||
if (!isSame(info.nvlMesh, peerIpcInfos.front().nvlMesh) ||
|
||||
info.topology != peerIpcInfos.front().topology) {
|
||||
LOG(WARNING) << "Aborting IntraNodeComm::rendezvous because some "
|
||||
"participants are observing different topologies ("
|
||||
<< int(info.topology) << " and " << int(topology) << ")";
|
||||
AT_CUDA_CHECK(cudaFree(p2pState));
|
||||
AT_CUDA_CHECK(cudaFree(buffer));
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
std::array<void*, kMaxDevices> p2pStates = {}, buffers = {};
|
||||
for (size_t r = 0; r < peerIpcInfos.size(); ++r) {
|
||||
if (r == rank) {
|
||||
p2pStates[r] = p2pState;
|
||||
buffers[r] = buffer;
|
||||
} else {
|
||||
AT_CUDA_CHECK(cudaIpcOpenMemHandle(
|
||||
&p2pStates[r],
|
||||
peerIpcInfos[r].p2pStateHandle,
|
||||
cudaIpcMemLazyEnablePeerAccess));
|
||||
AT_CUDA_CHECK(cudaIpcOpenMemHandle(
|
||||
&buffers[r],
|
||||
peerIpcInfos[r].bufferHandle,
|
||||
cudaIpcMemLazyEnablePeerAccess));
|
||||
}
|
||||
}
|
||||
void* topoInfo = initTopoInfo(topology, nvlMesh, rank);
|
||||
return c10::make_intrusive<IntraNodeComm>(
|
||||
topology, p2pStates, buffers, topoInfo, rank, worldSize);
|
||||
#else
|
||||
return nullptr;
|
||||
#endif
|
||||
}
|
||||
|
||||
AllReduceAlgo IntraNodeComm::selectAllReduceAlgo(const at::Tensor& input) {
|
||||
return c10d::intra_node_comm::selectAllReduceAlgo(
|
||||
input, topology_, worldSize_);
|
||||
}
|
||||
|
||||
static int64_t usageCounter = 0;
|
||||
|
||||
at::Tensor IntraNodeComm::allReduce(
|
||||
const at::Tensor& input,
|
||||
AllReduceAlgo algo) {
|
||||
// Report usage for testing purposes.
|
||||
// We don't care about overflowing.
|
||||
++usageCounter;
|
||||
auto stream = at::cuda::getCurrentCUDAStream();
|
||||
c10::cuda::CUDACachingAllocator::recordStream(
|
||||
input.storage().data_ptr(), stream);
|
||||
return c10d::intra_node_comm::allReduce(
|
||||
input, p2pStates_, buffers_, topoInfo_, rank_, worldSize_, algo, stream);
|
||||
}
|
||||
|
||||
int64_t getIntraNodeCommUsageCounter() {
|
||||
return usageCounter;
|
||||
}
|
||||
|
||||
} // namespace intra_node_comm
|
||||
} // namespace c10d
|
708
torch/csrc/distributed/c10d/intra_node_comm.cu
Normal file
708
torch/csrc/distributed/c10d/intra_node_comm.cu
Normal file
@ -0,0 +1,708 @@
|
||||
#include <torch/csrc/distributed/c10d/intra_node_comm.hpp>
|
||||
|
||||
#include <ATen/Dispatch.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
|
||||
namespace c10d {
|
||||
namespace intra_node_comm {
|
||||
|
||||
static constexpr size_t kBytesPerThread = 16;
|
||||
static constexpr size_t kMaxAllReduceBlocks = 24;
|
||||
static constexpr size_t kThreadsPerBlock = 1024;
|
||||
static constexpr size_t kWarpSize = 32;
|
||||
|
||||
static constexpr size_t kHcmThreshBytes = 256 * 1024;
|
||||
static constexpr size_t kOneShotThreshBytes = 256 * 1024;
|
||||
static constexpr size_t kTwoShotThreshBytes = 10 * 1024 * 1024;
|
||||
|
||||
#if defined(USE_ROCM)
|
||||
using __nv_bfloat162 = uint32_t;
|
||||
#endif
|
||||
|
||||
struct __align__(16) bf16x8 {
|
||||
__nv_bfloat162 vals[4];
|
||||
};
|
||||
|
||||
#define DEVICE_INLINE __device__ inline __attribute__((always_inline))
|
||||
|
||||
DEVICE_INLINE __nv_bfloat162
|
||||
bf16hadd2(const __nv_bfloat162 x, const __nv_bfloat162 y) {
|
||||
#if defined(USE_ROCM) || (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800))
|
||||
CUDA_KERNEL_ASSERT(false);
|
||||
#else
|
||||
return __hadd2(x, y);
|
||||
#endif
|
||||
}
|
||||
|
||||
DEVICE_INLINE bf16x8 add_bf16x8(bf16x8 a, bf16x8 b) {
|
||||
bf16x8 c;
|
||||
c.vals[0] = bf16hadd2(a.vals[0], b.vals[0]);
|
||||
c.vals[1] = bf16hadd2(a.vals[1], b.vals[1]);
|
||||
c.vals[2] = bf16hadd2(a.vals[2], b.vals[2]);
|
||||
c.vals[3] = bf16hadd2(a.vals[3], b.vals[3]);
|
||||
return c;
|
||||
}
|
||||
|
||||
/**
|
||||
* NOTE [cross device memory synchronization]
|
||||
*
|
||||
* The multi-stage algorithms (e.g. two-shot, hcm allreduce) require the writes
|
||||
* of a thread to be visible by threads with the same block/thread ID on other
|
||||
* devices. To satisfy CUDA's memory consistency model, every thread has to
|
||||
* release its writes at the system scope, and the consuming thread has to
|
||||
* acquire the writes at the system scope. This incurs high overhead and
|
||||
* attempts in optmizing this process can be prone to race condition.
|
||||
*
|
||||
* Instead, we go around caching by having each thread:
|
||||
*
|
||||
* - Directly write to global memory via st.cs (cache-streaming).
|
||||
* - Synchronize with threads within the block.
|
||||
* - Perform cross device synchronization at block level (via system scope
|
||||
* atomic ops).
|
||||
* - Synchronize with threads within the block.
|
||||
* - Directly read from global memory via ld.nc (non-coherent/non-cached).
|
||||
*/
|
||||
template <typename T>
|
||||
DEVICE_INLINE void streamLoad128(bf16x8& val, const T* addr) {
|
||||
#if defined(USE_ROCM) || (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800))
|
||||
CUDA_KERNEL_ASSERT(false);
|
||||
#else
|
||||
unsigned long long int low, high;
|
||||
asm("ld.global.nc.v2.u64 {%0, %1}, [%2];"
|
||||
: "=l"(low), "=l"(high)
|
||||
: "l"(addr));
|
||||
reinterpret_cast<unsigned long long int*>(&val)[0] = low;
|
||||
reinterpret_cast<unsigned long long int*>(&val)[1] = high;
|
||||
#endif
|
||||
}
|
||||
|
||||
__device__ inline void streamStore128(at::BFloat16* addr, const bf16x8& val) {
|
||||
#if defined(USE_ROCM) || (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800))
|
||||
CUDA_KERNEL_ASSERT(false);
|
||||
#else
|
||||
unsigned long long int low, high;
|
||||
low = reinterpret_cast<const unsigned long long int*>(&val)[0];
|
||||
high = reinterpret_cast<const unsigned long long int*>(&val)[1];
|
||||
asm("st.global.cs.v2.u64 [%0], {%1, %2};" : : "l"(addr), "l"(low), "l"(high));
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
DEVICE_INLINE void load128(bf16x8& val, const T* addr) {
|
||||
*reinterpret_cast<uint4*>(&val) = reinterpret_cast<const uint4*>(addr)[0];
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
DEVICE_INLINE void store128(T* addr, const bf16x8& val) {
|
||||
*reinterpret_cast<uint4*>(addr) = reinterpret_cast<const uint4*>(&val)[0];
|
||||
}
|
||||
|
||||
DEVICE_INLINE void releaseSignal(uint32_t* addr) {
|
||||
#if defined(USE_ROCM) || (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800))
|
||||
CUDA_KERNEL_ASSERT(false);
|
||||
#else
|
||||
atomicAdd_system(addr, 1);
|
||||
#endif
|
||||
}
|
||||
|
||||
DEVICE_INLINE void acquireSignal(uint32_t* addr) {
|
||||
#if defined(USE_ROCM) || (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800))
|
||||
CUDA_KERNEL_ASSERT(false);
|
||||
#else
|
||||
volatile uint32_t* signal = addr;
|
||||
uint32_t val;
|
||||
do {
|
||||
val = *signal;
|
||||
} while (val == 0 || atomicCAS_system(addr, val, val - 1) != val);
|
||||
#endif
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
// Fully Connected Algos
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
struct P2pState {
|
||||
uint32_t signals0[kMaxAllReduceBlocks][kMaxDevices];
|
||||
uint32_t signals1[kMaxAllReduceBlocks][kMaxDevices];
|
||||
};
|
||||
|
||||
template <uint32_t kWorldSize, bool kAligned>
|
||||
static __global__ void oneShotAllReduceKernel(
|
||||
at::BFloat16* input,
|
||||
size_t N,
|
||||
size_t N_aligned,
|
||||
std::array<P2pState*, kMaxDevices> p2pStates,
|
||||
std::array<at::BFloat16*, kMaxDevices> buffers,
|
||||
size_t rank) {
|
||||
const size_t numelPerThread = kBytesPerThread / sizeof(at::BFloat16);
|
||||
const size_t offset =
|
||||
(blockDim.x * blockIdx.x + threadIdx.x) * numelPerThread;
|
||||
const size_t stride = blockDim.x * gridDim.x * numelPerThread;
|
||||
|
||||
// Wait for all other ranks to enter the kernel
|
||||
if (threadIdx.x < kWorldSize) {
|
||||
auto targetRank = threadIdx.x;
|
||||
releaseSignal(&p2pStates[targetRank]->signals0[blockIdx.x][rank]);
|
||||
acquireSignal(&p2pStates[rank]->signals0[blockIdx.x][targetRank]);
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// The source pointers. Distributed round-robin for the different warps
|
||||
const at::BFloat16* srcs[kWorldSize];
|
||||
#pragma unroll kWorldSize
|
||||
for (int ii = 0; ii < kWorldSize; ++ii) {
|
||||
int srcRank = (rank + ii) % kWorldSize;
|
||||
srcs[ii] = buffers[srcRank];
|
||||
}
|
||||
|
||||
for (size_t i = offset; i < N_aligned; i += stride) {
|
||||
bf16x8 vals[kWorldSize];
|
||||
#pragma unroll kWorldSize
|
||||
for (size_t ii = 0; ii < kWorldSize; ++ii) {
|
||||
streamLoad128(vals[ii], &srcs[ii][i]);
|
||||
}
|
||||
|
||||
bf16x8 sums;
|
||||
memset(reinterpret_cast<void*>(&sums), 0, sizeof(sums));
|
||||
|
||||
#pragma unroll kWorldSize
|
||||
for (size_t ii = 0; ii < kWorldSize; ++ii) {
|
||||
sums = add_bf16x8(sums, vals[ii]);
|
||||
}
|
||||
if constexpr (kAligned) {
|
||||
streamStore128(&input[i], sums);
|
||||
} else {
|
||||
for (size_t ii = 0; ii < numelPerThread; ++ii) {
|
||||
if (i + ii < N) {
|
||||
input[i + ii] = reinterpret_cast<at::BFloat16*>(&sums)[ii];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <uint32_t kWorldSize>
|
||||
static __launch_bounds__(1024) __global__ void twoShotAllReduceKernel(
|
||||
at::BFloat16* input,
|
||||
size_t N_aligned,
|
||||
std::array<P2pState*, kMaxDevices> p2pStates,
|
||||
std::array<at::BFloat16*, kMaxDevices> buffers,
|
||||
size_t rank) {
|
||||
const size_t numelPerThread = kBytesPerThread / sizeof(at::BFloat16);
|
||||
const size_t offset =
|
||||
(blockDim.x * blockIdx.x + threadIdx.x) * numelPerThread;
|
||||
const size_t stride = blockDim.x * gridDim.x * numelPerThread;
|
||||
const size_t N_per_rank = N_aligned / kWorldSize;
|
||||
const size_t N_start = N_per_rank * rank;
|
||||
|
||||
// Wait for all other ranks to enter the kernel
|
||||
if (threadIdx.x < kWorldSize) {
|
||||
auto targetRank = threadIdx.x;
|
||||
releaseSignal(&p2pStates[targetRank]->signals0[blockIdx.x][rank]);
|
||||
acquireSignal(&p2pStates[rank]->signals0[blockIdx.x][targetRank]);
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// The source pointers. Distributed round-robin for the different warps
|
||||
at::BFloat16* srcs[kWorldSize];
|
||||
size_t srcRanks[kWorldSize];
|
||||
#pragma unroll kWorldSize
|
||||
for (int ii = 0; ii < kWorldSize; ++ii) {
|
||||
int srcRank = (rank + ii) % kWorldSize;
|
||||
srcs[ii] = buffers[srcRank];
|
||||
srcRanks[ii] = srcRank;
|
||||
}
|
||||
|
||||
for (size_t i = offset; i < N_per_rank; i += stride) {
|
||||
bf16x8 vals[kWorldSize];
|
||||
#pragma unroll kWorldSize
|
||||
for (size_t ii = 0; ii < kWorldSize; ++ii) {
|
||||
streamLoad128(vals[ii], &srcs[ii][N_start + i]);
|
||||
}
|
||||
|
||||
bf16x8 sums;
|
||||
memset(reinterpret_cast<void*>(&sums), 0, sizeof(sums));
|
||||
|
||||
#pragma unroll kWorldSize
|
||||
for (size_t ii = 0; ii < kWorldSize; ++ii) {
|
||||
sums = add_bf16x8(sums, vals[ii]);
|
||||
}
|
||||
streamStore128(&srcs[0][N_start + i], sums);
|
||||
// Store local sums into input now so we can avoid
|
||||
// a global memory access later for it.
|
||||
streamStore128(&input[N_start + i], sums);
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
if (threadIdx.x < kWorldSize) {
|
||||
auto targetRank = threadIdx.x;
|
||||
releaseSignal(&p2pStates[targetRank]->signals1[blockIdx.x][rank]);
|
||||
acquireSignal(&p2pStates[rank]->signals1[blockIdx.x][targetRank]);
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
for (size_t i = offset; i < N_per_rank; i += stride) {
|
||||
#pragma unroll kWorldSize - 1
|
||||
for (size_t ii = 1; ii < kWorldSize; ++ii) {
|
||||
size_t k = N_start + i + (srcRanks[ii] - rank) * N_per_rank;
|
||||
bf16x8 val;
|
||||
streamLoad128(val, &srcs[ii][k]);
|
||||
streamStore128(&input[k], val);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
// Hybrid Cube Mesh Algos
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/**
|
||||
* NOTE [hybrid cube mesh]
|
||||
*
|
||||
* In a hybrid cube mesh topology, every device has exactly 4 neighbors
|
||||
* (directly connected via NVLink). For every device X, it has exactly 1
|
||||
* neighbor Y that is a neighbor of the 3 non-neighbor of X. We call Y the
|
||||
* relay neighbor of X. This property is symmetrical: X is also guaranteed to
|
||||
* be the relay neighbor of Y.
|
||||
*
|
||||
* With this property, we can perform a variant of one-shot allreduce algo that
|
||||
* only moves data across NVLinks:
|
||||
*
|
||||
* - Each device one-shot allreduce among itself and 3 non-relay neighbors.
|
||||
* - Each device exchange data with its relay neighbor.
|
||||
*
|
||||
* HybridCubeMesh is a data structure for describing the topology:
|
||||
*
|
||||
* - hcm[X][0:3] are the 3 neighbors of X.
|
||||
* - hcm[X][3] is the relay neighbor of X.
|
||||
* - For load balancing purpose, we also ensure that if hcm[X][k] = Y,
|
||||
* hcm[Y][k] = X.
|
||||
*/
|
||||
std::optional<HybridCubeMesh> getHybridCubeMesh(NvlMesh nvlMesh) {
|
||||
std::array<std::unordered_set<size_t>, kMaxDevices> neighbors = {};
|
||||
std::array<size_t, kMaxDevices> neighborMasks = {};
|
||||
for (size_t i = 0; i < kMaxDevices; ++i) {
|
||||
for (size_t j = 0; j < kMaxDevices; ++j) {
|
||||
if (nvlMesh[i][j] > 0) {
|
||||
neighbors[i].insert(j);
|
||||
neighborMasks[i] |= (1ul << j);
|
||||
}
|
||||
}
|
||||
}
|
||||
HybridCubeMesh hcm = {};
|
||||
for (auto& row : hcm) {
|
||||
row.fill(-1);
|
||||
}
|
||||
// A topology is an HCM if:
|
||||
// - Every device has exactly 4 neighbors.
|
||||
// - For every device, it has exactly 1 relay neighbor that is
|
||||
// a neighbor of the 3 non-neighbor of the device.
|
||||
for (size_t i = 0; i < kMaxDevices; ++i) {
|
||||
if (neighbors[i].size() != 4) {
|
||||
return std::nullopt;
|
||||
}
|
||||
// Condition 1: check the number of neighbors
|
||||
std::vector<size_t> relayNeighbors;
|
||||
for (size_t j = 0; j < kMaxDevices; ++j) {
|
||||
if ((neighborMasks[i] & neighborMasks[j]) == 0) {
|
||||
relayNeighbors.push_back(j);
|
||||
}
|
||||
}
|
||||
// Condition 2: check the number of relay neighbors
|
||||
if (relayNeighbors.size() != 1) {
|
||||
return std::nullopt;
|
||||
}
|
||||
neighbors[i].erase(relayNeighbors[0]);
|
||||
hcm[i][3] = relayNeighbors[0];
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < kMaxDevices; ++i) {
|
||||
for (size_t k = 0; k < 3; ++k) {
|
||||
// We can only fill hcm[i][k] with j if hcm[j][k] is not filled
|
||||
for (size_t j : neighbors[i]) {
|
||||
if (hcm[j][k] == -1) {
|
||||
hcm[i][k] = j;
|
||||
hcm[j][k] = i;
|
||||
break;
|
||||
}
|
||||
}
|
||||
TORCH_CHECK(hcm[i][k] != -1);
|
||||
neighbors[i].erase(hcm[i][k]);
|
||||
}
|
||||
}
|
||||
return hcm;
|
||||
}
|
||||
|
||||
template <bool kAligned>
|
||||
static __global__ void hybridCubeMeshAllReduceKernel(
|
||||
at::BFloat16* input,
|
||||
size_t N,
|
||||
size_t N_aligned,
|
||||
std::array<P2pState*, kMaxDevices> p2pStates,
|
||||
std::array<at::BFloat16*, kMaxDevices> buffers,
|
||||
int hcmInfo[4],
|
||||
size_t rank) {
|
||||
const size_t numelPerThread = kBytesPerThread / sizeof(at::BFloat16);
|
||||
const size_t offset =
|
||||
(blockDim.x * blockIdx.x + threadIdx.x) * numelPerThread;
|
||||
const size_t stride = blockDim.x * gridDim.x * numelPerThread;
|
||||
const int relayRank = hcmInfo[3];
|
||||
|
||||
// Wait for HCM neigbors to enter the kernel
|
||||
if (threadIdx.x < 3) {
|
||||
auto targetRank = hcmInfo[threadIdx.x];
|
||||
releaseSignal(&p2pStates[targetRank]->signals0[blockIdx.x][rank]);
|
||||
acquireSignal(&p2pStates[rank]->signals0[blockIdx.x][targetRank]);
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
const at::BFloat16* srcs[4] = {
|
||||
buffers[rank],
|
||||
buffers[hcmInfo[0]],
|
||||
buffers[hcmInfo[1]],
|
||||
buffers[hcmInfo[2]],
|
||||
};
|
||||
at::BFloat16* localRelay = buffers[rank] + kMaxIntraNodeSize / 2;
|
||||
at::BFloat16* remoteRelay = buffers[relayRank] + kMaxIntraNodeSize / 2;
|
||||
|
||||
for (size_t i = offset; i < N_aligned; i += stride) {
|
||||
bf16x8 vals[4];
|
||||
|
||||
#pragma unroll 4
|
||||
for (size_t ii = 0; ii < 4; ++ii) {
|
||||
streamLoad128(vals[ii], &srcs[ii][i]);
|
||||
}
|
||||
|
||||
bf16x8 sums;
|
||||
memset(reinterpret_cast<void*>(&sums), 0, sizeof(sums));
|
||||
|
||||
#pragma unroll 4
|
||||
for (size_t ii = 0; ii < 4; ++ii) {
|
||||
sums = add_bf16x8(sums, vals[ii]);
|
||||
}
|
||||
// Cached store for local sums
|
||||
store128(&localRelay[i], sums);
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
releaseSignal(&p2pStates[relayRank]->signals0[blockIdx.x][rank]);
|
||||
acquireSignal(&p2pStates[rank]->signals0[blockIdx.x][relayRank]);
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
for (size_t i = offset; i < N_aligned; i += stride) {
|
||||
bf16x8 localSum, remoteSum;
|
||||
// Cached load for local sums
|
||||
load128(localSum, &localRelay[i]);
|
||||
streamLoad128(remoteSum, &remoteRelay[i]);
|
||||
localSum = add_bf16x8(localSum, remoteSum);
|
||||
if constexpr (kAligned) {
|
||||
streamStore128(&input[i], localSum);
|
||||
} else {
|
||||
for (size_t ii = 0; ii < numelPerThread; ++ii) {
|
||||
if (i + ii < N) {
|
||||
input[i + ii] = reinterpret_cast<at::BFloat16*>(&localSum)[ii];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static inline size_t divUp(uint32_t a, uint32_t b) {
|
||||
return (a + b - 1) / b;
|
||||
}
|
||||
|
||||
static inline size_t alignUp(uint32_t a, uint32_t b) {
|
||||
return divUp(a, b) * b;
|
||||
}
|
||||
|
||||
static void checkInput(const at::Tensor& input, size_t rank) {
|
||||
TORCH_CHECK(
|
||||
input.dtype() == at::kBFloat16,
|
||||
"oneShotAllReduce only supports bf16 for now");
|
||||
TORCH_CHECK(input.is_non_overlapping_and_dense());
|
||||
TORCH_CHECK(input.device().is_cuda());
|
||||
TORCH_CHECK(static_cast<size_t>(input.get_device()) == rank);
|
||||
}
|
||||
|
||||
static void getLaunchConfig(
|
||||
size_t N_aligned,
|
||||
size_t elemSize,
|
||||
dim3& blocks,
|
||||
dim3& threads) {
|
||||
blocks = dim3(0, 1, 1);
|
||||
threads = dim3(0, 1, 1);
|
||||
|
||||
const auto numelPerThread = kBytesPerThread / elemSize;
|
||||
const auto numelPerWarp = numelPerThread * kWarpSize;
|
||||
TORCH_CHECK(N_aligned % numelPerThread == 0);
|
||||
TORCH_CHECK(N_aligned % numelPerWarp == 0);
|
||||
if (N_aligned < numelPerThread * kThreadsPerBlock) {
|
||||
threads.x = N_aligned / numelPerWarp * kWarpSize;
|
||||
blocks.x = 1;
|
||||
} else {
|
||||
auto warpsRequired = N_aligned / numelPerWarp;
|
||||
auto threadsRequired = N_aligned / numelPerThread;
|
||||
blocks.x =
|
||||
std::min(divUp(threadsRequired, kThreadsPerBlock), kMaxAllReduceBlocks);
|
||||
auto warpsPerBlock = divUp(warpsRequired, blocks.x);
|
||||
threads.x = std::min(kThreadsPerBlock, warpsPerBlock * kWarpSize);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static auto castArr(std::array<void*, kMaxDevices> arr) {
|
||||
std::array<T, kMaxDevices> arr_;
|
||||
for (size_t i = 0; i < kMaxDevices; ++i) {
|
||||
arr_[i] = reinterpret_cast<T>(arr[i]);
|
||||
}
|
||||
return arr_;
|
||||
}
|
||||
|
||||
bool isIntraNodeCommSupported() {
|
||||
#if defined(USE_ROCM) || (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800))
|
||||
return false;
|
||||
#else
|
||||
return true;
|
||||
#endif
|
||||
}
|
||||
|
||||
void* initP2pState() {
|
||||
void* state = nullptr;
|
||||
AT_CUDA_CHECK(cudaMalloc(&state, sizeof(P2pState)));
|
||||
AT_CUDA_CHECK(cudaMemset(state, 0, sizeof(P2pState)));
|
||||
return state;
|
||||
}
|
||||
|
||||
void* initTopoInfo(Topology topology, NvlMesh nvlMesh, size_t rank) {
|
||||
void* topoInfo = nullptr;
|
||||
if (topology != Topology::HYBRID_CUBE_MESH) {
|
||||
return topoInfo;
|
||||
}
|
||||
auto hcm = getHybridCubeMesh(nvlMesh);
|
||||
int hcmInfo[4];
|
||||
std::copy((*hcm)[rank].begin(), (*hcm)[rank].begin() + 4, hcmInfo);
|
||||
AT_CUDA_CHECK(cudaMalloc(&topoInfo, sizeof(hcmInfo)));
|
||||
AT_CUDA_CHECK(
|
||||
cudaMemcpy(topoInfo, hcmInfo, sizeof(hcmInfo), cudaMemcpyHostToDevice));
|
||||
return topoInfo;
|
||||
}
|
||||
|
||||
at::Tensor oneShotAllReduce(
|
||||
const at::Tensor& input,
|
||||
std::array<void*, kMaxDevices> p2pStates,
|
||||
std::array<void*, kMaxDevices> buffers,
|
||||
size_t rank,
|
||||
size_t worldSize,
|
||||
at::cuda::CUDAStream& stream) {
|
||||
checkInput(input, rank);
|
||||
|
||||
size_t numelPerWarp = kBytesPerThread / input.element_size() * kWarpSize;
|
||||
size_t N_aligned = alignUp(input.numel(), numelPerWarp);
|
||||
TORCH_CHECK(N_aligned <= kMaxIntraNodeSize / input.element_size());
|
||||
|
||||
dim3 blocks, threads;
|
||||
getLaunchConfig(N_aligned, input.element_size(), blocks, threads);
|
||||
|
||||
at::cuda::OptionalCUDAGuard guard(input.get_device());
|
||||
AT_CUDA_CHECK(cudaMemcpyAsync(
|
||||
buffers[rank],
|
||||
input.data_ptr(),
|
||||
input.numel() * input.element_size(),
|
||||
cudaMemcpyDeviceToDevice,
|
||||
stream));
|
||||
|
||||
#define X(kWorldSize, kAligned) \
|
||||
if (worldSize == kWorldSize) { \
|
||||
oneShotAllReduceKernel<kWorldSize, kAligned> \
|
||||
<<<blocks, threads, 0, stream>>>( \
|
||||
input.data_ptr<at::BFloat16>(), \
|
||||
input.numel(), \
|
||||
N_aligned, \
|
||||
castArr<P2pState*>(p2pStates), \
|
||||
castArr<at::BFloat16*>(buffers), \
|
||||
rank); \
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK(); \
|
||||
}
|
||||
|
||||
#define DISPATCH_ALL_WORLD_SIZES(kAligned) \
|
||||
X(2, kAligned); \
|
||||
X(3, kAligned); \
|
||||
X(4, kAligned); \
|
||||
X(5, kAligned); \
|
||||
X(6, kAligned); \
|
||||
X(7, kAligned); \
|
||||
X(8, kAligned);
|
||||
|
||||
if (N_aligned == static_cast<size_t>(input.numel())) {
|
||||
DISPATCH_ALL_WORLD_SIZES(true);
|
||||
} else {
|
||||
DISPATCH_ALL_WORLD_SIZES(false);
|
||||
}
|
||||
|
||||
#undef DISPATCH_ALL_WORLD_SIZES
|
||||
#undef X
|
||||
return input;
|
||||
}
|
||||
|
||||
at::Tensor twoShotAllReduce(
|
||||
const at::Tensor& input,
|
||||
std::array<void*, kMaxDevices> p2pStates,
|
||||
std::array<void*, kMaxDevices> buffers,
|
||||
size_t rank,
|
||||
size_t worldSize,
|
||||
at::cuda::CUDAStream& stream) {
|
||||
checkInput(input, rank);
|
||||
|
||||
size_t numelPerWarp = kBytesPerThread / input.element_size() * kWarpSize;
|
||||
size_t N_aligned = alignUp(input.numel(), worldSize * numelPerWarp);
|
||||
size_t N_per_rank = N_aligned / worldSize;
|
||||
TORCH_CHECK(N_aligned <= kMaxIntraNodeSize / input.element_size());
|
||||
|
||||
dim3 blocks, threads;
|
||||
getLaunchConfig(N_per_rank, input.element_size(), blocks, threads);
|
||||
|
||||
auto output = N_aligned == static_cast<size_t>(input.numel())
|
||||
? input
|
||||
: input.new_empty(N_aligned);
|
||||
|
||||
at::cuda::OptionalCUDAGuard guard(input.get_device());
|
||||
AT_CUDA_CHECK(cudaMemcpyAsync(
|
||||
buffers[rank],
|
||||
input.data_ptr(),
|
||||
input.numel() * input.element_size(),
|
||||
cudaMemcpyDeviceToDevice,
|
||||
stream));
|
||||
|
||||
#define X(kWorldSize) \
|
||||
if (worldSize == kWorldSize) { \
|
||||
twoShotAllReduceKernel<kWorldSize><<<blocks, threads, 0, stream>>>( \
|
||||
output.data_ptr<at::BFloat16>(), \
|
||||
N_aligned, \
|
||||
castArr<P2pState*>(p2pStates), \
|
||||
castArr<at::BFloat16*>(buffers), \
|
||||
rank); \
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK(); \
|
||||
}
|
||||
X(2);
|
||||
X(3);
|
||||
X(4);
|
||||
X(5);
|
||||
X(6);
|
||||
X(7);
|
||||
X(8);
|
||||
#undef X
|
||||
|
||||
if (output.data_ptr() != input.data_ptr()) {
|
||||
AT_CUDA_CHECK(cudaMemcpyAsync(
|
||||
input.data_ptr(),
|
||||
output.data_ptr(),
|
||||
input.numel() * input.element_size(),
|
||||
cudaMemcpyDeviceToDevice,
|
||||
stream));
|
||||
}
|
||||
return input;
|
||||
}
|
||||
|
||||
at::Tensor hybridCubeMeshAllReduce(
|
||||
const at::Tensor& input,
|
||||
std::array<void*, kMaxDevices> p2pStates,
|
||||
std::array<void*, kMaxDevices> buffers,
|
||||
int hcmInfo[4],
|
||||
size_t rank,
|
||||
size_t worldSize,
|
||||
at::cuda::CUDAStream& stream) {
|
||||
checkInput(input, rank);
|
||||
|
||||
size_t numelPerWarp = kBytesPerThread / input.element_size() * kWarpSize;
|
||||
size_t N_aligned = alignUp(input.numel(), numelPerWarp);
|
||||
TORCH_CHECK(N_aligned <= kMaxIntraNodeSize / input.element_size());
|
||||
|
||||
dim3 blocks, threads;
|
||||
getLaunchConfig(N_aligned, input.element_size(), blocks, threads);
|
||||
|
||||
at::cuda::OptionalCUDAGuard guard(input.get_device());
|
||||
AT_CUDA_CHECK(cudaMemcpyAsync(
|
||||
buffers[rank],
|
||||
input.data_ptr(),
|
||||
input.numel() * input.element_size(),
|
||||
cudaMemcpyDeviceToDevice,
|
||||
stream));
|
||||
|
||||
#define X(kAligned) \
|
||||
hybridCubeMeshAllReduceKernel<kAligned><<<blocks, threads, 0, stream>>>( \
|
||||
input.data_ptr<at::BFloat16>(), \
|
||||
input.numel(), \
|
||||
N_aligned, \
|
||||
castArr<P2pState*>(p2pStates), \
|
||||
castArr<at::BFloat16*>(buffers), \
|
||||
hcmInfo, \
|
||||
rank); \
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
|
||||
if (N_aligned == static_cast<size_t>(input.numel())) {
|
||||
X(true);
|
||||
} else {
|
||||
X(false);
|
||||
}
|
||||
#undef X
|
||||
return input;
|
||||
}
|
||||
|
||||
AllReduceAlgo selectAllReduceAlgo(
|
||||
const at::Tensor& input,
|
||||
Topology topology,
|
||||
size_t worldSize) {
|
||||
// Only support bf16 for now
|
||||
if (input.dtype() != at::kBFloat16 ||
|
||||
input.numel() * input.element_size() > kMaxIntraNodeSize) {
|
||||
return AllReduceAlgo::NONE;
|
||||
}
|
||||
const auto numel = input.numel();
|
||||
const auto numelPerWarp = kBytesPerThread / input.element_size() * kWarpSize;
|
||||
if (topology == Topology::HYBRID_CUBE_MESH) {
|
||||
TORCH_CHECK(
|
||||
worldSize == 8, "hyperCubeAllReduce only supports exactly 8 GPUs");
|
||||
if (alignUp(numel, numelPerWarp) <= kHcmThreshBytes) {
|
||||
return AllReduceAlgo::HCM;
|
||||
}
|
||||
}
|
||||
if (topology == Topology::FULLY_CONNECTED) {
|
||||
if (alignUp(numel, numelPerWarp) <= kOneShotThreshBytes) {
|
||||
return AllReduceAlgo::ONE_SHOT;
|
||||
}
|
||||
if (alignUp(numel, numelPerWarp * worldSize) <= kTwoShotThreshBytes) {
|
||||
return AllReduceAlgo::TWO_SHOT;
|
||||
}
|
||||
}
|
||||
return AllReduceAlgo::NONE;
|
||||
}
|
||||
|
||||
at::Tensor allReduce(
|
||||
const at::Tensor& input,
|
||||
std::array<void*, kMaxDevices> p2pStates,
|
||||
std::array<void*, kMaxDevices> buffers,
|
||||
void* topoInfo,
|
||||
size_t rank,
|
||||
size_t worldSize,
|
||||
AllReduceAlgo algo,
|
||||
at::cuda::CUDAStream& stream) {
|
||||
switch (algo) {
|
||||
case AllReduceAlgo::ONE_SHOT:
|
||||
return oneShotAllReduce(
|
||||
input, p2pStates, buffers, rank, worldSize, stream);
|
||||
case AllReduceAlgo::TWO_SHOT:
|
||||
return twoShotAllReduce(
|
||||
input, p2pStates, buffers, rank, worldSize, stream);
|
||||
case AllReduceAlgo::HCM:
|
||||
return hybridCubeMeshAllReduce(
|
||||
input, p2pStates, buffers, (int*)topoInfo, rank, worldSize, stream);
|
||||
default:
|
||||
C10_THROW_ERROR(ValueError, "IntraNodeComm: invalid algo");
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace intra_node_comm
|
||||
} // namespace c10d
|
102
torch/csrc/distributed/c10d/intra_node_comm.hpp
Normal file
102
torch/csrc/distributed/c10d/intra_node_comm.hpp
Normal file
@ -0,0 +1,102 @@
|
||||
#pragma once
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/cuda/CUDAEvent.h>
|
||||
#include <c10/cuda/CUDAStream.h>
|
||||
#include <torch/csrc/distributed/c10d/Store.hpp>
|
||||
#include <torch/csrc/distributed/c10d/Work.hpp>
|
||||
|
||||
namespace c10d {
|
||||
namespace intra_node_comm {
|
||||
|
||||
constexpr size_t kMaxDevices = 8;
|
||||
constexpr size_t kMaxIntraNodeSize = 10 * 1024 * 1024;
|
||||
|
||||
using NvlMesh = std::array<std::array<size_t, kMaxDevices>, kMaxDevices>;
|
||||
using HybridCubeMesh = std::array<std::array<int, 4>, kMaxDevices>;
|
||||
|
||||
enum class Topology { UNKNOWN = 0, FULLY_CONNECTED = 1, HYBRID_CUBE_MESH = 2 };
|
||||
|
||||
enum class AllReduceAlgo { NONE = 0, ONE_SHOT = 1, TWO_SHOT = 2, HCM = 3 };
|
||||
|
||||
class TORCH_API IntraNodeComm : public c10::intrusive_ptr_target {
|
||||
public:
|
||||
IntraNodeComm(
|
||||
Topology topology,
|
||||
std::array<void*, kMaxDevices> p2pStates,
|
||||
std::array<void*, kMaxDevices> buffers,
|
||||
void* topoInfo,
|
||||
size_t rank,
|
||||
size_t worldSize);
|
||||
|
||||
~IntraNodeComm();
|
||||
|
||||
/**
|
||||
* Rendezvous via a c10d::Store.
|
||||
* This function may return nullptr if intra-node comm is not applicable.
|
||||
* It guarantees all participants either succeeds or abort.
|
||||
*/
|
||||
static c10::intrusive_ptr<IntraNodeComm> rendezvous(
|
||||
c10::intrusive_ptr<c10d::Store> store,
|
||||
const std::string& prefix,
|
||||
size_t rank,
|
||||
size_t worldSize);
|
||||
|
||||
/**
|
||||
* Selects a AllReduceAlgo that we think will outperform nccl.
|
||||
* Returns AllReduceAlgo::NONE if we don't think we can outperform nccl.
|
||||
*/
|
||||
AllReduceAlgo selectAllReduceAlgo(const at::Tensor& input);
|
||||
|
||||
at::Tensor allReduce(const at::Tensor& input, AllReduceAlgo algo);
|
||||
|
||||
private:
|
||||
Topology topology_;
|
||||
std::array<void*, kMaxDevices> p2pStates_;
|
||||
std::array<void*, kMaxDevices> buffers_;
|
||||
void* topoInfo_;
|
||||
size_t rank_;
|
||||
size_t worldSize_;
|
||||
};
|
||||
|
||||
/**
|
||||
* NOTE [IntraNodeComm Stream Semantics]
|
||||
*
|
||||
* ProcessGroupNCCL launches kernels differently from the conventional PyTorch
|
||||
* CUDA semantics: it always launches collective kernels onto a dedicated
|
||||
* communication stream. Therefore, it needs to:
|
||||
*
|
||||
* - Synchronize the calling stream and the comm stream.
|
||||
* - Ensure the memory safety of the operands (via record_stream or stashing).
|
||||
* - Synchronize the waiting stream with the comm stream.
|
||||
*
|
||||
* Unconditionally performing these tasks makes sense when we expect most of the
|
||||
* communication to benefit from compute/comm overlap. However, IntraNodeComm
|
||||
* primarily aims to optimize small, latency-sensitive, blocking communication,
|
||||
* in which the overhead incurred by the above steps can be quite pronounced.
|
||||
*
|
||||
* Thus, IntraNodeComm follows the conventional PyTorch CUDA semantics and
|
||||
* launches kernels onto the stream specified by the user. Although the user
|
||||
* can perform neccessary synchronization via wait_stream, to provide a UX
|
||||
* consistent to that of ProcessGroupNCCL, the neccessary stream
|
||||
* synchronization can also be performed via IntraNodeWork::wait().
|
||||
*/
|
||||
class IntraNodeCommWork : public c10d::Work {
|
||||
public:
|
||||
IntraNodeCommWork() : c10d::Work() {
|
||||
event_.record();
|
||||
}
|
||||
|
||||
bool wait(std::chrono::milliseconds timeout = kNoTimeout) override {
|
||||
event_.block(at::cuda::getCurrentCUDAStream());
|
||||
return true;
|
||||
}
|
||||
|
||||
private:
|
||||
at::cuda::CUDAEvent event_;
|
||||
};
|
||||
|
||||
TORCH_API int64_t getIntraNodeCommUsageCounter();
|
||||
|
||||
} // namespace intra_node_comm
|
||||
} // namespace c10d
|
Reference in New Issue
Block a user