mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-16 15:34:57 +08:00
Update on "[NOT FOR LANDING] experimental NVSHMEM integration"
cc H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o [ghstack-poisoned]
This commit is contained in:
@ -980,8 +980,8 @@ class NVSHMEMSymmetricMemory : public SymmetricMemory {
|
||||
cudaMemcpyHostToDevice));
|
||||
}
|
||||
|
||||
~NVSHMEMSymmetricMemory() override {
|
||||
// TODO
|
||||
~NVSHMEMSymmetricMemory() override{
|
||||
// TODO
|
||||
};
|
||||
|
||||
std::vector<void*> get_buffer_ptrs() override {
|
||||
@ -1107,7 +1107,11 @@ class NVSHMEMSymmetricMemory : public SymmetricMemory {
|
||||
void put_signal(int dst_rank, int channel, size_t timeout_ms) override {
|
||||
check_channel(channel, world_size_);
|
||||
c10::cuda::CUDAGuard guard(device_idx_);
|
||||
put_signal_kernel<<<1, C10_WARP_SIZE, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
put_signal_kernel<<<
|
||||
1,
|
||||
C10_WARP_SIZE,
|
||||
0,
|
||||
at::cuda::getCurrentCUDAStream()>>>(
|
||||
reinterpret_cast<uint32_t**>(signal_pads_dev_),
|
||||
dst_rank,
|
||||
channel,
|
||||
@ -1120,7 +1124,11 @@ class NVSHMEMSymmetricMemory : public SymmetricMemory {
|
||||
void wait_signal(int src_rank, int channel, size_t timeout_ms) override {
|
||||
check_channel(channel, world_size_);
|
||||
c10::cuda::CUDAGuard guard(device_idx_);
|
||||
wait_signal_kernel<<<1, C10_WARP_SIZE, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
wait_signal_kernel<<<
|
||||
1,
|
||||
C10_WARP_SIZE,
|
||||
0,
|
||||
at::cuda::getCurrentCUDAStream()>>>(
|
||||
reinterpret_cast<uint32_t**>(signal_pads_dev_),
|
||||
src_rank,
|
||||
channel,
|
||||
@ -1168,6 +1176,7 @@ class NVSHMEMSymmetricMemoryAllocator : public SymmetricMemoryAllocator {
|
||||
auto symm_mem = c10::make_intrusive<NVSHMEMSymmetricMemory>(
|
||||
size, device_idx, rank, world_size);
|
||||
void* ptr = symm_mem->get_buffer_ptrs()[rank];
|
||||
// TODO: thread safety
|
||||
ptr_to_symm_mem_[ptr] = symm_mem;
|
||||
return ptr;
|
||||
}
|
||||
|
||||
@ -20,6 +20,7 @@
|
||||
#include <torch/csrc/distributed/c10d/CUDASymmetricMemory-inl.h>
|
||||
#include <torch/csrc/distributed/c10d/CUDASymmetricMemory.hpp>
|
||||
#include <torch/csrc/distributed/c10d/cuda/AsyncMM.cuh>
|
||||
#include <torch/csrc/distributed/c10d/nvshmem_extension.cuh>
|
||||
|
||||
#define INT_SWITCH_CASE(name, val, ...) \
|
||||
case val: { \
|
||||
@ -717,4 +718,6 @@ TORCH_LIBRARY_IMPL(symm_mem, CUDA, m) {
|
||||
#endif
|
||||
m.impl("stream_write_value32_", ::stream_write_value32_);
|
||||
m.impl("memset32_", ::memset32_);
|
||||
|
||||
m.impl("nvshmem_hello", c10d::nvshmem_extension::nvshmem_hello);
|
||||
}
|
||||
|
||||
@ -248,6 +248,9 @@ TORCH_LIBRARY_FRAGMENT(symm_mem, m) {
|
||||
"stream_write_value32_(Tensor(a!) input, int offset, int val) -> Tensor(a!)");
|
||||
m.def(
|
||||
"memset32_(Tensor(a!) input, int offset, int val, int count) -> Tensor(a!)");
|
||||
|
||||
m.def(
|
||||
"nvshmem_hello(Tensor(a!) input) -> Tensor(a!)");
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_IMPL(symm_mem, Meta, m) {
|
||||
|
||||
@ -2,6 +2,8 @@
|
||||
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
|
||||
#include <torch/csrc/distributed/c10d/SymmetricMemory.hpp>
|
||||
|
||||
#include <nvshmem.h>
|
||||
|
||||
namespace c10d::nvshmem_extension {
|
||||
@ -79,8 +81,49 @@ void* nvshmem_malloc(size_t size) {
|
||||
return ::nvshmem_malloc(size);
|
||||
}
|
||||
|
||||
void* nvshmem_ptr(const void *dest, int pe) {
|
||||
void* nvshmem_ptr(const void* dest, int pe) {
|
||||
return ::nvshmem_ptr(dest, pe);
|
||||
}
|
||||
|
||||
__global__ void ring_bcast(int* data, size_t nelem, int root, uint64_t* psync) {
|
||||
int mype = nvshmem_my_pe();
|
||||
int npes = nvshmem_n_pes();
|
||||
int peer = (mype + 1) % npes;
|
||||
|
||||
if (mype == root)
|
||||
*psync = 1;
|
||||
|
||||
nvshmem_signal_wait_until(psync, NVSHMEM_CMP_NE, 0);
|
||||
|
||||
if (mype == npes - 1)
|
||||
return;
|
||||
|
||||
nvshmem_int_put(data, data, nelem, peer);
|
||||
nvshmem_fence();
|
||||
nvshmemx_signal_op(psync, 1, NVSHMEM_SIGNAL_SET, peer);
|
||||
|
||||
*psync = 0;
|
||||
}
|
||||
|
||||
at::Tensor nvshmem_hello(at::Tensor& input) {
|
||||
auto symm_mem = c10d::symmetric_memory::rendezvous(input, "0");
|
||||
int rank = symm_mem->get_rank();
|
||||
int world_size = symm_mem->get_world_size();
|
||||
|
||||
void* buffer_ptr = symm_mem->get_buffer_ptrs()[rank];
|
||||
void* signal_pad_ptr = symm_mem->get_signal_pad_ptrs()[rank];
|
||||
size_t buffer_size = symm_mem->get_buffer_size();
|
||||
int root = 0;
|
||||
void* args[] = {&buffer_ptr, &buffer_size, &root, &signal_pad_ptr};
|
||||
|
||||
dim3 grid_dim(1), block_dim(1);
|
||||
auto stream = at::cuda::getCurrentCUDAStream();
|
||||
nvshmemx_barrier_all_on_stream(stream);
|
||||
nvshmemx_collective_launch(
|
||||
(const void*)ring_bcast, grid_dim, block_dim, args, 0, stream);
|
||||
nvshmemx_barrier_all_on_stream(stream);
|
||||
|
||||
return input;
|
||||
}
|
||||
|
||||
} // namespace c10d::nvshmem_extension
|
||||
|
||||
@ -13,6 +13,8 @@ void initialize_nvshmem_with_store(
|
||||
|
||||
void* nvshmem_malloc(size_t size);
|
||||
|
||||
void* nvshmem_ptr(const void *dest, int pe);
|
||||
void* nvshmem_ptr(const void* dest, int pe);
|
||||
|
||||
}
|
||||
at::Tensor nvshmem_hello(at::Tensor& input);
|
||||
|
||||
} // namespace c10d::nvshmem_extension
|
||||
|
||||
Reference in New Issue
Block a user