Update on "[NOT FOR LANDING] experimental NVSHMEM integration"

cc H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o

[ghstack-poisoned]
This commit is contained in:
Yifu Wang
2025-02-06 11:08:03 -08:00
parent 702bb7ade4
commit 792d2dabb4
5 changed files with 67 additions and 7 deletions

View File

@ -980,8 +980,8 @@ class NVSHMEMSymmetricMemory : public SymmetricMemory {
cudaMemcpyHostToDevice));
}
~NVSHMEMSymmetricMemory() override {
// TODO
~NVSHMEMSymmetricMemory() override{
// TODO
};
std::vector<void*> get_buffer_ptrs() override {
@ -1107,7 +1107,11 @@ class NVSHMEMSymmetricMemory : public SymmetricMemory {
void put_signal(int dst_rank, int channel, size_t timeout_ms) override {
check_channel(channel, world_size_);
c10::cuda::CUDAGuard guard(device_idx_);
put_signal_kernel<<<1, C10_WARP_SIZE, 0, at::cuda::getCurrentCUDAStream()>>>(
put_signal_kernel<<<
1,
C10_WARP_SIZE,
0,
at::cuda::getCurrentCUDAStream()>>>(
reinterpret_cast<uint32_t**>(signal_pads_dev_),
dst_rank,
channel,
@ -1120,7 +1124,11 @@ class NVSHMEMSymmetricMemory : public SymmetricMemory {
void wait_signal(int src_rank, int channel, size_t timeout_ms) override {
check_channel(channel, world_size_);
c10::cuda::CUDAGuard guard(device_idx_);
wait_signal_kernel<<<1, C10_WARP_SIZE, 0, at::cuda::getCurrentCUDAStream()>>>(
wait_signal_kernel<<<
1,
C10_WARP_SIZE,
0,
at::cuda::getCurrentCUDAStream()>>>(
reinterpret_cast<uint32_t**>(signal_pads_dev_),
src_rank,
channel,
@ -1168,6 +1176,7 @@ class NVSHMEMSymmetricMemoryAllocator : public SymmetricMemoryAllocator {
auto symm_mem = c10::make_intrusive<NVSHMEMSymmetricMemory>(
size, device_idx, rank, world_size);
void* ptr = symm_mem->get_buffer_ptrs()[rank];
// TODO: thread safety
ptr_to_symm_mem_[ptr] = symm_mem;
return ptr;
}

View File

@ -20,6 +20,7 @@
#include <torch/csrc/distributed/c10d/CUDASymmetricMemory-inl.h>
#include <torch/csrc/distributed/c10d/CUDASymmetricMemory.hpp>
#include <torch/csrc/distributed/c10d/cuda/AsyncMM.cuh>
#include <torch/csrc/distributed/c10d/nvshmem_extension.cuh>
#define INT_SWITCH_CASE(name, val, ...) \
case val: { \
@ -717,4 +718,6 @@ TORCH_LIBRARY_IMPL(symm_mem, CUDA, m) {
#endif
m.impl("stream_write_value32_", ::stream_write_value32_);
m.impl("memset32_", ::memset32_);
m.impl("nvshmem_hello", c10d::nvshmem_extension::nvshmem_hello);
}

View File

@ -248,6 +248,9 @@ TORCH_LIBRARY_FRAGMENT(symm_mem, m) {
"stream_write_value32_(Tensor(a!) input, int offset, int val) -> Tensor(a!)");
m.def(
"memset32_(Tensor(a!) input, int offset, int val, int count) -> Tensor(a!)");
m.def(
"nvshmem_hello(Tensor(a!) input) -> Tensor(a!)");
}
TORCH_LIBRARY_IMPL(symm_mem, Meta, m) {

View File

@ -2,6 +2,8 @@
#include <c10/cuda/CUDAGuard.h>
#include <torch/csrc/distributed/c10d/SymmetricMemory.hpp>
#include <nvshmem.h>
namespace c10d::nvshmem_extension {
@ -79,8 +81,49 @@ void* nvshmem_malloc(size_t size) {
return ::nvshmem_malloc(size);
}
void* nvshmem_ptr(const void *dest, int pe) {
void* nvshmem_ptr(const void* dest, int pe) {
return ::nvshmem_ptr(dest, pe);
}
__global__ void ring_bcast(int* data, size_t nelem, int root, uint64_t* psync) {
int mype = nvshmem_my_pe();
int npes = nvshmem_n_pes();
int peer = (mype + 1) % npes;
if (mype == root)
*psync = 1;
nvshmem_signal_wait_until(psync, NVSHMEM_CMP_NE, 0);
if (mype == npes - 1)
return;
nvshmem_int_put(data, data, nelem, peer);
nvshmem_fence();
nvshmemx_signal_op(psync, 1, NVSHMEM_SIGNAL_SET, peer);
*psync = 0;
}
at::Tensor nvshmem_hello(at::Tensor& input) {
auto symm_mem = c10d::symmetric_memory::rendezvous(input, "0");
int rank = symm_mem->get_rank();
int world_size = symm_mem->get_world_size();
void* buffer_ptr = symm_mem->get_buffer_ptrs()[rank];
void* signal_pad_ptr = symm_mem->get_signal_pad_ptrs()[rank];
size_t buffer_size = symm_mem->get_buffer_size();
int root = 0;
void* args[] = {&buffer_ptr, &buffer_size, &root, &signal_pad_ptr};
dim3 grid_dim(1), block_dim(1);
auto stream = at::cuda::getCurrentCUDAStream();
nvshmemx_barrier_all_on_stream(stream);
nvshmemx_collective_launch(
(const void*)ring_bcast, grid_dim, block_dim, args, 0, stream);
nvshmemx_barrier_all_on_stream(stream);
return input;
}
} // namespace c10d::nvshmem_extension

View File

@ -13,6 +13,8 @@ void initialize_nvshmem_with_store(
void* nvshmem_malloc(size_t size);
void* nvshmem_ptr(const void *dest, int pe);
void* nvshmem_ptr(const void* dest, int pe);
}
at::Tensor nvshmem_hello(at::Tensor& input);
} // namespace c10d::nvshmem_extension