Update on "[NOT FOR LANDING] experimental NVSHMEM integration"

cc H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o

[ghstack-poisoned]
This commit is contained in:
Yifu Wang
2025-02-06 11:08:03 -08:00
parent 702bb7ade4
commit 792d2dabb4
5 changed files with 67 additions and 7 deletions

View File

@ -980,8 +980,8 @@ class NVSHMEMSymmetricMemory : public SymmetricMemory {
cudaMemcpyHostToDevice)); cudaMemcpyHostToDevice));
} }
~NVSHMEMSymmetricMemory() override { ~NVSHMEMSymmetricMemory() override{
// TODO // TODO
}; };
std::vector<void*> get_buffer_ptrs() override { std::vector<void*> get_buffer_ptrs() override {
@ -1107,7 +1107,11 @@ class NVSHMEMSymmetricMemory : public SymmetricMemory {
void put_signal(int dst_rank, int channel, size_t timeout_ms) override { void put_signal(int dst_rank, int channel, size_t timeout_ms) override {
check_channel(channel, world_size_); check_channel(channel, world_size_);
c10::cuda::CUDAGuard guard(device_idx_); c10::cuda::CUDAGuard guard(device_idx_);
put_signal_kernel<<<1, C10_WARP_SIZE, 0, at::cuda::getCurrentCUDAStream()>>>( put_signal_kernel<<<
1,
C10_WARP_SIZE,
0,
at::cuda::getCurrentCUDAStream()>>>(
reinterpret_cast<uint32_t**>(signal_pads_dev_), reinterpret_cast<uint32_t**>(signal_pads_dev_),
dst_rank, dst_rank,
channel, channel,
@ -1120,7 +1124,11 @@ class NVSHMEMSymmetricMemory : public SymmetricMemory {
void wait_signal(int src_rank, int channel, size_t timeout_ms) override { void wait_signal(int src_rank, int channel, size_t timeout_ms) override {
check_channel(channel, world_size_); check_channel(channel, world_size_);
c10::cuda::CUDAGuard guard(device_idx_); c10::cuda::CUDAGuard guard(device_idx_);
wait_signal_kernel<<<1, C10_WARP_SIZE, 0, at::cuda::getCurrentCUDAStream()>>>( wait_signal_kernel<<<
1,
C10_WARP_SIZE,
0,
at::cuda::getCurrentCUDAStream()>>>(
reinterpret_cast<uint32_t**>(signal_pads_dev_), reinterpret_cast<uint32_t**>(signal_pads_dev_),
src_rank, src_rank,
channel, channel,
@ -1168,6 +1176,7 @@ class NVSHMEMSymmetricMemoryAllocator : public SymmetricMemoryAllocator {
auto symm_mem = c10::make_intrusive<NVSHMEMSymmetricMemory>( auto symm_mem = c10::make_intrusive<NVSHMEMSymmetricMemory>(
size, device_idx, rank, world_size); size, device_idx, rank, world_size);
void* ptr = symm_mem->get_buffer_ptrs()[rank]; void* ptr = symm_mem->get_buffer_ptrs()[rank];
// TODO: thread safety
ptr_to_symm_mem_[ptr] = symm_mem; ptr_to_symm_mem_[ptr] = symm_mem;
return ptr; return ptr;
} }

View File

@ -20,6 +20,7 @@
#include <torch/csrc/distributed/c10d/CUDASymmetricMemory-inl.h> #include <torch/csrc/distributed/c10d/CUDASymmetricMemory-inl.h>
#include <torch/csrc/distributed/c10d/CUDASymmetricMemory.hpp> #include <torch/csrc/distributed/c10d/CUDASymmetricMemory.hpp>
#include <torch/csrc/distributed/c10d/cuda/AsyncMM.cuh> #include <torch/csrc/distributed/c10d/cuda/AsyncMM.cuh>
#include <torch/csrc/distributed/c10d/nvshmem_extension.cuh>
#define INT_SWITCH_CASE(name, val, ...) \ #define INT_SWITCH_CASE(name, val, ...) \
case val: { \ case val: { \
@ -717,4 +718,6 @@ TORCH_LIBRARY_IMPL(symm_mem, CUDA, m) {
#endif #endif
m.impl("stream_write_value32_", ::stream_write_value32_); m.impl("stream_write_value32_", ::stream_write_value32_);
m.impl("memset32_", ::memset32_); m.impl("memset32_", ::memset32_);
m.impl("nvshmem_hello", c10d::nvshmem_extension::nvshmem_hello);
} }

View File

@ -248,6 +248,9 @@ TORCH_LIBRARY_FRAGMENT(symm_mem, m) {
"stream_write_value32_(Tensor(a!) input, int offset, int val) -> Tensor(a!)"); "stream_write_value32_(Tensor(a!) input, int offset, int val) -> Tensor(a!)");
m.def( m.def(
"memset32_(Tensor(a!) input, int offset, int val, int count) -> Tensor(a!)"); "memset32_(Tensor(a!) input, int offset, int val, int count) -> Tensor(a!)");
m.def(
"nvshmem_hello(Tensor(a!) input) -> Tensor(a!)");
} }
TORCH_LIBRARY_IMPL(symm_mem, Meta, m) { TORCH_LIBRARY_IMPL(symm_mem, Meta, m) {

View File

@ -2,6 +2,8 @@
#include <c10/cuda/CUDAGuard.h> #include <c10/cuda/CUDAGuard.h>
#include <torch/csrc/distributed/c10d/SymmetricMemory.hpp>
#include <nvshmem.h> #include <nvshmem.h>
namespace c10d::nvshmem_extension { namespace c10d::nvshmem_extension {
@ -79,8 +81,49 @@ void* nvshmem_malloc(size_t size) {
return ::nvshmem_malloc(size); return ::nvshmem_malloc(size);
} }
void* nvshmem_ptr(const void *dest, int pe) { void* nvshmem_ptr(const void* dest, int pe) {
return ::nvshmem_ptr(dest, pe); return ::nvshmem_ptr(dest, pe);
} }
__global__ void ring_bcast(int* data, size_t nelem, int root, uint64_t* psync) {
int mype = nvshmem_my_pe();
int npes = nvshmem_n_pes();
int peer = (mype + 1) % npes;
if (mype == root)
*psync = 1;
nvshmem_signal_wait_until(psync, NVSHMEM_CMP_NE, 0);
if (mype == npes - 1)
return;
nvshmem_int_put(data, data, nelem, peer);
nvshmem_fence();
nvshmemx_signal_op(psync, 1, NVSHMEM_SIGNAL_SET, peer);
*psync = 0;
} }
at::Tensor nvshmem_hello(at::Tensor& input) {
auto symm_mem = c10d::symmetric_memory::rendezvous(input, "0");
int rank = symm_mem->get_rank();
int world_size = symm_mem->get_world_size();
void* buffer_ptr = symm_mem->get_buffer_ptrs()[rank];
void* signal_pad_ptr = symm_mem->get_signal_pad_ptrs()[rank];
size_t buffer_size = symm_mem->get_buffer_size();
int root = 0;
void* args[] = {&buffer_ptr, &buffer_size, &root, &signal_pad_ptr};
dim3 grid_dim(1), block_dim(1);
auto stream = at::cuda::getCurrentCUDAStream();
nvshmemx_barrier_all_on_stream(stream);
nvshmemx_collective_launch(
(const void*)ring_bcast, grid_dim, block_dim, args, 0, stream);
nvshmemx_barrier_all_on_stream(stream);
return input;
}
} // namespace c10d::nvshmem_extension

View File

@ -13,6 +13,8 @@ void initialize_nvshmem_with_store(
void* nvshmem_malloc(size_t size); void* nvshmem_malloc(size_t size);
void* nvshmem_ptr(const void *dest, int pe); void* nvshmem_ptr(const void* dest, int pe);
} at::Tensor nvshmem_hello(at::Tensor& input);
} // namespace c10d::nvshmem_extension