Revert "Introduce 3 low-latency, intra-node allreduce algorithms for small messages to PyTorch (#114001)"

This reverts commit adfbd2b219f4995d3f13870927022b67550f8b0e.

Reverted https://github.com/pytorch/pytorch/pull/114001 on behalf of https://github.com/atalman due to OSSCI oncall, breaks periodic jobs ([comment](https://github.com/pytorch/pytorch/pull/114001#issuecomment-1856539040))
This commit is contained in:
PyTorch MergeBot
2023-12-14 20:33:10 +00:00
parent 67232199b1
commit 7ecddaef23
12 changed files with 6 additions and 1354 deletions

View File

@ -1452,10 +1452,7 @@ cu_library(
# https://github.com/pytorch/pytorch/issues/79236
# To solve it we add it into the `caffe2_cuda`,
# this is also aligned with the CMake build.
srcs = [":caffe2_cu_srcs"] + [
"torch/csrc/distributed/c10d/intra_node_comm.cu",
"torch/csrc/distributed/c10d/quantization/quantization_gpu.cu",
],
srcs = [":caffe2_cu_srcs"] + ["torch/csrc/distributed/c10d/quantization/quantization_gpu.cu"],
copts = CAFFE2_COPTS + torch_cuda_half_options,
visibility = ["//visibility:public"],
deps = [
@ -1622,7 +1619,6 @@ cc_library(
exclude = [
"torch/csrc/cuda/python_nccl.cpp",
"torch/csrc/cuda/nccl.cpp",
"torch/csrc/distributed/c10d/intra_node_comm.cu",
"torch/csrc/distributed/c10d/quantization/quantization_gpu.cu",
],
)) + torch_sources,

View File

@ -674,8 +674,6 @@ libtorch_cuda_distributed_extra_sources = [
"torch/csrc/distributed/c10d/ProcessGroupUCC.cpp",
"torch/csrc/distributed/c10d/UCCTracing.cpp",
"torch/csrc/distributed/c10d/UCCUtils.cpp",
"torch/csrc/distributed/c10d/intra_node_comm.cpp",
"torch/csrc/distributed/c10d/intra_node_comm.cu",
"torch/csrc/distributed/rpc/tensorpipe_cuda.cpp",
"torch/csrc/distributed/c10d/quantization/quantization_gpu.cu",
]

View File

@ -37,7 +37,7 @@ void* DriverAPI::get_nvml_handle() {
return nvml_hanle;
}
C10_EXPORT DriverAPI* DriverAPI::get() {
DriverAPI* DriverAPI::get() {
static DriverAPI singleton = create_driver_api();
return &singleton;
}

View File

@ -31,8 +31,6 @@
#define C10_NVML_DRIVER_API(_) \
_(nvmlInit_v2) \
_(nvmlDeviceGetHandleByPciBusId_v2) \
_(nvmlDeviceGetNvLinkRemoteDeviceType) \
_(nvmlDeviceGetNvLinkRemotePciInfo_v2) \
_(nvmlDeviceGetComputeRunningProcesses)
namespace c10 {

View File

@ -641,10 +641,6 @@ if(USE_CUDA)
append_filelist("libtorch_cuda_distributed_base_sources" Caffe2_GPU_SRCS)
if(NOT WIN32)
append_filelist("libtorch_cuda_distributed_extra_sources" Caffe2_GPU_SRCS)
set_source_files_properties(
${TORCH_SRC_DIR}/csrc/distributed/c10d/intra_node_comm.cpp
PROPERTIES COMPILE_FLAGS "-DPYTORCH_C10_DRIVER_API_SUPPORTED=1"
)
endif()
endif()
set_source_files_properties(

View File

@ -34,7 +34,6 @@ from test_c10d_common import gpus_for_rank, DoubleGpuNet, ConvNet, ModuleForDdpC
from torch import nn
from torch._C._distributed_c10d import OpType
from torch.nn.parallel import DistributedDataParallel
from torch.testing._internal.common_cuda import SM80OrLater
from torch.testing._internal.common_distributed import (
MultiProcessTestCase,
init_multigpu_helper,
@ -3114,56 +3113,6 @@ class CommTest(test_c10d_common.AbstractCommTest, MultiProcessTestCase):
for i, t in enumerate(tensors):
self.assertEqual(t, torch.full_like(t, self.world_size * (i + (self.world_size + 1.) / 2.)))
@requires_nccl()
@skip_if_lt_x_gpu(2)
def test_intra_node_comm_all_reduce(self):
if not SM80OrLater:
return
from torch._C._distributed_c10d import _get_intra_node_comm_usage_counter
store = c10d.FileStore(self.file_name, self.world_size)
os.environ["ENABLE_INTRA_NODE_COMM"] = "1"
os.environ["TEST_INTRA_NODE_COMM"] = "1"
torch.cuda.set_device(self.rank)
c10d.init_process_group(
backend="nccl", rank=self.rank, world_size=self.world_size, store=store
)
expect = self.world_size * (self.world_size - 1) // 2
# IntraNodeComm currently only supports sum and bf16.
# Verify that it is not used in the next two configurations.
t = torch.full((4 * 1024 // 2,), self.rank).cuda()
c10d.all_reduce(t, c10d.ReduceOp.SUM)
self.assertTrue(t.eq(expect).all())
self.assertEqual(_get_intra_node_comm_usage_counter(), 0)
t = torch.full((4 * 1024 // 2,), self.rank, dtype=torch.bfloat16).cuda()
c10d.all_reduce(t, c10d.ReduceOp.AVG)
self.assertEqual(_get_intra_node_comm_usage_counter(), 0)
# Verify that IntraNodeComm is used up to 10MB
t = torch.full((4 * 1024 // 2,), self.rank, dtype=torch.bfloat16).cuda()
c10d.all_reduce(t, c10d.ReduceOp.SUM)
self.assertTrue(t.eq(expect).all())
self.assertEqual(_get_intra_node_comm_usage_counter(), 1)
t = torch.full((512 * 1024 // 2,), self.rank, dtype=torch.bfloat16).cuda()
c10d.all_reduce(t, c10d.ReduceOp.SUM)
self.assertTrue(t.eq(expect).all())
self.assertEqual(_get_intra_node_comm_usage_counter(), 2)
t = torch.full((10 * 1024 ** 2 // 2,), self.rank, dtype=torch.bfloat16).cuda()
c10d.all_reduce(t, c10d.ReduceOp.SUM)
self.assertTrue(t.eq(expect).all())
self.assertEqual(_get_intra_node_comm_usage_counter(), 3)
# Verify that IntraNodeComm is not used beyond 10MB
t = torch.full((10 * 1024 ** 2 // 2 + 1,), self.rank, dtype=torch.bfloat16).cuda()
c10d.all_reduce(t, c10d.ReduceOp.SUM)
self.assertTrue(t.eq(expect).all())
self.assertEqual(_get_intra_node_comm_usage_counter(), 3)
c10d.destroy_process_group()
@requires_nccl()
@skip_if_lt_x_gpu(2)
def test_sequence_num_set_default_pg_nccl(self):

View File

@ -712,8 +712,7 @@ ProcessGroupNCCL::ProcessGroupNCCL(
terminateProcessGroup_(false),
terminateHeartbeatMonitorThread_(false),
collectiveDebugInfoMode_(false),
uid_(process_group_id++),
intraNodeComm_(initIntraNodeComm()) {
uid_(process_group_id++) {
TORCH_CHECK_WITH(
ValueError,
at::cuda::getNumGPUs() != 0,
@ -896,12 +895,6 @@ void ProcessGroupNCCL::performNocolorSplit(at::Device device) {
#endif
}
c10::intrusive_ptr<intra_node_comm::IntraNodeComm> ProcessGroupNCCL::
initIntraNodeComm() {
return intra_node_comm::IntraNodeComm::rendezvous(
store_, std::to_string(uid_), rank_, size_);
}
void ProcessGroupNCCL::runHealthCheck() {
// Run health check in a separate thread and wait on CV to handle timeouts,
// since majority of getNCCLComm failures are hangs.
@ -2808,16 +2801,6 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::allreduce_impl(
c10::intrusive_ptr<Work> ProcessGroupNCCL::allreduce(
std::vector<at::Tensor>& tensors,
const AllreduceOptions& opts) {
if (intraNodeComm_ != nullptr && tensors.size() == 1 &&
opts.reduceOp == ReduceOp::SUM) {
using namespace intra_node_comm;
auto algo = intraNodeComm_->selectAllReduceAlgo(tensors[0]);
if (algo != intra_node_comm::AllReduceAlgo::NONE) {
intraNodeComm_->allReduce(tensors[0], algo);
return c10::make_intrusive<IntraNodeCommWork>();
}
}
check_gpu_tensors_different_devices(tensors);
// @lint-ignore CLANGTIDY

View File

@ -13,7 +13,6 @@
#include <torch/csrc/distributed/c10d/Backend.hpp>
#include <torch/csrc/distributed/c10d/NCCLUtils.hpp>
#include <torch/csrc/distributed/c10d/Store.hpp>
#include <torch/csrc/distributed/c10d/intra_node_comm.hpp>
#include <ATen/DynamicLibrary.h>
#include <ATen/cuda/CUDAContext.h>
@ -547,8 +546,6 @@ class TORCH_API ProcessGroupNCCL : public Backend {
// Provide an API for users to define their own ways to store NCCL debug info.
void registerDebugInfoWriter(std::unique_ptr<DebugInfoWriter> writer);
c10::intrusive_ptr<intra_node_comm::IntraNodeComm> initIntraNodeComm();
// Provides an API to abort the ProcessGroup (similar to ncclCommAbort)
// instead of relying on ProcessGroupNCCL destructor.
void abort(c10::optional<std::string> abortReason = c10::nullopt);
@ -946,8 +943,6 @@ class TORCH_API ProcessGroupNCCL : public Backend {
std::unique_ptr<DebugInfoWriter> debugInfoWriter_ = nullptr;
size_t uid_;
c10::intrusive_ptr<intra_node_comm::IntraNodeComm> intraNodeComm_;
};
TORCH_API std::string dump_nccl_trace();

View File

@ -21,7 +21,6 @@
#ifdef USE_C10D_NCCL
#include <torch/csrc/distributed/c10d/NCCLUtils.hpp>
#include <torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp>
#include <torch/csrc/distributed/c10d/intra_node_comm.hpp>
#endif
#ifdef USE_C10D_MPI
@ -2329,10 +2328,6 @@ options :class:`~torch.distributed.ProcessGroupNCCL.Options`).
"perform_nocolor_split",
&::c10d::ProcessGroupNCCL::performNocolorSplit);
module.def(
"_get_intra_node_comm_usage_counter",
&::c10d::intra_node_comm::getIntraNodeCommUsageCounter);
#ifdef NCCL_HAS_COMM_CTA_CGA
py::class_<ncclConfig_t>(
processGroupNCCL,

View File

@ -1,448 +0,0 @@
#include <torch/csrc/distributed/c10d/intra_node_comm.hpp>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <c10/util/Logging.h>
#include <torch/csrc/distributed/c10d/Utils.hpp>
#include <iostream>
#include <random>
#include <fcntl.h>
#include <pthread.h>
#include <semaphore.h>
#include <sys/mman.h>
#include <sys/stat.h>
#include <unistd.h>
#if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED)
#include <c10/cuda/driver_api.h>
#include <nvml.h>
#endif
#include <cuda_runtime.h>
namespace c10d {
namespace intra_node_comm {
static std::vector<std::string> ENABLE_INTRA_NODE_COMM = {
"ENABLE_INTRA_NODE_COMM"};
// Forces detectedTopology() to return Topology::FULLY_CONNECTED, so
// IntraNodeComm can be used even without NVLink connection. This is only used
// for testing purposes.
static std::vector<std::string> TEST_INTRA_NODE_COMM = {"TEST_INTRA_NODE_COMM"};
////////////////////////////////////////////////////////////////////////////////
// CUDA Functions
////////////////////////////////////////////////////////////////////////////////
bool isIntraNodeCommSupported();
std::optional<HybridCubeMesh> getHybridCubeMesh(NvlMesh nvlMesh);
void* initP2pState();
void* initTopoInfo(Topology topology, NvlMesh nvlMesh, size_t rank);
AllReduceAlgo selectAllReduceAlgo(
const at::Tensor& input,
Topology topology,
size_t worldSize);
at::Tensor allReduce(
const at::Tensor& input,
std::array<void*, kMaxDevices> p2pStates,
std::array<void*, kMaxDevices> buffers,
void* topoInfo,
size_t rank,
size_t worldSize,
AllReduceAlgo algo,
at::cuda::CUDAStream& stream);
////////////////////////////////////////////////////////////////////////////////
// Topology Detection
////////////////////////////////////////////////////////////////////////////////
// TODO: find a better way to determine this
static constexpr size_t kMaxNvLinks = 20;
static std::ostream& operator<<(std::ostream& os, const NvlMesh& nvlMesh) {
std::ostringstream oss;
for (size_t i = 0; i < kMaxDevices; ++i) {
for (size_t j = 0; j < kMaxDevices; ++j) {
oss << nvlMesh[i][j] << " ";
}
oss << std::endl;
}
os << oss.str();
return os;
}
static bool isSame(NvlMesh lhs, NvlMesh rhs) {
for (size_t i = 0; i < kMaxDevices; ++i) {
for (size_t j = 0; j < kMaxDevices; ++j) {
if (lhs[i][j] != rhs[i][j]) {
return false;
}
}
}
return true;
}
/**
* Query the nvlink connection among devices.
*/
static NvlMesh getNvlMesh(std::vector<std::string> rankToBusId) {
#if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED)
using namespace c10::cuda;
NvlMesh nvlMesh = {};
auto driverApi = DriverAPI::get();
if (driverApi == nullptr) {
return nvlMesh;
}
const auto worldSize = rankToBusId.size();
std::vector<nvmlDevice_t> devices(worldSize, 0);
std::unordered_map<std::string, size_t> busIdToRank;
std::vector<size_t> switchLinkCount(worldSize, 0);
for (size_t r = 0; r < worldSize; ++r) {
busIdToRank.emplace(std::make_pair(rankToBusId[r], r));
TORCH_CHECK(
driverApi->nvmlDeviceGetHandleByPciBusId_v2_(
rankToBusId[r].c_str(), &devices[r]) == NVML_SUCCESS);
}
// For each device, loop over devices connected to it via NVLink
for (size_t idx = 0; idx < worldSize; ++idx) {
for (size_t link = 0; link < kMaxNvLinks; ++link) {
nvmlReturn_t ret;
nvmlIntNvLinkDeviceType_t deviceType;
ret = driverApi->nvmlDeviceGetNvLinkRemoteDeviceType_(
devices[idx], link, &deviceType);
if (ret != NVML_SUCCESS) {
// We've exhausted the NVLinks connected to this device.
// This error is benign. There doesn't seem to be a reliable
// way to obtain the maximum link value that can be passed to
// the API, so we simply increment the link value until the
// API fails or we hit a predefined maximum value.
break;
}
// Remote device is GPU
if (deviceType == NVML_NVLINK_DEVICE_TYPE_GPU) {
nvmlPciInfo_t pciInfo;
ret = driverApi->nvmlDeviceGetNvLinkRemotePciInfo_v2_(
devices[idx], link, &pciInfo);
if (ret != NVML_SUCCESS) {
// Unexpected error. Return an empty NvlMesh
return {};
}
auto it = busIdToRank.find(pciInfo.busId);
if (it != busIdToRank.end()) {
if (idx != it->second) {
nvlMesh[idx][it->second] += 1;
}
}
// Remote device is NVSwitch
} else if (deviceType == NVML_NVLINK_DEVICE_TYPE_SWITCH) {
switchLinkCount[idx] += 1;
}
}
}
// Process NVSwitch connections. For simplicity, we assume
// all NVSwitches are interconnected.
for (size_t i = 0; i < worldSize; ++i) {
for (size_t j = 0; j < worldSize; ++j) {
if (i == j) {
continue;
}
nvlMesh[i][j] += std::min(switchLinkCount[i], switchLinkCount[j]);
}
}
return nvlMesh;
#else
return {};
#endif
}
/**
* Determine if the devices form a hybrid cube mesh
* topology given a NvlMesh.
*/
static bool isHybridCubeMesh(const NvlMesh nvlMesh) {
std::array<size_t, kMaxDevices> numNeighbors = {};
for (size_t i = 0; i < kMaxDevices; ++i) {
for (size_t j = 0; j < kMaxDevices; ++j) {
if (nvlMesh[i][j] > 0) {
numNeighbors[i] += 1;
}
}
}
for (size_t i = 0; i < kMaxDevices; ++i) {
// TODO: this is insufficent and needs revisit
if (numNeighbors[i] != 4) {
return false;
}
}
return true;
}
/**
* Detech topology given a NvlMesh.
*/
static Topology detectTopology(const NvlMesh nvlMesh, size_t worldSize) {
if (getCvarBool(TEST_INTRA_NODE_COMM, false)) {
return Topology::FULLY_CONNECTED;
}
bool fullyConnected = true;
for (size_t i = 0; i < worldSize - 1; ++i) {
for (size_t j = i + 1; j < worldSize; ++j) {
if (nvlMesh[i][j] == 0 || nvlMesh[j][i] == 0) {
fullyConnected = false;
}
}
}
if (fullyConnected) {
LOG(INFO) << "IntraNodeComm: Topology::FULLY_CONNECTED";
return Topology::FULLY_CONNECTED;
}
if (worldSize == kMaxDevices && getHybridCubeMesh(nvlMesh) != std::nullopt) {
LOG(INFO) << "IntraNodeComm: Topology::HYBRID_CUBE_MESH";
return Topology::HYBRID_CUBE_MESH;
}
LOG(INFO) << "IntraNodeComm: Topology::UNKNOWN";
return Topology::UNKNOWN;
};
////////////////////////////////////////////////////////////////////////////////
// Rendezvous and Initialization
////////////////////////////////////////////////////////////////////////////////
IntraNodeComm::IntraNodeComm(
Topology topology,
std::array<void*, kMaxDevices> p2pStates,
std::array<void*, kMaxDevices> buffers,
void* topoInfo,
size_t rank,
size_t worldSize)
: topology_(topology),
p2pStates_(p2pStates),
buffers_(buffers),
topoInfo_(topoInfo),
rank_(rank),
worldSize_(worldSize) {}
IntraNodeComm::~IntraNodeComm() {
// Intentionally releasing resources without synchronizing devices. The
// teardown logic is safe for propoerly sync'd user program. We don't want
// improperly sync'd user program to hang here.
for (size_t r = 0; r < worldSize_; ++r) {
if (r == rank_) {
continue;
}
AT_CUDA_CHECK(cudaIpcCloseMemHandle(p2pStates_[r]));
AT_CUDA_CHECK(cudaIpcCloseMemHandle(buffers_[r]));
}
AT_CUDA_CHECK(cudaFree(p2pStates_[rank_]));
AT_CUDA_CHECK(cudaFree(buffers_[rank_]));
if (topoInfo_ != nullptr) {
AT_CUDA_CHECK(cudaFree(topoInfo_));
}
}
/**
* Use c10d::Store to perform allgather on a trivially copyable type.
*/
template <typename T>
std::vector<T> storeAllGather(
c10::intrusive_ptr<c10d::Store> store,
const std::string& prefix,
size_t rank,
size_t worldSize,
T val) {
static_assert(std::is_trivially_copyable<T>::value);
std::vector<std::string> peerKeys;
for (size_t r = 0; r < worldSize; ++r) {
std::ostringstream oss;
oss << prefix << "-" << r;
peerKeys.push_back(oss.str());
}
{
std::vector<uint8_t> payload(
reinterpret_cast<uint8_t*>(&val),
reinterpret_cast<uint8_t*>(&val) + sizeof(T));
store->set(peerKeys[rank], payload);
}
std::vector<T> peerVals;
for (size_t r = 0; r < worldSize; ++r) {
if (r == rank) {
peerVals.push_back(val);
continue;
}
store->wait({peerKeys[r]});
auto payload = store->get(peerKeys[r]);
TORCH_CHECK(payload.size() == sizeof(T));
T peerVal;
std::memcpy(&peerVal, payload.data(), sizeof(T));
peerVals.push_back(peerVal);
}
return peerVals;
}
c10::intrusive_ptr<IntraNodeComm> IntraNodeComm::rendezvous(
c10::intrusive_ptr<c10d::Store> store,
const std::string& prefix,
size_t rank,
size_t worldSize) {
#if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED)
if (!isIntraNodeCommSupported() ||
!getCvarBool(ENABLE_INTRA_NODE_COMM, false) || worldSize < 2 ||
worldSize > kMaxDevices) {
return nullptr;
}
int deviceIdx = at::cuda::current_device();
c10::cuda::CUDAGuard guard(deviceIdx);
// First hand shake: exchange hostname and device bus ID
struct DevInfo {
char hostname[HOST_NAME_MAX + 1];
char busId[80];
};
DevInfo devInfo{};
gethostname(devInfo.hostname, sizeof(devInfo.hostname));
cudaDeviceProp prop{};
AT_CUDA_CHECK(cudaGetDeviceProperties(&prop, deviceIdx));
snprintf(
devInfo.busId,
sizeof(devInfo.busId),
NVML_DEVICE_PCI_BUS_ID_FMT,
prop.pciDomainID,
prop.pciBusID,
prop.pciDeviceID);
auto peerDevInfos = storeAllGather(
store, prefix + "-IntraNodeCommHandShake-0", rank, worldSize, devInfo);
std::vector<std::string> rankToBusId;
for (const auto& info : peerDevInfos) {
if (strcmp(info.hostname, peerDevInfos.front().hostname) != 0) {
LOG(WARNING) << "Aborting IntraNodeComm::rendezvous because some "
"participants are not on the same host ("
<< info.hostname << ", " << devInfo.hostname << ")";
return nullptr;
}
rankToBusId.emplace_back(info.busId);
}
// Verify unique devices
{
std::unordered_set uniqueBusIds(rankToBusId.begin(), rankToBusId.end());
TORCH_CHECK(
uniqueBusIds.size() == worldSize,
"IntraNodeComm::rendezvous: detected overlapping devices across ranks. "
"Please properly set device via torch.cuda.set_device() before "
"initiating rendezvous.");
}
// Query nvlink connection
auto nvlMesh = getNvlMesh(rankToBusId);
// Detect topology
Topology topology = detectTopology(nvlMesh, worldSize);
// Initialize p2p state
auto p2pState = initP2pState();
// Allocate buffer
void* buffer = nullptr;
AT_CUDA_CHECK(cudaMalloc(&buffer, kMaxIntraNodeSize * 2));
// Second handshake: exchange topology and CUDA IPC handles
struct IpcInfo {
NvlMesh nvlMesh;
Topology topology;
cudaIpcMemHandle_t p2pStateHandle, bufferHandle;
};
// Make p2p state and buffer available for IPC
cudaIpcMemHandle_t p2pStateHandle, bufferHandle;
AT_CUDA_CHECK(cudaIpcGetMemHandle(&p2pStateHandle, p2pState));
AT_CUDA_CHECK(cudaIpcGetMemHandle(&bufferHandle, buffer));
IpcInfo ipcInfo{
.nvlMesh = nvlMesh,
.topology = topology,
.p2pStateHandle = p2pStateHandle,
.bufferHandle = bufferHandle};
auto peerIpcInfos = storeAllGather(
store, prefix + "-IntraNodeCommHandShake-2", rank, worldSize, ipcInfo);
for (const auto& info : peerIpcInfos) {
if (!isSame(info.nvlMesh, peerIpcInfos.front().nvlMesh) ||
info.topology != peerIpcInfos.front().topology) {
LOG(WARNING) << "Aborting IntraNodeComm::rendezvous because some "
"participants are observing different topologies ("
<< int(info.topology) << " and " << int(topology) << ")";
AT_CUDA_CHECK(cudaFree(p2pState));
AT_CUDA_CHECK(cudaFree(buffer));
return nullptr;
}
}
std::array<void*, kMaxDevices> p2pStates = {}, buffers = {};
for (size_t r = 0; r < peerIpcInfos.size(); ++r) {
if (r == rank) {
p2pStates[r] = p2pState;
buffers[r] = buffer;
} else {
AT_CUDA_CHECK(cudaIpcOpenMemHandle(
&p2pStates[r],
peerIpcInfos[r].p2pStateHandle,
cudaIpcMemLazyEnablePeerAccess));
AT_CUDA_CHECK(cudaIpcOpenMemHandle(
&buffers[r],
peerIpcInfos[r].bufferHandle,
cudaIpcMemLazyEnablePeerAccess));
}
}
void* topoInfo = initTopoInfo(topology, nvlMesh, rank);
return c10::make_intrusive<IntraNodeComm>(
topology, p2pStates, buffers, topoInfo, rank, worldSize);
#else
return nullptr;
#endif
}
AllReduceAlgo IntraNodeComm::selectAllReduceAlgo(const at::Tensor& input) {
return c10d::intra_node_comm::selectAllReduceAlgo(
input, topology_, worldSize_);
}
static int64_t usageCounter = 0;
at::Tensor IntraNodeComm::allReduce(
const at::Tensor& input,
AllReduceAlgo algo) {
// Report usage for testing purposes.
// We don't care about overflowing.
++usageCounter;
auto stream = at::cuda::getCurrentCUDAStream();
c10::cuda::CUDACachingAllocator::recordStream(
input.storage().data_ptr(), stream);
return c10d::intra_node_comm::allReduce(
input, p2pStates_, buffers_, topoInfo_, rank_, worldSize_, algo, stream);
}
int64_t getIntraNodeCommUsageCounter() {
return usageCounter;
}
} // namespace intra_node_comm
} // namespace c10d

View File

@ -1,708 +0,0 @@
#include <torch/csrc/distributed/c10d/intra_node_comm.hpp>
#include <ATen/Dispatch.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
namespace c10d {
namespace intra_node_comm {
static constexpr size_t kBytesPerThread = 16;
static constexpr size_t kMaxAllReduceBlocks = 24;
static constexpr size_t kThreadsPerBlock = 1024;
static constexpr size_t kWarpSize = 32;
static constexpr size_t kHcmThreshBytes = 256 * 1024;
static constexpr size_t kOneShotThreshBytes = 256 * 1024;
static constexpr size_t kTwoShotThreshBytes = 10 * 1024 * 1024;
#if defined(USE_ROCM)
using __nv_bfloat162 = uint32_t;
#endif
struct __align__(16) bf16x8 {
__nv_bfloat162 vals[4];
};
#define DEVICE_INLINE __device__ inline __attribute__((always_inline))
DEVICE_INLINE __nv_bfloat162
bf16hadd2(const __nv_bfloat162 x, const __nv_bfloat162 y) {
#if defined(USE_ROCM) || (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800))
CUDA_KERNEL_ASSERT(false);
#else
return __hadd2(x, y);
#endif
}
DEVICE_INLINE bf16x8 add_bf16x8(bf16x8 a, bf16x8 b) {
bf16x8 c;
c.vals[0] = bf16hadd2(a.vals[0], b.vals[0]);
c.vals[1] = bf16hadd2(a.vals[1], b.vals[1]);
c.vals[2] = bf16hadd2(a.vals[2], b.vals[2]);
c.vals[3] = bf16hadd2(a.vals[3], b.vals[3]);
return c;
}
/**
* NOTE [cross device memory synchronization]
*
* The multi-stage algorithms (e.g. two-shot, hcm allreduce) require the writes
* of a thread to be visible by threads with the same block/thread ID on other
* devices. To satisfy CUDA's memory consistency model, every thread has to
* release its writes at the system scope, and the consuming thread has to
* acquire the writes at the system scope. This incurs high overhead and
* attempts in optmizing this process can be prone to race condition.
*
* Instead, we go around caching by having each thread:
*
* - Directly write to global memory via st.cs (cache-streaming).
* - Synchronize with threads within the block.
* - Perform cross device synchronization at block level (via system scope
* atomic ops).
* - Synchronize with threads within the block.
* - Directly read from global memory via ld.nc (non-coherent/non-cached).
*/
template <typename T>
DEVICE_INLINE void streamLoad128(bf16x8& val, const T* addr) {
#if defined(USE_ROCM) || (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800))
CUDA_KERNEL_ASSERT(false);
#else
unsigned long long int low, high;
asm("ld.global.nc.v2.u64 {%0, %1}, [%2];"
: "=l"(low), "=l"(high)
: "l"(addr));
reinterpret_cast<unsigned long long int*>(&val)[0] = low;
reinterpret_cast<unsigned long long int*>(&val)[1] = high;
#endif
}
__device__ inline void streamStore128(at::BFloat16* addr, const bf16x8& val) {
#if defined(USE_ROCM) || (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800))
CUDA_KERNEL_ASSERT(false);
#else
unsigned long long int low, high;
low = reinterpret_cast<const unsigned long long int*>(&val)[0];
high = reinterpret_cast<const unsigned long long int*>(&val)[1];
asm("st.global.cs.v2.u64 [%0], {%1, %2};" : : "l"(addr), "l"(low), "l"(high));
#endif
}
template <typename T>
DEVICE_INLINE void load128(bf16x8& val, const T* addr) {
*reinterpret_cast<uint4*>(&val) = reinterpret_cast<const uint4*>(addr)[0];
}
template <typename T>
DEVICE_INLINE void store128(T* addr, const bf16x8& val) {
*reinterpret_cast<uint4*>(addr) = reinterpret_cast<const uint4*>(&val)[0];
}
DEVICE_INLINE void releaseSignal(uint32_t* addr) {
#if defined(USE_ROCM) || (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800))
CUDA_KERNEL_ASSERT(false);
#else
atomicAdd_system(addr, 1);
#endif
}
DEVICE_INLINE void acquireSignal(uint32_t* addr) {
#if defined(USE_ROCM) || (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800))
CUDA_KERNEL_ASSERT(false);
#else
volatile uint32_t* signal = addr;
uint32_t val;
do {
val = *signal;
} while (val == 0 || atomicCAS_system(addr, val, val - 1) != val);
#endif
}
////////////////////////////////////////////////////////////////////////////////
// Fully Connected Algos
////////////////////////////////////////////////////////////////////////////////
struct P2pState {
uint32_t signals0[kMaxAllReduceBlocks][kMaxDevices];
uint32_t signals1[kMaxAllReduceBlocks][kMaxDevices];
};
template <uint32_t kWorldSize, bool kAligned>
static __global__ void oneShotAllReduceKernel(
at::BFloat16* input,
size_t N,
size_t N_aligned,
std::array<P2pState*, kMaxDevices> p2pStates,
std::array<at::BFloat16*, kMaxDevices> buffers,
size_t rank) {
const size_t numelPerThread = kBytesPerThread / sizeof(at::BFloat16);
const size_t offset =
(blockDim.x * blockIdx.x + threadIdx.x) * numelPerThread;
const size_t stride = blockDim.x * gridDim.x * numelPerThread;
// Wait for all other ranks to enter the kernel
if (threadIdx.x < kWorldSize) {
auto targetRank = threadIdx.x;
releaseSignal(&p2pStates[targetRank]->signals0[blockIdx.x][rank]);
acquireSignal(&p2pStates[rank]->signals0[blockIdx.x][targetRank]);
}
__syncthreads();
// The source pointers. Distributed round-robin for the different warps
const at::BFloat16* srcs[kWorldSize];
#pragma unroll kWorldSize
for (int ii = 0; ii < kWorldSize; ++ii) {
int srcRank = (rank + ii) % kWorldSize;
srcs[ii] = buffers[srcRank];
}
for (size_t i = offset; i < N_aligned; i += stride) {
bf16x8 vals[kWorldSize];
#pragma unroll kWorldSize
for (size_t ii = 0; ii < kWorldSize; ++ii) {
streamLoad128(vals[ii], &srcs[ii][i]);
}
bf16x8 sums;
memset(reinterpret_cast<void*>(&sums), 0, sizeof(sums));
#pragma unroll kWorldSize
for (size_t ii = 0; ii < kWorldSize; ++ii) {
sums = add_bf16x8(sums, vals[ii]);
}
if constexpr (kAligned) {
streamStore128(&input[i], sums);
} else {
for (size_t ii = 0; ii < numelPerThread; ++ii) {
if (i + ii < N) {
input[i + ii] = reinterpret_cast<at::BFloat16*>(&sums)[ii];
}
}
}
}
}
template <uint32_t kWorldSize>
static __launch_bounds__(1024) __global__ void twoShotAllReduceKernel(
at::BFloat16* input,
size_t N_aligned,
std::array<P2pState*, kMaxDevices> p2pStates,
std::array<at::BFloat16*, kMaxDevices> buffers,
size_t rank) {
const size_t numelPerThread = kBytesPerThread / sizeof(at::BFloat16);
const size_t offset =
(blockDim.x * blockIdx.x + threadIdx.x) * numelPerThread;
const size_t stride = blockDim.x * gridDim.x * numelPerThread;
const size_t N_per_rank = N_aligned / kWorldSize;
const size_t N_start = N_per_rank * rank;
// Wait for all other ranks to enter the kernel
if (threadIdx.x < kWorldSize) {
auto targetRank = threadIdx.x;
releaseSignal(&p2pStates[targetRank]->signals0[blockIdx.x][rank]);
acquireSignal(&p2pStates[rank]->signals0[blockIdx.x][targetRank]);
}
__syncthreads();
// The source pointers. Distributed round-robin for the different warps
at::BFloat16* srcs[kWorldSize];
size_t srcRanks[kWorldSize];
#pragma unroll kWorldSize
for (int ii = 0; ii < kWorldSize; ++ii) {
int srcRank = (rank + ii) % kWorldSize;
srcs[ii] = buffers[srcRank];
srcRanks[ii] = srcRank;
}
for (size_t i = offset; i < N_per_rank; i += stride) {
bf16x8 vals[kWorldSize];
#pragma unroll kWorldSize
for (size_t ii = 0; ii < kWorldSize; ++ii) {
streamLoad128(vals[ii], &srcs[ii][N_start + i]);
}
bf16x8 sums;
memset(reinterpret_cast<void*>(&sums), 0, sizeof(sums));
#pragma unroll kWorldSize
for (size_t ii = 0; ii < kWorldSize; ++ii) {
sums = add_bf16x8(sums, vals[ii]);
}
streamStore128(&srcs[0][N_start + i], sums);
// Store local sums into input now so we can avoid
// a global memory access later for it.
streamStore128(&input[N_start + i], sums);
}
__syncthreads();
if (threadIdx.x < kWorldSize) {
auto targetRank = threadIdx.x;
releaseSignal(&p2pStates[targetRank]->signals1[blockIdx.x][rank]);
acquireSignal(&p2pStates[rank]->signals1[blockIdx.x][targetRank]);
}
__syncthreads();
for (size_t i = offset; i < N_per_rank; i += stride) {
#pragma unroll kWorldSize - 1
for (size_t ii = 1; ii < kWorldSize; ++ii) {
size_t k = N_start + i + (srcRanks[ii] - rank) * N_per_rank;
bf16x8 val;
streamLoad128(val, &srcs[ii][k]);
streamStore128(&input[k], val);
}
}
}
////////////////////////////////////////////////////////////////////////////////
// Hybrid Cube Mesh Algos
////////////////////////////////////////////////////////////////////////////////
/**
* NOTE [hybrid cube mesh]
*
* In a hybrid cube mesh topology, every device has exactly 4 neighbors
* (directly connected via NVLink). For every device X, it has exactly 1
* neighbor Y that is a neighbor of the 3 non-neighbor of X. We call Y the
* relay neighbor of X. This property is symmetrical: X is also guaranteed to
* be the relay neighbor of Y.
*
* With this property, we can perform a variant of one-shot allreduce algo that
* only moves data across NVLinks:
*
* - Each device one-shot allreduce among itself and 3 non-relay neighbors.
* - Each device exchange data with its relay neighbor.
*
* HybridCubeMesh is a data structure for describing the topology:
*
* - hcm[X][0:3] are the 3 neighbors of X.
* - hcm[X][3] is the relay neighbor of X.
* - For load balancing purpose, we also ensure that if hcm[X][k] = Y,
* hcm[Y][k] = X.
*/
std::optional<HybridCubeMesh> getHybridCubeMesh(NvlMesh nvlMesh) {
std::array<std::unordered_set<size_t>, kMaxDevices> neighbors = {};
std::array<size_t, kMaxDevices> neighborMasks = {};
for (size_t i = 0; i < kMaxDevices; ++i) {
for (size_t j = 0; j < kMaxDevices; ++j) {
if (nvlMesh[i][j] > 0) {
neighbors[i].insert(j);
neighborMasks[i] |= (1ul << j);
}
}
}
HybridCubeMesh hcm = {};
for (auto& row : hcm) {
row.fill(-1);
}
// A topology is an HCM if:
// - Every device has exactly 4 neighbors.
// - For every device, it has exactly 1 relay neighbor that is
// a neighbor of the 3 non-neighbor of the device.
for (size_t i = 0; i < kMaxDevices; ++i) {
if (neighbors[i].size() != 4) {
return std::nullopt;
}
// Condition 1: check the number of neighbors
std::vector<size_t> relayNeighbors;
for (size_t j = 0; j < kMaxDevices; ++j) {
if ((neighborMasks[i] & neighborMasks[j]) == 0) {
relayNeighbors.push_back(j);
}
}
// Condition 2: check the number of relay neighbors
if (relayNeighbors.size() != 1) {
return std::nullopt;
}
neighbors[i].erase(relayNeighbors[0]);
hcm[i][3] = relayNeighbors[0];
}
for (size_t i = 0; i < kMaxDevices; ++i) {
for (size_t k = 0; k < 3; ++k) {
// We can only fill hcm[i][k] with j if hcm[j][k] is not filled
for (size_t j : neighbors[i]) {
if (hcm[j][k] == -1) {
hcm[i][k] = j;
hcm[j][k] = i;
break;
}
}
TORCH_CHECK(hcm[i][k] != -1);
neighbors[i].erase(hcm[i][k]);
}
}
return hcm;
}
template <bool kAligned>
static __global__ void hybridCubeMeshAllReduceKernel(
at::BFloat16* input,
size_t N,
size_t N_aligned,
std::array<P2pState*, kMaxDevices> p2pStates,
std::array<at::BFloat16*, kMaxDevices> buffers,
int hcmInfo[4],
size_t rank) {
const size_t numelPerThread = kBytesPerThread / sizeof(at::BFloat16);
const size_t offset =
(blockDim.x * blockIdx.x + threadIdx.x) * numelPerThread;
const size_t stride = blockDim.x * gridDim.x * numelPerThread;
const int relayRank = hcmInfo[3];
// Wait for HCM neigbors to enter the kernel
if (threadIdx.x < 3) {
auto targetRank = hcmInfo[threadIdx.x];
releaseSignal(&p2pStates[targetRank]->signals0[blockIdx.x][rank]);
acquireSignal(&p2pStates[rank]->signals0[blockIdx.x][targetRank]);
}
__syncthreads();
const at::BFloat16* srcs[4] = {
buffers[rank],
buffers[hcmInfo[0]],
buffers[hcmInfo[1]],
buffers[hcmInfo[2]],
};
at::BFloat16* localRelay = buffers[rank] + kMaxIntraNodeSize / 2;
at::BFloat16* remoteRelay = buffers[relayRank] + kMaxIntraNodeSize / 2;
for (size_t i = offset; i < N_aligned; i += stride) {
bf16x8 vals[4];
#pragma unroll 4
for (size_t ii = 0; ii < 4; ++ii) {
streamLoad128(vals[ii], &srcs[ii][i]);
}
bf16x8 sums;
memset(reinterpret_cast<void*>(&sums), 0, sizeof(sums));
#pragma unroll 4
for (size_t ii = 0; ii < 4; ++ii) {
sums = add_bf16x8(sums, vals[ii]);
}
// Cached store for local sums
store128(&localRelay[i], sums);
}
__syncthreads();
if (threadIdx.x == 0) {
releaseSignal(&p2pStates[relayRank]->signals0[blockIdx.x][rank]);
acquireSignal(&p2pStates[rank]->signals0[blockIdx.x][relayRank]);
}
__syncthreads();
for (size_t i = offset; i < N_aligned; i += stride) {
bf16x8 localSum, remoteSum;
// Cached load for local sums
load128(localSum, &localRelay[i]);
streamLoad128(remoteSum, &remoteRelay[i]);
localSum = add_bf16x8(localSum, remoteSum);
if constexpr (kAligned) {
streamStore128(&input[i], localSum);
} else {
for (size_t ii = 0; ii < numelPerThread; ++ii) {
if (i + ii < N) {
input[i + ii] = reinterpret_cast<at::BFloat16*>(&localSum)[ii];
}
}
}
}
}
static inline size_t divUp(uint32_t a, uint32_t b) {
return (a + b - 1) / b;
}
static inline size_t alignUp(uint32_t a, uint32_t b) {
return divUp(a, b) * b;
}
static void checkInput(const at::Tensor& input, size_t rank) {
TORCH_CHECK(
input.dtype() == at::kBFloat16,
"oneShotAllReduce only supports bf16 for now");
TORCH_CHECK(input.is_non_overlapping_and_dense());
TORCH_CHECK(input.device().is_cuda());
TORCH_CHECK(static_cast<size_t>(input.get_device()) == rank);
}
static void getLaunchConfig(
size_t N_aligned,
size_t elemSize,
dim3& blocks,
dim3& threads) {
blocks = dim3(0, 1, 1);
threads = dim3(0, 1, 1);
const auto numelPerThread = kBytesPerThread / elemSize;
const auto numelPerWarp = numelPerThread * kWarpSize;
TORCH_CHECK(N_aligned % numelPerThread == 0);
TORCH_CHECK(N_aligned % numelPerWarp == 0);
if (N_aligned < numelPerThread * kThreadsPerBlock) {
threads.x = N_aligned / numelPerWarp * kWarpSize;
blocks.x = 1;
} else {
auto warpsRequired = N_aligned / numelPerWarp;
auto threadsRequired = N_aligned / numelPerThread;
blocks.x =
std::min(divUp(threadsRequired, kThreadsPerBlock), kMaxAllReduceBlocks);
auto warpsPerBlock = divUp(warpsRequired, blocks.x);
threads.x = std::min(kThreadsPerBlock, warpsPerBlock * kWarpSize);
}
}
template <typename T>
static auto castArr(std::array<void*, kMaxDevices> arr) {
std::array<T, kMaxDevices> arr_;
for (size_t i = 0; i < kMaxDevices; ++i) {
arr_[i] = reinterpret_cast<T>(arr[i]);
}
return arr_;
}
bool isIntraNodeCommSupported() {
#if defined(USE_ROCM) || (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800))
return false;
#else
return true;
#endif
}
void* initP2pState() {
void* state = nullptr;
AT_CUDA_CHECK(cudaMalloc(&state, sizeof(P2pState)));
AT_CUDA_CHECK(cudaMemset(state, 0, sizeof(P2pState)));
return state;
}
void* initTopoInfo(Topology topology, NvlMesh nvlMesh, size_t rank) {
void* topoInfo = nullptr;
if (topology != Topology::HYBRID_CUBE_MESH) {
return topoInfo;
}
auto hcm = getHybridCubeMesh(nvlMesh);
int hcmInfo[4];
std::copy((*hcm)[rank].begin(), (*hcm)[rank].begin() + 4, hcmInfo);
AT_CUDA_CHECK(cudaMalloc(&topoInfo, sizeof(hcmInfo)));
AT_CUDA_CHECK(
cudaMemcpy(topoInfo, hcmInfo, sizeof(hcmInfo), cudaMemcpyHostToDevice));
return topoInfo;
}
at::Tensor oneShotAllReduce(
const at::Tensor& input,
std::array<void*, kMaxDevices> p2pStates,
std::array<void*, kMaxDevices> buffers,
size_t rank,
size_t worldSize,
at::cuda::CUDAStream& stream) {
checkInput(input, rank);
size_t numelPerWarp = kBytesPerThread / input.element_size() * kWarpSize;
size_t N_aligned = alignUp(input.numel(), numelPerWarp);
TORCH_CHECK(N_aligned <= kMaxIntraNodeSize / input.element_size());
dim3 blocks, threads;
getLaunchConfig(N_aligned, input.element_size(), blocks, threads);
at::cuda::OptionalCUDAGuard guard(input.get_device());
AT_CUDA_CHECK(cudaMemcpyAsync(
buffers[rank],
input.data_ptr(),
input.numel() * input.element_size(),
cudaMemcpyDeviceToDevice,
stream));
#define X(kWorldSize, kAligned) \
if (worldSize == kWorldSize) { \
oneShotAllReduceKernel<kWorldSize, kAligned> \
<<<blocks, threads, 0, stream>>>( \
input.data_ptr<at::BFloat16>(), \
input.numel(), \
N_aligned, \
castArr<P2pState*>(p2pStates), \
castArr<at::BFloat16*>(buffers), \
rank); \
C10_CUDA_KERNEL_LAUNCH_CHECK(); \
}
#define DISPATCH_ALL_WORLD_SIZES(kAligned) \
X(2, kAligned); \
X(3, kAligned); \
X(4, kAligned); \
X(5, kAligned); \
X(6, kAligned); \
X(7, kAligned); \
X(8, kAligned);
if (N_aligned == static_cast<size_t>(input.numel())) {
DISPATCH_ALL_WORLD_SIZES(true);
} else {
DISPATCH_ALL_WORLD_SIZES(false);
}
#undef DISPATCH_ALL_WORLD_SIZES
#undef X
return input;
}
at::Tensor twoShotAllReduce(
const at::Tensor& input,
std::array<void*, kMaxDevices> p2pStates,
std::array<void*, kMaxDevices> buffers,
size_t rank,
size_t worldSize,
at::cuda::CUDAStream& stream) {
checkInput(input, rank);
size_t numelPerWarp = kBytesPerThread / input.element_size() * kWarpSize;
size_t N_aligned = alignUp(input.numel(), worldSize * numelPerWarp);
size_t N_per_rank = N_aligned / worldSize;
TORCH_CHECK(N_aligned <= kMaxIntraNodeSize / input.element_size());
dim3 blocks, threads;
getLaunchConfig(N_per_rank, input.element_size(), blocks, threads);
auto output = N_aligned == static_cast<size_t>(input.numel())
? input
: input.new_empty(N_aligned);
at::cuda::OptionalCUDAGuard guard(input.get_device());
AT_CUDA_CHECK(cudaMemcpyAsync(
buffers[rank],
input.data_ptr(),
input.numel() * input.element_size(),
cudaMemcpyDeviceToDevice,
stream));
#define X(kWorldSize) \
if (worldSize == kWorldSize) { \
twoShotAllReduceKernel<kWorldSize><<<blocks, threads, 0, stream>>>( \
output.data_ptr<at::BFloat16>(), \
N_aligned, \
castArr<P2pState*>(p2pStates), \
castArr<at::BFloat16*>(buffers), \
rank); \
C10_CUDA_KERNEL_LAUNCH_CHECK(); \
}
X(2);
X(3);
X(4);
X(5);
X(6);
X(7);
X(8);
#undef X
if (output.data_ptr() != input.data_ptr()) {
AT_CUDA_CHECK(cudaMemcpyAsync(
input.data_ptr(),
output.data_ptr(),
input.numel() * input.element_size(),
cudaMemcpyDeviceToDevice,
stream));
}
return input;
}
at::Tensor hybridCubeMeshAllReduce(
const at::Tensor& input,
std::array<void*, kMaxDevices> p2pStates,
std::array<void*, kMaxDevices> buffers,
int hcmInfo[4],
size_t rank,
size_t worldSize,
at::cuda::CUDAStream& stream) {
checkInput(input, rank);
size_t numelPerWarp = kBytesPerThread / input.element_size() * kWarpSize;
size_t N_aligned = alignUp(input.numel(), numelPerWarp);
TORCH_CHECK(N_aligned <= kMaxIntraNodeSize / input.element_size());
dim3 blocks, threads;
getLaunchConfig(N_aligned, input.element_size(), blocks, threads);
at::cuda::OptionalCUDAGuard guard(input.get_device());
AT_CUDA_CHECK(cudaMemcpyAsync(
buffers[rank],
input.data_ptr(),
input.numel() * input.element_size(),
cudaMemcpyDeviceToDevice,
stream));
#define X(kAligned) \
hybridCubeMeshAllReduceKernel<kAligned><<<blocks, threads, 0, stream>>>( \
input.data_ptr<at::BFloat16>(), \
input.numel(), \
N_aligned, \
castArr<P2pState*>(p2pStates), \
castArr<at::BFloat16*>(buffers), \
hcmInfo, \
rank); \
C10_CUDA_KERNEL_LAUNCH_CHECK();
if (N_aligned == static_cast<size_t>(input.numel())) {
X(true);
} else {
X(false);
}
#undef X
return input;
}
AllReduceAlgo selectAllReduceAlgo(
const at::Tensor& input,
Topology topology,
size_t worldSize) {
// Only support bf16 for now
if (input.dtype() != at::kBFloat16 ||
input.numel() * input.element_size() > kMaxIntraNodeSize) {
return AllReduceAlgo::NONE;
}
const auto numel = input.numel();
const auto numelPerWarp = kBytesPerThread / input.element_size() * kWarpSize;
if (topology == Topology::HYBRID_CUBE_MESH) {
TORCH_CHECK(
worldSize == 8, "hyperCubeAllReduce only supports exactly 8 GPUs");
if (alignUp(numel, numelPerWarp) <= kHcmThreshBytes) {
return AllReduceAlgo::HCM;
}
}
if (topology == Topology::FULLY_CONNECTED) {
if (alignUp(numel, numelPerWarp) <= kOneShotThreshBytes) {
return AllReduceAlgo::ONE_SHOT;
}
if (alignUp(numel, numelPerWarp * worldSize) <= kTwoShotThreshBytes) {
return AllReduceAlgo::TWO_SHOT;
}
}
return AllReduceAlgo::NONE;
}
at::Tensor allReduce(
const at::Tensor& input,
std::array<void*, kMaxDevices> p2pStates,
std::array<void*, kMaxDevices> buffers,
void* topoInfo,
size_t rank,
size_t worldSize,
AllReduceAlgo algo,
at::cuda::CUDAStream& stream) {
switch (algo) {
case AllReduceAlgo::ONE_SHOT:
return oneShotAllReduce(
input, p2pStates, buffers, rank, worldSize, stream);
case AllReduceAlgo::TWO_SHOT:
return twoShotAllReduce(
input, p2pStates, buffers, rank, worldSize, stream);
case AllReduceAlgo::HCM:
return hybridCubeMeshAllReduce(
input, p2pStates, buffers, (int*)topoInfo, rank, worldSize, stream);
default:
C10_THROW_ERROR(ValueError, "IntraNodeComm: invalid algo");
}
}
} // namespace intra_node_comm
} // namespace c10d

View File

@ -1,102 +0,0 @@
#pragma once
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAEvent.h>
#include <c10/cuda/CUDAStream.h>
#include <torch/csrc/distributed/c10d/Store.hpp>
#include <torch/csrc/distributed/c10d/Work.hpp>
namespace c10d {
namespace intra_node_comm {
constexpr size_t kMaxDevices = 8;
constexpr size_t kMaxIntraNodeSize = 10 * 1024 * 1024;
using NvlMesh = std::array<std::array<size_t, kMaxDevices>, kMaxDevices>;
using HybridCubeMesh = std::array<std::array<int, 4>, kMaxDevices>;
enum class Topology { UNKNOWN = 0, FULLY_CONNECTED = 1, HYBRID_CUBE_MESH = 2 };
enum class AllReduceAlgo { NONE = 0, ONE_SHOT = 1, TWO_SHOT = 2, HCM = 3 };
class TORCH_API IntraNodeComm : public c10::intrusive_ptr_target {
public:
IntraNodeComm(
Topology topology,
std::array<void*, kMaxDevices> p2pStates,
std::array<void*, kMaxDevices> buffers,
void* topoInfo,
size_t rank,
size_t worldSize);
~IntraNodeComm();
/**
* Rendezvous via a c10d::Store.
* This function may return nullptr if intra-node comm is not applicable.
* It guarantees all participants either succeeds or abort.
*/
static c10::intrusive_ptr<IntraNodeComm> rendezvous(
c10::intrusive_ptr<c10d::Store> store,
const std::string& prefix,
size_t rank,
size_t worldSize);
/**
* Selects a AllReduceAlgo that we think will outperform nccl.
* Returns AllReduceAlgo::NONE if we don't think we can outperform nccl.
*/
AllReduceAlgo selectAllReduceAlgo(const at::Tensor& input);
at::Tensor allReduce(const at::Tensor& input, AllReduceAlgo algo);
private:
Topology topology_;
std::array<void*, kMaxDevices> p2pStates_;
std::array<void*, kMaxDevices> buffers_;
void* topoInfo_;
size_t rank_;
size_t worldSize_;
};
/**
* NOTE [IntraNodeComm Stream Semantics]
*
* ProcessGroupNCCL launches kernels differently from the conventional PyTorch
* CUDA semantics: it always launches collective kernels onto a dedicated
* communication stream. Therefore, it needs to:
*
* - Synchronize the calling stream and the comm stream.
* - Ensure the memory safety of the operands (via record_stream or stashing).
* - Synchronize the waiting stream with the comm stream.
*
* Unconditionally performing these tasks makes sense when we expect most of the
* communication to benefit from compute/comm overlap. However, IntraNodeComm
* primarily aims to optimize small, latency-sensitive, blocking communication,
* in which the overhead incurred by the above steps can be quite pronounced.
*
* Thus, IntraNodeComm follows the conventional PyTorch CUDA semantics and
* launches kernels onto the stream specified by the user. Although the user
* can perform neccessary synchronization via wait_stream, to provide a UX
* consistent to that of ProcessGroupNCCL, the neccessary stream
* synchronization can also be performed via IntraNodeWork::wait().
*/
class IntraNodeCommWork : public c10d::Work {
public:
IntraNodeCommWork() : c10d::Work() {
event_.record();
}
bool wait(std::chrono::milliseconds timeout = kNoTimeout) override {
event_.block(at::cuda::getCurrentCUDAStream());
return true;
}
private:
at::cuda::CUDAEvent event_;
};
TORCH_API int64_t getIntraNodeCommUsageCounter();
} // namespace intra_node_comm
} // namespace c10d