[Distributed] Add custom allreduce support for ROCM (#14125)

Signed-off-by: ilmarkov <imarkov@redhat.com>
Co-authored-by: ilmarkov <imarkov@redhat.com>
This commit is contained in:
Ilya Markov
2025-04-01 07:49:12 +02:00
committed by GitHub
parent e6e3c55ef2
commit b7b7676d67
13 changed files with 373 additions and 160 deletions

View File

@ -242,6 +242,7 @@ set(VLLM_EXT_SRC
"csrc/quantization/gguf/gguf_kernel.cu"
"csrc/cuda_utils_kernels.cu"
"csrc/prepare_inputs/advance_step.cu"
"csrc/custom_all_reduce.cu"
"csrc/torch_bindings.cpp")
if(VLLM_GPU_LANG STREQUAL "CUDA")
@ -283,7 +284,6 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
"csrc/mamba/causal_conv1d/causal_conv1d.cu"
"csrc/quantization/aqlm/gemm_kernels.cu"
"csrc/quantization/awq/gemm_kernels.cu"
"csrc/custom_all_reduce.cu"
"csrc/permute_cols.cu"
"csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu"
"csrc/quantization/fp4/nvfp4_quant_entry.cu"

View File

@ -12,7 +12,7 @@ static_assert(sizeof(void*) == sizeof(fptr_t));
fptr_t init_custom_ar(const std::vector<fptr_t>& fake_ipc_ptrs,
torch::Tensor& rank_data, int64_t rank,
bool full_nvlink) {
bool fully_connected) {
int world_size = fake_ipc_ptrs.size();
if (world_size > 8)
throw std::invalid_argument("world size > 8 is not supported");
@ -27,7 +27,7 @@ fptr_t init_custom_ar(const std::vector<fptr_t>& fake_ipc_ptrs,
}
return (fptr_t) new vllm::CustomAllreduce(ipc_ptrs, rank_data.data_ptr(),
rank_data.numel(), rank, world_size,
full_nvlink);
fully_connected);
}
/**
@ -142,3 +142,48 @@ void register_graph_buffers(fptr_t _fa,
bytes.reserve(handles.size());
fa->register_graph_buffers(bytes, offsets);
}
std::tuple<fptr_t, torch::Tensor> allocate_shared_buffer_and_handle(
int64_t size) {
auto device_index = c10::cuda::current_device();
at::DeviceGuard device_guard(at::Device(at::DeviceType::CUDA, device_index));
void* buffer;
cudaStreamCaptureMode mode = cudaStreamCaptureModeRelaxed;
auto stream = c10::cuda::getCurrentCUDAStream().stream();
AT_CUDA_CHECK(cudaThreadExchangeStreamCaptureMode(&mode));
// Allocate buffer
#if defined(USE_ROCM)
// data buffers need to be "uncached" for signal on MI200
AT_CUDA_CHECK(
hipExtMallocWithFlags((void**)&buffer, size, hipDeviceMallocUncached));
#else
AT_CUDA_CHECK(cudaMalloc((void**)&buffer, size));
#endif
AT_CUDA_CHECK(cudaMemsetAsync(buffer, 0, size, stream));
AT_CUDA_CHECK(cudaStreamSynchronize(stream));
AT_CUDA_CHECK(cudaThreadExchangeStreamCaptureMode(&mode));
// Create IPC memhandle for the allocated buffer.
// Will use it in open_mem_handle.
auto options =
torch::TensorOptions().dtype(torch::kUInt8).device(torch::kCPU);
auto handle =
torch::empty({static_cast<int64_t>(sizeof(cudaIpcMemHandle_t))}, options);
AT_CUDA_CHECK(
cudaIpcGetMemHandle((cudaIpcMemHandle_t*)handle.data_ptr(), buffer));
return std::make_tuple(reinterpret_cast<fptr_t>(buffer), handle);
}
fptr_t open_mem_handle(torch::Tensor& mem_handle) {
void* ipc_ptr;
AT_CUDA_CHECK(cudaIpcOpenMemHandle(
(void**)&ipc_ptr, *((const cudaIpcMemHandle_t*)mem_handle.data_ptr()),
cudaIpcMemLazyEnablePeerAccess));
return reinterpret_cast<fptr_t>(ipc_ptr);
}
void free_shared_buffer(fptr_t buffer) {
AT_CUDA_CHECK(cudaFree(reinterpret_cast<void*>(buffer)));
}

View File

@ -5,6 +5,10 @@
#include <cuda_fp16.h>
#include <cuda_runtime.h>
#if defined(USE_ROCM)
typedef __hip_bfloat16 nv_bfloat16;
#endif
#include <iostream>
#include <array>
#include <limits>
@ -12,6 +16,7 @@
#include <unordered_map>
#include <vector>
namespace vllm {
#define CUDACHECK(cmd) \
do { \
cudaError_t e = cmd; \
@ -22,24 +27,37 @@
} \
} while (0)
namespace vllm {
// Maximal number of blocks in allreduce kernel.
constexpr int kMaxBlocks = 36;
// Default number of blocks in allreduce kernel.
#ifndef USE_ROCM
const int defaultBlockLimit = 36;
CUpointer_attribute rangeStartAddrAttr = CU_POINTER_ATTRIBUTE_RANGE_START_ADDR;
#else
const int defaultBlockLimit = 16;
hipPointer_attribute rangeStartAddrAttr =
HIP_POINTER_ATTRIBUTE_RANGE_START_ADDR;
#endif
// Counter may overflow, but it's fine since unsigned int overflow is
// well-defined behavior.
using FlagType = uint32_t;
// Two sets of peer counters are needed for two syncs: starting and ending an
// operation. The reason is that it's possible for peer GPU block to arrive at
// the second sync point while the current GPU block haven't passed the first
// sync point. Thus, peer GPU may write counter+1 while current GPU is busy
// waiting for counter. We use alternating counter array to avoid this
// possibility.
struct Signal {
alignas(128) FlagType self_counter[kMaxBlocks][8];
// Two sets of peer counters are needed for two syncs. The reason is that
// it's possible for peer GPU block to arrive at the second sync point while
// the current GPU block haven't passed the first sync point. Thus, peer GPU
// may write counter+1 while current GPU is busy waiting for counter. We use
// alternating counter array to avoid this possibility.
alignas(128) FlagType peer_counter[2][kMaxBlocks][8];
alignas(128) FlagType start[kMaxBlocks][8];
alignas(128) FlagType end[kMaxBlocks][8];
alignas(128) FlagType _flag[kMaxBlocks]; // incremental flags for each rank
};
struct __align__(16) RankData {
const void* __restrict__ ptrs[8];
const void* ptrs[8];
};
struct __align__(16) RankSignals {
@ -134,27 +152,29 @@ DINLINE O downcast(array_t<float, O::size> val) {
}
}
#if !defined(USE_ROCM)
static DINLINE void st_flag_release(FlagType* flag_addr, FlagType flag) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700
asm volatile("st.release.sys.global.u32 [%1], %0;" ::"r"(flag),
"l"(flag_addr));
#else
#else
asm volatile("membar.sys; st.volatile.global.u32 [%1], %0;" ::"r"(flag),
"l"(flag_addr));
#endif
#endif
}
static DINLINE FlagType ld_flag_acquire(FlagType* flag_addr) {
FlagType flag;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700
asm volatile("ld.acquire.sys.global.u32 %0, [%1];"
: "=r"(flag)
: "l"(flag_addr));
#else
#else
asm volatile("ld.volatile.global.u32 %0, [%1]; membar.gl;"
: "=r"(flag)
: "l"(flag_addr));
#endif
#endif
return flag;
}
@ -170,37 +190,99 @@ static DINLINE FlagType ld_flag_volatile(FlagType* flag_addr) {
return flag;
}
// is_start: whether this is the very first synchronization barrier.
// need_fence: whether a memory fence is needed. If true, a release-acquire
// semantic is used to enforce memory access order before and after this
// barrier.
template <int ngpus, bool is_start, bool need_fence = false>
DINLINE void multi_gpu_barrier(const RankSignals& sg, Signal* self_sg,
int rank) {
if constexpr (!is_start) __syncthreads();
static_assert(
!(is_start && need_fence)); // Start barrier shouldn't need fence.
// This function is meant to be used as the first synchronization in the all
// reduce kernel. Thus, it doesn't need to make any visibility guarantees for
// prior memory accesses. Note: volatile writes will not be reordered against
// other volatile writes.
template <int ngpus>
DINLINE void barrier_at_start(const RankSignals& sg, Signal* self_sg,
int rank) {
uint32_t flag = self_sg->_flag[blockIdx.x] + 1;
if (threadIdx.x < ngpus) {
// Increment the counter. Technically we only need one counter, but we use
// multiple per block to eliminate the need to share the counter via smem.
auto val = self_sg->self_counter[blockIdx.x][threadIdx.x] += 1;
auto peer_counter_ptr = &sg.signals[threadIdx.x]->start[blockIdx.x][rank];
auto self_counter_ptr = &self_sg->start[blockIdx.x][threadIdx.x];
// Write the expected counter value to peer and wait for correct value
// from peer.
st_flag_volatile(peer_counter_ptr, flag);
while (ld_flag_volatile(self_counter_ptr) != flag);
}
__syncthreads();
// use one thread to update flag
if (threadIdx.x == 0) self_sg->_flag[blockIdx.x] = flag;
}
// This function is meant to be used as the second or the final
// synchronization barrier in the all reduce kernel. If it's the final
// synchronization barrier, we don't need to make any visibility guarantees
// for prior memory accesses.
template <int ngpus, bool final_sync = false>
DINLINE void barrier_at_end(const RankSignals& sg, Signal* self_sg, int rank) {
__syncthreads();
uint32_t flag = self_sg->_flag[blockIdx.x] + 1;
if (threadIdx.x < ngpus) {
auto peer_counter_ptr = &sg.signals[threadIdx.x]->end[blockIdx.x][rank];
auto self_counter_ptr = &self_sg->end[blockIdx.x][threadIdx.x];
// Write the expected counter value to peer and wait for correct value from
// peer.
auto peer_counter_ptr =
&sg.signals[threadIdx.x]->peer_counter[val % 2][blockIdx.x][rank];
auto self_counter_ptr =
&self_sg->peer_counter[val % 2][blockIdx.x][threadIdx.x];
if constexpr (need_fence) {
st_flag_release(peer_counter_ptr, val);
while (ld_flag_acquire(self_counter_ptr) != val);
if constexpr (!final_sync) {
st_flag_release(peer_counter_ptr, flag);
while (ld_flag_acquire(self_counter_ptr) != flag);
} else {
st_flag_volatile(peer_counter_ptr, val);
while (ld_flag_volatile(self_counter_ptr) != val);
st_flag_volatile(peer_counter_ptr, flag);
while (ld_flag_volatile(self_counter_ptr) != flag);
}
}
if constexpr (is_start || need_fence) __syncthreads();
if constexpr (!final_sync) __syncthreads();
// use one thread to update flag
if (threadIdx.x == 0) self_sg->_flag[blockIdx.x] = flag;
}
#else
template <int ngpus>
DINLINE void barrier_at_start(const RankSignals& sg, Signal* self_sg,
int rank) {
uint32_t flag = self_sg->_flag[blockIdx.x] + 1;
if (threadIdx.x < ngpus) {
// simultaneously write to the corresponding flag of all ranks.
// Latency = 1 p2p write
__scoped_atomic_store_n(&sg.signals[threadIdx.x]->start[blockIdx.x][rank],
flag, __ATOMIC_RELAXED, __MEMORY_SCOPE_SYSTEM);
// wait until we got true from all ranks
while (__scoped_atomic_load_n(&self_sg->start[blockIdx.x][threadIdx.x],
__ATOMIC_RELAXED,
__MEMORY_SCOPE_DEVICE) < flag);
}
__syncthreads();
// use one thread to update flag
if (threadIdx.x == 0) self_sg->_flag[blockIdx.x] = flag;
}
template <int ngpus, bool final_sync = false>
DINLINE void barrier_at_end(const RankSignals& sg, Signal* self_sg, int rank) {
__syncthreads();
uint32_t flag = self_sg->_flag[blockIdx.x] + 1;
if (threadIdx.x < ngpus) {
// simultaneously write to the corresponding flag of all ranks.
// Latency = 1 p2p write
__scoped_atomic_store_n(&sg.signals[threadIdx.x]->end[blockIdx.x][rank],
flag,
final_sync ? __ATOMIC_RELAXED : __ATOMIC_RELEASE,
__MEMORY_SCOPE_SYSTEM);
// wait until we got true from all ranks
while (
__scoped_atomic_load_n(&self_sg->end[blockIdx.x][threadIdx.x],
final_sync ? __ATOMIC_RELAXED : __ATOMIC_ACQUIRE,
__MEMORY_SCOPE_DEVICE) < flag);
}
if constexpr (!final_sync) __syncthreads();
// use one thread to update flag
if (threadIdx.x == 0) self_sg->_flag[blockIdx.x] = flag;
}
#endif
template <typename P, int ngpus, typename A>
DINLINE P packed_reduce(const P* ptrs[], int idx) {
A tmp = upcast(ptrs[0][idx]);
@ -220,13 +302,13 @@ __global__ void __launch_bounds__(512, 1)
// note: we don't reorder the address so the accumulation order is the same
// for all ranks, ensuring bitwise identical results
auto dp = *_dp;
multi_gpu_barrier<ngpus, true>(sg, self_sg, rank);
barrier_at_start<ngpus>(sg, self_sg, rank);
// do the actual reduction
for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size;
idx += gridDim.x * blockDim.x) {
((P*)result)[idx] = packed_reduce<P, ngpus, A>((const P**)&dp.ptrs[0], idx);
}
multi_gpu_barrier<ngpus, false>(sg, self_sg, rank);
barrier_at_end<ngpus, true>(sg, self_sg, rank);
}
template <typename P>
@ -255,18 +337,20 @@ __global__ void __launch_bounds__(512, 1)
tmps[i] = get_tmp_buf<P>(sg.signals[target]);
}
auto tmp_out = tmps[0];
multi_gpu_barrier<ngpus, true>(sg, self_sg, rank);
barrier_at_start<ngpus>(sg, self_sg, rank);
// stage 1: reduce scatter
for (int idx = start + tid; idx < end; idx += stride) {
tmp_out[idx - start] = packed_reduce<P, ngpus, A>(ptrs, idx);
}
multi_gpu_barrier<ngpus, false, true>(sg, self_sg, rank);
barrier_at_end<ngpus>(sg, self_sg, rank);
// stage 2: allgather. Note: it's important to match the tid between
// the two stages, because visibility across devices is only guaranteed
// between threads that have the same tid. If thread i computes the sum of
// start + i in the first stage, then thread i also gathers start + i from all
// ranks.
// start + i in the first stage, then thread i also gathers start + i from
// all ranks.
for (int idx = tid; idx < largest_part; idx += stride) {
#pragma unroll
for (int i = 0; i < ngpus; i++) {
@ -287,21 +371,22 @@ class CustomAllreduce {
public:
int rank_;
int world_size_;
bool full_nvlink_;
// Full NVLink or xGMI connection between GPUs.
bool fully_connected_;
RankSignals sg_;
// Stores an map from a pointer to its peer pointters from all ranks.
// Stores an map from a pointer to its peer pointers from all ranks.
std::unordered_map<void*, RankData*> buffers_;
Signal* self_sg_;
// Stores rank data from all ranks. This is mainly for cuda graph purposes.
// For cuda graph to work, all kernel arguments must be fixed during graph
// capture time. However, the peer pointers are not known during graph capture
// time. Therefore, during capture, we increment the rank data pointer and use
// that as the argument to the kernel. The kernel arguments are stored in
// graph_unreg_buffers_. The actual peer pointers will be filled in at the
// memory pointed to by the pointers in graph_unreg_buffers_ when
// the IPC handles are exchanged between ranks.
// capture time. However, the peer pointers are not known during graph
// capture time. Therefore, during capture, we increment the rank data
// pointer and use that as the argument to the kernel. The kernel arguments
// are stored in graph_unreg_buffers_. The actual peer pointers will be
// filled in at the memory pointed to by the pointers in
// graph_unreg_buffers_ when the IPC handles are exchanged between ranks.
//
// The overall process looks like this:
// 1. Graph capture.
@ -319,17 +404,18 @@ class CustomAllreduce {
* Signals are an array of ipc-enabled buffers from all ranks.
* For each of the buffer, the layout is as follows:
* | -- sizeof(Signal) -- | ------ a few MB ----- |
* The first section is for allreduce synchronization, and the second section
* is for storing the intermediate results required by some allreduce algos.
* The first section is for allreduce synchronization, and the second
* section is for storing the intermediate results required by some
* allreduce algos.
*
* Note: this class does not own any device memory. Any required buffers
* are passed in from the constructor.
*/
CustomAllreduce(Signal** signals, void* rank_data, size_t rank_data_sz,
int rank, int world_size, bool full_nvlink = true)
int rank, int world_size, bool fully_connected = true)
: rank_(rank),
world_size_(world_size),
full_nvlink_(full_nvlink),
fully_connected_(fully_connected),
self_sg_(signals[rank]),
d_rank_data_base_(reinterpret_cast<RankData*>(rank_data)),
d_rank_data_end_(d_rank_data_base_ + rank_data_sz / sizeof(RankData)) {
@ -361,8 +447,7 @@ class CustomAllreduce {
void* base_ptr;
// note: must share the base address of each allocation, or we get wrong
// address
if (cuPointerGetAttribute(&base_ptr,
CU_POINTER_ATTRIBUTE_RANGE_START_ADDR,
if (cuPointerGetAttribute(&base_ptr, rangeStartAddrAttr,
(CUdeviceptr)ptr) != CUDA_SUCCESS)
throw std::runtime_error("failed to get pointer attr");
CUDACHECK(cudaIpcGetMemHandle(
@ -396,11 +481,11 @@ class CustomAllreduce {
// Note: when registering graph buffers, we intentionally choose to not
// deduplicate the addresses. That means if the allocator reuses some
// addresses, they will be registered again. This is to account for the remote
// possibility of different allocation patterns between ranks. For example,
// rank 1 may get the same input address for the second allreduce, but rank 2
// got a different address. IPC handles have internal reference counting
// mechanism so overhead should be small.
// addresses, they will be registered again. This is to account for the
// remote possibility of different allocation patterns between ranks. For
// example, rank 1 may get the same input address for the second allreduce,
// but rank 2 got a different address. IPC handles have internal reference
// counting mechanism so overhead should be small.
void register_graph_buffers(
const std::vector<std::string>& handles,
const std::vector<std::vector<int64_t>>& offsets) {
@ -431,15 +516,15 @@ class CustomAllreduce {
/**
* Performs allreduce, assuming input has already been registered.
*
* Block and grid default configs are results after careful grid search. Using
* 36 blocks give the best or close to the best runtime on the devices I
* tried: A100, A10, A30, T4, V100. You'll notice that NCCL kernels also only
* take a small amount of SMs. Not quite sure the underlying reason, but my
* guess is that too many SMs will cause contention on NVLink bus.
* Block and grid default configs are results after careful grid search.
* Using 36 blocks give the best or close to the best runtime on the devices
* I tried: A100, A10, A30, T4, V100. You'll notice that NCCL kernels also
* only take a small amount of SMs. Not quite sure the underlying reason,
* but my guess is that too many SMs will cause contention on NVLink bus.
*/
template <typename T>
void allreduce(cudaStream_t stream, T* input, T* output, int size,
int threads = 512, int block_limit = 36) {
int threads = 512, int block_limit = defaultBlockLimit) {
auto d = packed_t<T>::P::size;
if (size % d != 0)
throw std::runtime_error(
@ -473,13 +558,11 @@ class CustomAllreduce {
#define KL(ngpus, name) \
name<T, ngpus><<<blocks, threads, 0, stream>>>(ptrs, sg_, self_sg_, output, \
rank_, size);
// TODO(hanzhi713): Threshold is different for A100 and H100.
// Add per device threshold.
#define REDUCE_CASE(ngpus) \
case ngpus: { \
if (world_size_ == 2) { \
KL(ngpus, cross_device_reduce_1stage); \
} else if (full_nvlink_) { \
} else if (fully_connected_) { \
if ((world_size_ <= 4 && bytes < 512 * 1024) || \
(world_size_ <= 8 && bytes < 256 * 1024)) { \
KL(ngpus, cross_device_reduce_1stage); \
@ -497,7 +580,8 @@ class CustomAllreduce {
REDUCE_CASE(8)
default:
throw std::runtime_error(
"custom allreduce only supports num gpus in (2,4,6,8). Actual num "
"custom allreduce only supports num gpus in (2,4,6,8). Actual "
"num "
"gpus = " +
std::to_string(world_size_));
}
@ -511,10 +595,11 @@ class CustomAllreduce {
}
}
};
/**
* To inspect PTX/SASS, copy paste this header file to compiler explorer and add
a template instantiation:
* To inspect PTX/SASS, copy paste this header file to compiler explorer and
add a template instantiation:
* template void vllm::CustomAllreduce::allreduce<half>(cudaStream_t, half *,
half *, int, int, int);
*/
} // namespace vllm
} // namespace vllm

View File

@ -1,9 +1,9 @@
/**
* This is a standalone test for custom allreduce.
* To compile, make sure you have MPI and NCCL installed in your system.
* export MPI_HOME=xxx
* export MPI_HOME=XXX
* nvcc -O2 -arch=native -std=c++17 custom_all_reduce_test.cu -o
* custom_all_reduce_test -lnccl -I${MPI_HOME} -lmpi
* custom_all_reduce_test -lnccl -I${MPI_HOME}/include -lmpi
*
* Warning: this C++ test is not designed to be very readable and was used
* during the rapid prototyping process.
@ -22,7 +22,15 @@
#include "cuda_profiler_api.h"
#include "custom_all_reduce.cuh"
#include "mpi.h"
#include "nccl.h"
#ifdef USE_ROCM
#include <hip/hip_bf16.h>
typedef __hip_bfloat16 nv_bfloat16;
#include "rccl/rccl.h"
#include "custom_all_reduce_hip.cuh"
#else
#include "nccl.h"
#include "custom_all_reduce.cuh"
#endif
#define MPICHECK(cmd) \
do { \
@ -43,16 +51,29 @@
} \
} while (0)
#ifdef USE_ROCM
__global__ void dummy_kernel() {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700
for (int i = 0; i < 100; i++) {
uint64_t start = wall_clock64();
uint64_t cycles_elapsed;
do {
cycles_elapsed = wall_clock64() - start;
} while (cycles_elapsed < 100);
}
for (int i = 0; i < 100; i++) __nanosleep(1000000); // 100ms
}
#else
__global__ void dummy_kernel() {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700
for (int i = 0; i < 100; i++) __nanosleep(1000000); // 100ms
#else
for (int i = 0; i < 100; i++) {
long long int start = clock64();
while (clock64() - start < 150000000); // approximately 98.4ms on P40
}
#endif
#endif
}
#endif
template <typename T>
__global__ void set_data(T* data, int size, int myRank) {
@ -121,8 +142,14 @@ void run(int myRank, int nRanks, ncclComm_t& comm, int threads, int block_limit,
* registration, they are allocated and registered together in the test for
* convenience.
*/
#ifdef USE_ROCM
CUDACHECK(hipExtMallocWithFlags(
(void**)&buffer, 2 * data_size * sizeof(T) + sizeof(vllm::Signal),
hipDeviceMallocUncached));
#else
CUDACHECK(
cudaMalloc(&buffer, 2 * data_size * sizeof(T) + sizeof(vllm::Signal)));
#endif
CUDACHECK(
cudaMemset(buffer, 0, 2 * data_size * sizeof(T) + sizeof(vllm::Signal)));
CUDACHECK(cudaMalloc(&self_data_copy, data_size * sizeof(T)));
@ -311,13 +338,18 @@ int main(int argc, char** argv) {
bool performance_test = true;
cudaProfilerStart();
// Uncomment to scan through different block size configs.
// for (int threads : {256, 512, 1024}) {
// for (int block_limit = 16; block_limit < 112; block_limit += 4) {
// run<half>(myRank, nRanks, comm, threads, block_limit, 1024 * 1024,
// performance_test);
// }
// }
// Uncomment to scan through different block size configs.
// for (int threads : {256, 512, 1024}) {
// for (int block_limit = 16; block_limit < 112; block_limit += 4) {
// run<half>(myRank, nRanks, comm, threads, block_limit, 1024 * 1024,
// performance_test);
// }
// }
#ifdef USE_ROCM
const int block_limit = 16;
#else
const int block_limit = 36;
#endif
// Scan through different sizes to test performance.
for (int sz = 512; sz <= (8 << 20); sz *= 2) {
run<half>(myRank, nRanks, comm, 512, 36, sz + 8 * 47, performance_test);
@ -326,4 +358,4 @@ int main(int argc, char** argv) {
cudaProfilerStop();
MPICHECK(MPI_Finalize());
return EXIT_SUCCESS;
}
}

View File

@ -267,10 +267,10 @@ void causal_conv1d_fwd(const at::Tensor& x, const at::Tensor& weight,
const std::optional<at::Tensor>& has_initial_state,
bool silu_activation, int64_t pad_slot_id);
#ifndef USE_ROCM
using fptr_t = int64_t;
fptr_t init_custom_ar(const std::vector<int64_t>& fake_ipc_ptrs,
torch::Tensor& rank_data, int64_t rank, bool full_nvlink);
torch::Tensor& rank_data, int64_t rank,
bool fully_connected);
void all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out,
fptr_t reg_buffer, int64_t reg_buffer_sz_bytes);
void dispose(fptr_t _fa);
@ -281,4 +281,7 @@ get_graph_buffer_ipc_meta(fptr_t _fa);
void register_graph_buffers(fptr_t _fa,
const std::vector<std::vector<int64_t>>& handles,
const std::vector<std::vector<int64_t>>& offsets);
#endif
std::tuple<int64_t, torch::Tensor> allocate_shared_buffer_and_handle(
int64_t size);
int64_t open_mem_handle(torch::Tensor& mem_handle);
void free_shared_buffer(int64_t buffer);

View File

@ -614,12 +614,11 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cuda_utils), cuda_utils) {
&get_max_shared_memory_per_block_device_attribute);
}
#ifndef USE_ROCM
TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _custom_ar), custom_ar) {
// Custom all-reduce kernels
custom_ar.def(
"init_custom_ar(int[] ipc_tensors, Tensor rank_data, "
"int rank, bool full_nvlink) -> int");
"int rank, bool fully_connected) -> int");
custom_ar.impl("init_custom_ar", torch::kCUDA, &init_custom_ar);
custom_ar.def(
"all_reduce(int fa, Tensor inp, Tensor! out, int reg_buffer, "
@ -632,7 +631,13 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _custom_ar), custom_ar) {
custom_ar.def("register_buffer", &register_buffer);
custom_ar.def("get_graph_buffer_ipc_meta", &get_graph_buffer_ipc_meta);
custom_ar.def("register_graph_buffers", &register_graph_buffers);
custom_ar.def("allocate_shared_buffer_and_handle",
&allocate_shared_buffer_and_handle);
custom_ar.def("open_mem_handle(Tensor mem_handle) -> int", &open_mem_handle);
custom_ar.impl("open_mem_handle", torch::kCPU, &open_mem_handle);
custom_ar.def("free_shared_buffer", &free_shared_buffer);
}
#endif
REGISTER_EXTENSION(TORCH_EXTENSION_NAME)

View File

@ -106,7 +106,7 @@ def eager_allreduce(
# communicate independently
num_communication = rank // tp_size + 1
sz = 1024
fa = get_tp_group().ca_comm
fa = get_tp_group().device_communicator.ca_comm
inp = torch.ones(sz, dtype=torch.float32, device=device)
out = inp
for _ in range(num_communication):

View File

@ -612,7 +612,16 @@ def multi_process_parallel(
# as compared to multiprocessing.
# NOTE: We need to set working_dir for distributed tests,
# otherwise we may get import errors on ray workers
ray.init(runtime_env={"working_dir": VLLM_PATH})
# NOTE: Force ray not to use gitignore file as excluding, otherwise
# it will not move .so files to working dir.
# So we have to manually add some of large directories
os.environ["RAY_RUNTIME_ENV_IGNORE_GITIGNORE"] = "1"
ray.init(
runtime_env={
"working_dir": VLLM_PATH,
"excludes":
["build", ".git", "cmake-build-*", "shellcheck", "dist"]
})
distributed_init_port = get_open_port()
refs = []

View File

@ -1337,9 +1337,9 @@ def get_max_shared_memory_per_block_device_attribute(device: int) -> int:
# custom ar
def init_custom_ar(ipc_tensors: list[torch.Tensor], rank_data: torch.Tensor,
rank: int, full_nvlink: bool) -> int:
rank: int, fully_connected: bool) -> int:
return torch.ops._C_custom_ar.init_custom_ar(ipc_tensors, rank_data, rank,
full_nvlink)
fully_connected)
def all_reduce(fa: int, inp: torch.Tensor, out: torch.Tensor, reg_buffer: int,
@ -1369,6 +1369,18 @@ def register_graph_buffers(fa: int, handles: list[list[int]],
torch.ops._C_custom_ar.register_graph_buffers(fa, handles, offsets)
def allocate_shared_buffer_and_handle(size: int) -> tuple[int, torch.Tensor]:
return torch.ops._C_custom_ar.allocate_shared_buffer_and_handle(size)
def open_mem_handle(mem_handle: torch.Tensor):
return torch.ops._C_custom_ar.open_mem_handle(mem_handle)
def free_shared_buffer(ptr: int) -> None:
torch.ops._C_custom_ar.free_shared_buffer(ptr)
def get_flash_mla_metadata(
cache_seqlens: torch.Tensor,
num_heads_per_head_k: int,

View File

@ -1606,11 +1606,13 @@ class ParallelConfig:
if self.use_ray:
from vllm.executor import ray_utils
ray_utils.assert_ray_available()
if current_platform.is_rocm():
device_capability = current_platform.get_device_capability()
if (current_platform.is_rocm() and device_capability is not None
and device_capability < (9, 4)):
self.disable_custom_all_reduce = True
logger.info(
"Disabled the custom all-reduce kernel because it is not "
"supported on AMD GPUs.")
"supported on AMD GPUs older than MI300X.")
if self.ray_workers_use_nsight and not self.use_ray:
raise ValueError("Unable to use nsight profiling unless workers "
"run with Ray.")

View File

@ -1,6 +1,5 @@
# SPDX-License-Identifier: Apache-2.0
import ctypes
from contextlib import contextmanager
from typing import List, Optional, Union
@ -10,7 +9,6 @@ from torch.distributed import ProcessGroup
import vllm.envs as envs
from vllm import _custom_ops as ops
from vllm.distributed.device_communicators.cuda_wrapper import CudaRTLibrary
from vllm.distributed.device_communicators.custom_all_reduce_utils import (
gpu_p2p_access_check)
from vllm.distributed.parallel_state import in_the_same_node_as
@ -22,7 +20,7 @@ try:
ops.meta_size()
custom_ar = True
except Exception:
# For AMD GPUs and CPUs
# For CPUs
custom_ar = False
logger = init_logger(__name__)
@ -71,7 +69,9 @@ class CustomAllreduce:
if not custom_ar:
# disable because of missing custom allreduce library
# e.g. in a non-cuda environment
# e.g. in a non-GPU environment
logger.info("Custom allreduce is disabled because "
"of missing custom allreduce library")
return
self.group = group
@ -129,11 +129,10 @@ class CustomAllreduce:
# test nvlink first, this will filter out most of the cases
# where custom allreduce is not supported
# this checks hardware and driver support for NVLink
assert current_platform.is_cuda()
from vllm.platforms.cuda import CudaPlatform
cuda_platform: CudaPlatform = current_platform
full_nvlink = cuda_platform.is_full_nvlink(physical_device_ids)
if world_size > 2 and not full_nvlink:
assert current_platform.is_cuda_alike()
fully_connected = current_platform.is_fully_connected(
physical_device_ids)
if world_size > 2 and not fully_connected:
logger.warning(
"Custom allreduce is disabled because it's not supported on"
" more than two PCIe-only GPUs. To silence this warning, "
@ -142,7 +141,8 @@ class CustomAllreduce:
# test P2P capability, this checks software/cudaruntime support
# this is expensive to compute at the first time
# then we cache the result
if not _can_p2p(rank, world_size):
# On AMD GPU, p2p is always enabled between XGMI connected GPUs
if not current_platform.is_rocm() and not _can_p2p(rank, world_size):
logger.warning(
"Custom allreduce is disabled because your platform lacks "
"GPU P2P capability or P2P test failed. To silence this "
@ -154,7 +154,8 @@ class CustomAllreduce:
# Meta data composes of two parts: meta data for synchronization and a
# temporary buffer for storing intermediate allreduce results.
self.meta_ptrs = self.create_shared_buffer(ops.meta_size() + max_size,
group=group)
group=group,
uncached=True)
# This is a pre-registered IPC buffer. In eager mode, input tensors
# are first copied into this buffer before allreduce is performed
self.buffer_ptrs = self.create_shared_buffer(max_size, group=group)
@ -169,46 +170,11 @@ class CustomAllreduce:
self.max_size = max_size
self.rank = rank
self.world_size = world_size
self.full_nvlink = full_nvlink
self.fully_connected = fully_connected
self._ptr = ops.init_custom_ar(self.meta_ptrs, self.rank_data, rank,
self.full_nvlink)
self.fully_connected)
ops.register_buffer(self._ptr, self.buffer_ptrs)
@staticmethod
def create_shared_buffer(
size_in_bytes: int,
group: Optional[ProcessGroup] = None) -> List[int]:
"""
Creates a shared buffer and returns a list of pointers
representing the buffer on all processes in the group.
"""
lib = CudaRTLibrary()
pointer = lib.cudaMalloc(size_in_bytes)
handle = lib.cudaIpcGetMemHandle(pointer)
world_size = dist.get_world_size(group=group)
rank = dist.get_rank(group=group)
handles = [None] * world_size
dist.all_gather_object(handles, handle, group=group)
pointers: List[int] = []
for i, h in enumerate(handles):
if i == rank:
pointers.append(pointer.value) # type: ignore
else:
pointers.append(
lib.cudaIpcOpenMemHandle(h).value) # type: ignore
return pointers
@staticmethod
def free_shared_buffer(pointers: List[int],
group: Optional[ProcessGroup] = None,
rank: Optional[int] = None) -> None:
if rank is None:
rank = dist.get_rank(group=group)
lib = CudaRTLibrary()
lib.cudaFree(ctypes.c_void_p(pointers[rank]))
@contextmanager
def capture(self):
"""
@ -255,7 +221,7 @@ class CustomAllreduce:
return False
# for 4 or more non NVLink-capable GPUs, custom allreduce provides
# little performance improvement over NCCL.
if self.world_size == 2 or self.full_nvlink:
if self.world_size == 2 or self.fully_connected:
return inp_size < self.max_size
return False
@ -306,3 +272,30 @@ class CustomAllreduce:
def __del__(self):
self.close()
@staticmethod
def create_shared_buffer(size_in_bytes: int,
group: Optional[ProcessGroup] = None,
uncached: Optional[bool] = False) -> List[int]:
pointer, handle = ops.allocate_shared_buffer_and_handle(size_in_bytes)
world_size = dist.get_world_size(group=group)
rank = dist.get_rank(group=group)
handles = [None] * world_size
dist.all_gather_object(handles, handle, group=group)
pointers: List[int] = []
for i, h in enumerate(handles):
if i == rank:
pointers.append(pointer) # type: ignore
else:
pointers.append(ops.open_mem_handle(h))
return pointers
@staticmethod
def free_shared_buffer(pointers: List[int],
group: Optional[ProcessGroup] = None,
rank: Optional[int] = 0) -> None:
if rank is None:
rank = dist.get_rank(group=group)
ops.free_shared_buffer(pointers[rank])

View File

@ -101,7 +101,7 @@ class CudaPlatformBase(Platform):
return True
@classmethod
def is_full_nvlink(cls, device_ids: List[int]) -> bool:
def is_fully_connected(cls, device_ids: List[int]) -> bool:
raise NotImplementedError
@classmethod
@ -362,7 +362,7 @@ class NvmlCudaPlatform(CudaPlatformBase):
@classmethod
@with_nvml_context
def is_full_nvlink(cls, physical_device_ids: List[int]) -> bool:
def is_fully_connected(cls, physical_device_ids: List[int]) -> bool:
"""
query if the set of gpus are fully connected by nvlink (1 hop)
"""
@ -427,7 +427,7 @@ class NonNvmlCudaPlatform(CudaPlatformBase):
return device_props.total_memory
@classmethod
def is_full_nvlink(cls, physical_device_ids: List[int]) -> bool:
def is_fully_connected(cls, physical_device_ids: List[int]) -> bool:
logger.exception(
"NVLink detection not possible, as context support was"
" not found. Assuming no NVLink available.")

View File

@ -20,8 +20,9 @@ else:
logger = init_logger(__name__)
try:
from amdsmi import (amdsmi_get_gpu_asic_info, amdsmi_get_processor_handles,
amdsmi_init, amdsmi_shut_down)
from amdsmi import (AmdSmiException, amdsmi_get_gpu_asic_info,
amdsmi_get_processor_handles, amdsmi_init,
amdsmi_shut_down, amdsmi_topo_get_link_type)
except ImportError as e:
logger.warning("Failed to import from amdsmi with %r", e)
@ -135,10 +136,36 @@ class RocmPlatform(Platform):
@classmethod
@lru_cache(maxsize=8)
def get_device_capability(cls, device_id: int = 0) -> DeviceCapability:
def get_device_capability(cls,
device_id: int = 0
) -> Optional[DeviceCapability]:
major, minor = torch.cuda.get_device_capability(device_id)
return DeviceCapability(major=major, minor=minor)
@staticmethod
@with_amdsmi_context
def is_fully_connected(physical_device_ids: List[int]) -> bool:
"""
Query if the set of gpus are fully connected by xgmi (1 hop)
"""
handles = [
amdsmi_get_processor_handles()[i] for i in physical_device_ids
]
for i, handle in enumerate(handles):
for j, peer_handle in enumerate(handles):
if i < j:
try:
link_type = amdsmi_topo_get_link_type(
handle, peer_handle)
# type is 2 for XGMI
if link_type["hops"] != 1 or link_type["type"] != 2:
return False
except AmdSmiException as error:
logger.error("AMD 1 hop XGMI detection failed.",
exc_info=error)
return False
return True
@classmethod
@with_amdsmi_context
@lru_cache(maxsize=8)