mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[Distributed] Add custom allreduce support for ROCM (#14125)
Signed-off-by: ilmarkov <imarkov@redhat.com> Co-authored-by: ilmarkov <imarkov@redhat.com>
This commit is contained in:
@ -242,6 +242,7 @@ set(VLLM_EXT_SRC
|
||||
"csrc/quantization/gguf/gguf_kernel.cu"
|
||||
"csrc/cuda_utils_kernels.cu"
|
||||
"csrc/prepare_inputs/advance_step.cu"
|
||||
"csrc/custom_all_reduce.cu"
|
||||
"csrc/torch_bindings.cpp")
|
||||
|
||||
if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
@ -283,7 +284,6 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
"csrc/mamba/causal_conv1d/causal_conv1d.cu"
|
||||
"csrc/quantization/aqlm/gemm_kernels.cu"
|
||||
"csrc/quantization/awq/gemm_kernels.cu"
|
||||
"csrc/custom_all_reduce.cu"
|
||||
"csrc/permute_cols.cu"
|
||||
"csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu"
|
||||
"csrc/quantization/fp4/nvfp4_quant_entry.cu"
|
||||
|
@ -12,7 +12,7 @@ static_assert(sizeof(void*) == sizeof(fptr_t));
|
||||
|
||||
fptr_t init_custom_ar(const std::vector<fptr_t>& fake_ipc_ptrs,
|
||||
torch::Tensor& rank_data, int64_t rank,
|
||||
bool full_nvlink) {
|
||||
bool fully_connected) {
|
||||
int world_size = fake_ipc_ptrs.size();
|
||||
if (world_size > 8)
|
||||
throw std::invalid_argument("world size > 8 is not supported");
|
||||
@ -27,7 +27,7 @@ fptr_t init_custom_ar(const std::vector<fptr_t>& fake_ipc_ptrs,
|
||||
}
|
||||
return (fptr_t) new vllm::CustomAllreduce(ipc_ptrs, rank_data.data_ptr(),
|
||||
rank_data.numel(), rank, world_size,
|
||||
full_nvlink);
|
||||
fully_connected);
|
||||
}
|
||||
|
||||
/**
|
||||
@ -142,3 +142,48 @@ void register_graph_buffers(fptr_t _fa,
|
||||
bytes.reserve(handles.size());
|
||||
fa->register_graph_buffers(bytes, offsets);
|
||||
}
|
||||
|
||||
std::tuple<fptr_t, torch::Tensor> allocate_shared_buffer_and_handle(
|
||||
int64_t size) {
|
||||
auto device_index = c10::cuda::current_device();
|
||||
at::DeviceGuard device_guard(at::Device(at::DeviceType::CUDA, device_index));
|
||||
void* buffer;
|
||||
cudaStreamCaptureMode mode = cudaStreamCaptureModeRelaxed;
|
||||
auto stream = c10::cuda::getCurrentCUDAStream().stream();
|
||||
AT_CUDA_CHECK(cudaThreadExchangeStreamCaptureMode(&mode));
|
||||
|
||||
// Allocate buffer
|
||||
#if defined(USE_ROCM)
|
||||
// data buffers need to be "uncached" for signal on MI200
|
||||
AT_CUDA_CHECK(
|
||||
hipExtMallocWithFlags((void**)&buffer, size, hipDeviceMallocUncached));
|
||||
#else
|
||||
AT_CUDA_CHECK(cudaMalloc((void**)&buffer, size));
|
||||
#endif
|
||||
AT_CUDA_CHECK(cudaMemsetAsync(buffer, 0, size, stream));
|
||||
AT_CUDA_CHECK(cudaStreamSynchronize(stream));
|
||||
AT_CUDA_CHECK(cudaThreadExchangeStreamCaptureMode(&mode));
|
||||
|
||||
// Create IPC memhandle for the allocated buffer.
|
||||
// Will use it in open_mem_handle.
|
||||
auto options =
|
||||
torch::TensorOptions().dtype(torch::kUInt8).device(torch::kCPU);
|
||||
auto handle =
|
||||
torch::empty({static_cast<int64_t>(sizeof(cudaIpcMemHandle_t))}, options);
|
||||
AT_CUDA_CHECK(
|
||||
cudaIpcGetMemHandle((cudaIpcMemHandle_t*)handle.data_ptr(), buffer));
|
||||
|
||||
return std::make_tuple(reinterpret_cast<fptr_t>(buffer), handle);
|
||||
}
|
||||
|
||||
fptr_t open_mem_handle(torch::Tensor& mem_handle) {
|
||||
void* ipc_ptr;
|
||||
AT_CUDA_CHECK(cudaIpcOpenMemHandle(
|
||||
(void**)&ipc_ptr, *((const cudaIpcMemHandle_t*)mem_handle.data_ptr()),
|
||||
cudaIpcMemLazyEnablePeerAccess));
|
||||
return reinterpret_cast<fptr_t>(ipc_ptr);
|
||||
}
|
||||
|
||||
void free_shared_buffer(fptr_t buffer) {
|
||||
AT_CUDA_CHECK(cudaFree(reinterpret_cast<void*>(buffer)));
|
||||
}
|
||||
|
@ -5,6 +5,10 @@
|
||||
#include <cuda_fp16.h>
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
#if defined(USE_ROCM)
|
||||
typedef __hip_bfloat16 nv_bfloat16;
|
||||
#endif
|
||||
|
||||
#include <iostream>
|
||||
#include <array>
|
||||
#include <limits>
|
||||
@ -12,6 +16,7 @@
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
namespace vllm {
|
||||
#define CUDACHECK(cmd) \
|
||||
do { \
|
||||
cudaError_t e = cmd; \
|
||||
@ -22,24 +27,37 @@
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
namespace vllm {
|
||||
|
||||
// Maximal number of blocks in allreduce kernel.
|
||||
constexpr int kMaxBlocks = 36;
|
||||
|
||||
// Default number of blocks in allreduce kernel.
|
||||
#ifndef USE_ROCM
|
||||
const int defaultBlockLimit = 36;
|
||||
CUpointer_attribute rangeStartAddrAttr = CU_POINTER_ATTRIBUTE_RANGE_START_ADDR;
|
||||
#else
|
||||
const int defaultBlockLimit = 16;
|
||||
hipPointer_attribute rangeStartAddrAttr =
|
||||
HIP_POINTER_ATTRIBUTE_RANGE_START_ADDR;
|
||||
#endif
|
||||
|
||||
// Counter may overflow, but it's fine since unsigned int overflow is
|
||||
// well-defined behavior.
|
||||
using FlagType = uint32_t;
|
||||
|
||||
// Two sets of peer counters are needed for two syncs: starting and ending an
|
||||
// operation. The reason is that it's possible for peer GPU block to arrive at
|
||||
// the second sync point while the current GPU block haven't passed the first
|
||||
// sync point. Thus, peer GPU may write counter+1 while current GPU is busy
|
||||
// waiting for counter. We use alternating counter array to avoid this
|
||||
// possibility.
|
||||
struct Signal {
|
||||
alignas(128) FlagType self_counter[kMaxBlocks][8];
|
||||
// Two sets of peer counters are needed for two syncs. The reason is that
|
||||
// it's possible for peer GPU block to arrive at the second sync point while
|
||||
// the current GPU block haven't passed the first sync point. Thus, peer GPU
|
||||
// may write counter+1 while current GPU is busy waiting for counter. We use
|
||||
// alternating counter array to avoid this possibility.
|
||||
alignas(128) FlagType peer_counter[2][kMaxBlocks][8];
|
||||
alignas(128) FlagType start[kMaxBlocks][8];
|
||||
alignas(128) FlagType end[kMaxBlocks][8];
|
||||
alignas(128) FlagType _flag[kMaxBlocks]; // incremental flags for each rank
|
||||
};
|
||||
|
||||
struct __align__(16) RankData {
|
||||
const void* __restrict__ ptrs[8];
|
||||
const void* ptrs[8];
|
||||
};
|
||||
|
||||
struct __align__(16) RankSignals {
|
||||
@ -134,27 +152,29 @@ DINLINE O downcast(array_t<float, O::size> val) {
|
||||
}
|
||||
}
|
||||
|
||||
#if !defined(USE_ROCM)
|
||||
|
||||
static DINLINE void st_flag_release(FlagType* flag_addr, FlagType flag) {
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700
|
||||
asm volatile("st.release.sys.global.u32 [%1], %0;" ::"r"(flag),
|
||||
"l"(flag_addr));
|
||||
#else
|
||||
#else
|
||||
asm volatile("membar.sys; st.volatile.global.u32 [%1], %0;" ::"r"(flag),
|
||||
"l"(flag_addr));
|
||||
#endif
|
||||
#endif
|
||||
}
|
||||
|
||||
static DINLINE FlagType ld_flag_acquire(FlagType* flag_addr) {
|
||||
FlagType flag;
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700
|
||||
asm volatile("ld.acquire.sys.global.u32 %0, [%1];"
|
||||
: "=r"(flag)
|
||||
: "l"(flag_addr));
|
||||
#else
|
||||
#else
|
||||
asm volatile("ld.volatile.global.u32 %0, [%1]; membar.gl;"
|
||||
: "=r"(flag)
|
||||
: "l"(flag_addr));
|
||||
#endif
|
||||
#endif
|
||||
return flag;
|
||||
}
|
||||
|
||||
@ -170,37 +190,99 @@ static DINLINE FlagType ld_flag_volatile(FlagType* flag_addr) {
|
||||
return flag;
|
||||
}
|
||||
|
||||
// is_start: whether this is the very first synchronization barrier.
|
||||
// need_fence: whether a memory fence is needed. If true, a release-acquire
|
||||
// semantic is used to enforce memory access order before and after this
|
||||
// barrier.
|
||||
template <int ngpus, bool is_start, bool need_fence = false>
|
||||
DINLINE void multi_gpu_barrier(const RankSignals& sg, Signal* self_sg,
|
||||
int rank) {
|
||||
if constexpr (!is_start) __syncthreads();
|
||||
static_assert(
|
||||
!(is_start && need_fence)); // Start barrier shouldn't need fence.
|
||||
// This function is meant to be used as the first synchronization in the all
|
||||
// reduce kernel. Thus, it doesn't need to make any visibility guarantees for
|
||||
// prior memory accesses. Note: volatile writes will not be reordered against
|
||||
// other volatile writes.
|
||||
template <int ngpus>
|
||||
DINLINE void barrier_at_start(const RankSignals& sg, Signal* self_sg,
|
||||
int rank) {
|
||||
uint32_t flag = self_sg->_flag[blockIdx.x] + 1;
|
||||
if (threadIdx.x < ngpus) {
|
||||
// Increment the counter. Technically we only need one counter, but we use
|
||||
// multiple per block to eliminate the need to share the counter via smem.
|
||||
auto val = self_sg->self_counter[blockIdx.x][threadIdx.x] += 1;
|
||||
auto peer_counter_ptr = &sg.signals[threadIdx.x]->start[blockIdx.x][rank];
|
||||
auto self_counter_ptr = &self_sg->start[blockIdx.x][threadIdx.x];
|
||||
// Write the expected counter value to peer and wait for correct value
|
||||
// from peer.
|
||||
st_flag_volatile(peer_counter_ptr, flag);
|
||||
while (ld_flag_volatile(self_counter_ptr) != flag);
|
||||
}
|
||||
__syncthreads();
|
||||
// use one thread to update flag
|
||||
if (threadIdx.x == 0) self_sg->_flag[blockIdx.x] = flag;
|
||||
}
|
||||
|
||||
// This function is meant to be used as the second or the final
|
||||
// synchronization barrier in the all reduce kernel. If it's the final
|
||||
// synchronization barrier, we don't need to make any visibility guarantees
|
||||
// for prior memory accesses.
|
||||
template <int ngpus, bool final_sync = false>
|
||||
DINLINE void barrier_at_end(const RankSignals& sg, Signal* self_sg, int rank) {
|
||||
__syncthreads();
|
||||
uint32_t flag = self_sg->_flag[blockIdx.x] + 1;
|
||||
if (threadIdx.x < ngpus) {
|
||||
auto peer_counter_ptr = &sg.signals[threadIdx.x]->end[blockIdx.x][rank];
|
||||
auto self_counter_ptr = &self_sg->end[blockIdx.x][threadIdx.x];
|
||||
// Write the expected counter value to peer and wait for correct value from
|
||||
// peer.
|
||||
auto peer_counter_ptr =
|
||||
&sg.signals[threadIdx.x]->peer_counter[val % 2][blockIdx.x][rank];
|
||||
auto self_counter_ptr =
|
||||
&self_sg->peer_counter[val % 2][blockIdx.x][threadIdx.x];
|
||||
if constexpr (need_fence) {
|
||||
st_flag_release(peer_counter_ptr, val);
|
||||
while (ld_flag_acquire(self_counter_ptr) != val);
|
||||
if constexpr (!final_sync) {
|
||||
st_flag_release(peer_counter_ptr, flag);
|
||||
while (ld_flag_acquire(self_counter_ptr) != flag);
|
||||
} else {
|
||||
st_flag_volatile(peer_counter_ptr, val);
|
||||
while (ld_flag_volatile(self_counter_ptr) != val);
|
||||
st_flag_volatile(peer_counter_ptr, flag);
|
||||
while (ld_flag_volatile(self_counter_ptr) != flag);
|
||||
}
|
||||
}
|
||||
if constexpr (is_start || need_fence) __syncthreads();
|
||||
if constexpr (!final_sync) __syncthreads();
|
||||
|
||||
// use one thread to update flag
|
||||
if (threadIdx.x == 0) self_sg->_flag[blockIdx.x] = flag;
|
||||
}
|
||||
|
||||
#else
|
||||
|
||||
template <int ngpus>
|
||||
DINLINE void barrier_at_start(const RankSignals& sg, Signal* self_sg,
|
||||
int rank) {
|
||||
uint32_t flag = self_sg->_flag[blockIdx.x] + 1;
|
||||
if (threadIdx.x < ngpus) {
|
||||
// simultaneously write to the corresponding flag of all ranks.
|
||||
// Latency = 1 p2p write
|
||||
__scoped_atomic_store_n(&sg.signals[threadIdx.x]->start[blockIdx.x][rank],
|
||||
flag, __ATOMIC_RELAXED, __MEMORY_SCOPE_SYSTEM);
|
||||
// wait until we got true from all ranks
|
||||
while (__scoped_atomic_load_n(&self_sg->start[blockIdx.x][threadIdx.x],
|
||||
__ATOMIC_RELAXED,
|
||||
__MEMORY_SCOPE_DEVICE) < flag);
|
||||
}
|
||||
__syncthreads();
|
||||
// use one thread to update flag
|
||||
if (threadIdx.x == 0) self_sg->_flag[blockIdx.x] = flag;
|
||||
}
|
||||
|
||||
template <int ngpus, bool final_sync = false>
|
||||
DINLINE void barrier_at_end(const RankSignals& sg, Signal* self_sg, int rank) {
|
||||
__syncthreads();
|
||||
uint32_t flag = self_sg->_flag[blockIdx.x] + 1;
|
||||
if (threadIdx.x < ngpus) {
|
||||
// simultaneously write to the corresponding flag of all ranks.
|
||||
// Latency = 1 p2p write
|
||||
__scoped_atomic_store_n(&sg.signals[threadIdx.x]->end[blockIdx.x][rank],
|
||||
flag,
|
||||
final_sync ? __ATOMIC_RELAXED : __ATOMIC_RELEASE,
|
||||
__MEMORY_SCOPE_SYSTEM);
|
||||
// wait until we got true from all ranks
|
||||
while (
|
||||
__scoped_atomic_load_n(&self_sg->end[blockIdx.x][threadIdx.x],
|
||||
final_sync ? __ATOMIC_RELAXED : __ATOMIC_ACQUIRE,
|
||||
__MEMORY_SCOPE_DEVICE) < flag);
|
||||
}
|
||||
if constexpr (!final_sync) __syncthreads();
|
||||
// use one thread to update flag
|
||||
if (threadIdx.x == 0) self_sg->_flag[blockIdx.x] = flag;
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
template <typename P, int ngpus, typename A>
|
||||
DINLINE P packed_reduce(const P* ptrs[], int idx) {
|
||||
A tmp = upcast(ptrs[0][idx]);
|
||||
@ -220,13 +302,13 @@ __global__ void __launch_bounds__(512, 1)
|
||||
// note: we don't reorder the address so the accumulation order is the same
|
||||
// for all ranks, ensuring bitwise identical results
|
||||
auto dp = *_dp;
|
||||
multi_gpu_barrier<ngpus, true>(sg, self_sg, rank);
|
||||
barrier_at_start<ngpus>(sg, self_sg, rank);
|
||||
// do the actual reduction
|
||||
for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size;
|
||||
idx += gridDim.x * blockDim.x) {
|
||||
((P*)result)[idx] = packed_reduce<P, ngpus, A>((const P**)&dp.ptrs[0], idx);
|
||||
}
|
||||
multi_gpu_barrier<ngpus, false>(sg, self_sg, rank);
|
||||
barrier_at_end<ngpus, true>(sg, self_sg, rank);
|
||||
}
|
||||
|
||||
template <typename P>
|
||||
@ -255,18 +337,20 @@ __global__ void __launch_bounds__(512, 1)
|
||||
tmps[i] = get_tmp_buf<P>(sg.signals[target]);
|
||||
}
|
||||
auto tmp_out = tmps[0];
|
||||
multi_gpu_barrier<ngpus, true>(sg, self_sg, rank);
|
||||
barrier_at_start<ngpus>(sg, self_sg, rank);
|
||||
|
||||
// stage 1: reduce scatter
|
||||
for (int idx = start + tid; idx < end; idx += stride) {
|
||||
tmp_out[idx - start] = packed_reduce<P, ngpus, A>(ptrs, idx);
|
||||
}
|
||||
multi_gpu_barrier<ngpus, false, true>(sg, self_sg, rank);
|
||||
barrier_at_end<ngpus>(sg, self_sg, rank);
|
||||
|
||||
// stage 2: allgather. Note: it's important to match the tid between
|
||||
// the two stages, because visibility across devices is only guaranteed
|
||||
// between threads that have the same tid. If thread i computes the sum of
|
||||
// start + i in the first stage, then thread i also gathers start + i from all
|
||||
// ranks.
|
||||
// start + i in the first stage, then thread i also gathers start + i from
|
||||
// all ranks.
|
||||
|
||||
for (int idx = tid; idx < largest_part; idx += stride) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < ngpus; i++) {
|
||||
@ -287,21 +371,22 @@ class CustomAllreduce {
|
||||
public:
|
||||
int rank_;
|
||||
int world_size_;
|
||||
bool full_nvlink_;
|
||||
// Full NVLink or xGMI connection between GPUs.
|
||||
bool fully_connected_;
|
||||
|
||||
RankSignals sg_;
|
||||
// Stores an map from a pointer to its peer pointters from all ranks.
|
||||
// Stores an map from a pointer to its peer pointers from all ranks.
|
||||
std::unordered_map<void*, RankData*> buffers_;
|
||||
Signal* self_sg_;
|
||||
|
||||
// Stores rank data from all ranks. This is mainly for cuda graph purposes.
|
||||
// For cuda graph to work, all kernel arguments must be fixed during graph
|
||||
// capture time. However, the peer pointers are not known during graph capture
|
||||
// time. Therefore, during capture, we increment the rank data pointer and use
|
||||
// that as the argument to the kernel. The kernel arguments are stored in
|
||||
// graph_unreg_buffers_. The actual peer pointers will be filled in at the
|
||||
// memory pointed to by the pointers in graph_unreg_buffers_ when
|
||||
// the IPC handles are exchanged between ranks.
|
||||
// capture time. However, the peer pointers are not known during graph
|
||||
// capture time. Therefore, during capture, we increment the rank data
|
||||
// pointer and use that as the argument to the kernel. The kernel arguments
|
||||
// are stored in graph_unreg_buffers_. The actual peer pointers will be
|
||||
// filled in at the memory pointed to by the pointers in
|
||||
// graph_unreg_buffers_ when the IPC handles are exchanged between ranks.
|
||||
//
|
||||
// The overall process looks like this:
|
||||
// 1. Graph capture.
|
||||
@ -319,17 +404,18 @@ class CustomAllreduce {
|
||||
* Signals are an array of ipc-enabled buffers from all ranks.
|
||||
* For each of the buffer, the layout is as follows:
|
||||
* | -- sizeof(Signal) -- | ------ a few MB ----- |
|
||||
* The first section is for allreduce synchronization, and the second section
|
||||
* is for storing the intermediate results required by some allreduce algos.
|
||||
* The first section is for allreduce synchronization, and the second
|
||||
* section is for storing the intermediate results required by some
|
||||
* allreduce algos.
|
||||
*
|
||||
* Note: this class does not own any device memory. Any required buffers
|
||||
* are passed in from the constructor.
|
||||
*/
|
||||
CustomAllreduce(Signal** signals, void* rank_data, size_t rank_data_sz,
|
||||
int rank, int world_size, bool full_nvlink = true)
|
||||
int rank, int world_size, bool fully_connected = true)
|
||||
: rank_(rank),
|
||||
world_size_(world_size),
|
||||
full_nvlink_(full_nvlink),
|
||||
fully_connected_(fully_connected),
|
||||
self_sg_(signals[rank]),
|
||||
d_rank_data_base_(reinterpret_cast<RankData*>(rank_data)),
|
||||
d_rank_data_end_(d_rank_data_base_ + rank_data_sz / sizeof(RankData)) {
|
||||
@ -361,8 +447,7 @@ class CustomAllreduce {
|
||||
void* base_ptr;
|
||||
// note: must share the base address of each allocation, or we get wrong
|
||||
// address
|
||||
if (cuPointerGetAttribute(&base_ptr,
|
||||
CU_POINTER_ATTRIBUTE_RANGE_START_ADDR,
|
||||
if (cuPointerGetAttribute(&base_ptr, rangeStartAddrAttr,
|
||||
(CUdeviceptr)ptr) != CUDA_SUCCESS)
|
||||
throw std::runtime_error("failed to get pointer attr");
|
||||
CUDACHECK(cudaIpcGetMemHandle(
|
||||
@ -396,11 +481,11 @@ class CustomAllreduce {
|
||||
|
||||
// Note: when registering graph buffers, we intentionally choose to not
|
||||
// deduplicate the addresses. That means if the allocator reuses some
|
||||
// addresses, they will be registered again. This is to account for the remote
|
||||
// possibility of different allocation patterns between ranks. For example,
|
||||
// rank 1 may get the same input address for the second allreduce, but rank 2
|
||||
// got a different address. IPC handles have internal reference counting
|
||||
// mechanism so overhead should be small.
|
||||
// addresses, they will be registered again. This is to account for the
|
||||
// remote possibility of different allocation patterns between ranks. For
|
||||
// example, rank 1 may get the same input address for the second allreduce,
|
||||
// but rank 2 got a different address. IPC handles have internal reference
|
||||
// counting mechanism so overhead should be small.
|
||||
void register_graph_buffers(
|
||||
const std::vector<std::string>& handles,
|
||||
const std::vector<std::vector<int64_t>>& offsets) {
|
||||
@ -431,15 +516,15 @@ class CustomAllreduce {
|
||||
/**
|
||||
* Performs allreduce, assuming input has already been registered.
|
||||
*
|
||||
* Block and grid default configs are results after careful grid search. Using
|
||||
* 36 blocks give the best or close to the best runtime on the devices I
|
||||
* tried: A100, A10, A30, T4, V100. You'll notice that NCCL kernels also only
|
||||
* take a small amount of SMs. Not quite sure the underlying reason, but my
|
||||
* guess is that too many SMs will cause contention on NVLink bus.
|
||||
* Block and grid default configs are results after careful grid search.
|
||||
* Using 36 blocks give the best or close to the best runtime on the devices
|
||||
* I tried: A100, A10, A30, T4, V100. You'll notice that NCCL kernels also
|
||||
* only take a small amount of SMs. Not quite sure the underlying reason,
|
||||
* but my guess is that too many SMs will cause contention on NVLink bus.
|
||||
*/
|
||||
template <typename T>
|
||||
void allreduce(cudaStream_t stream, T* input, T* output, int size,
|
||||
int threads = 512, int block_limit = 36) {
|
||||
int threads = 512, int block_limit = defaultBlockLimit) {
|
||||
auto d = packed_t<T>::P::size;
|
||||
if (size % d != 0)
|
||||
throw std::runtime_error(
|
||||
@ -473,13 +558,11 @@ class CustomAllreduce {
|
||||
#define KL(ngpus, name) \
|
||||
name<T, ngpus><<<blocks, threads, 0, stream>>>(ptrs, sg_, self_sg_, output, \
|
||||
rank_, size);
|
||||
// TODO(hanzhi713): Threshold is different for A100 and H100.
|
||||
// Add per device threshold.
|
||||
#define REDUCE_CASE(ngpus) \
|
||||
case ngpus: { \
|
||||
if (world_size_ == 2) { \
|
||||
KL(ngpus, cross_device_reduce_1stage); \
|
||||
} else if (full_nvlink_) { \
|
||||
} else if (fully_connected_) { \
|
||||
if ((world_size_ <= 4 && bytes < 512 * 1024) || \
|
||||
(world_size_ <= 8 && bytes < 256 * 1024)) { \
|
||||
KL(ngpus, cross_device_reduce_1stage); \
|
||||
@ -497,7 +580,8 @@ class CustomAllreduce {
|
||||
REDUCE_CASE(8)
|
||||
default:
|
||||
throw std::runtime_error(
|
||||
"custom allreduce only supports num gpus in (2,4,6,8). Actual num "
|
||||
"custom allreduce only supports num gpus in (2,4,6,8). Actual "
|
||||
"num "
|
||||
"gpus = " +
|
||||
std::to_string(world_size_));
|
||||
}
|
||||
@ -511,10 +595,11 @@ class CustomAllreduce {
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* To inspect PTX/SASS, copy paste this header file to compiler explorer and add
|
||||
a template instantiation:
|
||||
* To inspect PTX/SASS, copy paste this header file to compiler explorer and
|
||||
add a template instantiation:
|
||||
* template void vllm::CustomAllreduce::allreduce<half>(cudaStream_t, half *,
|
||||
half *, int, int, int);
|
||||
*/
|
||||
} // namespace vllm
|
||||
} // namespace vllm
|
@ -1,9 +1,9 @@
|
||||
/**
|
||||
* This is a standalone test for custom allreduce.
|
||||
* To compile, make sure you have MPI and NCCL installed in your system.
|
||||
* export MPI_HOME=xxx
|
||||
* export MPI_HOME=XXX
|
||||
* nvcc -O2 -arch=native -std=c++17 custom_all_reduce_test.cu -o
|
||||
* custom_all_reduce_test -lnccl -I${MPI_HOME} -lmpi
|
||||
* custom_all_reduce_test -lnccl -I${MPI_HOME}/include -lmpi
|
||||
*
|
||||
* Warning: this C++ test is not designed to be very readable and was used
|
||||
* during the rapid prototyping process.
|
||||
@ -22,7 +22,15 @@
|
||||
#include "cuda_profiler_api.h"
|
||||
#include "custom_all_reduce.cuh"
|
||||
#include "mpi.h"
|
||||
#include "nccl.h"
|
||||
#ifdef USE_ROCM
|
||||
#include <hip/hip_bf16.h>
|
||||
typedef __hip_bfloat16 nv_bfloat16;
|
||||
#include "rccl/rccl.h"
|
||||
#include "custom_all_reduce_hip.cuh"
|
||||
#else
|
||||
#include "nccl.h"
|
||||
#include "custom_all_reduce.cuh"
|
||||
#endif
|
||||
|
||||
#define MPICHECK(cmd) \
|
||||
do { \
|
||||
@ -43,16 +51,29 @@
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
#ifdef USE_ROCM
|
||||
__global__ void dummy_kernel() {
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700
|
||||
for (int i = 0; i < 100; i++) {
|
||||
uint64_t start = wall_clock64();
|
||||
uint64_t cycles_elapsed;
|
||||
do {
|
||||
cycles_elapsed = wall_clock64() - start;
|
||||
} while (cycles_elapsed < 100);
|
||||
}
|
||||
for (int i = 0; i < 100; i++) __nanosleep(1000000); // 100ms
|
||||
}
|
||||
#else
|
||||
__global__ void dummy_kernel() {
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700
|
||||
for (int i = 0; i < 100; i++) __nanosleep(1000000); // 100ms
|
||||
#else
|
||||
for (int i = 0; i < 100; i++) {
|
||||
long long int start = clock64();
|
||||
while (clock64() - start < 150000000); // approximately 98.4ms on P40
|
||||
}
|
||||
#endif
|
||||
#endif
|
||||
}
|
||||
#endif
|
||||
|
||||
template <typename T>
|
||||
__global__ void set_data(T* data, int size, int myRank) {
|
||||
@ -121,8 +142,14 @@ void run(int myRank, int nRanks, ncclComm_t& comm, int threads, int block_limit,
|
||||
* registration, they are allocated and registered together in the test for
|
||||
* convenience.
|
||||
*/
|
||||
#ifdef USE_ROCM
|
||||
CUDACHECK(hipExtMallocWithFlags(
|
||||
(void**)&buffer, 2 * data_size * sizeof(T) + sizeof(vllm::Signal),
|
||||
hipDeviceMallocUncached));
|
||||
#else
|
||||
CUDACHECK(
|
||||
cudaMalloc(&buffer, 2 * data_size * sizeof(T) + sizeof(vllm::Signal)));
|
||||
#endif
|
||||
CUDACHECK(
|
||||
cudaMemset(buffer, 0, 2 * data_size * sizeof(T) + sizeof(vllm::Signal)));
|
||||
CUDACHECK(cudaMalloc(&self_data_copy, data_size * sizeof(T)));
|
||||
@ -311,13 +338,18 @@ int main(int argc, char** argv) {
|
||||
|
||||
bool performance_test = true;
|
||||
cudaProfilerStart();
|
||||
// Uncomment to scan through different block size configs.
|
||||
// for (int threads : {256, 512, 1024}) {
|
||||
// for (int block_limit = 16; block_limit < 112; block_limit += 4) {
|
||||
// run<half>(myRank, nRanks, comm, threads, block_limit, 1024 * 1024,
|
||||
// performance_test);
|
||||
// }
|
||||
// }
|
||||
// Uncomment to scan through different block size configs.
|
||||
// for (int threads : {256, 512, 1024}) {
|
||||
// for (int block_limit = 16; block_limit < 112; block_limit += 4) {
|
||||
// run<half>(myRank, nRanks, comm, threads, block_limit, 1024 * 1024,
|
||||
// performance_test);
|
||||
// }
|
||||
// }
|
||||
#ifdef USE_ROCM
|
||||
const int block_limit = 16;
|
||||
#else
|
||||
const int block_limit = 36;
|
||||
#endif
|
||||
// Scan through different sizes to test performance.
|
||||
for (int sz = 512; sz <= (8 << 20); sz *= 2) {
|
||||
run<half>(myRank, nRanks, comm, 512, 36, sz + 8 * 47, performance_test);
|
||||
@ -326,4 +358,4 @@ int main(int argc, char** argv) {
|
||||
cudaProfilerStop();
|
||||
MPICHECK(MPI_Finalize());
|
||||
return EXIT_SUCCESS;
|
||||
}
|
||||
}
|
@ -267,10 +267,10 @@ void causal_conv1d_fwd(const at::Tensor& x, const at::Tensor& weight,
|
||||
const std::optional<at::Tensor>& has_initial_state,
|
||||
bool silu_activation, int64_t pad_slot_id);
|
||||
|
||||
#ifndef USE_ROCM
|
||||
using fptr_t = int64_t;
|
||||
fptr_t init_custom_ar(const std::vector<int64_t>& fake_ipc_ptrs,
|
||||
torch::Tensor& rank_data, int64_t rank, bool full_nvlink);
|
||||
torch::Tensor& rank_data, int64_t rank,
|
||||
bool fully_connected);
|
||||
void all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out,
|
||||
fptr_t reg_buffer, int64_t reg_buffer_sz_bytes);
|
||||
void dispose(fptr_t _fa);
|
||||
@ -281,4 +281,7 @@ get_graph_buffer_ipc_meta(fptr_t _fa);
|
||||
void register_graph_buffers(fptr_t _fa,
|
||||
const std::vector<std::vector<int64_t>>& handles,
|
||||
const std::vector<std::vector<int64_t>>& offsets);
|
||||
#endif
|
||||
std::tuple<int64_t, torch::Tensor> allocate_shared_buffer_and_handle(
|
||||
int64_t size);
|
||||
int64_t open_mem_handle(torch::Tensor& mem_handle);
|
||||
void free_shared_buffer(int64_t buffer);
|
||||
|
@ -614,12 +614,11 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cuda_utils), cuda_utils) {
|
||||
&get_max_shared_memory_per_block_device_attribute);
|
||||
}
|
||||
|
||||
#ifndef USE_ROCM
|
||||
TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _custom_ar), custom_ar) {
|
||||
// Custom all-reduce kernels
|
||||
custom_ar.def(
|
||||
"init_custom_ar(int[] ipc_tensors, Tensor rank_data, "
|
||||
"int rank, bool full_nvlink) -> int");
|
||||
"int rank, bool fully_connected) -> int");
|
||||
custom_ar.impl("init_custom_ar", torch::kCUDA, &init_custom_ar);
|
||||
custom_ar.def(
|
||||
"all_reduce(int fa, Tensor inp, Tensor! out, int reg_buffer, "
|
||||
@ -632,7 +631,13 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _custom_ar), custom_ar) {
|
||||
custom_ar.def("register_buffer", ®ister_buffer);
|
||||
custom_ar.def("get_graph_buffer_ipc_meta", &get_graph_buffer_ipc_meta);
|
||||
custom_ar.def("register_graph_buffers", ®ister_graph_buffers);
|
||||
|
||||
custom_ar.def("allocate_shared_buffer_and_handle",
|
||||
&allocate_shared_buffer_and_handle);
|
||||
custom_ar.def("open_mem_handle(Tensor mem_handle) -> int", &open_mem_handle);
|
||||
custom_ar.impl("open_mem_handle", torch::kCPU, &open_mem_handle);
|
||||
|
||||
custom_ar.def("free_shared_buffer", &free_shared_buffer);
|
||||
}
|
||||
#endif
|
||||
|
||||
REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
|
||||
|
@ -106,7 +106,7 @@ def eager_allreduce(
|
||||
# communicate independently
|
||||
num_communication = rank // tp_size + 1
|
||||
sz = 1024
|
||||
fa = get_tp_group().ca_comm
|
||||
fa = get_tp_group().device_communicator.ca_comm
|
||||
inp = torch.ones(sz, dtype=torch.float32, device=device)
|
||||
out = inp
|
||||
for _ in range(num_communication):
|
||||
|
@ -612,7 +612,16 @@ def multi_process_parallel(
|
||||
# as compared to multiprocessing.
|
||||
# NOTE: We need to set working_dir for distributed tests,
|
||||
# otherwise we may get import errors on ray workers
|
||||
ray.init(runtime_env={"working_dir": VLLM_PATH})
|
||||
# NOTE: Force ray not to use gitignore file as excluding, otherwise
|
||||
# it will not move .so files to working dir.
|
||||
# So we have to manually add some of large directories
|
||||
os.environ["RAY_RUNTIME_ENV_IGNORE_GITIGNORE"] = "1"
|
||||
ray.init(
|
||||
runtime_env={
|
||||
"working_dir": VLLM_PATH,
|
||||
"excludes":
|
||||
["build", ".git", "cmake-build-*", "shellcheck", "dist"]
|
||||
})
|
||||
|
||||
distributed_init_port = get_open_port()
|
||||
refs = []
|
||||
|
@ -1337,9 +1337,9 @@ def get_max_shared_memory_per_block_device_attribute(device: int) -> int:
|
||||
|
||||
# custom ar
|
||||
def init_custom_ar(ipc_tensors: list[torch.Tensor], rank_data: torch.Tensor,
|
||||
rank: int, full_nvlink: bool) -> int:
|
||||
rank: int, fully_connected: bool) -> int:
|
||||
return torch.ops._C_custom_ar.init_custom_ar(ipc_tensors, rank_data, rank,
|
||||
full_nvlink)
|
||||
fully_connected)
|
||||
|
||||
|
||||
def all_reduce(fa: int, inp: torch.Tensor, out: torch.Tensor, reg_buffer: int,
|
||||
@ -1369,6 +1369,18 @@ def register_graph_buffers(fa: int, handles: list[list[int]],
|
||||
torch.ops._C_custom_ar.register_graph_buffers(fa, handles, offsets)
|
||||
|
||||
|
||||
def allocate_shared_buffer_and_handle(size: int) -> tuple[int, torch.Tensor]:
|
||||
return torch.ops._C_custom_ar.allocate_shared_buffer_and_handle(size)
|
||||
|
||||
|
||||
def open_mem_handle(mem_handle: torch.Tensor):
|
||||
return torch.ops._C_custom_ar.open_mem_handle(mem_handle)
|
||||
|
||||
|
||||
def free_shared_buffer(ptr: int) -> None:
|
||||
torch.ops._C_custom_ar.free_shared_buffer(ptr)
|
||||
|
||||
|
||||
def get_flash_mla_metadata(
|
||||
cache_seqlens: torch.Tensor,
|
||||
num_heads_per_head_k: int,
|
||||
|
@ -1606,11 +1606,13 @@ class ParallelConfig:
|
||||
if self.use_ray:
|
||||
from vllm.executor import ray_utils
|
||||
ray_utils.assert_ray_available()
|
||||
if current_platform.is_rocm():
|
||||
device_capability = current_platform.get_device_capability()
|
||||
if (current_platform.is_rocm() and device_capability is not None
|
||||
and device_capability < (9, 4)):
|
||||
self.disable_custom_all_reduce = True
|
||||
logger.info(
|
||||
"Disabled the custom all-reduce kernel because it is not "
|
||||
"supported on AMD GPUs.")
|
||||
"supported on AMD GPUs older than MI300X.")
|
||||
if self.ray_workers_use_nsight and not self.use_ray:
|
||||
raise ValueError("Unable to use nsight profiling unless workers "
|
||||
"run with Ray.")
|
||||
|
@ -1,6 +1,5 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import ctypes
|
||||
from contextlib import contextmanager
|
||||
from typing import List, Optional, Union
|
||||
|
||||
@ -10,7 +9,6 @@ from torch.distributed import ProcessGroup
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.distributed.device_communicators.cuda_wrapper import CudaRTLibrary
|
||||
from vllm.distributed.device_communicators.custom_all_reduce_utils import (
|
||||
gpu_p2p_access_check)
|
||||
from vllm.distributed.parallel_state import in_the_same_node_as
|
||||
@ -22,7 +20,7 @@ try:
|
||||
ops.meta_size()
|
||||
custom_ar = True
|
||||
except Exception:
|
||||
# For AMD GPUs and CPUs
|
||||
# For CPUs
|
||||
custom_ar = False
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@ -71,7 +69,9 @@ class CustomAllreduce:
|
||||
|
||||
if not custom_ar:
|
||||
# disable because of missing custom allreduce library
|
||||
# e.g. in a non-cuda environment
|
||||
# e.g. in a non-GPU environment
|
||||
logger.info("Custom allreduce is disabled because "
|
||||
"of missing custom allreduce library")
|
||||
return
|
||||
|
||||
self.group = group
|
||||
@ -129,11 +129,10 @@ class CustomAllreduce:
|
||||
# test nvlink first, this will filter out most of the cases
|
||||
# where custom allreduce is not supported
|
||||
# this checks hardware and driver support for NVLink
|
||||
assert current_platform.is_cuda()
|
||||
from vllm.platforms.cuda import CudaPlatform
|
||||
cuda_platform: CudaPlatform = current_platform
|
||||
full_nvlink = cuda_platform.is_full_nvlink(physical_device_ids)
|
||||
if world_size > 2 and not full_nvlink:
|
||||
assert current_platform.is_cuda_alike()
|
||||
fully_connected = current_platform.is_fully_connected(
|
||||
physical_device_ids)
|
||||
if world_size > 2 and not fully_connected:
|
||||
logger.warning(
|
||||
"Custom allreduce is disabled because it's not supported on"
|
||||
" more than two PCIe-only GPUs. To silence this warning, "
|
||||
@ -142,7 +141,8 @@ class CustomAllreduce:
|
||||
# test P2P capability, this checks software/cudaruntime support
|
||||
# this is expensive to compute at the first time
|
||||
# then we cache the result
|
||||
if not _can_p2p(rank, world_size):
|
||||
# On AMD GPU, p2p is always enabled between XGMI connected GPUs
|
||||
if not current_platform.is_rocm() and not _can_p2p(rank, world_size):
|
||||
logger.warning(
|
||||
"Custom allreduce is disabled because your platform lacks "
|
||||
"GPU P2P capability or P2P test failed. To silence this "
|
||||
@ -154,7 +154,8 @@ class CustomAllreduce:
|
||||
# Meta data composes of two parts: meta data for synchronization and a
|
||||
# temporary buffer for storing intermediate allreduce results.
|
||||
self.meta_ptrs = self.create_shared_buffer(ops.meta_size() + max_size,
|
||||
group=group)
|
||||
group=group,
|
||||
uncached=True)
|
||||
# This is a pre-registered IPC buffer. In eager mode, input tensors
|
||||
# are first copied into this buffer before allreduce is performed
|
||||
self.buffer_ptrs = self.create_shared_buffer(max_size, group=group)
|
||||
@ -169,46 +170,11 @@ class CustomAllreduce:
|
||||
self.max_size = max_size
|
||||
self.rank = rank
|
||||
self.world_size = world_size
|
||||
self.full_nvlink = full_nvlink
|
||||
self.fully_connected = fully_connected
|
||||
self._ptr = ops.init_custom_ar(self.meta_ptrs, self.rank_data, rank,
|
||||
self.full_nvlink)
|
||||
self.fully_connected)
|
||||
ops.register_buffer(self._ptr, self.buffer_ptrs)
|
||||
|
||||
@staticmethod
|
||||
def create_shared_buffer(
|
||||
size_in_bytes: int,
|
||||
group: Optional[ProcessGroup] = None) -> List[int]:
|
||||
"""
|
||||
Creates a shared buffer and returns a list of pointers
|
||||
representing the buffer on all processes in the group.
|
||||
"""
|
||||
lib = CudaRTLibrary()
|
||||
pointer = lib.cudaMalloc(size_in_bytes)
|
||||
handle = lib.cudaIpcGetMemHandle(pointer)
|
||||
world_size = dist.get_world_size(group=group)
|
||||
rank = dist.get_rank(group=group)
|
||||
handles = [None] * world_size
|
||||
dist.all_gather_object(handles, handle, group=group)
|
||||
|
||||
pointers: List[int] = []
|
||||
for i, h in enumerate(handles):
|
||||
if i == rank:
|
||||
pointers.append(pointer.value) # type: ignore
|
||||
else:
|
||||
pointers.append(
|
||||
lib.cudaIpcOpenMemHandle(h).value) # type: ignore
|
||||
|
||||
return pointers
|
||||
|
||||
@staticmethod
|
||||
def free_shared_buffer(pointers: List[int],
|
||||
group: Optional[ProcessGroup] = None,
|
||||
rank: Optional[int] = None) -> None:
|
||||
if rank is None:
|
||||
rank = dist.get_rank(group=group)
|
||||
lib = CudaRTLibrary()
|
||||
lib.cudaFree(ctypes.c_void_p(pointers[rank]))
|
||||
|
||||
@contextmanager
|
||||
def capture(self):
|
||||
"""
|
||||
@ -255,7 +221,7 @@ class CustomAllreduce:
|
||||
return False
|
||||
# for 4 or more non NVLink-capable GPUs, custom allreduce provides
|
||||
# little performance improvement over NCCL.
|
||||
if self.world_size == 2 or self.full_nvlink:
|
||||
if self.world_size == 2 or self.fully_connected:
|
||||
return inp_size < self.max_size
|
||||
return False
|
||||
|
||||
@ -306,3 +272,30 @@ class CustomAllreduce:
|
||||
|
||||
def __del__(self):
|
||||
self.close()
|
||||
|
||||
@staticmethod
|
||||
def create_shared_buffer(size_in_bytes: int,
|
||||
group: Optional[ProcessGroup] = None,
|
||||
uncached: Optional[bool] = False) -> List[int]:
|
||||
pointer, handle = ops.allocate_shared_buffer_and_handle(size_in_bytes)
|
||||
|
||||
world_size = dist.get_world_size(group=group)
|
||||
rank = dist.get_rank(group=group)
|
||||
handles = [None] * world_size
|
||||
dist.all_gather_object(handles, handle, group=group)
|
||||
|
||||
pointers: List[int] = []
|
||||
for i, h in enumerate(handles):
|
||||
if i == rank:
|
||||
pointers.append(pointer) # type: ignore
|
||||
else:
|
||||
pointers.append(ops.open_mem_handle(h))
|
||||
return pointers
|
||||
|
||||
@staticmethod
|
||||
def free_shared_buffer(pointers: List[int],
|
||||
group: Optional[ProcessGroup] = None,
|
||||
rank: Optional[int] = 0) -> None:
|
||||
if rank is None:
|
||||
rank = dist.get_rank(group=group)
|
||||
ops.free_shared_buffer(pointers[rank])
|
||||
|
@ -101,7 +101,7 @@ class CudaPlatformBase(Platform):
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def is_full_nvlink(cls, device_ids: List[int]) -> bool:
|
||||
def is_fully_connected(cls, device_ids: List[int]) -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
@ -362,7 +362,7 @@ class NvmlCudaPlatform(CudaPlatformBase):
|
||||
|
||||
@classmethod
|
||||
@with_nvml_context
|
||||
def is_full_nvlink(cls, physical_device_ids: List[int]) -> bool:
|
||||
def is_fully_connected(cls, physical_device_ids: List[int]) -> bool:
|
||||
"""
|
||||
query if the set of gpus are fully connected by nvlink (1 hop)
|
||||
"""
|
||||
@ -427,7 +427,7 @@ class NonNvmlCudaPlatform(CudaPlatformBase):
|
||||
return device_props.total_memory
|
||||
|
||||
@classmethod
|
||||
def is_full_nvlink(cls, physical_device_ids: List[int]) -> bool:
|
||||
def is_fully_connected(cls, physical_device_ids: List[int]) -> bool:
|
||||
logger.exception(
|
||||
"NVLink detection not possible, as context support was"
|
||||
" not found. Assuming no NVLink available.")
|
||||
|
@ -20,8 +20,9 @@ else:
|
||||
logger = init_logger(__name__)
|
||||
|
||||
try:
|
||||
from amdsmi import (amdsmi_get_gpu_asic_info, amdsmi_get_processor_handles,
|
||||
amdsmi_init, amdsmi_shut_down)
|
||||
from amdsmi import (AmdSmiException, amdsmi_get_gpu_asic_info,
|
||||
amdsmi_get_processor_handles, amdsmi_init,
|
||||
amdsmi_shut_down, amdsmi_topo_get_link_type)
|
||||
except ImportError as e:
|
||||
logger.warning("Failed to import from amdsmi with %r", e)
|
||||
|
||||
@ -135,10 +136,36 @@ class RocmPlatform(Platform):
|
||||
|
||||
@classmethod
|
||||
@lru_cache(maxsize=8)
|
||||
def get_device_capability(cls, device_id: int = 0) -> DeviceCapability:
|
||||
def get_device_capability(cls,
|
||||
device_id: int = 0
|
||||
) -> Optional[DeviceCapability]:
|
||||
major, minor = torch.cuda.get_device_capability(device_id)
|
||||
return DeviceCapability(major=major, minor=minor)
|
||||
|
||||
@staticmethod
|
||||
@with_amdsmi_context
|
||||
def is_fully_connected(physical_device_ids: List[int]) -> bool:
|
||||
"""
|
||||
Query if the set of gpus are fully connected by xgmi (1 hop)
|
||||
"""
|
||||
handles = [
|
||||
amdsmi_get_processor_handles()[i] for i in physical_device_ids
|
||||
]
|
||||
for i, handle in enumerate(handles):
|
||||
for j, peer_handle in enumerate(handles):
|
||||
if i < j:
|
||||
try:
|
||||
link_type = amdsmi_topo_get_link_type(
|
||||
handle, peer_handle)
|
||||
# type is 2 for XGMI
|
||||
if link_type["hops"] != 1 or link_type["type"] != 2:
|
||||
return False
|
||||
except AmdSmiException as error:
|
||||
logger.error("AMD 1 hop XGMI detection failed.",
|
||||
exc_info=error)
|
||||
return False
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
@with_amdsmi_context
|
||||
@lru_cache(maxsize=8)
|
||||
|
Reference in New Issue
Block a user