[SymmMem] Experimental NVSHMEM integration (#151261)

Adding NVSHMEM as a backend for `SymmetricMemory`, implementation of which is in `NVSHMEMSymmetricMemory.cu`.

Moving some helper functions in `CUDASymmetricMemory.cu` to `CUDASymmetricMemoryUtils.cpp`, so that they can be shared by `NVSHMEMSymmetricMemory`. These functions are mostly side-band exchange helpers (`store_all_gather`, `IpcChannel`, etc).

Adding `TORCH_SYMMEM` to control which implementation to use for CUDA tensors, currently support: `CUDA` (in-house impl), `NVSHMEM`.

The NVSHMEM feature is gated by build-time flag: `USE_NVSHMEM=1`. And `NVSHMEM_HOME` setting is required (TODO).

Ported most code from #146593.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/151261
Approved by: https://github.com/fegin, https://github.com/fduwjj
This commit is contained in:
Ke Wen
2025-04-30 17:04:57 -07:00
committed by PyTorch MergeBot
parent 13add553b2
commit a7f1ddc184
15 changed files with 895 additions and 286 deletions

View File

@ -740,6 +740,7 @@ cc_library(
"torch/csrc/distributed/c10d/intra_node_comm.cu",
"torch/csrc/distributed/c10d/CUDASymmetricMemory.cu",
"torch/csrc/distributed/c10d/CUDASymmetricMemoryOps.cu",
"torch/csrc/distributed/c10d/CUDASymmetricMemoryUtils.cpp",
"torch/csrc/distributed/c10d/cuda/AsyncMM.cu",
"torch/csrc/distributed/c10d/NanCheck.cu",
"torch/csrc/distributed/c10d/quantization/quantization_gpu.cu",

View File

@ -695,6 +695,7 @@ libtorch_cuda_distributed_extra_sources = [
"torch/csrc/distributed/c10d/intra_node_comm.cu",
"torch/csrc/distributed/c10d/CUDASymmetricMemory.cu",
"torch/csrc/distributed/c10d/CUDASymmetricMemoryOps.cu",
"torch/csrc/distributed/c10d/CUDASymmetricMemoryUtils.cpp",
"torch/csrc/distributed/c10d/cuda/AsyncMM.cu",
"torch/csrc/distributed/c10d/cuda/utils.cpp",
"torch/csrc/distributed/c10d/NanCheck.cu",

View File

@ -572,6 +572,7 @@ if(USE_CUDA)
${TORCH_SRC_DIR}/csrc/distributed/c10d/CudaDMAConnectivity.cpp
${TORCH_SRC_DIR}/csrc/distributed/c10d/CUDASymmetricMemory.cu
${TORCH_SRC_DIR}/csrc/distributed/c10d/CUDASymmetricMemoryOps.cu
${TORCH_SRC_DIR}/csrc/distributed/c10d/CUDASymmetricMemoryUtils.cpp
${TORCH_SRC_DIR}/csrc/distributed/c10d/ProcessGroupNCCL.cpp
PROPERTIES COMPILE_FLAGS "-DPYTORCH_C10_DRIVER_API_SUPPORTED=1"
)
@ -976,6 +977,41 @@ elseif(USE_CUDA)
target_link_libraries(torch_cuda PRIVATE __caffe2_nccl)
target_compile_definitions(torch_cuda PRIVATE USE_NCCL)
endif()
# Use env var for these for now for prototyping purposes
set(USE_NVSHMEM $ENV{USE_NVSHMEM} CACHE BOOL "Enable NVSHMEM support")
set(NVSHMEM_HOME $ENV{NVSHMEM_HOME} CACHE PATH "Path to NVSHMEM build dir")
if(USE_NVSHMEM)
set(NVSHMEM_INCLUDE_DIR "${NVSHMEM_HOME}/include")
set(NVSHMEM_LIB_DIR "${NVSHMEM_HOME}/lib")
include_directories(${NVSHMEM_INCLUDE_DIR})
# Linking with nvshmem requires the source binary to be built with -rdc
# which is not viable for libtorch_cuda. So we isolate the linking of
# nvshmem in nvshmem_extension.
add_library(nvshmem_extension SHARED
"${TORCH_SRC_DIR}/csrc/distributed/c10d/nvshmem_extension.cu"
"${TORCH_SRC_DIR}/csrc/distributed/c10d/NVSHMEMSymmetricMemory.cu"
"${TORCH_SRC_DIR}/csrc/distributed/c10d/CUDASymmetricMemoryUtils.cpp"
"${TORCH_SRC_DIR}/csrc/distributed/c10d/cuda/utils.cpp"
)
set_target_properties(nvshmem_extension PROPERTIES CUDA_SEPARABLE_COMPILATION ON)
target_compile_options(nvshmem_extension PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:-rdc=true>)
target_compile_options(nvshmem_extension PRIVATE "-U__CUDA_NO_HALF_OPERATORS__")
target_link_libraries(nvshmem_extension PRIVATE
${NVSHMEM_LIB_DIR}/libnvshmem.a
${NVSHMEM_LIB_DIR}/nvshmem_bootstrap_uid.so
)
target_link_libraries(nvshmem_extension PRIVATE mlx5)
target_link_libraries(torch_cuda PRIVATE nvshmem_extension)
install(TARGETS nvshmem_extension EXPORT Caffe2Targets DESTINATION lib)
install(
FILES ${NVSHMEM_LIB_DIR}/nvshmem_bootstrap_uid.so
DESTINATION lib
)
endif()
if(USE_UCC)
target_link_libraries(torch_cuda PRIVATE __caffe2_ucc)
target_compile_definitions(torch_cuda PRIVATE USE_UCC)

View File

@ -30,8 +30,8 @@ template <std::memory_order Sem>
__device__ __forceinline__ uint32_t
cas(uint32_t* addr, uint32_t compare, uint32_t val) {
#if !defined(USE_ROCM) && defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 600)
cuda::atomic_ref<uint32_t, cuda::thread_scope_system> ref(*addr);
ref.compare_exchange_strong(compare, val, cuda::std::memory_order(Sem));
::cuda::atomic_ref<uint32_t, ::cuda::thread_scope_system> ref(*addr);
ref.compare_exchange_strong(compare, val, ::cuda::std::memory_order(Sem));
return compare;
#else
CUDA_KERNEL_ASSERT(false);

View File

@ -1,275 +1,33 @@
#include <torch/csrc/distributed/c10d/CUDASymmetricMemory.hpp>
#include <torch/csrc/distributed/c10d/CUDASymmetricMemory-inl.h>
#include <torch/csrc/distributed/c10d/CUDASymmetricMemory.hpp>
#include <torch/csrc/distributed/c10d/CUDASymmetricMemoryUtils.hpp>
#include <torch/csrc/distributed/c10d/cuda/utils.hpp>
#include <ATen/ceil_div.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDACachingAllocator.h>
#include <c10/cuda/CUDAGuard.h>
#include <c10/util/env.h>
#include <c10/util/error.h>
#include <sys/socket.h>
#include <unistd.h>
#if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED)
#include <c10/cuda/driver_api.h>
#endif
#include <sys/socket.h>
#include <sys/syscall.h>
#include <sys/un.h>
#include <unistd.h>
#if defined(CUDART_VERSION) && CUDART_VERSION >= 12030
#define CUDART_SUPPORTS_MULTICAST
#endif
namespace {
bool device_has_multicast_support(int device_idx) {
if (c10::utils::check_env("TORCH_SYMM_MEM_DISABLE_MULTICAST") == true) {
return false;
}
return c10d::cuda::deviceSupportsMulticast(device_idx);
}
bool allow_overlapping_devices() {
return c10::utils::check_env("TORCH_SYMM_MEM_ALLOW_OVERLAPPING_DEVICES") ==
true;
}
class IpcChannel {
public:
IpcChannel() : socket_name_(get_socket_name(getpid())) {
TORCH_CHECK(
// Local-only, Uses file paths instead of IP addresses
(socket_ = socket(AF_UNIX, SOCK_DGRAM, 0)) != 0,
"Failed to create socket: ",
c10::utils::str_error(errno));
struct sockaddr_un addr = {.sun_family = AF_UNIX};
std::copy(socket_name_.begin(), socket_name_.end(), addr.sun_path);
TORCH_CHECK(
bind(socket_, (struct sockaddr*)&addr, SUN_LEN(&addr)) == 0,
"Failed to bind socket: ",
c10::utils::str_error(errno));
}
~IpcChannel() {
close(socket_);
unlink(socket_name_.c_str());
}
// Because file descriptors are process-local kernel objects,
// and we cant pass them via normal socket payloads (like write() or send()).
// Unix domain sockets provide a mechanism to pass actual FDs via sendmsg()/recvmsg().
void send_fd(int dst_pid, int fd) {
struct sockaddr_un addr = {.sun_family = AF_UNIX};
auto socket_name = get_socket_name(dst_pid);
std::copy(socket_name.begin(), socket_name.end(), addr.sun_path);
struct iovec io = {.iov_base = (void*)("fd"), .iov_len = 2};
char cbuf[CMSG_SPACE(sizeof(int))];
memset(cbuf, 0, sizeof(cbuf));
struct msghdr msg {
.msg_name = (void*)&addr, .msg_namelen = sizeof(struct sockaddr_un),
.msg_iov = &io, .msg_iovlen = 1, .msg_control = cbuf,
.msg_controllen = sizeof(cbuf)
};
// This points to the first control message header
// With SCM_RIGHTS we let the kernel know that we are passing file descriptors.
auto cmsg = CMSG_FIRSTHDR(&msg);
cmsg->cmsg_len = CMSG_LEN(sizeof(int));
cmsg->cmsg_level = SOL_SOCKET;
cmsg->cmsg_type = SCM_RIGHTS;
if (fd != -1) {
std::copy(
reinterpret_cast<const char*>(&fd),
reinterpret_cast<const char*>(&fd) + sizeof(fd),
reinterpret_cast<char*>(CMSG_DATA(cmsg)));
} else {
msg.msg_controllen = 0;
}
TORCH_CHECK(
sendmsg(socket_, &msg, 0) > 0,
"Failed to send fd: ",
c10::utils::str_error(errno));
}
int recv_fd() {
char buf[2];
struct iovec io = {.iov_base = (void*)buf, .iov_len = sizeof(buf)};
char cbuf[CMSG_SPACE(sizeof(int))];
memset(cbuf, 0, sizeof(cbuf));
struct msghdr msg = {
.msg_iov = &io,
.msg_iovlen = 1,
.msg_control = cbuf,
.msg_controllen = sizeof(cbuf)};
TORCH_CHECK(
recvmsg(socket_, &msg, 0) > 0,
"Failed to receive fd: ",
c10::utils::str_error(errno));
if (msg.msg_controllen == 0) {
return -1;
}
auto cmsg = CMSG_FIRSTHDR(&msg);
TORCH_CHECK(cmsg != NULL);
TORCH_CHECK(cmsg->cmsg_len == CMSG_LEN(sizeof(int)));
TORCH_CHECK(
cmsg->cmsg_level == SOL_SOCKET && cmsg->cmsg_type == SCM_RIGHTS);
return *reinterpret_cast<int*>(CMSG_DATA(cmsg));
}
std::vector<int> all_gather_fds(
int rank,
const std::vector<int>& pids,
int fd) {
size_t world_size = pids.size();
std::vector<int> fds(pids.size());
fds[rank] = fd;
int dst_rank = (rank + 1) % world_size;
for (size_t step = 1; step < world_size; ++step) {
int src_rank = (rank + world_size - step) % world_size;
send_fd(pids[dst_rank], fd);
fd = recv_fd();
fds[src_rank] = fd;
}
return fds;
}
int broadcast_fds(
int rank,
int src_rank,
const std::vector<int>& pids,
int fd) {
size_t world_size = pids.size();
if (rank == src_rank) {
for (int dst_rank = 0; dst_rank < (int)world_size; ++dst_rank) {
if (dst_rank == rank) {
continue;
}
send_fd(pids[dst_rank], fd);
}
return fd;
}
return recv_fd();
}
private:
static std::string get_socket_name(int pid) {
std::string tmp_dir = "/tmp";
for (const char* env_var : {"TMPDIR", "TMP", "TEMP", "TEMPDIR"}) {
if (const auto path = c10::utils::get_env(env_var)) {
tmp_dir = path.value();
break;
}
}
std::ostringstream oss;
oss << tmp_dir << "/symm_mem-" << pid;
return oss.str();
}
std::string socket_name_;
int socket_;
};
constexpr size_t signal_pad_size = 2048;
const std::string store_comm_prefix = "CUDASymmetricMemory";
static size_t store_comm_seq_id = 0;
template <typename T>
std::vector<T> store_all_gather(
const c10::intrusive_ptr<c10d::Store>& store,
int rank,
int world_size,
T val) {
static_assert(std::is_trivially_copyable_v<T>);
std::vector<std::string> peer_keys;
for (int r = 0; r < world_size; ++r) {
std::ostringstream oss;
oss << store_comm_prefix << "/" << store_comm_seq_id << "/" << r;
peer_keys.push_back(oss.str());
}
++store_comm_seq_id;
{
std::vector<uint8_t> payload(
reinterpret_cast<uint8_t*>(&val),
reinterpret_cast<uint8_t*>(&val) + sizeof(T));
store->set(peer_keys[rank], payload);
}
std::vector<T> peer_vals;
for (int r = 0; r < world_size; ++r) {
if (r == rank) {
peer_vals.push_back(val);
continue;
}
store->wait({peer_keys[r]});
auto payload = store->get(peer_keys[r]);
TORCH_CHECK(payload.size() == sizeof(T));
T peer_val{};
std::memcpy(&peer_val, payload.data(), sizeof(T));
peer_vals.push_back(peer_val);
}
return peer_vals;
}
void store_barrier(
const c10::intrusive_ptr<c10d::Store>& store,
int rank,
int world_size) {
store_all_gather(store, rank, world_size, 0);
}
// This function returns a pointer of virtual address space that is
// mapped to the same physical memory as the given handler.
void map_block(
void** ptr,
c10d::symmetric_memory::HandleType handle,
size_t size,
int device_idx) {
#if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED)
auto driver_api = c10::cuda::DriverAPI::get();
auto dev_ptr = reinterpret_cast<CUdeviceptr*>(ptr);
// Allocate virtual address space.
C10_CUDA_DRIVER_CHECK(
driver_api->cuMemAddressReserve_(dev_ptr, size, 0ULL, 0, 0ULL));
// Map the physical memory to the virtual address.
C10_CUDA_DRIVER_CHECK(driver_api->cuMemMap_(*dev_ptr, size, 0, handle, 0ULL));
// Set access permissions.
CUmemAccessDesc desc;
desc.location.type = CU_MEM_LOCATION_TYPE_DEVICE;
// NOLINTNEXTLINE(bugprone-signed-char-misuse)
desc.location.id = static_cast<int>(device_idx);
desc.flags = CU_MEM_ACCESS_FLAGS_PROT_READWRITE;
C10_CUDA_DRIVER_CHECK(driver_api->cuMemSetAccess_(*dev_ptr, size, &desc, 1));
#else
TORCH_CHECK(
false, "CUDASymmetricMemory requires PYTORCH_C10_DRIVER_API_SUPPORTED");
#endif
}
} // namespace
namespace c10d {
namespace symmetric_memory {
/* Start of CUDASymmetricMemory implementation */
// A set of exchange methods with prefix "CUDASymmetricMemory"
static StoreExchange storeExchange = StoreExchange("CUDASymmetricMemory");
AllocationRef::AllocationRef(
void* ptr,
HandleType handle,
@ -794,7 +552,7 @@ static void init_multicast_for_block(
mc_handle, 0, block->alloc_ref->handle, 0, block->block_size, 0));
map_block(&mc_addr, mc_handle, block->block_size, block->device_idx);
store_barrier(store, rank, world_size);
storeExchange.barrier(store, rank, world_size);
#endif
}
@ -855,7 +613,7 @@ c10::intrusive_ptr<SymmetricMemory> CUDASymmetricMemoryAllocator::rendezvous(
.buffer_size = block->buffer_size,
.signal_pad_offset = block->signal_pad_offset,
.has_multicast_support = device_has_multicast_support(block->device_idx)};
auto reqs = store_all_gather(store, rank, world_size, local_req);
auto reqs = storeExchange.all_gather(store, rank, world_size, local_req);
validate_rendezvous_requests(reqs, world_size);
std::vector<int> pids(world_size);
@ -885,7 +643,7 @@ c10::intrusive_ptr<SymmetricMemory> CUDASymmetricMemoryAllocator::rendezvous(
signal_pads[r] = (void*)((uintptr_t)buffers[r] + block->signal_pad_offset);
close(imported_fds[r]);
}
store_barrier(store, rank, world_size);
storeExchange.barrier(store, rank, world_size);
close(block_fd);
HandleType mc_handle{};
@ -939,9 +697,13 @@ c10::intrusive_ptr<Block> CUDASymmetricMemoryAllocator::find_block(void* ptr) {
struct RegisterCUDASymmetricMemoryAllocator {
RegisterCUDASymmetricMemoryAllocator() {
register_allocator(
c10::DeviceType::CUDA,
c10::make_intrusive<CUDASymmetricMemoryAllocator>());
// Query backend used for CUDA tensor
// "CUDA" backend stands for this implementation
if (getSymmMemBackendCUDA() == "CUDA") {
register_allocator(
c10::DeviceType::CUDA,
c10::make_intrusive<CUDASymmetricMemoryAllocator>());
}
}
};

View File

@ -1,17 +1,12 @@
#pragma once
#include <ATen/ATen.h>
#include <torch/csrc/distributed/c10d/CUDASymmetricMemoryTypes.hpp>
#include <torch/csrc/distributed/c10d/Store.hpp>
#include <torch/csrc/distributed/c10d/SymmetricMemory.hpp>
namespace c10d::symmetric_memory {
#if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED)
using HandleType = CUmemGenericAllocationHandle;
#else
using HandleType = void*;
#endif
// Resource wrapper that owns a (vaddr, allocation handle) pair. Upon
// destruction, it unmaps the vaddr and releases the allocation handle.
struct AllocationRef : public c10::intrusive_ptr_target {

View File

@ -0,0 +1,13 @@
#pragma once
namespace c10d::symmetric_memory {
constexpr size_t signal_pad_size = 2048;
#if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED)
using HandleType = CUmemGenericAllocationHandle;
#else
using HandleType = void*;
#endif
} // namespace c10d::symmetric_memory

View File

@ -0,0 +1,216 @@
#include <sys/socket.h>
#include <sys/syscall.h>
#include <sys/un.h>
#include <unistd.h>
#include <c10/util/error.h>
#if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED)
#include <c10/cuda/driver_api.h>
#endif
#include <torch/csrc/distributed/c10d/CUDASymmetricMemoryUtils.hpp>
#include <torch/csrc/distributed/c10d/Store.hpp>
#include <torch/csrc/distributed/c10d/cuda/utils.hpp>
namespace c10d::symmetric_memory {
bool device_has_multicast_support(int device_idx) {
if (c10::utils::check_env("TORCH_SYMM_MEM_DISABLE_MULTICAST") == true) {
return false;
}
return c10d::cuda::deviceSupportsMulticast(device_idx);
}
bool allow_overlapping_devices() {
return c10::utils::check_env("TORCH_SYMM_MEM_ALLOW_OVERLAPPING_DEVICES") ==
true;
}
// Query environment variable to get the backend used for CUDA Symmetric Memory.
std::string getSymmMemBackendCUDA() {
static auto val = c10::utils::get_env("TORCH_SYMMMEM");
if (!val.has_value()) {
// In-house implementation: `CUDASymmetricMemory`
return "CUDA";
} else {
// Other backends:
// - "NVSHMEM": `NVSHMEMSymmetricMemory`
return val.value();
}
}
IpcChannel::IpcChannel()
: socket_name_(get_socket_name(getpid())),
socket_(socket(AF_UNIX, SOCK_DGRAM, 0)) {
// On success, a file descriptor for the new socket is returned.
// On error, -1 is returned, and errno is set to indicate the error.
TORCH_CHECK(
socket_ != -1, "Failed to create socket: ", c10::utils::str_error(errno));
struct sockaddr_un addr = {.sun_family = AF_UNIX};
std::copy(socket_name_.begin(), socket_name_.end(), addr.sun_path);
TORCH_CHECK(
bind(socket_, (struct sockaddr*)&addr, SUN_LEN(&addr)) == 0,
"Failed to bind socket: ",
c10::utils::str_error(errno));
}
IpcChannel::~IpcChannel() {
close(socket_);
unlink(socket_name_.c_str());
}
void IpcChannel::send_fd(int dst_pid, int fd) {
// Because file descriptors are process-local kernel objects, and we cant
// pass them via normal socket payloads (like write() or send()). Unix domain
// sockets provide a mechanism to pass actual FDs via sendmsg()/recvmsg().
struct sockaddr_un addr = {.sun_family = AF_UNIX};
auto socket_name = get_socket_name(dst_pid);
std::copy(socket_name.begin(), socket_name.end(), addr.sun_path);
struct iovec io = {.iov_base = (void*)("fd"), .iov_len = 2};
// NOLINTNEXTLINE(*array*)
char cbuf[CMSG_SPACE(sizeof(int))];
memset(cbuf, 0, sizeof(cbuf));
struct msghdr msg {
.msg_name = (void*)&addr, .msg_namelen = sizeof(struct sockaddr_un),
.msg_iov = &io, .msg_iovlen = 1, .msg_control = cbuf,
.msg_controllen = sizeof(cbuf)
};
// This points to the first control message header
// With SCM_RIGHTS we let the kernel know that we are passing file
// descriptors.
auto cmsg = CMSG_FIRSTHDR(&msg);
cmsg->cmsg_len = CMSG_LEN(sizeof(int));
cmsg->cmsg_level = SOL_SOCKET;
cmsg->cmsg_type = SCM_RIGHTS;
if (fd != -1) {
std::copy(
reinterpret_cast<const char*>(&fd),
reinterpret_cast<const char*>(&fd) + sizeof(fd),
reinterpret_cast<char*>(CMSG_DATA(cmsg)));
} else {
msg.msg_controllen = 0;
}
TORCH_CHECK(
sendmsg(socket_, &msg, 0) > 0,
"Failed to send fd: ",
c10::utils::str_error(errno));
}
int IpcChannel::recv_fd() {
// NOLINTNEXTLINE(*array*)
char buf[2];
struct iovec io = {.iov_base = (void*)buf, .iov_len = sizeof(buf)};
// NOLINTNEXTLINE(*array*)
char cbuf[CMSG_SPACE(sizeof(int))];
memset(cbuf, 0, sizeof(cbuf));
struct msghdr msg = {
.msg_iov = &io,
.msg_iovlen = 1,
.msg_control = cbuf,
.msg_controllen = sizeof(cbuf)};
TORCH_CHECK(
recvmsg(socket_, &msg, 0) > 0,
"Failed to receive fd: ",
c10::utils::str_error(errno));
if (msg.msg_controllen == 0) {
return -1;
}
auto cmsg = CMSG_FIRSTHDR(&msg);
TORCH_CHECK(cmsg != nullptr);
TORCH_CHECK(cmsg->cmsg_len == CMSG_LEN(sizeof(int)));
TORCH_CHECK(cmsg->cmsg_level == SOL_SOCKET && cmsg->cmsg_type == SCM_RIGHTS);
return *reinterpret_cast<int*>(CMSG_DATA(cmsg));
}
std::vector<int> IpcChannel::all_gather_fds(
int rank,
const std::vector<int>& pids,
int fd) {
int world_size = (int)pids.size();
std::vector<int> fds(pids.size());
fds[rank] = fd;
int dst_rank = (rank + 1) % world_size;
for (int step = 1; step < world_size; ++step) {
int src_rank = (rank + world_size - step) % world_size;
send_fd(pids[dst_rank], fd);
fd = recv_fd();
fds[src_rank] = fd;
}
return fds;
}
int IpcChannel::broadcast_fds(
int rank,
int src_rank,
const std::vector<int>& pids,
int fd) {
int world_size = (int)pids.size();
if (rank == src_rank) {
for (int dst_rank = 0; dst_rank < (int)world_size; ++dst_rank) {
if (dst_rank == rank) {
continue;
}
send_fd(pids[dst_rank], fd);
}
return fd;
}
return recv_fd();
}
std::string IpcChannel::get_socket_name(int pid) {
const char* tmp_dir = "/tmp";
for (const char* env_var : {"TMPDIR", "TMP", "TEMP", "TEMPDIR"}) {
if (const char* path = getenv(env_var)) {
tmp_dir = path;
break;
}
}
std::ostringstream oss;
oss << tmp_dir << "/symm_mem-" << pid;
return oss.str();
}
void map_block(
void** ptr,
c10d::symmetric_memory::HandleType handle,
size_t size,
int device_idx) {
#if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED)
auto driver_api = c10::cuda::DriverAPI::get();
auto dev_ptr = reinterpret_cast<CUdeviceptr*>(ptr);
// Allocate virtual address space
C10_CUDA_DRIVER_CHECK(
driver_api->cuMemAddressReserve_(dev_ptr, size, 0ULL, 0, 0ULL));
// Map the physical memory to the virtual address
C10_CUDA_DRIVER_CHECK(driver_api->cuMemMap_(*dev_ptr, size, 0, handle, 0ULL));
// Set access permissions
CUmemAccessDesc desc;
desc.location.type = CU_MEM_LOCATION_TYPE_DEVICE;
// NOLINTNEXTLINE(bugprone-signed-char-misuse)
desc.location.id = static_cast<int>(device_idx);
desc.flags = CU_MEM_ACCESS_FLAGS_PROT_READWRITE;
C10_CUDA_DRIVER_CHECK(driver_api->cuMemSetAccess_(*dev_ptr, size, &desc, 1));
#else
TORCH_CHECK(
false, "CUDASymmetricMemory requires PYTORCH_C10_DRIVER_API_SUPPORTED");
#endif
}
} // namespace c10d::symmetric_memory

View File

@ -0,0 +1,115 @@
#pragma once
#include <torch/csrc/distributed/c10d/CUDASymmetricMemoryTypes.hpp>
#include <torch/csrc/distributed/c10d/Store.hpp>
#include <torch/csrc/distributed/c10d/SymmetricMemory.hpp>
namespace c10d {
namespace symmetric_memory {
bool device_has_multicast_support(int device_idx);
bool allow_overlapping_devices();
// Query environment variable to get the backend used for CUDA Symmetric Memory.
std::string getSymmMemBackendCUDA();
class IpcChannel {
public:
IpcChannel();
~IpcChannel();
void send_fd(int dst_pid, int fd);
int recv_fd();
std::vector<int> all_gather_fds(
int rank,
const std::vector<int>& pids,
int fd);
int broadcast_fds(
int rank,
int src_rank,
const std::vector<int>& pids,
int fd);
private:
static std::string get_socket_name(int pid);
std::string socket_name_;
int socket_;
};
// A set of store-based exchange methods with a preset prefix typically type of
// the SymmetricMemory. Most used as static instances at respective
// SymmetricMemory implementation files.
class StoreExchange {
public:
StoreExchange(const std::string& store_prefix)
: store_prefix_(store_prefix) {}
// Put template function in header file so that compiler can easily access it.
template <typename T>
std::vector<T> all_gather(
const c10::intrusive_ptr<c10d::Store>& store,
int rank,
int world_size,
T val) {
static_assert(std::is_trivially_copyable_v<T>);
std::vector<std::string> peer_keys;
peer_keys.reserve(world_size);
for (int r = 0; r < world_size; ++r) {
std::ostringstream oss;
oss << store_prefix_ << "/" << seq_id_ << "/" << r;
peer_keys.push_back(oss.str());
}
++seq_id_;
{
std::vector<uint8_t> payload(
reinterpret_cast<uint8_t*>(&val),
reinterpret_cast<uint8_t*>(&val) + sizeof(T));
store->set(peer_keys[rank], payload);
}
std::vector<T> peer_vals;
peer_vals.reserve(world_size);
for (int r = 0; r < world_size; ++r) {
if (r == rank) {
peer_vals.push_back(val);
continue;
}
store->wait({peer_keys[r]});
auto payload = store->get(peer_keys[r]);
TORCH_CHECK(payload.size() == sizeof(T));
T peer_val{};
std::memcpy(&peer_val, payload.data(), sizeof(T));
peer_vals.push_back(peer_val);
}
return peer_vals;
}
void barrier(
const c10::intrusive_ptr<c10d::Store>& store,
int rank,
int world_size) {
// TODO: implement an efficient one?
all_gather(store, rank, world_size, 0);
}
private:
const std::string store_prefix_;
size_t seq_id_ = 0;
};
// Teturns a pointer of virtual address that is mapped to the physical memory
// held by the handle.
void map_block(
void** ptr,
c10d::symmetric_memory::HandleType handle,
size_t size,
int device_idx);
} // namespace symmetric_memory
} // namespace c10d

View File

@ -0,0 +1,329 @@
#include <torch/csrc/distributed/c10d/SymmetricMemory.hpp>
#include <torch/csrc/distributed/c10d/CUDASymmetricMemoryUtils.hpp>
#include <torch/csrc/distributed/c10d/CUDASymmetricMemory-inl.h>
#include <torch/csrc/distributed/c10d/cuda/utils.hpp>
#include <torch/csrc/distributed/c10d/nvshmem_extension.cuh>
#include <ATen/ceil_div.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDACachingAllocator.h>
#include <c10/cuda/CUDAGuard.h>
#include <c10/util/error.h>
namespace c10d {
namespace symmetric_memory {
/* Start of CUDASymmetricMemory implementation */
static StoreExchange storeExchange = StoreExchange("NVSHMEMSymmetricMemory");
struct NVSHMEMAllocation {
void* ptr;
size_t buffer_size;
int device_idx;
NVSHMEMAllocation(void* ptr, size_t buffer_size, int device_idx)
: ptr(ptr), buffer_size(buffer_size), device_idx(device_idx) {}
};
class NVSHMEMSymmetricMemory : public SymmetricMemory {
public:
NVSHMEMSymmetricMemory(
std::shared_ptr<NVSHMEMAllocation> allocation,
const std::string& group_name)
: allocation_(allocation),
buffer_size_(allocation->buffer_size),
device_idx_(allocation->device_idx),
group_name_(group_name) {
c10::cuda::CUDAGuard guard(device_idx_);
auto global_rank = get_group_info("0").rank;
auto group_info = get_group_info(group_name_);
auto store = group_info.store;
rank_ = group_info.rank;
world_size_ = group_info.world_size;
rank_to_global_rank_ =
storeExchange.all_gather(store, rank_, world_size_, global_rank);
LOG(INFO) << "[rank " << rank_ << "]"
<< "rank_to_global_rank: " << rank_to_global_rank_;
for (int r = 0; r < world_size_; ++r) {
buffers_.push_back(nvshmem_extension::nvshmem_ptr(
allocation->ptr, rank_to_global_rank_[r]));
}
// TODO: use the same allocation for signal pad
void* signal_pad_ptr = nvshmem_extension::nvshmem_malloc(signal_pad_size);
AT_CUDA_CHECK(cudaMemset(signal_pad_ptr, 0, signal_pad_size));
for (int r = 0; r < world_size_; ++r) {
signal_pads_.push_back(nvshmem_extension::nvshmem_ptr(
signal_pad_ptr, rank_to_global_rank_[r]));
}
const size_t arr_size = sizeof(void*) * world_size_;
buffers_dev_ = reinterpret_cast<void**>(
c10::cuda::CUDACachingAllocator::raw_alloc(arr_size));
signal_pads_dev_ = reinterpret_cast<void**>(
c10::cuda::CUDACachingAllocator::raw_alloc(arr_size));
AT_CUDA_CHECK(cudaMemcpy(
buffers_dev_, buffers_.data(), arr_size, cudaMemcpyHostToDevice));
AT_CUDA_CHECK(cudaMemcpy(
signal_pads_dev_,
signal_pads_.data(),
arr_size,
cudaMemcpyHostToDevice));
rank_to_global_rank_dev_ = reinterpret_cast<int*>(
c10::cuda::CUDACachingAllocator::raw_alloc(sizeof(int) * world_size_));
AT_CUDA_CHECK(cudaMemcpy(
rank_to_global_rank_dev_,
rank_to_global_rank_.data(),
sizeof(int) * world_size_,
cudaMemcpyHostToDevice));
}
~NVSHMEMSymmetricMemory() override{
// TODO
};
std::vector<void*> get_buffer_ptrs() override {
return buffers_;
}
std::vector<void*> get_signal_pad_ptrs() override {
return signal_pads_;
}
void** get_buffer_ptrs_dev() override {
return buffers_dev_;
}
void** get_signal_pad_ptrs_dev() override {
return signal_pads_dev_;
}
size_t get_buffer_size() override {
return buffer_size_;
}
size_t get_signal_pad_size() override {
return signal_pad_size;
};
bool has_multicast_support() override {
// TODO
return false;
}
void* get_multicast_ptr() override {
// TODO
return nullptr;
}
at::Tensor get_buffer(
int rank,
c10::IntArrayRef sizes,
c10::ScalarType dtype,
int64_t storage_offset) {
// TODO: deduplicate
const size_t numel = std::accumulate(
sizes.begin(),
sizes.end(),
static_cast<size_t>(1),
std::multiplies<size_t>());
const auto element_size = c10::elementSize(dtype);
const auto req_size = (numel + storage_offset) * element_size;
TORCH_CHECK(
req_size <= buffer_size_,
"NVSHMEMSymmetricMemory::get_buffer: the requested size (",
req_size,
" bytes) exceeds the allocated size (",
buffer_size_,
" bytes)");
auto data_ptr = reinterpret_cast<uint8_t*>(buffers_[rank]) +
storage_offset * element_size;
auto device = c10::Device(c10::DeviceType::CUDA, device_idx_);
auto options = at::TensorOptions().dtype(dtype).device(device);
return at::for_blob(data_ptr, sizes)
.options(options)
.target_device(device)
.make_tensor();
}
at::Tensor get_signal_pad(
int rank,
c10::IntArrayRef sizes,
std::optional<c10::ScalarType> dtype,
int64_t storage_offset) override {
// TODO: deduplicate
// If the dtype is unspecified, default it to UInt32, as it
// is the most common type for signaling purposes.
if (!dtype.has_value()) {
dtype = c10::ScalarType::UInt32;
}
// If the shape is unspecified, treat the signal pad as a 1d tensor.
const auto element_size = c10::elementSize(*dtype);
std::vector<int64_t> shape;
if (sizes.size() != 0) {
shape = sizes.vec();
} else {
shape.push_back(signal_pad_size / element_size);
}
const size_t numel = std::accumulate(
shape.begin(),
shape.end(),
static_cast<size_t>(1),
std::multiplies<size_t>());
const auto req_size = (numel + storage_offset) * element_size;
TORCH_CHECK(
req_size <= signal_pad_size,
"NVSHMEMSymmetricMemory::get_signal_pad: the requested size (",
req_size,
" bytes) exceeds the allocated size (",
signal_pad_size,
" bytes)");
auto data_ptr = reinterpret_cast<uint8_t*>(signal_pads_[rank]) +
storage_offset * element_size;
auto device = c10::Device(c10::DeviceType::CUDA, device_idx_);
auto options = at::TensorOptions().dtype(*dtype).device(device);
return at::for_blob(data_ptr, shape)
.options(options)
.target_device(device)
.make_tensor();
}
void barrier(int channel, size_t timeout_ms) override {
// TODO
}
void put_signal(int dst_rank, int channel, size_t timeout_ms) override {
// TODO
}
void wait_signal(int src_rank, int channel, size_t timeout_ms) override {
// TODO
}
int get_rank() override {
return rank_;
}
int get_world_size() override {
return world_size_;
}
virtual std::vector<int> get_rank_to_global_rank() override {
return rank_to_global_rank_;
};
int* get_rank_to_global_rank_dev() override {
return rank_to_global_rank_dev_;
};
private:
std::shared_ptr<NVSHMEMAllocation> allocation_;
size_t buffer_size_;
std::vector<void*> buffers_;
std::vector<void*> signal_pads_;
int device_idx_;
int rank_;
int world_size_;
void** buffers_dev_;
void** signal_pads_dev_;
std::string group_name_;
std::vector<int> rank_to_global_rank_;
int* rank_to_global_rank_dev_;
};
class NVSHMEMSymmetricMemoryAllocator : public SymmetricMemoryAllocator {
public:
void* alloc(
size_t size,
int device_idx,
const std::optional<std::string>& group_name) override {
TORCH_CHECK(
group_name == std::nullopt,
"NVSHMEMSymmetricMemoryAllocator::alloc "
"must not be called with a group_name");
auto group_info = get_group_info("0");
auto store = group_info.store;
int rank = group_info.rank;
int world_size = group_info.world_size;
nvshmem_extension::initialize_nvshmem_with_store(store, rank, world_size);
auto ptr = nvshmem_extension::nvshmem_malloc(size);
auto allocation =
std::make_shared<NVSHMEMAllocation>(ptr, size, device_idx);
// TODO: thread safety
allocations_.emplace(ptr, allocation);
return ptr;
}
void free(void* ptr) override {
// TODO: thread safety
ptr_to_symm_mem_.erase(ptr);
};
size_t get_alloc_size(void* ptr) override {
auto it = ptr_to_symm_mem_.find(ptr);
if (it == ptr_to_symm_mem_.end()) {
TORCH_CHECK(
false, ptr, " is not allocated with NVSHMEMSymmetricMemoryAllocator");
}
return it->second->get_buffer_size();
};
c10::intrusive_ptr<SymmetricMemory> rendezvous(
void* ptr,
const std::optional<std::string>& group_name) override {
TORCH_CHECK(group_name.has_value());
{
auto it = symm_mems_.find(std::make_tuple(ptr, *group_name));
if (it != symm_mems_.end()) {
return it->second;
}
}
auto it = allocations_.find(ptr);
TORCH_CHECK(it != allocations_.end());
auto symm_mem =
c10::make_intrusive<NVSHMEMSymmetricMemory>(it->second, *group_name);
symm_mems_[std::make_tuple(ptr, *group_name)] = symm_mem;
return symm_mem;
};
bool has_multicast_support(int device_idx) override {
// TODO
return false;
};
private:
std::unordered_map<void*, c10::intrusive_ptr<SymmetricMemory>>
ptr_to_symm_mem_;
std::unordered_map<void*, std::shared_ptr<NVSHMEMAllocation>> allocations_;
std::map<std::tuple<void*, std::string>, c10::intrusive_ptr<SymmetricMemory>>
symm_mems_;
};
struct RegisterNVSHMEMSymmetricMemoryAllocator {
RegisterNVSHMEMSymmetricMemoryAllocator() {
// Query backend used for CUDA tensor
if (getSymmMemBackendCUDA() == "NVSHMEM") {
register_allocator(
c10::DeviceType::CUDA,
c10::make_intrusive<NVSHMEMSymmetricMemoryAllocator>());
}
}
};
static RegisterNVSHMEMSymmetricMemoryAllocator register_allocator_;
} // namespace symmetric_memory
} // namespace c10d

View File

@ -6,6 +6,7 @@ using namespace c10d::symmetric_memory;
static bool is_finalizing_ = false;
// NOLINTNEXTLINE(cppcoreguidelines-special-member-functions)
class AllocatorMap {
public:
AllocatorMap(const AllocatorMap&) = delete;
@ -212,7 +213,9 @@ namespace {
at::Tensor one_shot_all_reduce_meta(
const at::Tensor& input,
// NOLINTNEXTLINE(performance-unnecessary-value-param)
std::string reduce_op,
// NOLINTNEXTLINE(performance-unnecessary-value-param)
std::string group_name) {
return at::empty_like(input);
}
@ -220,7 +223,9 @@ at::Tensor one_shot_all_reduce_meta(
at::Tensor one_shot_all_reduce_copy_meta(
const at::Tensor& symm_buffer,
const at::Tensor& local_input,
// NOLINTNEXTLINE(performance-unnecessary-value-param)
std::string reduce_op,
// NOLINTNEXTLINE(performance-unnecessary-value-param)
std::string group_name) {
return at::empty_like(local_input);
}
@ -269,6 +274,8 @@ TORCH_LIBRARY_FRAGMENT(symm_mem, m) {
"stream_write_value32_(Tensor(a!) input, int offset, int val) -> Tensor(a!)");
m.def(
"memset32_(Tensor(a!) input, int offset, int val, int count) -> Tensor(a!)");
m.def("nvshmem_broadcast(Tensor(a!) input, str group_name) -> Tensor(a!)");
}
TORCH_LIBRARY_IMPL(symm_mem, Meta, m) {

View File

@ -71,6 +71,14 @@ class TORCH_API SymmetricMemory : public c10::intrusive_ptr_target {
virtual int get_rank() = 0;
virtual int get_world_size() = 0;
virtual std::vector<int> get_rank_to_global_rank() {
TORCH_CHECK(false, "NYI");
};
virtual int* get_rank_to_global_rank_dev() {
TORCH_CHECK(false, "NYI");
};
};
class SymmetricMemoryAllocator : public c10::intrusive_ptr_target {

View File

@ -0,0 +1,126 @@
#include <torch/csrc/distributed/c10d/nvshmem_extension.cuh>
#include <c10/cuda/CUDAGuard.h>
#include <torch/csrc/distributed/c10d/CUDASymmetricMemory-inl.h>
#include <torch/csrc/distributed/c10d/CUDASymmetricMemoryUtils.hpp>
#include <torch/csrc/distributed/c10d/SymmetricMemory.hpp>
#include <cuda_awbarrier_primitives.h>
#include <nvshmem.h>
namespace c10d::nvshmem_extension {
using c10d::symmetric_memory::StoreExchange;
static StoreExchange storeExchange = StoreExchange("nvshmem_ext");
// Bootstrap based on user's setting for NCCL
// Long term, this may be a bit unclean; short term, it improves UX
void maybe_initialize_env_vars() {
auto nccl_socket_if_name = c10::utils::get_env("NCCL_SOCKET_IFNAME");
auto nccl_hca_list = c10::utils::get_env("NCCL_IB_HCA");
auto nccl_ib_gid_index = c10::utils::get_env("NCCL_IB_GID_INDEX");
auto nvshmem_socket_if_name =
c10::utils::get_env("NVSHMEM_BOOTSTRAP_UID_SOCK_IFNAME");
auto nvshmem_hca_list = c10::utils::get_env("NCCL_IB_HCA");
auto nvshmem_ib_gid_index = c10::utils::get_env("NVSHMEM_IB_GID_INDEX");
if (!nvshmem_socket_if_name.has_value() && nccl_socket_if_name.has_value()) {
c10::utils::set_env(
"NVSHMEM_BOOTSTRAP_UID_SOCK_IFNAME", nccl_socket_if_name->c_str());
}
if (!nvshmem_hca_list.has_value() && nccl_hca_list.has_value()) {
c10::utils::set_env("NVSHMEM_ENABLE_NIC_PE_MAPPING", "1");
c10::utils::set_env("NVSHMEM_HCA_LIST", nccl_hca_list->c_str());
}
if (!nvshmem_ib_gid_index.has_value() && nccl_ib_gid_index.has_value()) {
c10::utils::set_env("NVSHMEM_IB_GID_INDEX", nccl_ib_gid_index->c_str());
}
}
void initialize_nvshmem_with_store(
c10::intrusive_ptr<c10d::Store> store,
int rank,
int world_size) {
static bool is_initialized = false;
if (is_initialized) {
return;
}
maybe_initialize_env_vars();
nvshmemx_uniqueid_t unique_id;
TORCH_CHECK(
nvshmemx_get_uniqueid(&unique_id) == 0, "nvshmemx_get_uniqueid failed");
// Using an existing store_all_gather due to laziness.
// TODO(yifu): should use broadcast
auto unique_ids = storeExchange.all_gather(store, rank, world_size, unique_id);
nvshmemx_init_attr_t attr;
nvshmemx_set_attr_uniqueid_args(rank, world_size, &unique_ids[0], &attr);
TORCH_CHECK(
nvshmemx_init_attr(NVSHMEMX_INIT_WITH_UNIQUEID, &attr) == 0,
"nvshmemx_init_attr failed");
is_initialized = true;
}
void* nvshmem_malloc(size_t size) {
return ::nvshmem_malloc(size);
}
void* nvshmem_ptr(const void* dest, int pe) {
return ::nvshmem_ptr(dest, pe);
}
std::unordered_map<std::string, nvshmem_team_t> group_name_to_team_;
nvshmem_team_t group_to_team(
const std::string& group_name,
const std::vector<int>& global_ranks) {
auto it = group_name_to_team_.find(group_name);
if (it != group_name_to_team_.end()) {
return it->second;
}
TORCH_CHECK(global_ranks.size() > 1);
int stride = global_ranks[1] - global_ranks[0];
for (size_t r = 1; r < global_ranks.size(); ++r) {
TORCH_CHECK(global_ranks[r] - global_ranks[r - 1] == stride);
}
nvshmem_team_t team;
TORCH_CHECK(
nvshmem_team_split_strided(
NVSHMEM_TEAM_WORLD,
global_ranks[0],
stride,
global_ranks.size(),
nullptr,
0,
&team) == 0);
group_name_to_team_[group_name] = team;
TORCH_CHECK(team != NVSHMEM_TEAM_INVALID);
return team;
}
at::Tensor nvshmem_broadcast(at::Tensor& input, const std::string& group_name) {
auto input_hdl = c10d::symmetric_memory::rendezvous(input, group_name);
int rank = input_hdl->get_rank();
int world_size = input_hdl->get_world_size();
auto team = group_to_team(group_name, input_hdl->get_rank_to_global_rank());
void* buffer_ptr = input_hdl->get_buffer_ptrs()[rank];
auto stream = at::cuda::getCurrentCUDAStream();
nvshmemx_broadcastmem_on_stream(team, buffer_ptr, buffer_ptr, input_hdl->get_buffer_size(), 0, stream);
return input;
}
} // namespace c10d::nvshmem_extension
TORCH_LIBRARY_IMPL(symm_mem, CUDA, m) {
m.impl("nvshmem_broadcast", c10d::nvshmem_extension::nvshmem_broadcast);
}

View File

@ -0,0 +1,20 @@
#pragma once
#include <ATen/ATen.h>
#include <torch/csrc/distributed/c10d/Store.hpp>
namespace c10d::nvshmem_extension {
void initialize_nvshmem_with_store(
c10::intrusive_ptr<c10d::Store> store,
int rank,
int world_size);
void* nvshmem_malloc(size_t size);
void* nvshmem_ptr(const void* dest, int pe);
at::Tensor nvshmem_broadcast(at::Tensor& input, const std::string& group_name);
} // namespace c10d::nvshmem_extension

View File

@ -37,26 +37,6 @@ def enable_symm_mem_for_group(group_name: str) -> None:
f"symmetric_memory-{global_ranks_str}",
c10d._get_process_group_store(group),
)
# Use one store-based broadcast to bootstrap a file store from the process
# and simultaneously verify that all ranks are on the same host.
hostname = socket.gethostname()
if group.rank() == 0:
uid = str(uuid.uuid4())
msg = f"{hostname}/{uid}"
store.set("init", msg)
else:
msg = store.get("init").decode("utf-8")
tokens = msg.split("/")
assert len(tokens) == 2, tokens
rank_0_hostname, uid = tokens
if hostname != rank_0_hostname:
raise RuntimeError(
"init_symmetric_memory_for_process_group() failed for "
f'group "{group_name}". Rank 0 and rank {group.rank()} '
f"are on different hosts ({rank_0_hostname} and {hostname})"
)
store = torch._C._distributed_c10d.FileStore(f"/tmp/{uid}", group.size())
# TODO: check device connectiivity
_group_name_to_store[group_name] = store
_SymmetricMemory.set_group_info(
group_name,