Revert "Introduce 3 low-latency, intra-node allreduce algorithms for small messages to PyTorch (#114001)"

This reverts commit 4edc921857f39ba9510b6ab1c454149cfb2de157.

Reverted https://github.com/pytorch/pytorch/pull/114001 on behalf of https://github.com/jeanschmidt due to Breaking multiple internal tests, might be flakiness but multiple retries did not elicit an improvement, please check internal diff ([comment](https://github.com/pytorch/pytorch/pull/114001#issuecomment-1863036417))
This commit is contained in:
PyTorch MergeBot
2023-12-19 16:01:19 +00:00
parent b6d0d0819a
commit 91e184fd74
12 changed files with 7 additions and 1363 deletions

View File

@ -1452,10 +1452,7 @@ cu_library(
# https://github.com/pytorch/pytorch/issues/79236
# To solve it we add it into the `caffe2_cuda`,
# this is also aligned with the CMake build.
srcs = [":caffe2_cu_srcs"] + [
"torch/csrc/distributed/c10d/intra_node_comm.cu",
"torch/csrc/distributed/c10d/quantization/quantization_gpu.cu",
],
srcs = [":caffe2_cu_srcs"] + ["torch/csrc/distributed/c10d/quantization/quantization_gpu.cu"],
copts = CAFFE2_COPTS + torch_cuda_half_options,
visibility = ["//visibility:public"],
deps = [
@ -1622,7 +1619,6 @@ cc_library(
exclude = [
"torch/csrc/cuda/python_nccl.cpp",
"torch/csrc/cuda/nccl.cpp",
"torch/csrc/distributed/c10d/intra_node_comm.cu",
"torch/csrc/distributed/c10d/quantization/quantization_gpu.cu",
],
)) + torch_sources,

View File

@ -674,8 +674,6 @@ libtorch_cuda_distributed_extra_sources = [
"torch/csrc/distributed/c10d/ProcessGroupUCC.cpp",
"torch/csrc/distributed/c10d/UCCTracing.cpp",
"torch/csrc/distributed/c10d/UCCUtils.cpp",
"torch/csrc/distributed/c10d/intra_node_comm.cpp",
"torch/csrc/distributed/c10d/intra_node_comm.cu",
"torch/csrc/distributed/rpc/tensorpipe_cuda.cpp",
"torch/csrc/distributed/c10d/quantization/quantization_gpu.cu",
]

View File

@ -37,7 +37,7 @@ void* DriverAPI::get_nvml_handle() {
return nvml_hanle;
}
C10_EXPORT DriverAPI* DriverAPI::get() {
DriverAPI* DriverAPI::get() {
static DriverAPI singleton = create_driver_api();
return &singleton;
}

View File

@ -31,8 +31,6 @@
#define C10_NVML_DRIVER_API(_) \
_(nvmlInit_v2) \
_(nvmlDeviceGetHandleByPciBusId_v2) \
_(nvmlDeviceGetNvLinkRemoteDeviceType) \
_(nvmlDeviceGetNvLinkRemotePciInfo_v2) \
_(nvmlDeviceGetComputeRunningProcesses)
namespace c10 {

View File

@ -641,10 +641,6 @@ if(USE_CUDA)
append_filelist("libtorch_cuda_distributed_base_sources" Caffe2_GPU_SRCS)
if(NOT WIN32)
append_filelist("libtorch_cuda_distributed_extra_sources" Caffe2_GPU_SRCS)
set_source_files_properties(
${TORCH_SRC_DIR}/csrc/distributed/c10d/intra_node_comm.cpp
PROPERTIES COMPILE_FLAGS "-DPYTORCH_C10_DRIVER_API_SUPPORTED=1"
)
endif()
endif()
set_source_files_properties(

View File

@ -15,7 +15,7 @@ import warnings
from contextlib import contextmanager
from datetime import datetime, timedelta
from itertools import chain, product
from unittest import SkipTest, mock
from unittest import mock
import torch
import torch.distributed as c10d
@ -3113,65 +3113,6 @@ class CommTest(test_c10d_common.AbstractCommTest, MultiProcessTestCase):
for i, t in enumerate(tensors):
self.assertEqual(t, torch.full_like(t, self.world_size * (i + (self.world_size + 1.) / 2.)))
@requires_nccl()
@skip_if_lt_x_gpu(2)
@skip_if_rocm
def test_intra_node_comm_all_reduce(self):
from torch._C._distributed_c10d import _get_intra_node_comm_usage_counter
from torch.testing._internal.common_cuda import SM80OrLater
for peer in range(self.world_size):
if peer == self.rank:
continue
if not torch._C._cuda_canDeviceAccessPeer(self.rank, peer):
raise SkipTest("Test requires p2p access")
if not SM80OrLater:
raise SkipTest("Test requires sm>=80")
store = c10d.FileStore(self.file_name, self.world_size)
os.environ["ENABLE_INTRA_NODE_COMM"] = "1"
os.environ["TEST_INTRA_NODE_COMM"] = "1"
torch.cuda.set_device(self.rank)
c10d.init_process_group(
backend="nccl", rank=self.rank, world_size=self.world_size, store=store
)
expect = self.world_size * (self.world_size - 1) // 2
# IntraNodeComm currently only supports sum and bf16.
# Verify that it is not used in the next two configurations.
t = torch.full((4 * 1024 // 2,), self.rank).cuda()
c10d.all_reduce(t, c10d.ReduceOp.SUM)
self.assertTrue(t.eq(expect).all())
self.assertEqual(_get_intra_node_comm_usage_counter(), 0)
t = torch.full((4 * 1024 // 2,), self.rank, dtype=torch.bfloat16).cuda()
c10d.all_reduce(t, c10d.ReduceOp.AVG)
self.assertEqual(_get_intra_node_comm_usage_counter(), 0)
# Verify that IntraNodeComm is used up to 10MB
t = torch.full((4 * 1024 // 2,), self.rank, dtype=torch.bfloat16).cuda()
c10d.all_reduce(t, c10d.ReduceOp.SUM)
self.assertTrue(t.eq(expect).all())
self.assertEqual(_get_intra_node_comm_usage_counter(), 1)
t = torch.full((512 * 1024 // 2,), self.rank, dtype=torch.bfloat16).cuda()
c10d.all_reduce(t, c10d.ReduceOp.SUM)
self.assertTrue(t.eq(expect).all())
self.assertEqual(_get_intra_node_comm_usage_counter(), 2)
t = torch.full((10 * 1024 ** 2 // 2,), self.rank, dtype=torch.bfloat16).cuda()
c10d.all_reduce(t, c10d.ReduceOp.SUM)
self.assertTrue(t.eq(expect).all())
self.assertEqual(_get_intra_node_comm_usage_counter(), 3)
# Verify that IntraNodeComm is not used beyond 10MB
t = torch.full((10 * 1024 ** 2 // 2 + 1,), self.rank, dtype=torch.bfloat16).cuda()
c10d.all_reduce(t, c10d.ReduceOp.SUM)
self.assertTrue(t.eq(expect).all())
self.assertEqual(_get_intra_node_comm_usage_counter(), 3)
c10d.destroy_process_group()
@requires_nccl()
@skip_if_lt_x_gpu(2)
def test_sequence_num_set_default_pg_nccl(self):

View File

@ -712,8 +712,7 @@ ProcessGroupNCCL::ProcessGroupNCCL(
terminateProcessGroup_(false),
terminateHeartbeatMonitorThread_(false),
collectiveDebugInfoMode_(false),
uid_(process_group_id++),
intraNodeComm_(initIntraNodeComm()) {
uid_(process_group_id++) {
TORCH_CHECK_WITH(
ValueError,
at::cuda::getNumGPUs() != 0,
@ -896,12 +895,6 @@ void ProcessGroupNCCL::performNocolorSplit(at::Device device) {
#endif
}
c10::intrusive_ptr<intra_node_comm::IntraNodeComm> ProcessGroupNCCL::
initIntraNodeComm() {
return intra_node_comm::IntraNodeComm::rendezvous(
store_, std::to_string(uid_), rank_, size_);
}
void ProcessGroupNCCL::runHealthCheck() {
// Run health check in a separate thread and wait on CV to handle timeouts,
// since majority of getNCCLComm failures are hangs.
@ -2846,16 +2839,6 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::allreduce_impl(
c10::intrusive_ptr<Work> ProcessGroupNCCL::allreduce(
std::vector<at::Tensor>& tensors,
const AllreduceOptions& opts) {
if (intraNodeComm_ != nullptr && tensors.size() == 1 &&
opts.reduceOp == ReduceOp::SUM) {
using namespace intra_node_comm;
auto algo = intraNodeComm_->selectAllReduceAlgo(tensors[0]);
if (algo != intra_node_comm::AllReduceAlgo::NONE) {
intraNodeComm_->allReduce(tensors[0], algo);
return c10::make_intrusive<IntraNodeCommWork>();
}
}
check_gpu_tensors_different_devices(tensors);
// @lint-ignore CLANGTIDY

View File

@ -13,7 +13,6 @@
#include <torch/csrc/distributed/c10d/Backend.hpp>
#include <torch/csrc/distributed/c10d/NCCLUtils.hpp>
#include <torch/csrc/distributed/c10d/Store.hpp>
#include <torch/csrc/distributed/c10d/intra_node_comm.hpp>
#include <ATen/DynamicLibrary.h>
#include <ATen/cuda/CUDAContext.h>
@ -547,8 +546,6 @@ class TORCH_API ProcessGroupNCCL : public Backend {
// Provide an API for users to define their own ways to store NCCL debug info.
void registerDebugInfoWriter(std::unique_ptr<DebugInfoWriter> writer);
c10::intrusive_ptr<intra_node_comm::IntraNodeComm> initIntraNodeComm();
// Provides an API to abort the ProcessGroup (similar to ncclCommAbort)
// instead of relying on ProcessGroupNCCL destructor.
void abort(c10::optional<std::string> abortReason = c10::nullopt);
@ -947,8 +944,6 @@ class TORCH_API ProcessGroupNCCL : public Backend {
std::unique_ptr<DebugInfoWriter> debugInfoWriter_ = nullptr;
size_t uid_;
c10::intrusive_ptr<intra_node_comm::IntraNodeComm> intraNodeComm_;
};
TORCH_API std::string dump_nccl_trace();

View File

@ -21,7 +21,6 @@
#ifdef USE_C10D_NCCL
#include <torch/csrc/distributed/c10d/NCCLUtils.hpp>
#include <torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp>
#include <torch/csrc/distributed/c10d/intra_node_comm.hpp>
#endif
#ifdef USE_C10D_MPI
@ -2329,10 +2328,6 @@ options :class:`~torch.distributed.ProcessGroupNCCL.Options`).
"perform_nocolor_split",
&::c10d::ProcessGroupNCCL::performNocolorSplit);
module.def(
"_get_intra_node_comm_usage_counter",
&::c10d::intra_node_comm::getIntraNodeCommUsageCounter);
#ifdef NCCL_HAS_COMM_CTA_CGA
py::class_<ncclConfig_t>(
processGroupNCCL,

View File

@ -1,448 +0,0 @@
#include <torch/csrc/distributed/c10d/intra_node_comm.hpp>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <c10/util/Logging.h>
#include <torch/csrc/distributed/c10d/Utils.hpp>
#include <iostream>
#include <random>
#include <fcntl.h>
#include <pthread.h>
#include <semaphore.h>
#include <sys/mman.h>
#include <sys/stat.h>
#include <unistd.h>
#if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED)
#include <c10/cuda/driver_api.h>
#include <nvml.h>
#endif
#include <cuda_runtime.h>
namespace c10d {
namespace intra_node_comm {
static std::vector<std::string> ENABLE_INTRA_NODE_COMM = {
"ENABLE_INTRA_NODE_COMM"};
// Forces detectedTopology() to return Topology::FULLY_CONNECTED, so
// IntraNodeComm can be used even without NVLink connection. This is only used
// for testing purposes.
static std::vector<std::string> TEST_INTRA_NODE_COMM = {"TEST_INTRA_NODE_COMM"};
////////////////////////////////////////////////////////////////////////////////
// CUDA Functions
////////////////////////////////////////////////////////////////////////////////
bool isIntraNodeCommSupported();
std::optional<HybridCubeMesh> getHybridCubeMesh(NvlMesh nvlMesh);
void* initP2pState();
void* initTopoInfo(Topology topology, NvlMesh nvlMesh, size_t rank);
AllReduceAlgo selectAllReduceAlgo(
const at::Tensor& input,
Topology topology,
size_t worldSize);
at::Tensor allReduce(
const at::Tensor& input,
std::array<void*, kMaxDevices> p2pStates,
std::array<void*, kMaxDevices> buffers,
void* topoInfo,
size_t rank,
size_t worldSize,
AllReduceAlgo algo,
at::cuda::CUDAStream& stream);
////////////////////////////////////////////////////////////////////////////////
// Topology Detection
////////////////////////////////////////////////////////////////////////////////
// TODO: find a better way to determine this
static constexpr size_t kMaxNvLinks = 20;
static std::ostream& operator<<(std::ostream& os, const NvlMesh& nvlMesh) {
std::ostringstream oss;
for (size_t i = 0; i < kMaxDevices; ++i) {
for (size_t j = 0; j < kMaxDevices; ++j) {
oss << nvlMesh[i][j] << " ";
}
oss << std::endl;
}
os << oss.str();
return os;
}
static bool isSame(NvlMesh lhs, NvlMesh rhs) {
for (size_t i = 0; i < kMaxDevices; ++i) {
for (size_t j = 0; j < kMaxDevices; ++j) {
if (lhs[i][j] != rhs[i][j]) {
return false;
}
}
}
return true;
}
/**
* Query the nvlink connection among devices.
*/
static NvlMesh getNvlMesh(std::vector<std::string> rankToBusId) {
#if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED)
using namespace c10::cuda;
NvlMesh nvlMesh = {};
auto driverApi = DriverAPI::get();
if (driverApi == nullptr) {
return nvlMesh;
}
const auto worldSize = rankToBusId.size();
std::vector<nvmlDevice_t> devices(worldSize, 0);
std::unordered_map<std::string, size_t> busIdToRank;
std::vector<size_t> switchLinkCount(worldSize, 0);
for (size_t r = 0; r < worldSize; ++r) {
busIdToRank.emplace(std::make_pair(rankToBusId[r], r));
TORCH_CHECK(
driverApi->nvmlDeviceGetHandleByPciBusId_v2_(
rankToBusId[r].c_str(), &devices[r]) == NVML_SUCCESS);
}
// For each device, loop over devices connected to it via NVLink
for (size_t idx = 0; idx < worldSize; ++idx) {
for (size_t link = 0; link < kMaxNvLinks; ++link) {
nvmlReturn_t ret;
nvmlIntNvLinkDeviceType_t deviceType;
ret = driverApi->nvmlDeviceGetNvLinkRemoteDeviceType_(
devices[idx], link, &deviceType);
if (ret != NVML_SUCCESS) {
// We've exhausted the NVLinks connected to this device.
// This error is benign. There doesn't seem to be a reliable
// way to obtain the maximum link value that can be passed to
// the API, so we simply increment the link value until the
// API fails or we hit a predefined maximum value.
break;
}
// Remote device is GPU
if (deviceType == NVML_NVLINK_DEVICE_TYPE_GPU) {
nvmlPciInfo_t pciInfo;
ret = driverApi->nvmlDeviceGetNvLinkRemotePciInfo_v2_(
devices[idx], link, &pciInfo);
if (ret != NVML_SUCCESS) {
// Unexpected error. Return an empty NvlMesh
return {};
}
auto it = busIdToRank.find(pciInfo.busId);
if (it != busIdToRank.end()) {
if (idx != it->second) {
nvlMesh[idx][it->second] += 1;
}
}
// Remote device is NVSwitch
} else if (deviceType == NVML_NVLINK_DEVICE_TYPE_SWITCH) {
switchLinkCount[idx] += 1;
}
}
}
// Process NVSwitch connections. For simplicity, we assume
// all NVSwitches are interconnected.
for (size_t i = 0; i < worldSize; ++i) {
for (size_t j = 0; j < worldSize; ++j) {
if (i == j) {
continue;
}
nvlMesh[i][j] += std::min(switchLinkCount[i], switchLinkCount[j]);
}
}
return nvlMesh;
#else
return {};
#endif
}
/**
* Determine if the devices form a hybrid cube mesh
* topology given a NvlMesh.
*/
static bool isHybridCubeMesh(const NvlMesh nvlMesh) {
std::array<size_t, kMaxDevices> numNeighbors = {};
for (size_t i = 0; i < kMaxDevices; ++i) {
for (size_t j = 0; j < kMaxDevices; ++j) {
if (nvlMesh[i][j] > 0) {
numNeighbors[i] += 1;
}
}
}
for (size_t i = 0; i < kMaxDevices; ++i) {
// TODO: this is insufficent and needs revisit
if (numNeighbors[i] != 4) {
return false;
}
}
return true;
}
/**
* Detech topology given a NvlMesh.
*/
static Topology detectTopology(const NvlMesh nvlMesh, size_t worldSize) {
if (getCvarBool(TEST_INTRA_NODE_COMM, false)) {
return Topology::FULLY_CONNECTED;
}
bool fullyConnected = true;
for (size_t i = 0; i < worldSize - 1; ++i) {
for (size_t j = i + 1; j < worldSize; ++j) {
if (nvlMesh[i][j] == 0 || nvlMesh[j][i] == 0) {
fullyConnected = false;
}
}
}
if (fullyConnected) {
LOG(INFO) << "IntraNodeComm: Topology::FULLY_CONNECTED";
return Topology::FULLY_CONNECTED;
}
if (worldSize == kMaxDevices && getHybridCubeMesh(nvlMesh) != std::nullopt) {
LOG(INFO) << "IntraNodeComm: Topology::HYBRID_CUBE_MESH";
return Topology::HYBRID_CUBE_MESH;
}
LOG(INFO) << "IntraNodeComm: Topology::UNKNOWN";
return Topology::UNKNOWN;
};
////////////////////////////////////////////////////////////////////////////////
// Rendezvous and Initialization
////////////////////////////////////////////////////////////////////////////////
IntraNodeComm::IntraNodeComm(
Topology topology,
std::array<void*, kMaxDevices> p2pStates,
std::array<void*, kMaxDevices> buffers,
void* topoInfo,
size_t rank,
size_t worldSize)
: topology_(topology),
p2pStates_(p2pStates),
buffers_(buffers),
topoInfo_(topoInfo),
rank_(rank),
worldSize_(worldSize) {}
IntraNodeComm::~IntraNodeComm() {
// Intentionally releasing resources without synchronizing devices. The
// teardown logic is safe for propoerly sync'd user program. We don't want
// improperly sync'd user program to hang here.
for (size_t r = 0; r < worldSize_; ++r) {
if (r == rank_) {
continue;
}
AT_CUDA_CHECK(cudaIpcCloseMemHandle(p2pStates_[r]));
AT_CUDA_CHECK(cudaIpcCloseMemHandle(buffers_[r]));
}
AT_CUDA_CHECK(cudaFree(p2pStates_[rank_]));
AT_CUDA_CHECK(cudaFree(buffers_[rank_]));
if (topoInfo_ != nullptr) {
AT_CUDA_CHECK(cudaFree(topoInfo_));
}
}
/**
* Use c10d::Store to perform allgather on a trivially copyable type.
*/
template <typename T>
std::vector<T> storeAllGather(
c10::intrusive_ptr<c10d::Store> store,
const std::string& prefix,
size_t rank,
size_t worldSize,
T val) {
static_assert(std::is_trivially_copyable<T>::value);
std::vector<std::string> peerKeys;
for (size_t r = 0; r < worldSize; ++r) {
std::ostringstream oss;
oss << prefix << "-" << r;
peerKeys.push_back(oss.str());
}
{
std::vector<uint8_t> payload(
reinterpret_cast<uint8_t*>(&val),
reinterpret_cast<uint8_t*>(&val) + sizeof(T));
store->set(peerKeys[rank], payload);
}
std::vector<T> peerVals;
for (size_t r = 0; r < worldSize; ++r) {
if (r == rank) {
peerVals.push_back(val);
continue;
}
store->wait({peerKeys[r]});
auto payload = store->get(peerKeys[r]);
TORCH_CHECK(payload.size() == sizeof(T));
T peerVal;
std::memcpy(&peerVal, payload.data(), sizeof(T));
peerVals.push_back(peerVal);
}
return peerVals;
}
c10::intrusive_ptr<IntraNodeComm> IntraNodeComm::rendezvous(
c10::intrusive_ptr<c10d::Store> store,
const std::string& prefix,
size_t rank,
size_t worldSize) {
#if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED)
if (!isIntraNodeCommSupported() ||
!getCvarBool(ENABLE_INTRA_NODE_COMM, false) || worldSize < 2 ||
worldSize > kMaxDevices) {
return nullptr;
}
int deviceIdx = at::cuda::current_device();
c10::cuda::CUDAGuard guard(deviceIdx);
// First hand shake: exchange hostname and device bus ID
struct DevInfo {
char hostname[HOST_NAME_MAX + 1];
char busId[80];
};
DevInfo devInfo{};
gethostname(devInfo.hostname, sizeof(devInfo.hostname));
cudaDeviceProp prop{};
AT_CUDA_CHECK(cudaGetDeviceProperties(&prop, deviceIdx));
snprintf(
devInfo.busId,
sizeof(devInfo.busId),
NVML_DEVICE_PCI_BUS_ID_FMT,
prop.pciDomainID,
prop.pciBusID,
prop.pciDeviceID);
auto peerDevInfos = storeAllGather(
store, prefix + "-IntraNodeCommHandShake-0", rank, worldSize, devInfo);
std::vector<std::string> rankToBusId;
for (const auto& info : peerDevInfos) {
if (strcmp(info.hostname, peerDevInfos.front().hostname) != 0) {
LOG(WARNING) << "Aborting IntraNodeComm::rendezvous because some "
"participants are not on the same host ("
<< info.hostname << ", " << devInfo.hostname << ")";
return nullptr;
}
rankToBusId.emplace_back(info.busId);
}
// Verify unique devices
{
std::unordered_set uniqueBusIds(rankToBusId.begin(), rankToBusId.end());
TORCH_CHECK(
uniqueBusIds.size() == worldSize,
"IntraNodeComm::rendezvous: detected overlapping devices across ranks. "
"Please properly set device via torch.cuda.set_device() before "
"initiating rendezvous.");
}
// Query nvlink connection
auto nvlMesh = getNvlMesh(rankToBusId);
// Detect topology
Topology topology = detectTopology(nvlMesh, worldSize);
// Initialize p2p state
auto p2pState = initP2pState();
// Allocate buffer
void* buffer = nullptr;
AT_CUDA_CHECK(cudaMalloc(&buffer, kMaxIntraNodeSize * 2));
// Second handshake: exchange topology and CUDA IPC handles
struct IpcInfo {
NvlMesh nvlMesh;
Topology topology;
cudaIpcMemHandle_t p2pStateHandle, bufferHandle;
};
// Make p2p state and buffer available for IPC
cudaIpcMemHandle_t p2pStateHandle, bufferHandle;
AT_CUDA_CHECK(cudaIpcGetMemHandle(&p2pStateHandle, p2pState));
AT_CUDA_CHECK(cudaIpcGetMemHandle(&bufferHandle, buffer));
IpcInfo ipcInfo{
.nvlMesh = nvlMesh,
.topology = topology,
.p2pStateHandle = p2pStateHandle,
.bufferHandle = bufferHandle};
auto peerIpcInfos = storeAllGather(
store, prefix + "-IntraNodeCommHandShake-2", rank, worldSize, ipcInfo);
for (const auto& info : peerIpcInfos) {
if (!isSame(info.nvlMesh, peerIpcInfos.front().nvlMesh) ||
info.topology != peerIpcInfos.front().topology) {
LOG(WARNING) << "Aborting IntraNodeComm::rendezvous because some "
"participants are observing different topologies ("
<< int(info.topology) << " and " << int(topology) << ")";
AT_CUDA_CHECK(cudaFree(p2pState));
AT_CUDA_CHECK(cudaFree(buffer));
return nullptr;
}
}
std::array<void*, kMaxDevices> p2pStates = {}, buffers = {};
for (size_t r = 0; r < peerIpcInfos.size(); ++r) {
if (r == rank) {
p2pStates[r] = p2pState;
buffers[r] = buffer;
} else {
AT_CUDA_CHECK(cudaIpcOpenMemHandle(
&p2pStates[r],
peerIpcInfos[r].p2pStateHandle,
cudaIpcMemLazyEnablePeerAccess));
AT_CUDA_CHECK(cudaIpcOpenMemHandle(
&buffers[r],
peerIpcInfos[r].bufferHandle,
cudaIpcMemLazyEnablePeerAccess));
}
}
void* topoInfo = initTopoInfo(topology, nvlMesh, rank);
return c10::make_intrusive<IntraNodeComm>(
topology, p2pStates, buffers, topoInfo, rank, worldSize);
#else
return nullptr;
#endif
}
AllReduceAlgo IntraNodeComm::selectAllReduceAlgo(const at::Tensor& input) {
return c10d::intra_node_comm::selectAllReduceAlgo(
input, topology_, worldSize_);
}
static int64_t usageCounter = 0;
at::Tensor IntraNodeComm::allReduce(
const at::Tensor& input,
AllReduceAlgo algo) {
// Report usage for testing purposes.
// We don't care about overflowing.
++usageCounter;
auto stream = at::cuda::getCurrentCUDAStream();
c10::cuda::CUDACachingAllocator::recordStream(
input.storage().data_ptr(), stream);
return c10d::intra_node_comm::allReduce(
input, p2pStates_, buffers_, topoInfo_, rank_, worldSize_, algo, stream);
}
int64_t getIntraNodeCommUsageCounter() {
return usageCounter;
}
} // namespace intra_node_comm
} // namespace c10d

View File

@ -1,708 +0,0 @@
#include <torch/csrc/distributed/c10d/intra_node_comm.hpp>
#include <ATen/Dispatch.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
namespace c10d {
namespace intra_node_comm {
static constexpr size_t kBytesPerThread = 16;
static constexpr size_t kMaxAllReduceBlocks = 24;
static constexpr size_t kThreadsPerBlock = 1024;
static constexpr size_t kWarpSize = 32;
static constexpr size_t kHcmThreshBytes = 256 * 1024;
static constexpr size_t kOneShotThreshBytes = 256 * 1024;
static constexpr size_t kTwoShotThreshBytes = 10 * 1024 * 1024;
#if defined(USE_ROCM)
using __nv_bfloat162 = uint32_t;
#endif
struct __align__(16) bf16x8 {
__nv_bfloat162 vals[4];
};
#define DEVICE_INLINE __device__ inline __attribute__((always_inline))
DEVICE_INLINE __nv_bfloat162
bf16hadd2(const __nv_bfloat162 x, const __nv_bfloat162 y) {
#if defined(USE_ROCM) || (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800))
CUDA_KERNEL_ASSERT(false);
#else
return __hadd2(x, y);
#endif
}
DEVICE_INLINE bf16x8 add_bf16x8(bf16x8 a, bf16x8 b) {
bf16x8 c;
c.vals[0] = bf16hadd2(a.vals[0], b.vals[0]);
c.vals[1] = bf16hadd2(a.vals[1], b.vals[1]);
c.vals[2] = bf16hadd2(a.vals[2], b.vals[2]);
c.vals[3] = bf16hadd2(a.vals[3], b.vals[3]);
return c;
}
/**
* NOTE [cross device memory synchronization]
*
* The multi-stage algorithms (e.g. two-shot, hcm allreduce) require the writes
* of a thread to be visible by threads with the same block/thread ID on other
* devices. To satisfy CUDA's memory consistency model, every thread has to
* release its writes at the system scope, and the consuming thread has to
* acquire the writes at the system scope. This incurs high overhead and
* attempts in optmizing this process can be prone to race condition.
*
* Instead, we go around caching by having each thread:
*
* - Directly write to global memory via st.cs (cache-streaming).
* - Synchronize with threads within the block.
* - Perform cross device synchronization at block level (via system scope
* atomic ops).
* - Synchronize with threads within the block.
* - Directly read from global memory via ld.nc (non-coherent/non-cached).
*/
template <typename T>
DEVICE_INLINE void streamLoad128(bf16x8& val, const T* addr) {
#if defined(USE_ROCM) || (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800))
CUDA_KERNEL_ASSERT(false);
#else
unsigned long long int low, high;
asm("ld.global.nc.v2.u64 {%0, %1}, [%2];"
: "=l"(low), "=l"(high)
: "l"(addr));
reinterpret_cast<unsigned long long int*>(&val)[0] = low;
reinterpret_cast<unsigned long long int*>(&val)[1] = high;
#endif
}
__device__ inline void streamStore128(at::BFloat16* addr, const bf16x8& val) {
#if defined(USE_ROCM) || (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800))
CUDA_KERNEL_ASSERT(false);
#else
unsigned long long int low, high;
low = reinterpret_cast<const unsigned long long int*>(&val)[0];
high = reinterpret_cast<const unsigned long long int*>(&val)[1];
asm("st.global.cs.v2.u64 [%0], {%1, %2};" : : "l"(addr), "l"(low), "l"(high));
#endif
}
template <typename T>
DEVICE_INLINE void load128(bf16x8& val, const T* addr) {
*reinterpret_cast<uint4*>(&val) = reinterpret_cast<const uint4*>(addr)[0];
}
template <typename T>
DEVICE_INLINE void store128(T* addr, const bf16x8& val) {
*reinterpret_cast<uint4*>(addr) = reinterpret_cast<const uint4*>(&val)[0];
}
DEVICE_INLINE void releaseSignal(uint32_t* addr) {
#if defined(USE_ROCM) || (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800))
CUDA_KERNEL_ASSERT(false);
#else
atomicAdd_system(addr, 1);
#endif
}
DEVICE_INLINE void acquireSignal(uint32_t* addr) {
#if defined(USE_ROCM) || (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800))
CUDA_KERNEL_ASSERT(false);
#else
volatile uint32_t* signal = addr;
uint32_t val;
do {
val = *signal;
} while (val == 0 || atomicCAS_system(addr, val, val - 1) != val);
#endif
}
////////////////////////////////////////////////////////////////////////////////
// Fully Connected Algos
////////////////////////////////////////////////////////////////////////////////
struct P2pState {
uint32_t signals0[kMaxAllReduceBlocks][kMaxDevices];
uint32_t signals1[kMaxAllReduceBlocks][kMaxDevices];
};
template <uint32_t kWorldSize, bool kAligned>
static __global__ void oneShotAllReduceKernel(
at::BFloat16* input,
size_t N,
size_t N_aligned,
std::array<P2pState*, kMaxDevices> p2pStates,
std::array<at::BFloat16*, kMaxDevices> buffers,
size_t rank) {
const size_t numelPerThread = kBytesPerThread / sizeof(at::BFloat16);
const size_t offset =
(blockDim.x * blockIdx.x + threadIdx.x) * numelPerThread;
const size_t stride = blockDim.x * gridDim.x * numelPerThread;
// Wait for all other ranks to enter the kernel
if (threadIdx.x < kWorldSize) {
auto targetRank = threadIdx.x;
releaseSignal(&p2pStates[targetRank]->signals0[blockIdx.x][rank]);
acquireSignal(&p2pStates[rank]->signals0[blockIdx.x][targetRank]);
}
__syncthreads();
// The source pointers. Distributed round-robin for the different warps
const at::BFloat16* srcs[kWorldSize];
#pragma unroll kWorldSize
for (int ii = 0; ii < kWorldSize; ++ii) {
int srcRank = (rank + ii) % kWorldSize;
srcs[ii] = buffers[srcRank];
}
for (size_t i = offset; i < N_aligned; i += stride) {
bf16x8 vals[kWorldSize];
#pragma unroll kWorldSize
for (size_t ii = 0; ii < kWorldSize; ++ii) {
streamLoad128(vals[ii], &srcs[ii][i]);
}
bf16x8 sums;
memset(reinterpret_cast<void*>(&sums), 0, sizeof(sums));
#pragma unroll kWorldSize
for (size_t ii = 0; ii < kWorldSize; ++ii) {
sums = add_bf16x8(sums, vals[ii]);
}
if constexpr (kAligned) {
streamStore128(&input[i], sums);
} else {
for (size_t ii = 0; ii < numelPerThread; ++ii) {
if (i + ii < N) {
input[i + ii] = reinterpret_cast<at::BFloat16*>(&sums)[ii];
}
}
}
}
}
template <uint32_t kWorldSize>
static __launch_bounds__(1024) __global__ void twoShotAllReduceKernel(
at::BFloat16* input,
size_t N_aligned,
std::array<P2pState*, kMaxDevices> p2pStates,
std::array<at::BFloat16*, kMaxDevices> buffers,
size_t rank) {
const size_t numelPerThread = kBytesPerThread / sizeof(at::BFloat16);
const size_t offset =
(blockDim.x * blockIdx.x + threadIdx.x) * numelPerThread;
const size_t stride = blockDim.x * gridDim.x * numelPerThread;
const size_t N_per_rank = N_aligned / kWorldSize;
const size_t N_start = N_per_rank * rank;
// Wait for all other ranks to enter the kernel
if (threadIdx.x < kWorldSize) {
auto targetRank = threadIdx.x;
releaseSignal(&p2pStates[targetRank]->signals0[blockIdx.x][rank]);
acquireSignal(&p2pStates[rank]->signals0[blockIdx.x][targetRank]);
}
__syncthreads();
// The source pointers. Distributed round-robin for the different warps
at::BFloat16* srcs[kWorldSize];
size_t srcRanks[kWorldSize];
#pragma unroll kWorldSize
for (int ii = 0; ii < kWorldSize; ++ii) {
int srcRank = (rank + ii) % kWorldSize;
srcs[ii] = buffers[srcRank];
srcRanks[ii] = srcRank;
}
for (size_t i = offset; i < N_per_rank; i += stride) {
bf16x8 vals[kWorldSize];
#pragma unroll kWorldSize
for (size_t ii = 0; ii < kWorldSize; ++ii) {
streamLoad128(vals[ii], &srcs[ii][N_start + i]);
}
bf16x8 sums;
memset(reinterpret_cast<void*>(&sums), 0, sizeof(sums));
#pragma unroll kWorldSize
for (size_t ii = 0; ii < kWorldSize; ++ii) {
sums = add_bf16x8(sums, vals[ii]);
}
streamStore128(&srcs[0][N_start + i], sums);
// Store local sums into input now so we can avoid
// a global memory access later for it.
streamStore128(&input[N_start + i], sums);
}
__syncthreads();
if (threadIdx.x < kWorldSize) {
auto targetRank = threadIdx.x;
releaseSignal(&p2pStates[targetRank]->signals1[blockIdx.x][rank]);
acquireSignal(&p2pStates[rank]->signals1[blockIdx.x][targetRank]);
}
__syncthreads();
for (size_t i = offset; i < N_per_rank; i += stride) {
#pragma unroll kWorldSize - 1
for (size_t ii = 1; ii < kWorldSize; ++ii) {
size_t k = N_start + i + (srcRanks[ii] - rank) * N_per_rank;
bf16x8 val;
streamLoad128(val, &srcs[ii][k]);
streamStore128(&input[k], val);
}
}
}
////////////////////////////////////////////////////////////////////////////////
// Hybrid Cube Mesh Algos
////////////////////////////////////////////////////////////////////////////////
/**
* NOTE [hybrid cube mesh]
*
* In a hybrid cube mesh topology, every device has exactly 4 neighbors
* (directly connected via NVLink). For every device X, it has exactly 1
* neighbor Y that is a neighbor of the 3 non-neighbor of X. We call Y the
* relay neighbor of X. This property is symmetrical: X is also guaranteed to
* be the relay neighbor of Y.
*
* With this property, we can perform a variant of one-shot allreduce algo that
* only moves data across NVLinks:
*
* - Each device one-shot allreduce among itself and 3 non-relay neighbors.
* - Each device exchange data with its relay neighbor.
*
* HybridCubeMesh is a data structure for describing the topology:
*
* - hcm[X][0:3] are the 3 neighbors of X.
* - hcm[X][3] is the relay neighbor of X.
* - For load balancing purpose, we also ensure that if hcm[X][k] = Y,
* hcm[Y][k] = X.
*/
std::optional<HybridCubeMesh> getHybridCubeMesh(NvlMesh nvlMesh) {
std::array<std::unordered_set<size_t>, kMaxDevices> neighbors = {};
std::array<size_t, kMaxDevices> neighborMasks = {};
for (size_t i = 0; i < kMaxDevices; ++i) {
for (size_t j = 0; j < kMaxDevices; ++j) {
if (nvlMesh[i][j] > 0) {
neighbors[i].insert(j);
neighborMasks[i] |= (1ul << j);
}
}
}
HybridCubeMesh hcm = {};
for (auto& row : hcm) {
row.fill(-1);
}
// A topology is an HCM if:
// - Every device has exactly 4 neighbors.
// - For every device, it has exactly 1 relay neighbor that is
// a neighbor of the 3 non-neighbor of the device.
for (size_t i = 0; i < kMaxDevices; ++i) {
if (neighbors[i].size() != 4) {
return std::nullopt;
}
// Condition 1: check the number of neighbors
std::vector<size_t> relayNeighbors;
for (size_t j = 0; j < kMaxDevices; ++j) {
if ((neighborMasks[i] & neighborMasks[j]) == 0) {
relayNeighbors.push_back(j);
}
}
// Condition 2: check the number of relay neighbors
if (relayNeighbors.size() != 1) {
return std::nullopt;
}
neighbors[i].erase(relayNeighbors[0]);
hcm[i][3] = relayNeighbors[0];
}
for (size_t i = 0; i < kMaxDevices; ++i) {
for (size_t k = 0; k < 3; ++k) {
// We can only fill hcm[i][k] with j if hcm[j][k] is not filled
for (size_t j : neighbors[i]) {
if (hcm[j][k] == -1) {
hcm[i][k] = j;
hcm[j][k] = i;
break;
}
}
TORCH_CHECK(hcm[i][k] != -1);
neighbors[i].erase(hcm[i][k]);
}
}
return hcm;
}
template <bool kAligned>
static __global__ void hybridCubeMeshAllReduceKernel(
at::BFloat16* input,
size_t N,
size_t N_aligned,
std::array<P2pState*, kMaxDevices> p2pStates,
std::array<at::BFloat16*, kMaxDevices> buffers,
int hcmInfo[4],
size_t rank) {
const size_t numelPerThread = kBytesPerThread / sizeof(at::BFloat16);
const size_t offset =
(blockDim.x * blockIdx.x + threadIdx.x) * numelPerThread;
const size_t stride = blockDim.x * gridDim.x * numelPerThread;
const int relayRank = hcmInfo[3];
// Wait for HCM neigbors to enter the kernel
if (threadIdx.x < 3) {
auto targetRank = hcmInfo[threadIdx.x];
releaseSignal(&p2pStates[targetRank]->signals0[blockIdx.x][rank]);
acquireSignal(&p2pStates[rank]->signals0[blockIdx.x][targetRank]);
}
__syncthreads();
const at::BFloat16* srcs[4] = {
buffers[rank],
buffers[hcmInfo[0]],
buffers[hcmInfo[1]],
buffers[hcmInfo[2]],
};
at::BFloat16* localRelay = buffers[rank] + kMaxIntraNodeSize / 2;
at::BFloat16* remoteRelay = buffers[relayRank] + kMaxIntraNodeSize / 2;
for (size_t i = offset; i < N_aligned; i += stride) {
bf16x8 vals[4];
#pragma unroll 4
for (size_t ii = 0; ii < 4; ++ii) {
streamLoad128(vals[ii], &srcs[ii][i]);
}
bf16x8 sums;
memset(reinterpret_cast<void*>(&sums), 0, sizeof(sums));
#pragma unroll 4
for (size_t ii = 0; ii < 4; ++ii) {
sums = add_bf16x8(sums, vals[ii]);
}
// Cached store for local sums
store128(&localRelay[i], sums);
}
__syncthreads();
if (threadIdx.x == 0) {
releaseSignal(&p2pStates[relayRank]->signals0[blockIdx.x][rank]);
acquireSignal(&p2pStates[rank]->signals0[blockIdx.x][relayRank]);
}
__syncthreads();
for (size_t i = offset; i < N_aligned; i += stride) {
bf16x8 localSum, remoteSum;
// Cached load for local sums
load128(localSum, &localRelay[i]);
streamLoad128(remoteSum, &remoteRelay[i]);
localSum = add_bf16x8(localSum, remoteSum);
if constexpr (kAligned) {
streamStore128(&input[i], localSum);
} else {
for (size_t ii = 0; ii < numelPerThread; ++ii) {
if (i + ii < N) {
input[i + ii] = reinterpret_cast<at::BFloat16*>(&localSum)[ii];
}
}
}
}
}
static inline size_t divUp(uint32_t a, uint32_t b) {
return (a + b - 1) / b;
}
static inline size_t alignUp(uint32_t a, uint32_t b) {
return divUp(a, b) * b;
}
static void checkInput(const at::Tensor& input, size_t rank) {
TORCH_CHECK(
input.dtype() == at::kBFloat16,
"oneShotAllReduce only supports bf16 for now");
TORCH_CHECK(input.is_non_overlapping_and_dense());
TORCH_CHECK(input.device().is_cuda());
TORCH_CHECK(static_cast<size_t>(input.get_device()) == rank);
}
static void getLaunchConfig(
size_t N_aligned,
size_t elemSize,
dim3& blocks,
dim3& threads) {
blocks = dim3(0, 1, 1);
threads = dim3(0, 1, 1);
const auto numelPerThread = kBytesPerThread / elemSize;
const auto numelPerWarp = numelPerThread * kWarpSize;
TORCH_CHECK(N_aligned % numelPerThread == 0);
TORCH_CHECK(N_aligned % numelPerWarp == 0);
if (N_aligned < numelPerThread * kThreadsPerBlock) {
threads.x = N_aligned / numelPerWarp * kWarpSize;
blocks.x = 1;
} else {
auto warpsRequired = N_aligned / numelPerWarp;
auto threadsRequired = N_aligned / numelPerThread;
blocks.x =
std::min(divUp(threadsRequired, kThreadsPerBlock), kMaxAllReduceBlocks);
auto warpsPerBlock = divUp(warpsRequired, blocks.x);
threads.x = std::min(kThreadsPerBlock, warpsPerBlock * kWarpSize);
}
}
template <typename T>
static auto castArr(std::array<void*, kMaxDevices> arr) {
std::array<T, kMaxDevices> arr_;
for (size_t i = 0; i < kMaxDevices; ++i) {
arr_[i] = reinterpret_cast<T>(arr[i]);
}
return arr_;
}
bool isIntraNodeCommSupported() {
#if defined(USE_ROCM) || (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800))
return false;
#else
return true;
#endif
}
void* initP2pState() {
void* state = nullptr;
AT_CUDA_CHECK(cudaMalloc(&state, sizeof(P2pState)));
AT_CUDA_CHECK(cudaMemset(state, 0, sizeof(P2pState)));
return state;
}
void* initTopoInfo(Topology topology, NvlMesh nvlMesh, size_t rank) {
void* topoInfo = nullptr;
if (topology != Topology::HYBRID_CUBE_MESH) {
return topoInfo;
}
auto hcm = getHybridCubeMesh(nvlMesh);
int hcmInfo[4];
std::copy((*hcm)[rank].begin(), (*hcm)[rank].begin() + 4, hcmInfo);
AT_CUDA_CHECK(cudaMalloc(&topoInfo, sizeof(hcmInfo)));
AT_CUDA_CHECK(
cudaMemcpy(topoInfo, hcmInfo, sizeof(hcmInfo), cudaMemcpyHostToDevice));
return topoInfo;
}
at::Tensor oneShotAllReduce(
const at::Tensor& input,
std::array<void*, kMaxDevices> p2pStates,
std::array<void*, kMaxDevices> buffers,
size_t rank,
size_t worldSize,
at::cuda::CUDAStream& stream) {
checkInput(input, rank);
size_t numelPerWarp = kBytesPerThread / input.element_size() * kWarpSize;
size_t N_aligned = alignUp(input.numel(), numelPerWarp);
TORCH_CHECK(N_aligned <= kMaxIntraNodeSize / input.element_size());
dim3 blocks, threads;
getLaunchConfig(N_aligned, input.element_size(), blocks, threads);
at::cuda::OptionalCUDAGuard guard(input.get_device());
AT_CUDA_CHECK(cudaMemcpyAsync(
buffers[rank],
input.data_ptr(),
input.numel() * input.element_size(),
cudaMemcpyDeviceToDevice,
stream));
#define X(kWorldSize, kAligned) \
if (worldSize == kWorldSize) { \
oneShotAllReduceKernel<kWorldSize, kAligned> \
<<<blocks, threads, 0, stream>>>( \
input.data_ptr<at::BFloat16>(), \
input.numel(), \
N_aligned, \
castArr<P2pState*>(p2pStates), \
castArr<at::BFloat16*>(buffers), \
rank); \
C10_CUDA_KERNEL_LAUNCH_CHECK(); \
}
#define DISPATCH_ALL_WORLD_SIZES(kAligned) \
X(2, kAligned); \
X(3, kAligned); \
X(4, kAligned); \
X(5, kAligned); \
X(6, kAligned); \
X(7, kAligned); \
X(8, kAligned);
if (N_aligned == static_cast<size_t>(input.numel())) {
DISPATCH_ALL_WORLD_SIZES(true);
} else {
DISPATCH_ALL_WORLD_SIZES(false);
}
#undef DISPATCH_ALL_WORLD_SIZES
#undef X
return input;
}
at::Tensor twoShotAllReduce(
const at::Tensor& input,
std::array<void*, kMaxDevices> p2pStates,
std::array<void*, kMaxDevices> buffers,
size_t rank,
size_t worldSize,
at::cuda::CUDAStream& stream) {
checkInput(input, rank);
size_t numelPerWarp = kBytesPerThread / input.element_size() * kWarpSize;
size_t N_aligned = alignUp(input.numel(), worldSize * numelPerWarp);
size_t N_per_rank = N_aligned / worldSize;
TORCH_CHECK(N_aligned <= kMaxIntraNodeSize / input.element_size());
dim3 blocks, threads;
getLaunchConfig(N_per_rank, input.element_size(), blocks, threads);
auto output = N_aligned == static_cast<size_t>(input.numel())
? input
: input.new_empty(N_aligned);
at::cuda::OptionalCUDAGuard guard(input.get_device());
AT_CUDA_CHECK(cudaMemcpyAsync(
buffers[rank],
input.data_ptr(),
input.numel() * input.element_size(),
cudaMemcpyDeviceToDevice,
stream));
#define X(kWorldSize) \
if (worldSize == kWorldSize) { \
twoShotAllReduceKernel<kWorldSize><<<blocks, threads, 0, stream>>>( \
output.data_ptr<at::BFloat16>(), \
N_aligned, \
castArr<P2pState*>(p2pStates), \
castArr<at::BFloat16*>(buffers), \
rank); \
C10_CUDA_KERNEL_LAUNCH_CHECK(); \
}
X(2);
X(3);
X(4);
X(5);
X(6);
X(7);
X(8);
#undef X
if (output.data_ptr() != input.data_ptr()) {
AT_CUDA_CHECK(cudaMemcpyAsync(
input.data_ptr(),
output.data_ptr(),
input.numel() * input.element_size(),
cudaMemcpyDeviceToDevice,
stream));
}
return input;
}
at::Tensor hybridCubeMeshAllReduce(
const at::Tensor& input,
std::array<void*, kMaxDevices> p2pStates,
std::array<void*, kMaxDevices> buffers,
int hcmInfo[4],
size_t rank,
size_t worldSize,
at::cuda::CUDAStream& stream) {
checkInput(input, rank);
size_t numelPerWarp = kBytesPerThread / input.element_size() * kWarpSize;
size_t N_aligned = alignUp(input.numel(), numelPerWarp);
TORCH_CHECK(N_aligned <= kMaxIntraNodeSize / input.element_size());
dim3 blocks, threads;
getLaunchConfig(N_aligned, input.element_size(), blocks, threads);
at::cuda::OptionalCUDAGuard guard(input.get_device());
AT_CUDA_CHECK(cudaMemcpyAsync(
buffers[rank],
input.data_ptr(),
input.numel() * input.element_size(),
cudaMemcpyDeviceToDevice,
stream));
#define X(kAligned) \
hybridCubeMeshAllReduceKernel<kAligned><<<blocks, threads, 0, stream>>>( \
input.data_ptr<at::BFloat16>(), \
input.numel(), \
N_aligned, \
castArr<P2pState*>(p2pStates), \
castArr<at::BFloat16*>(buffers), \
hcmInfo, \
rank); \
C10_CUDA_KERNEL_LAUNCH_CHECK();
if (N_aligned == static_cast<size_t>(input.numel())) {
X(true);
} else {
X(false);
}
#undef X
return input;
}
AllReduceAlgo selectAllReduceAlgo(
const at::Tensor& input,
Topology topology,
size_t worldSize) {
// Only support bf16 for now
if (input.dtype() != at::kBFloat16 ||
input.numel() * input.element_size() > kMaxIntraNodeSize) {
return AllReduceAlgo::NONE;
}
const auto numel = input.numel();
const auto numelPerWarp = kBytesPerThread / input.element_size() * kWarpSize;
if (topology == Topology::HYBRID_CUBE_MESH) {
TORCH_CHECK(
worldSize == 8, "hyperCubeAllReduce only supports exactly 8 GPUs");
if (alignUp(numel, numelPerWarp) <= kHcmThreshBytes) {
return AllReduceAlgo::HCM;
}
}
if (topology == Topology::FULLY_CONNECTED) {
if (alignUp(numel, numelPerWarp) <= kOneShotThreshBytes) {
return AllReduceAlgo::ONE_SHOT;
}
if (alignUp(numel, numelPerWarp * worldSize) <= kTwoShotThreshBytes) {
return AllReduceAlgo::TWO_SHOT;
}
}
return AllReduceAlgo::NONE;
}
at::Tensor allReduce(
const at::Tensor& input,
std::array<void*, kMaxDevices> p2pStates,
std::array<void*, kMaxDevices> buffers,
void* topoInfo,
size_t rank,
size_t worldSize,
AllReduceAlgo algo,
at::cuda::CUDAStream& stream) {
switch (algo) {
case AllReduceAlgo::ONE_SHOT:
return oneShotAllReduce(
input, p2pStates, buffers, rank, worldSize, stream);
case AllReduceAlgo::TWO_SHOT:
return twoShotAllReduce(
input, p2pStates, buffers, rank, worldSize, stream);
case AllReduceAlgo::HCM:
return hybridCubeMeshAllReduce(
input, p2pStates, buffers, (int*)topoInfo, rank, worldSize, stream);
default:
C10_THROW_ERROR(ValueError, "IntraNodeComm: invalid algo");
}
}
} // namespace intra_node_comm
} // namespace c10d

View File

@ -1,102 +0,0 @@
#pragma once
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAEvent.h>
#include <c10/cuda/CUDAStream.h>
#include <torch/csrc/distributed/c10d/Store.hpp>
#include <torch/csrc/distributed/c10d/Work.hpp>
namespace c10d {
namespace intra_node_comm {
constexpr size_t kMaxDevices = 8;
constexpr size_t kMaxIntraNodeSize = 10 * 1024 * 1024;
using NvlMesh = std::array<std::array<size_t, kMaxDevices>, kMaxDevices>;
using HybridCubeMesh = std::array<std::array<int, 4>, kMaxDevices>;
enum class Topology { UNKNOWN = 0, FULLY_CONNECTED = 1, HYBRID_CUBE_MESH = 2 };
enum class AllReduceAlgo { NONE = 0, ONE_SHOT = 1, TWO_SHOT = 2, HCM = 3 };
class TORCH_API IntraNodeComm : public c10::intrusive_ptr_target {
public:
IntraNodeComm(
Topology topology,
std::array<void*, kMaxDevices> p2pStates,
std::array<void*, kMaxDevices> buffers,
void* topoInfo,
size_t rank,
size_t worldSize);
~IntraNodeComm();
/**
* Rendezvous via a c10d::Store.
* This function may return nullptr if intra-node comm is not applicable.
* It guarantees all participants either succeeds or abort.
*/
static c10::intrusive_ptr<IntraNodeComm> rendezvous(
c10::intrusive_ptr<c10d::Store> store,
const std::string& prefix,
size_t rank,
size_t worldSize);
/**
* Selects a AllReduceAlgo that we think will outperform nccl.
* Returns AllReduceAlgo::NONE if we don't think we can outperform nccl.
*/
AllReduceAlgo selectAllReduceAlgo(const at::Tensor& input);
at::Tensor allReduce(const at::Tensor& input, AllReduceAlgo algo);
private:
Topology topology_;
std::array<void*, kMaxDevices> p2pStates_;
std::array<void*, kMaxDevices> buffers_;
void* topoInfo_;
size_t rank_;
size_t worldSize_;
};
/**
* NOTE [IntraNodeComm Stream Semantics]
*
* ProcessGroupNCCL launches kernels differently from the conventional PyTorch
* CUDA semantics: it always launches collective kernels onto a dedicated
* communication stream. Therefore, it needs to:
*
* - Synchronize the calling stream and the comm stream.
* - Ensure the memory safety of the operands (via record_stream or stashing).
* - Synchronize the waiting stream with the comm stream.
*
* Unconditionally performing these tasks makes sense when we expect most of the
* communication to benefit from compute/comm overlap. However, IntraNodeComm
* primarily aims to optimize small, latency-sensitive, blocking communication,
* in which the overhead incurred by the above steps can be quite pronounced.
*
* Thus, IntraNodeComm follows the conventional PyTorch CUDA semantics and
* launches kernels onto the stream specified by the user. Although the user
* can perform neccessary synchronization via wait_stream, to provide a UX
* consistent to that of ProcessGroupNCCL, the neccessary stream
* synchronization can also be performed via IntraNodeWork::wait().
*/
class IntraNodeCommWork : public c10d::Work {
public:
IntraNodeCommWork() : c10d::Work() {
event_.record();
}
bool wait(std::chrono::milliseconds timeout = kNoTimeout) override {
event_.block(at::cuda::getCurrentCUDAStream());
return true;
}
private:
at::cuda::CUDAEvent event_;
};
TORCH_API int64_t getIntraNodeCommUsageCounter();
} // namespace intra_node_comm
} // namespace c10d