Files
pytorch/torch/csrc/distributed/c10d/intra_node_comm.cpp
Yifu Wang 4a09117d16 Introduce ProcessGroupCudaP2P (#122163)
## Context
This stack prototypes automatic micro-pipelining of `all-gather -> matmul` and `matmul -> reduce-scatter` via Inductor. The idea originates from the paper [Overlap Communication with Dependent Computation via
Decomposition in Large Deep Learning Models](https://dl.acm.org/doi/pdf/10.1145/3567955.3567959). The implementation and some key optimizations are heavily influenced by @lw's implementation in xformers.

The stack contains several components:
- `ProcessGroupCudaP2P` - a thin wrapper around `ProcessGroupNCCL`. It in addition maintains a P2P workspace that enables SM-free, one-sided P2P communication which is needed for optimal micro-pipelining.
- `fused_all_gather_matmul` and `fused_matmul_reduce_scatter` dispatcher ops.
- Post-grad fx pass that detects `all-gather -> matmul` and `matmul -> reduce-scatter` and replaces them with the fused dispatcher ops.

To enable the prototype feature:
- Set the distributed backend to `cuda_p2p`.
- Set `torch._inductor.config._micro_pipeline_tp` to `True`.

*NOTE: the prototype sets nothing in stone w.r.t to each component's design. The purpose is to have a performant baseline with reasonable design on which each component can be further improved.*

## Benchmark
Setup:
- 8 x H100 (500W) + 3rd gen NVSwitch.
- Llama3 8B training w/ torchtitan.
- 8-way TP. Reduced the number of layers from 32 to 8 for benchmarking purpose.

Trace (baseline): https://interncache-all.fbcdn.net/manifold/perfetto-artifacts/tree/ui/index.html#!/?url=https://interncache-all.fbcdn.net/manifold/perfetto_internal_traces/tree/shared_trace/yifu_tmpjaz8zgx0
<img width="832" alt="image" src="https://github.com/pytorch/pytorch/assets/4156752/4addba77-5abc-4d2e-93ea-f68078587fe1">

Trace (w/ micro pipelining): https://interncache-all.fbcdn.net/manifold/perfetto-artifacts/tree/ui/index.html#!/?url=https://interncache-all.fbcdn.net/manifold/perfetto_internal_traces/tree/shared_trace/yifu_tmpn073b4wn
<img width="963" alt="image" src="https://github.com/pytorch/pytorch/assets/4156752/4f44e78d-8196-43ab-a1ea-27390f07e9d2">

## This PR
`ProcessGroupCudaP2P` is a thin wrapper around `ProcessGroupNCCL`. By default, it routes all collectives to the underlying `ProcessGroupNCCL`. In addition, `ProcessGroupCudaP2P` initializes a P2P workspace that allows direct GPU memory access among the members. The workspace can be used in Python to optimize intra-node communication patterns or to create custom intra-node collectives in CUDA.

`ProcessGroupCudaP2P` aims to bridge the gap where certain important patterns can be better optimized via fine-grained P2P memory access than with collectives in the latest version of NCCL. It is meant to complement NCCL rather than replacing it.
Usage:
```
    # Using ProcessGroupCudaP2P
    dist.init_process_group(backend="cuda_p2p", ...)

    # Using ProcessGroupCudaP2P while specifying ProcessGroupCudaP2P.Options
    pg_options = ProcessGroupCudaP2P.Options()
    dist.init_process_group(backend="cuda_p2p", pg_options=pg_options, ...)

    # Using ProcessGroupCudaP2P while specifying ProcessGroupNCCL.Options
    pg_options = ProcessGroupNCCL.Options()
    dist.init_process_group(backend="cuda_p2p", pg_options=pg_options, ...)

    # Using ProcessGroupCudaP2P while specifying both
    # ProcessGroupCudaP2P.Options and ProcessGroupNCCL.Options
    pg_options = ProcessGroupCudaP2P.Options()
    pg_options.nccl_options = ProcessGroupNCCL.Options()
    dist.init_process_group(backend="cuda_p2p", pg_options=pg_options, ...)

    # Down-casting the backend to access p2p buffers for cuda_p2p specific
    # optimizations
    if is_cuda_p2p_group(group):
        backend = get_cuda_p2p_backend(group)
        if required_p2p_buffer_size > backend.get_buffer_size():
            # fallback
        p2p_buffer = backend.get_p2p_buffer(...)
    else:
        # fallback
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/122163
Approved by: https://github.com/wanchaol
2024-05-24 18:33:18 +00:00

431 lines
13 KiB
C++

#include <torch/csrc/distributed/c10d/intra_node_comm.hpp>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <c10/util/Logging.h>
#include <torch/csrc/distributed/c10d/Utils.hpp>
#include <iostream>
#include <utility>
#include <fcntl.h>
#include <pthread.h>
#include <semaphore.h>
#include <sys/mman.h>
#include <sys/stat.h>
#include <unistd.h>
#if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED)
#include <c10/cuda/driver_api.h>
#include <nvml.h>
#endif
#include <cuda_runtime.h>
namespace c10d::intra_node_comm {
static std::vector<std::string> ENABLE_INTRA_NODE_COMM = {
"ENABLE_INTRA_NODE_COMM"};
// Forces detectedTopology() to return Topology::FULLY_CONNECTED, so
// IntraNodeComm can be used even without NVLink connection. This is only used
// for testing purposes.
static std::vector<std::string> TEST_INTRA_NODE_COMM = {"TEST_INTRA_NODE_COMM"};
////////////////////////////////////////////////////////////////////////////////
// CUDA Functions
////////////////////////////////////////////////////////////////////////////////
bool isIntraNodeCommSupported();
std::optional<HybridCubeMesh> getHybridCubeMesh(NvlMesh nvlMesh);
void* initP2pState();
void* initTopoInfo(Topology topology, NvlMesh nvlMesh, size_t rank);
////////////////////////////////////////////////////////////////////////////////
// Topology Detection
////////////////////////////////////////////////////////////////////////////////
static std::ostream& operator<<(std::ostream& os, const NvlMesh& nvlMesh) {
std::ostringstream oss;
for (size_t i = 0; i < kMaxDevices; ++i) {
for (size_t j = 0; j < kMaxDevices; ++j) {
oss << nvlMesh[i][j] << " ";
}
oss << '\n';
}
os << oss.str();
return os;
}
static bool isSame(NvlMesh lhs, NvlMesh rhs) {
for (size_t i = 0; i < kMaxDevices; ++i) {
for (size_t j = 0; j < kMaxDevices; ++j) {
if (lhs[i][j] != rhs[i][j]) {
return false;
}
}
}
return true;
}
/**
* Query the nvlink connection among devices.
*/
static NvlMesh getNvlMesh(const std::vector<std::string>& rankToBusId) {
#if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED)
using namespace c10::cuda;
NvlMesh nvlMesh = {};
auto driverApi = DriverAPI::get();
if (driverApi == nullptr) {
return nvlMesh;
}
const auto worldSize = rankToBusId.size();
std::vector<nvmlDevice_t> devices(worldSize, nullptr);
std::unordered_map<std::string, size_t> busIdToRank;
std::vector<size_t> switchLinkCount(worldSize, 0);
for (size_t r = 0; r < worldSize; ++r) {
busIdToRank.emplace(rankToBusId[r], r);
TORCH_CHECK(
driverApi->nvmlDeviceGetHandleByPciBusId_v2_(
rankToBusId[r].c_str(), &devices[r]) == NVML_SUCCESS);
}
// TODO: find a better way to determine this
constexpr size_t kMaxNvLinks = 20;
// For each device, loop over devices connected to it via NVLink
for (size_t idx = 0; idx < worldSize; ++idx) {
for (size_t link = 0; link < kMaxNvLinks; ++link) {
nvmlReturn_t ret;
nvmlIntNvLinkDeviceType_t deviceType;
ret = driverApi->nvmlDeviceGetNvLinkRemoteDeviceType_(
devices[idx], link, &deviceType);
if (ret != NVML_SUCCESS) {
// We've exhausted the NVLinks connected to this device.
// This error is benign. There doesn't seem to be a reliable
// way to obtain the maximum link value that can be passed to
// the API, so we simply increment the link value until the
// API fails or we hit a predefined maximum value.
break;
}
// Remote device is GPU
if (deviceType == NVML_NVLINK_DEVICE_TYPE_GPU) {
nvmlPciInfo_t pciInfo;
ret = driverApi->nvmlDeviceGetNvLinkRemotePciInfo_v2_(
devices[idx], link, &pciInfo);
if (ret != NVML_SUCCESS) {
// Unexpected error. Return an empty NvlMesh
return {};
}
auto it = busIdToRank.find(pciInfo.busId);
if (it != busIdToRank.end()) {
if (idx != it->second) {
nvlMesh[idx][it->second] += 1;
}
}
// Remote device is NVSwitch
} else if (deviceType == NVML_NVLINK_DEVICE_TYPE_SWITCH) {
switchLinkCount[idx] += 1;
}
}
}
// Process NVSwitch connections. For simplicity, we assume
// all NVSwitches are interconnected.
for (size_t i = 0; i < worldSize; ++i) {
for (size_t j = 0; j < worldSize; ++j) {
if (i == j) {
continue;
}
nvlMesh[i][j] += std::min(switchLinkCount[i], switchLinkCount[j]);
}
}
return nvlMesh;
#else
return {};
#endif
}
/**
* Determine if the devices form a hybrid cube mesh
* topology given a NvlMesh.
*/
static bool isHybridCubeMesh(const NvlMesh nvlMesh) {
std::array<size_t, kMaxDevices> numNeighbors = {};
for (size_t i = 0; i < kMaxDevices; ++i) {
for (size_t j = 0; j < kMaxDevices; ++j) {
if (nvlMesh[i][j] > 0) {
numNeighbors[i] += 1;
}
}
}
for (size_t i = 0; i < kMaxDevices; ++i) {
// TODO: this is insufficent and needs revisit
if (numNeighbors[i] != 4) {
return false;
}
}
return true;
}
/**
* Detech topology given a NvlMesh.
*/
static Topology detectTopology(const NvlMesh nvlMesh, size_t worldSize) {
if (getCvarBool(TEST_INTRA_NODE_COMM, false)) {
return Topology::FULLY_CONNECTED;
}
bool fullyConnected = true;
for (size_t i = 0; i < worldSize - 1; ++i) {
for (size_t j = i + 1; j < worldSize; ++j) {
if (nvlMesh[i][j] == 0 || nvlMesh[j][i] == 0) {
fullyConnected = false;
}
}
}
if (fullyConnected) {
LOG(INFO) << "IntraNodeComm: Topology::FULLY_CONNECTED";
return Topology::FULLY_CONNECTED;
}
if (worldSize == kMaxDevices && getHybridCubeMesh(nvlMesh) != std::nullopt) {
LOG(INFO) << "IntraNodeComm: Topology::HYBRID_CUBE_MESH";
return Topology::HYBRID_CUBE_MESH;
}
LOG(INFO) << "IntraNodeComm: Topology::UNKNOWN";
return Topology::UNKNOWN;
};
////////////////////////////////////////////////////////////////////////////////
// Rendezvous and Initialization
////////////////////////////////////////////////////////////////////////////////
IntraNodeComm::IntraNodeComm(
c10::intrusive_ptr<c10d::Store> store,
size_t rank,
size_t worldSize,
std::optional<size_t> bufferSize)
: store_(std::move(store)),
rank_(rank),
worldSize_(worldSize),
bufferSize_(bufferSize.has_value() ? *bufferSize : kDefaultBufferSize),
barrierReady_(at::cuda::CUDAEvent()) {}
IntraNodeComm::~IntraNodeComm() {
if (!isInitialized_) {
return;
}
// Intentionally releasing resources without synchronizing devices. The
// teardown logic is safe for propoerly sync'd user program. We don't want
// improperly sync'd user program to hang here.
for (size_t r = 0; r < worldSize_; ++r) {
if (r == rank_) {
continue;
}
AT_CUDA_CHECK(cudaIpcCloseMemHandle(p2pStates_[r]));
AT_CUDA_CHECK(cudaIpcCloseMemHandle(buffers_[r]));
}
AT_CUDA_CHECK(cudaFree(p2pStates_[rank_]));
AT_CUDA_CHECK(cudaFree(buffers_[rank_]));
if (topoInfo_ != nullptr) {
AT_CUDA_CHECK(cudaFree(topoInfo_));
}
AT_CUDA_CHECK(cudaFree(p2pStatesDev_));
AT_CUDA_CHECK(cudaFree(buffersDev_));
}
bool IntraNodeComm::isEnabled() {
return getCvarBool(ENABLE_INTRA_NODE_COMM, false);
}
/**
* Use c10d::Store to perform allgather on a trivially copyable type.
*/
template <typename T>
std::vector<T> storeAllGather(
const c10::intrusive_ptr<c10d::Store>& store,
const std::string& prefix,
size_t rank,
size_t worldSize,
T val) {
static_assert(std::is_trivially_copyable_v<T>);
std::vector<std::string> peerKeys;
for (size_t r = 0; r < worldSize; ++r) {
std::ostringstream oss;
oss << prefix << "-" << r;
peerKeys.push_back(oss.str());
}
{
std::vector<uint8_t> payload(
reinterpret_cast<uint8_t*>(&val),
reinterpret_cast<uint8_t*>(&val) + sizeof(T));
store->set(peerKeys[rank], payload);
}
std::vector<T> peerVals;
for (size_t r = 0; r < worldSize; ++r) {
if (r == rank) {
peerVals.push_back(val);
continue;
}
store->wait({peerKeys[r]});
auto payload = store->get(peerKeys[r]);
TORCH_CHECK(payload.size() == sizeof(T));
T peerVal{};
std::memcpy(&peerVal, payload.data(), sizeof(T));
peerVals.push_back(peerVal);
}
return peerVals;
}
bool IntraNodeComm::rendezvous() {
if (isInitialized_) {
return true;
}
#if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED)
if (!isIntraNodeCommSupported() || worldSize_ < 2 ||
worldSize_ > kMaxDevices) {
return false;
}
auto deviceIdx = at::cuda::current_device();
c10::cuda::CUDAGuard guard(deviceIdx);
// First hand shake: exchange hostname and device bus ID
struct DevInfo {
char hostname[HOST_NAME_MAX + 1];
char busId[80];
};
DevInfo devInfo{};
gethostname(devInfo.hostname, sizeof(devInfo.hostname));
cudaDeviceProp prop{};
AT_CUDA_CHECK(cudaGetDeviceProperties(&prop, deviceIdx));
snprintf(
devInfo.busId,
sizeof(devInfo.busId),
NVML_DEVICE_PCI_BUS_ID_FMT,
prop.pciDomainID,
prop.pciBusID,
prop.pciDeviceID);
auto peerDevInfos =
storeAllGather(store_, "handshake-0", rank_, worldSize_, devInfo);
std::vector<std::string> rankToBusId;
for (const auto& info : peerDevInfos) {
if (strcmp(info.hostname, peerDevInfos.front().hostname) != 0) {
LOG(WARNING) << "Aborting IntraNodeComm::rendezvous because some "
"participants are not on the same host ("
<< info.hostname << ", " << devInfo.hostname << ")";
return false;
}
rankToBusId.emplace_back(info.busId);
}
// Verify unique devices
{
std::unordered_set uniqueBusIds(rankToBusId.begin(), rankToBusId.end());
TORCH_CHECK(
uniqueBusIds.size() == worldSize_,
"IntraNodeComm::rendezvous: detected overlapping devices across ranks. "
"Please properly set device via torch.cuda.set_device() before "
"initiating rendezvous.");
}
// Query nvlink connection
auto nvlMesh = getNvlMesh(rankToBusId);
// Detect topology
Topology topology = detectTopology(nvlMesh, worldSize_);
// Initialize p2p state
auto p2pState = initP2pState();
// Allocate buffer
void* buffer = nullptr;
AT_CUDA_CHECK(cudaMalloc(&buffer, bufferSize_));
// Second handshake: exchange topology and CUDA IPC handles
struct IpcInfo {
NvlMesh nvlMesh;
Topology topology;
cudaIpcMemHandle_t p2pStateHandle, bufferHandle;
};
// Make p2p state and buffer available for IPC
cudaIpcMemHandle_t p2pStateHandle, bufferHandle;
AT_CUDA_CHECK(cudaIpcGetMemHandle(&p2pStateHandle, p2pState));
AT_CUDA_CHECK(cudaIpcGetMemHandle(&bufferHandle, buffer));
IpcInfo ipcInfo{
.nvlMesh = nvlMesh,
.topology = topology,
.p2pStateHandle = p2pStateHandle,
.bufferHandle = bufferHandle};
auto peerIpcInfos =
storeAllGather(store_, "handshake-1", rank_, worldSize_, ipcInfo);
for (const auto& info : peerIpcInfos) {
if (!isSame(info.nvlMesh, peerIpcInfos.front().nvlMesh) ||
info.topology != peerIpcInfos.front().topology) {
LOG(WARNING) << "Aborting IntraNodeComm::rendezvous because some "
"participants are observing different topologies ("
<< int(info.topology) << " and " << int(topology) << ")";
AT_CUDA_CHECK(cudaFree(p2pState));
AT_CUDA_CHECK(cudaFree(buffer));
return false;
}
}
std::array<void*, kMaxDevices> p2pStates = {}, buffers = {};
for (size_t r = 0; r < peerIpcInfos.size(); ++r) {
if (r == rank_) {
p2pStates[r] = p2pState;
buffers[r] = buffer;
} else {
AT_CUDA_CHECK(cudaIpcOpenMemHandle(
&p2pStates[r],
peerIpcInfos[r].p2pStateHandle,
cudaIpcMemLazyEnablePeerAccess));
AT_CUDA_CHECK(cudaIpcOpenMemHandle(
&buffers[r],
peerIpcInfos[r].bufferHandle,
cudaIpcMemLazyEnablePeerAccess));
}
}
void* p2pStatesDev = nullptr;
AT_CUDA_CHECK(cudaMalloc(&p2pStatesDev, sizeof(p2pStates)));
AT_CUDA_CHECK(cudaMemcpy(
p2pStatesDev,
p2pStates.data(),
sizeof(p2pStates),
cudaMemcpyHostToDevice));
void* buffersDev = nullptr;
AT_CUDA_CHECK(cudaMalloc(&buffersDev, sizeof(buffers)));
AT_CUDA_CHECK(cudaMemcpy(
buffersDev, buffers.data(), sizeof(buffers), cudaMemcpyHostToDevice));
void* topoInfo = initTopoInfo(topology, nvlMesh, rank_);
isInitialized_ = true;
topology_ = topology;
std::copy(p2pStates.begin(), p2pStates.end(), p2pStates_.begin());
std::copy(buffers.begin(), buffers.end(), buffers_.begin());
p2pStatesDev_ = p2pStatesDev;
buffersDev_ = buffersDev;
topoInfo_ = topoInfo;
return true;
#endif
return false;
}
} // namespace c10d::intra_node_comm