mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[SymmMem] Experimental NVSHMEM integration (#151261)
Adding NVSHMEM as a backend for `SymmetricMemory`, implementation of which is in `NVSHMEMSymmetricMemory.cu`. Moving some helper functions in `CUDASymmetricMemory.cu` to `CUDASymmetricMemoryUtils.cpp`, so that they can be shared by `NVSHMEMSymmetricMemory`. These functions are mostly side-band exchange helpers (`store_all_gather`, `IpcChannel`, etc). Adding `TORCH_SYMMEM` to control which implementation to use for CUDA tensors, currently support: `CUDA` (in-house impl), `NVSHMEM`. The NVSHMEM feature is gated by build-time flag: `USE_NVSHMEM=1`. And `NVSHMEM_HOME` setting is required (TODO). Ported most code from #146593. Pull Request resolved: https://github.com/pytorch/pytorch/pull/151261 Approved by: https://github.com/fegin, https://github.com/fduwjj
This commit is contained in:
@ -740,6 +740,7 @@ cc_library(
|
||||
"torch/csrc/distributed/c10d/intra_node_comm.cu",
|
||||
"torch/csrc/distributed/c10d/CUDASymmetricMemory.cu",
|
||||
"torch/csrc/distributed/c10d/CUDASymmetricMemoryOps.cu",
|
||||
"torch/csrc/distributed/c10d/CUDASymmetricMemoryUtils.cpp",
|
||||
"torch/csrc/distributed/c10d/cuda/AsyncMM.cu",
|
||||
"torch/csrc/distributed/c10d/NanCheck.cu",
|
||||
"torch/csrc/distributed/c10d/quantization/quantization_gpu.cu",
|
||||
|
@ -695,6 +695,7 @@ libtorch_cuda_distributed_extra_sources = [
|
||||
"torch/csrc/distributed/c10d/intra_node_comm.cu",
|
||||
"torch/csrc/distributed/c10d/CUDASymmetricMemory.cu",
|
||||
"torch/csrc/distributed/c10d/CUDASymmetricMemoryOps.cu",
|
||||
"torch/csrc/distributed/c10d/CUDASymmetricMemoryUtils.cpp",
|
||||
"torch/csrc/distributed/c10d/cuda/AsyncMM.cu",
|
||||
"torch/csrc/distributed/c10d/cuda/utils.cpp",
|
||||
"torch/csrc/distributed/c10d/NanCheck.cu",
|
||||
|
@ -572,6 +572,7 @@ if(USE_CUDA)
|
||||
${TORCH_SRC_DIR}/csrc/distributed/c10d/CudaDMAConnectivity.cpp
|
||||
${TORCH_SRC_DIR}/csrc/distributed/c10d/CUDASymmetricMemory.cu
|
||||
${TORCH_SRC_DIR}/csrc/distributed/c10d/CUDASymmetricMemoryOps.cu
|
||||
${TORCH_SRC_DIR}/csrc/distributed/c10d/CUDASymmetricMemoryUtils.cpp
|
||||
${TORCH_SRC_DIR}/csrc/distributed/c10d/ProcessGroupNCCL.cpp
|
||||
PROPERTIES COMPILE_FLAGS "-DPYTORCH_C10_DRIVER_API_SUPPORTED=1"
|
||||
)
|
||||
@ -976,6 +977,41 @@ elseif(USE_CUDA)
|
||||
target_link_libraries(torch_cuda PRIVATE __caffe2_nccl)
|
||||
target_compile_definitions(torch_cuda PRIVATE USE_NCCL)
|
||||
endif()
|
||||
|
||||
# Use env var for these for now for prototyping purposes
|
||||
set(USE_NVSHMEM $ENV{USE_NVSHMEM} CACHE BOOL "Enable NVSHMEM support")
|
||||
set(NVSHMEM_HOME $ENV{NVSHMEM_HOME} CACHE PATH "Path to NVSHMEM build dir")
|
||||
|
||||
if(USE_NVSHMEM)
|
||||
set(NVSHMEM_INCLUDE_DIR "${NVSHMEM_HOME}/include")
|
||||
set(NVSHMEM_LIB_DIR "${NVSHMEM_HOME}/lib")
|
||||
|
||||
include_directories(${NVSHMEM_INCLUDE_DIR})
|
||||
|
||||
# Linking with nvshmem requires the source binary to be built with -rdc
|
||||
# which is not viable for libtorch_cuda. So we isolate the linking of
|
||||
# nvshmem in nvshmem_extension.
|
||||
add_library(nvshmem_extension SHARED
|
||||
"${TORCH_SRC_DIR}/csrc/distributed/c10d/nvshmem_extension.cu"
|
||||
"${TORCH_SRC_DIR}/csrc/distributed/c10d/NVSHMEMSymmetricMemory.cu"
|
||||
"${TORCH_SRC_DIR}/csrc/distributed/c10d/CUDASymmetricMemoryUtils.cpp"
|
||||
"${TORCH_SRC_DIR}/csrc/distributed/c10d/cuda/utils.cpp"
|
||||
)
|
||||
set_target_properties(nvshmem_extension PROPERTIES CUDA_SEPARABLE_COMPILATION ON)
|
||||
target_compile_options(nvshmem_extension PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:-rdc=true>)
|
||||
target_compile_options(nvshmem_extension PRIVATE "-U__CUDA_NO_HALF_OPERATORS__")
|
||||
target_link_libraries(nvshmem_extension PRIVATE
|
||||
${NVSHMEM_LIB_DIR}/libnvshmem.a
|
||||
${NVSHMEM_LIB_DIR}/nvshmem_bootstrap_uid.so
|
||||
)
|
||||
target_link_libraries(nvshmem_extension PRIVATE mlx5)
|
||||
target_link_libraries(torch_cuda PRIVATE nvshmem_extension)
|
||||
install(TARGETS nvshmem_extension EXPORT Caffe2Targets DESTINATION lib)
|
||||
install(
|
||||
FILES ${NVSHMEM_LIB_DIR}/nvshmem_bootstrap_uid.so
|
||||
DESTINATION lib
|
||||
)
|
||||
endif()
|
||||
if(USE_UCC)
|
||||
target_link_libraries(torch_cuda PRIVATE __caffe2_ucc)
|
||||
target_compile_definitions(torch_cuda PRIVATE USE_UCC)
|
||||
|
@ -30,8 +30,8 @@ template <std::memory_order Sem>
|
||||
__device__ __forceinline__ uint32_t
|
||||
cas(uint32_t* addr, uint32_t compare, uint32_t val) {
|
||||
#if !defined(USE_ROCM) && defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 600)
|
||||
cuda::atomic_ref<uint32_t, cuda::thread_scope_system> ref(*addr);
|
||||
ref.compare_exchange_strong(compare, val, cuda::std::memory_order(Sem));
|
||||
::cuda::atomic_ref<uint32_t, ::cuda::thread_scope_system> ref(*addr);
|
||||
ref.compare_exchange_strong(compare, val, ::cuda::std::memory_order(Sem));
|
||||
return compare;
|
||||
#else
|
||||
CUDA_KERNEL_ASSERT(false);
|
||||
|
@ -1,275 +1,33 @@
|
||||
#include <torch/csrc/distributed/c10d/CUDASymmetricMemory.hpp>
|
||||
#include <torch/csrc/distributed/c10d/CUDASymmetricMemory-inl.h>
|
||||
#include <torch/csrc/distributed/c10d/CUDASymmetricMemory.hpp>
|
||||
#include <torch/csrc/distributed/c10d/CUDASymmetricMemoryUtils.hpp>
|
||||
#include <torch/csrc/distributed/c10d/cuda/utils.hpp>
|
||||
|
||||
#include <ATen/ceil_div.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDACachingAllocator.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <c10/util/env.h>
|
||||
#include <c10/util/error.h>
|
||||
|
||||
#include <sys/socket.h>
|
||||
#include <unistd.h>
|
||||
|
||||
#if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED)
|
||||
#include <c10/cuda/driver_api.h>
|
||||
#endif
|
||||
|
||||
#include <sys/socket.h>
|
||||
#include <sys/syscall.h>
|
||||
#include <sys/un.h>
|
||||
#include <unistd.h>
|
||||
|
||||
#if defined(CUDART_VERSION) && CUDART_VERSION >= 12030
|
||||
#define CUDART_SUPPORTS_MULTICAST
|
||||
#endif
|
||||
|
||||
namespace {
|
||||
|
||||
bool device_has_multicast_support(int device_idx) {
|
||||
if (c10::utils::check_env("TORCH_SYMM_MEM_DISABLE_MULTICAST") == true) {
|
||||
return false;
|
||||
}
|
||||
return c10d::cuda::deviceSupportsMulticast(device_idx);
|
||||
}
|
||||
|
||||
bool allow_overlapping_devices() {
|
||||
return c10::utils::check_env("TORCH_SYMM_MEM_ALLOW_OVERLAPPING_DEVICES") ==
|
||||
true;
|
||||
}
|
||||
|
||||
class IpcChannel {
|
||||
public:
|
||||
IpcChannel() : socket_name_(get_socket_name(getpid())) {
|
||||
TORCH_CHECK(
|
||||
// Local-only, Uses file paths instead of IP addresses
|
||||
(socket_ = socket(AF_UNIX, SOCK_DGRAM, 0)) != 0,
|
||||
"Failed to create socket: ",
|
||||
c10::utils::str_error(errno));
|
||||
|
||||
struct sockaddr_un addr = {.sun_family = AF_UNIX};
|
||||
std::copy(socket_name_.begin(), socket_name_.end(), addr.sun_path);
|
||||
|
||||
TORCH_CHECK(
|
||||
bind(socket_, (struct sockaddr*)&addr, SUN_LEN(&addr)) == 0,
|
||||
"Failed to bind socket: ",
|
||||
c10::utils::str_error(errno));
|
||||
}
|
||||
|
||||
~IpcChannel() {
|
||||
close(socket_);
|
||||
unlink(socket_name_.c_str());
|
||||
}
|
||||
|
||||
// Because file descriptors are process-local kernel objects,
|
||||
// and we can’t pass them via normal socket payloads (like write() or send()).
|
||||
// Unix domain sockets provide a mechanism to pass actual FDs via sendmsg()/recvmsg().
|
||||
void send_fd(int dst_pid, int fd) {
|
||||
struct sockaddr_un addr = {.sun_family = AF_UNIX};
|
||||
auto socket_name = get_socket_name(dst_pid);
|
||||
std::copy(socket_name.begin(), socket_name.end(), addr.sun_path);
|
||||
|
||||
struct iovec io = {.iov_base = (void*)("fd"), .iov_len = 2};
|
||||
|
||||
char cbuf[CMSG_SPACE(sizeof(int))];
|
||||
memset(cbuf, 0, sizeof(cbuf));
|
||||
|
||||
struct msghdr msg {
|
||||
.msg_name = (void*)&addr, .msg_namelen = sizeof(struct sockaddr_un),
|
||||
.msg_iov = &io, .msg_iovlen = 1, .msg_control = cbuf,
|
||||
.msg_controllen = sizeof(cbuf)
|
||||
};
|
||||
|
||||
// This points to the first control message header
|
||||
// With SCM_RIGHTS we let the kernel know that we are passing file descriptors.
|
||||
auto cmsg = CMSG_FIRSTHDR(&msg);
|
||||
cmsg->cmsg_len = CMSG_LEN(sizeof(int));
|
||||
cmsg->cmsg_level = SOL_SOCKET;
|
||||
cmsg->cmsg_type = SCM_RIGHTS;
|
||||
|
||||
if (fd != -1) {
|
||||
std::copy(
|
||||
reinterpret_cast<const char*>(&fd),
|
||||
reinterpret_cast<const char*>(&fd) + sizeof(fd),
|
||||
reinterpret_cast<char*>(CMSG_DATA(cmsg)));
|
||||
} else {
|
||||
msg.msg_controllen = 0;
|
||||
}
|
||||
|
||||
TORCH_CHECK(
|
||||
sendmsg(socket_, &msg, 0) > 0,
|
||||
"Failed to send fd: ",
|
||||
c10::utils::str_error(errno));
|
||||
}
|
||||
|
||||
int recv_fd() {
|
||||
char buf[2];
|
||||
struct iovec io = {.iov_base = (void*)buf, .iov_len = sizeof(buf)};
|
||||
|
||||
char cbuf[CMSG_SPACE(sizeof(int))];
|
||||
memset(cbuf, 0, sizeof(cbuf));
|
||||
|
||||
struct msghdr msg = {
|
||||
.msg_iov = &io,
|
||||
.msg_iovlen = 1,
|
||||
.msg_control = cbuf,
|
||||
.msg_controllen = sizeof(cbuf)};
|
||||
|
||||
TORCH_CHECK(
|
||||
recvmsg(socket_, &msg, 0) > 0,
|
||||
"Failed to receive fd: ",
|
||||
c10::utils::str_error(errno));
|
||||
|
||||
if (msg.msg_controllen == 0) {
|
||||
return -1;
|
||||
}
|
||||
|
||||
auto cmsg = CMSG_FIRSTHDR(&msg);
|
||||
TORCH_CHECK(cmsg != NULL);
|
||||
TORCH_CHECK(cmsg->cmsg_len == CMSG_LEN(sizeof(int)));
|
||||
TORCH_CHECK(
|
||||
cmsg->cmsg_level == SOL_SOCKET && cmsg->cmsg_type == SCM_RIGHTS);
|
||||
return *reinterpret_cast<int*>(CMSG_DATA(cmsg));
|
||||
}
|
||||
|
||||
std::vector<int> all_gather_fds(
|
||||
int rank,
|
||||
const std::vector<int>& pids,
|
||||
int fd) {
|
||||
size_t world_size = pids.size();
|
||||
std::vector<int> fds(pids.size());
|
||||
fds[rank] = fd;
|
||||
|
||||
int dst_rank = (rank + 1) % world_size;
|
||||
for (size_t step = 1; step < world_size; ++step) {
|
||||
int src_rank = (rank + world_size - step) % world_size;
|
||||
send_fd(pids[dst_rank], fd);
|
||||
fd = recv_fd();
|
||||
fds[src_rank] = fd;
|
||||
}
|
||||
return fds;
|
||||
}
|
||||
|
||||
int broadcast_fds(
|
||||
int rank,
|
||||
int src_rank,
|
||||
const std::vector<int>& pids,
|
||||
int fd) {
|
||||
size_t world_size = pids.size();
|
||||
|
||||
if (rank == src_rank) {
|
||||
for (int dst_rank = 0; dst_rank < (int)world_size; ++dst_rank) {
|
||||
if (dst_rank == rank) {
|
||||
continue;
|
||||
}
|
||||
send_fd(pids[dst_rank], fd);
|
||||
}
|
||||
return fd;
|
||||
}
|
||||
return recv_fd();
|
||||
}
|
||||
|
||||
private:
|
||||
static std::string get_socket_name(int pid) {
|
||||
std::string tmp_dir = "/tmp";
|
||||
for (const char* env_var : {"TMPDIR", "TMP", "TEMP", "TEMPDIR"}) {
|
||||
if (const auto path = c10::utils::get_env(env_var)) {
|
||||
tmp_dir = path.value();
|
||||
break;
|
||||
}
|
||||
}
|
||||
std::ostringstream oss;
|
||||
oss << tmp_dir << "/symm_mem-" << pid;
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
std::string socket_name_;
|
||||
int socket_;
|
||||
};
|
||||
|
||||
constexpr size_t signal_pad_size = 2048;
|
||||
const std::string store_comm_prefix = "CUDASymmetricMemory";
|
||||
|
||||
static size_t store_comm_seq_id = 0;
|
||||
|
||||
template <typename T>
|
||||
std::vector<T> store_all_gather(
|
||||
const c10::intrusive_ptr<c10d::Store>& store,
|
||||
int rank,
|
||||
int world_size,
|
||||
T val) {
|
||||
static_assert(std::is_trivially_copyable_v<T>);
|
||||
|
||||
std::vector<std::string> peer_keys;
|
||||
for (int r = 0; r < world_size; ++r) {
|
||||
std::ostringstream oss;
|
||||
oss << store_comm_prefix << "/" << store_comm_seq_id << "/" << r;
|
||||
peer_keys.push_back(oss.str());
|
||||
}
|
||||
++store_comm_seq_id;
|
||||
|
||||
{
|
||||
std::vector<uint8_t> payload(
|
||||
reinterpret_cast<uint8_t*>(&val),
|
||||
reinterpret_cast<uint8_t*>(&val) + sizeof(T));
|
||||
store->set(peer_keys[rank], payload);
|
||||
}
|
||||
|
||||
std::vector<T> peer_vals;
|
||||
for (int r = 0; r < world_size; ++r) {
|
||||
if (r == rank) {
|
||||
peer_vals.push_back(val);
|
||||
continue;
|
||||
}
|
||||
store->wait({peer_keys[r]});
|
||||
auto payload = store->get(peer_keys[r]);
|
||||
TORCH_CHECK(payload.size() == sizeof(T));
|
||||
T peer_val{};
|
||||
std::memcpy(&peer_val, payload.data(), sizeof(T));
|
||||
peer_vals.push_back(peer_val);
|
||||
}
|
||||
return peer_vals;
|
||||
}
|
||||
|
||||
void store_barrier(
|
||||
const c10::intrusive_ptr<c10d::Store>& store,
|
||||
int rank,
|
||||
int world_size) {
|
||||
store_all_gather(store, rank, world_size, 0);
|
||||
}
|
||||
|
||||
// This function returns a pointer of virtual address space that is
|
||||
// mapped to the same physical memory as the given handler.
|
||||
void map_block(
|
||||
void** ptr,
|
||||
c10d::symmetric_memory::HandleType handle,
|
||||
size_t size,
|
||||
int device_idx) {
|
||||
#if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED)
|
||||
auto driver_api = c10::cuda::DriverAPI::get();
|
||||
auto dev_ptr = reinterpret_cast<CUdeviceptr*>(ptr);
|
||||
// Allocate virtual address space.
|
||||
C10_CUDA_DRIVER_CHECK(
|
||||
driver_api->cuMemAddressReserve_(dev_ptr, size, 0ULL, 0, 0ULL));
|
||||
// Map the physical memory to the virtual address.
|
||||
C10_CUDA_DRIVER_CHECK(driver_api->cuMemMap_(*dev_ptr, size, 0, handle, 0ULL));
|
||||
|
||||
// Set access permissions.
|
||||
CUmemAccessDesc desc;
|
||||
desc.location.type = CU_MEM_LOCATION_TYPE_DEVICE;
|
||||
// NOLINTNEXTLINE(bugprone-signed-char-misuse)
|
||||
desc.location.id = static_cast<int>(device_idx);
|
||||
desc.flags = CU_MEM_ACCESS_FLAGS_PROT_READWRITE;
|
||||
C10_CUDA_DRIVER_CHECK(driver_api->cuMemSetAccess_(*dev_ptr, size, &desc, 1));
|
||||
#else
|
||||
TORCH_CHECK(
|
||||
false, "CUDASymmetricMemory requires PYTORCH_C10_DRIVER_API_SUPPORTED");
|
||||
#endif
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
namespace c10d {
|
||||
namespace symmetric_memory {
|
||||
|
||||
/* Start of CUDASymmetricMemory implementation */
|
||||
|
||||
// A set of exchange methods with prefix "CUDASymmetricMemory"
|
||||
static StoreExchange storeExchange = StoreExchange("CUDASymmetricMemory");
|
||||
|
||||
AllocationRef::AllocationRef(
|
||||
void* ptr,
|
||||
HandleType handle,
|
||||
@ -794,7 +552,7 @@ static void init_multicast_for_block(
|
||||
mc_handle, 0, block->alloc_ref->handle, 0, block->block_size, 0));
|
||||
|
||||
map_block(&mc_addr, mc_handle, block->block_size, block->device_idx);
|
||||
store_barrier(store, rank, world_size);
|
||||
storeExchange.barrier(store, rank, world_size);
|
||||
#endif
|
||||
}
|
||||
|
||||
@ -855,7 +613,7 @@ c10::intrusive_ptr<SymmetricMemory> CUDASymmetricMemoryAllocator::rendezvous(
|
||||
.buffer_size = block->buffer_size,
|
||||
.signal_pad_offset = block->signal_pad_offset,
|
||||
.has_multicast_support = device_has_multicast_support(block->device_idx)};
|
||||
auto reqs = store_all_gather(store, rank, world_size, local_req);
|
||||
auto reqs = storeExchange.all_gather(store, rank, world_size, local_req);
|
||||
validate_rendezvous_requests(reqs, world_size);
|
||||
|
||||
std::vector<int> pids(world_size);
|
||||
@ -885,7 +643,7 @@ c10::intrusive_ptr<SymmetricMemory> CUDASymmetricMemoryAllocator::rendezvous(
|
||||
signal_pads[r] = (void*)((uintptr_t)buffers[r] + block->signal_pad_offset);
|
||||
close(imported_fds[r]);
|
||||
}
|
||||
store_barrier(store, rank, world_size);
|
||||
storeExchange.barrier(store, rank, world_size);
|
||||
close(block_fd);
|
||||
|
||||
HandleType mc_handle{};
|
||||
@ -939,9 +697,13 @@ c10::intrusive_ptr<Block> CUDASymmetricMemoryAllocator::find_block(void* ptr) {
|
||||
|
||||
struct RegisterCUDASymmetricMemoryAllocator {
|
||||
RegisterCUDASymmetricMemoryAllocator() {
|
||||
register_allocator(
|
||||
c10::DeviceType::CUDA,
|
||||
c10::make_intrusive<CUDASymmetricMemoryAllocator>());
|
||||
// Query backend used for CUDA tensor
|
||||
// "CUDA" backend stands for this implementation
|
||||
if (getSymmMemBackendCUDA() == "CUDA") {
|
||||
register_allocator(
|
||||
c10::DeviceType::CUDA,
|
||||
c10::make_intrusive<CUDASymmetricMemoryAllocator>());
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -1,17 +1,12 @@
|
||||
#pragma once
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <torch/csrc/distributed/c10d/CUDASymmetricMemoryTypes.hpp>
|
||||
#include <torch/csrc/distributed/c10d/Store.hpp>
|
||||
#include <torch/csrc/distributed/c10d/SymmetricMemory.hpp>
|
||||
|
||||
namespace c10d::symmetric_memory {
|
||||
|
||||
#if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED)
|
||||
using HandleType = CUmemGenericAllocationHandle;
|
||||
#else
|
||||
using HandleType = void*;
|
||||
#endif
|
||||
|
||||
// Resource wrapper that owns a (vaddr, allocation handle) pair. Upon
|
||||
// destruction, it unmaps the vaddr and releases the allocation handle.
|
||||
struct AllocationRef : public c10::intrusive_ptr_target {
|
||||
|
13
torch/csrc/distributed/c10d/CUDASymmetricMemoryTypes.hpp
Normal file
13
torch/csrc/distributed/c10d/CUDASymmetricMemoryTypes.hpp
Normal file
@ -0,0 +1,13 @@
|
||||
#pragma once
|
||||
|
||||
namespace c10d::symmetric_memory {
|
||||
|
||||
constexpr size_t signal_pad_size = 2048;
|
||||
|
||||
#if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED)
|
||||
using HandleType = CUmemGenericAllocationHandle;
|
||||
#else
|
||||
using HandleType = void*;
|
||||
#endif
|
||||
|
||||
} // namespace c10d::symmetric_memory
|
216
torch/csrc/distributed/c10d/CUDASymmetricMemoryUtils.cpp
Normal file
216
torch/csrc/distributed/c10d/CUDASymmetricMemoryUtils.cpp
Normal file
@ -0,0 +1,216 @@
|
||||
#include <sys/socket.h>
|
||||
#include <sys/syscall.h>
|
||||
#include <sys/un.h>
|
||||
#include <unistd.h>
|
||||
|
||||
#include <c10/util/error.h>
|
||||
|
||||
#if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED)
|
||||
#include <c10/cuda/driver_api.h>
|
||||
#endif
|
||||
|
||||
#include <torch/csrc/distributed/c10d/CUDASymmetricMemoryUtils.hpp>
|
||||
#include <torch/csrc/distributed/c10d/Store.hpp>
|
||||
#include <torch/csrc/distributed/c10d/cuda/utils.hpp>
|
||||
|
||||
namespace c10d::symmetric_memory {
|
||||
|
||||
bool device_has_multicast_support(int device_idx) {
|
||||
if (c10::utils::check_env("TORCH_SYMM_MEM_DISABLE_MULTICAST") == true) {
|
||||
return false;
|
||||
}
|
||||
return c10d::cuda::deviceSupportsMulticast(device_idx);
|
||||
}
|
||||
|
||||
bool allow_overlapping_devices() {
|
||||
return c10::utils::check_env("TORCH_SYMM_MEM_ALLOW_OVERLAPPING_DEVICES") ==
|
||||
true;
|
||||
}
|
||||
|
||||
// Query environment variable to get the backend used for CUDA Symmetric Memory.
|
||||
std::string getSymmMemBackendCUDA() {
|
||||
static auto val = c10::utils::get_env("TORCH_SYMMMEM");
|
||||
if (!val.has_value()) {
|
||||
// In-house implementation: `CUDASymmetricMemory`
|
||||
return "CUDA";
|
||||
} else {
|
||||
// Other backends:
|
||||
// - "NVSHMEM": `NVSHMEMSymmetricMemory`
|
||||
return val.value();
|
||||
}
|
||||
}
|
||||
|
||||
IpcChannel::IpcChannel()
|
||||
: socket_name_(get_socket_name(getpid())),
|
||||
socket_(socket(AF_UNIX, SOCK_DGRAM, 0)) {
|
||||
// On success, a file descriptor for the new socket is returned.
|
||||
// On error, -1 is returned, and errno is set to indicate the error.
|
||||
TORCH_CHECK(
|
||||
socket_ != -1, "Failed to create socket: ", c10::utils::str_error(errno));
|
||||
|
||||
struct sockaddr_un addr = {.sun_family = AF_UNIX};
|
||||
std::copy(socket_name_.begin(), socket_name_.end(), addr.sun_path);
|
||||
|
||||
TORCH_CHECK(
|
||||
bind(socket_, (struct sockaddr*)&addr, SUN_LEN(&addr)) == 0,
|
||||
"Failed to bind socket: ",
|
||||
c10::utils::str_error(errno));
|
||||
}
|
||||
|
||||
IpcChannel::~IpcChannel() {
|
||||
close(socket_);
|
||||
unlink(socket_name_.c_str());
|
||||
}
|
||||
|
||||
void IpcChannel::send_fd(int dst_pid, int fd) {
|
||||
// Because file descriptors are process-local kernel objects, and we can’t
|
||||
// pass them via normal socket payloads (like write() or send()). Unix domain
|
||||
// sockets provide a mechanism to pass actual FDs via sendmsg()/recvmsg().
|
||||
struct sockaddr_un addr = {.sun_family = AF_UNIX};
|
||||
auto socket_name = get_socket_name(dst_pid);
|
||||
std::copy(socket_name.begin(), socket_name.end(), addr.sun_path);
|
||||
|
||||
struct iovec io = {.iov_base = (void*)("fd"), .iov_len = 2};
|
||||
|
||||
// NOLINTNEXTLINE(*array*)
|
||||
char cbuf[CMSG_SPACE(sizeof(int))];
|
||||
memset(cbuf, 0, sizeof(cbuf));
|
||||
|
||||
struct msghdr msg {
|
||||
.msg_name = (void*)&addr, .msg_namelen = sizeof(struct sockaddr_un),
|
||||
.msg_iov = &io, .msg_iovlen = 1, .msg_control = cbuf,
|
||||
.msg_controllen = sizeof(cbuf)
|
||||
};
|
||||
|
||||
// This points to the first control message header
|
||||
// With SCM_RIGHTS we let the kernel know that we are passing file
|
||||
// descriptors.
|
||||
auto cmsg = CMSG_FIRSTHDR(&msg);
|
||||
cmsg->cmsg_len = CMSG_LEN(sizeof(int));
|
||||
cmsg->cmsg_level = SOL_SOCKET;
|
||||
cmsg->cmsg_type = SCM_RIGHTS;
|
||||
|
||||
if (fd != -1) {
|
||||
std::copy(
|
||||
reinterpret_cast<const char*>(&fd),
|
||||
reinterpret_cast<const char*>(&fd) + sizeof(fd),
|
||||
reinterpret_cast<char*>(CMSG_DATA(cmsg)));
|
||||
} else {
|
||||
msg.msg_controllen = 0;
|
||||
}
|
||||
|
||||
TORCH_CHECK(
|
||||
sendmsg(socket_, &msg, 0) > 0,
|
||||
"Failed to send fd: ",
|
||||
c10::utils::str_error(errno));
|
||||
}
|
||||
|
||||
int IpcChannel::recv_fd() {
|
||||
// NOLINTNEXTLINE(*array*)
|
||||
char buf[2];
|
||||
struct iovec io = {.iov_base = (void*)buf, .iov_len = sizeof(buf)};
|
||||
|
||||
// NOLINTNEXTLINE(*array*)
|
||||
char cbuf[CMSG_SPACE(sizeof(int))];
|
||||
memset(cbuf, 0, sizeof(cbuf));
|
||||
|
||||
struct msghdr msg = {
|
||||
.msg_iov = &io,
|
||||
.msg_iovlen = 1,
|
||||
.msg_control = cbuf,
|
||||
.msg_controllen = sizeof(cbuf)};
|
||||
|
||||
TORCH_CHECK(
|
||||
recvmsg(socket_, &msg, 0) > 0,
|
||||
"Failed to receive fd: ",
|
||||
c10::utils::str_error(errno));
|
||||
|
||||
if (msg.msg_controllen == 0) {
|
||||
return -1;
|
||||
}
|
||||
|
||||
auto cmsg = CMSG_FIRSTHDR(&msg);
|
||||
TORCH_CHECK(cmsg != nullptr);
|
||||
TORCH_CHECK(cmsg->cmsg_len == CMSG_LEN(sizeof(int)));
|
||||
TORCH_CHECK(cmsg->cmsg_level == SOL_SOCKET && cmsg->cmsg_type == SCM_RIGHTS);
|
||||
return *reinterpret_cast<int*>(CMSG_DATA(cmsg));
|
||||
}
|
||||
|
||||
std::vector<int> IpcChannel::all_gather_fds(
|
||||
int rank,
|
||||
const std::vector<int>& pids,
|
||||
int fd) {
|
||||
int world_size = (int)pids.size();
|
||||
std::vector<int> fds(pids.size());
|
||||
fds[rank] = fd;
|
||||
|
||||
int dst_rank = (rank + 1) % world_size;
|
||||
for (int step = 1; step < world_size; ++step) {
|
||||
int src_rank = (rank + world_size - step) % world_size;
|
||||
send_fd(pids[dst_rank], fd);
|
||||
fd = recv_fd();
|
||||
fds[src_rank] = fd;
|
||||
}
|
||||
return fds;
|
||||
}
|
||||
|
||||
int IpcChannel::broadcast_fds(
|
||||
int rank,
|
||||
int src_rank,
|
||||
const std::vector<int>& pids,
|
||||
int fd) {
|
||||
int world_size = (int)pids.size();
|
||||
|
||||
if (rank == src_rank) {
|
||||
for (int dst_rank = 0; dst_rank < (int)world_size; ++dst_rank) {
|
||||
if (dst_rank == rank) {
|
||||
continue;
|
||||
}
|
||||
send_fd(pids[dst_rank], fd);
|
||||
}
|
||||
return fd;
|
||||
}
|
||||
return recv_fd();
|
||||
}
|
||||
|
||||
std::string IpcChannel::get_socket_name(int pid) {
|
||||
const char* tmp_dir = "/tmp";
|
||||
for (const char* env_var : {"TMPDIR", "TMP", "TEMP", "TEMPDIR"}) {
|
||||
if (const char* path = getenv(env_var)) {
|
||||
tmp_dir = path;
|
||||
break;
|
||||
}
|
||||
}
|
||||
std::ostringstream oss;
|
||||
oss << tmp_dir << "/symm_mem-" << pid;
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
void map_block(
|
||||
void** ptr,
|
||||
c10d::symmetric_memory::HandleType handle,
|
||||
size_t size,
|
||||
int device_idx) {
|
||||
#if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED)
|
||||
auto driver_api = c10::cuda::DriverAPI::get();
|
||||
auto dev_ptr = reinterpret_cast<CUdeviceptr*>(ptr);
|
||||
// Allocate virtual address space
|
||||
C10_CUDA_DRIVER_CHECK(
|
||||
driver_api->cuMemAddressReserve_(dev_ptr, size, 0ULL, 0, 0ULL));
|
||||
// Map the physical memory to the virtual address
|
||||
C10_CUDA_DRIVER_CHECK(driver_api->cuMemMap_(*dev_ptr, size, 0, handle, 0ULL));
|
||||
|
||||
// Set access permissions
|
||||
CUmemAccessDesc desc;
|
||||
desc.location.type = CU_MEM_LOCATION_TYPE_DEVICE;
|
||||
// NOLINTNEXTLINE(bugprone-signed-char-misuse)
|
||||
desc.location.id = static_cast<int>(device_idx);
|
||||
desc.flags = CU_MEM_ACCESS_FLAGS_PROT_READWRITE;
|
||||
C10_CUDA_DRIVER_CHECK(driver_api->cuMemSetAccess_(*dev_ptr, size, &desc, 1));
|
||||
#else
|
||||
TORCH_CHECK(
|
||||
false, "CUDASymmetricMemory requires PYTORCH_C10_DRIVER_API_SUPPORTED");
|
||||
#endif
|
||||
}
|
||||
|
||||
} // namespace c10d::symmetric_memory
|
115
torch/csrc/distributed/c10d/CUDASymmetricMemoryUtils.hpp
Normal file
115
torch/csrc/distributed/c10d/CUDASymmetricMemoryUtils.hpp
Normal file
@ -0,0 +1,115 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/csrc/distributed/c10d/CUDASymmetricMemoryTypes.hpp>
|
||||
#include <torch/csrc/distributed/c10d/Store.hpp>
|
||||
#include <torch/csrc/distributed/c10d/SymmetricMemory.hpp>
|
||||
|
||||
namespace c10d {
|
||||
namespace symmetric_memory {
|
||||
|
||||
bool device_has_multicast_support(int device_idx);
|
||||
|
||||
bool allow_overlapping_devices();
|
||||
|
||||
// Query environment variable to get the backend used for CUDA Symmetric Memory.
|
||||
std::string getSymmMemBackendCUDA();
|
||||
|
||||
class IpcChannel {
|
||||
public:
|
||||
IpcChannel();
|
||||
~IpcChannel();
|
||||
|
||||
void send_fd(int dst_pid, int fd);
|
||||
int recv_fd();
|
||||
|
||||
std::vector<int> all_gather_fds(
|
||||
int rank,
|
||||
const std::vector<int>& pids,
|
||||
int fd);
|
||||
|
||||
int broadcast_fds(
|
||||
int rank,
|
||||
int src_rank,
|
||||
const std::vector<int>& pids,
|
||||
int fd);
|
||||
|
||||
private:
|
||||
static std::string get_socket_name(int pid);
|
||||
|
||||
std::string socket_name_;
|
||||
int socket_;
|
||||
};
|
||||
|
||||
// A set of store-based exchange methods with a preset prefix typically type of
|
||||
// the SymmetricMemory. Most used as static instances at respective
|
||||
// SymmetricMemory implementation files.
|
||||
class StoreExchange {
|
||||
public:
|
||||
StoreExchange(const std::string& store_prefix)
|
||||
: store_prefix_(store_prefix) {}
|
||||
|
||||
// Put template function in header file so that compiler can easily access it.
|
||||
template <typename T>
|
||||
std::vector<T> all_gather(
|
||||
const c10::intrusive_ptr<c10d::Store>& store,
|
||||
int rank,
|
||||
int world_size,
|
||||
T val) {
|
||||
static_assert(std::is_trivially_copyable_v<T>);
|
||||
|
||||
std::vector<std::string> peer_keys;
|
||||
peer_keys.reserve(world_size);
|
||||
for (int r = 0; r < world_size; ++r) {
|
||||
std::ostringstream oss;
|
||||
oss << store_prefix_ << "/" << seq_id_ << "/" << r;
|
||||
peer_keys.push_back(oss.str());
|
||||
}
|
||||
++seq_id_;
|
||||
|
||||
{
|
||||
std::vector<uint8_t> payload(
|
||||
reinterpret_cast<uint8_t*>(&val),
|
||||
reinterpret_cast<uint8_t*>(&val) + sizeof(T));
|
||||
store->set(peer_keys[rank], payload);
|
||||
}
|
||||
|
||||
std::vector<T> peer_vals;
|
||||
peer_vals.reserve(world_size);
|
||||
for (int r = 0; r < world_size; ++r) {
|
||||
if (r == rank) {
|
||||
peer_vals.push_back(val);
|
||||
continue;
|
||||
}
|
||||
store->wait({peer_keys[r]});
|
||||
auto payload = store->get(peer_keys[r]);
|
||||
TORCH_CHECK(payload.size() == sizeof(T));
|
||||
T peer_val{};
|
||||
std::memcpy(&peer_val, payload.data(), sizeof(T));
|
||||
peer_vals.push_back(peer_val);
|
||||
}
|
||||
return peer_vals;
|
||||
}
|
||||
|
||||
void barrier(
|
||||
const c10::intrusive_ptr<c10d::Store>& store,
|
||||
int rank,
|
||||
int world_size) {
|
||||
// TODO: implement an efficient one?
|
||||
all_gather(store, rank, world_size, 0);
|
||||
}
|
||||
|
||||
private:
|
||||
const std::string store_prefix_;
|
||||
size_t seq_id_ = 0;
|
||||
};
|
||||
|
||||
// Teturns a pointer of virtual address that is mapped to the physical memory
|
||||
// held by the handle.
|
||||
void map_block(
|
||||
void** ptr,
|
||||
c10d::symmetric_memory::HandleType handle,
|
||||
size_t size,
|
||||
int device_idx);
|
||||
|
||||
} // namespace symmetric_memory
|
||||
} // namespace c10d
|
329
torch/csrc/distributed/c10d/NVSHMEMSymmetricMemory.cu
Normal file
329
torch/csrc/distributed/c10d/NVSHMEMSymmetricMemory.cu
Normal file
@ -0,0 +1,329 @@
|
||||
#include <torch/csrc/distributed/c10d/SymmetricMemory.hpp>
|
||||
#include <torch/csrc/distributed/c10d/CUDASymmetricMemoryUtils.hpp>
|
||||
#include <torch/csrc/distributed/c10d/CUDASymmetricMemory-inl.h>
|
||||
#include <torch/csrc/distributed/c10d/cuda/utils.hpp>
|
||||
#include <torch/csrc/distributed/c10d/nvshmem_extension.cuh>
|
||||
|
||||
#include <ATen/ceil_div.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDACachingAllocator.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <c10/util/error.h>
|
||||
|
||||
namespace c10d {
|
||||
namespace symmetric_memory {
|
||||
|
||||
/* Start of CUDASymmetricMemory implementation */
|
||||
|
||||
static StoreExchange storeExchange = StoreExchange("NVSHMEMSymmetricMemory");
|
||||
|
||||
struct NVSHMEMAllocation {
|
||||
void* ptr;
|
||||
size_t buffer_size;
|
||||
int device_idx;
|
||||
|
||||
NVSHMEMAllocation(void* ptr, size_t buffer_size, int device_idx)
|
||||
: ptr(ptr), buffer_size(buffer_size), device_idx(device_idx) {}
|
||||
};
|
||||
|
||||
class NVSHMEMSymmetricMemory : public SymmetricMemory {
|
||||
public:
|
||||
NVSHMEMSymmetricMemory(
|
||||
std::shared_ptr<NVSHMEMAllocation> allocation,
|
||||
const std::string& group_name)
|
||||
: allocation_(allocation),
|
||||
buffer_size_(allocation->buffer_size),
|
||||
device_idx_(allocation->device_idx),
|
||||
group_name_(group_name) {
|
||||
c10::cuda::CUDAGuard guard(device_idx_);
|
||||
|
||||
auto global_rank = get_group_info("0").rank;
|
||||
auto group_info = get_group_info(group_name_);
|
||||
auto store = group_info.store;
|
||||
rank_ = group_info.rank;
|
||||
world_size_ = group_info.world_size;
|
||||
rank_to_global_rank_ =
|
||||
storeExchange.all_gather(store, rank_, world_size_, global_rank);
|
||||
LOG(INFO) << "[rank " << rank_ << "]"
|
||||
<< "rank_to_global_rank: " << rank_to_global_rank_;
|
||||
|
||||
for (int r = 0; r < world_size_; ++r) {
|
||||
buffers_.push_back(nvshmem_extension::nvshmem_ptr(
|
||||
allocation->ptr, rank_to_global_rank_[r]));
|
||||
}
|
||||
|
||||
// TODO: use the same allocation for signal pad
|
||||
void* signal_pad_ptr = nvshmem_extension::nvshmem_malloc(signal_pad_size);
|
||||
AT_CUDA_CHECK(cudaMemset(signal_pad_ptr, 0, signal_pad_size));
|
||||
|
||||
for (int r = 0; r < world_size_; ++r) {
|
||||
signal_pads_.push_back(nvshmem_extension::nvshmem_ptr(
|
||||
signal_pad_ptr, rank_to_global_rank_[r]));
|
||||
}
|
||||
|
||||
const size_t arr_size = sizeof(void*) * world_size_;
|
||||
buffers_dev_ = reinterpret_cast<void**>(
|
||||
c10::cuda::CUDACachingAllocator::raw_alloc(arr_size));
|
||||
signal_pads_dev_ = reinterpret_cast<void**>(
|
||||
c10::cuda::CUDACachingAllocator::raw_alloc(arr_size));
|
||||
|
||||
AT_CUDA_CHECK(cudaMemcpy(
|
||||
buffers_dev_, buffers_.data(), arr_size, cudaMemcpyHostToDevice));
|
||||
AT_CUDA_CHECK(cudaMemcpy(
|
||||
signal_pads_dev_,
|
||||
signal_pads_.data(),
|
||||
arr_size,
|
||||
cudaMemcpyHostToDevice));
|
||||
|
||||
rank_to_global_rank_dev_ = reinterpret_cast<int*>(
|
||||
c10::cuda::CUDACachingAllocator::raw_alloc(sizeof(int) * world_size_));
|
||||
AT_CUDA_CHECK(cudaMemcpy(
|
||||
rank_to_global_rank_dev_,
|
||||
rank_to_global_rank_.data(),
|
||||
sizeof(int) * world_size_,
|
||||
cudaMemcpyHostToDevice));
|
||||
}
|
||||
|
||||
~NVSHMEMSymmetricMemory() override{
|
||||
// TODO
|
||||
};
|
||||
|
||||
std::vector<void*> get_buffer_ptrs() override {
|
||||
return buffers_;
|
||||
}
|
||||
|
||||
std::vector<void*> get_signal_pad_ptrs() override {
|
||||
return signal_pads_;
|
||||
}
|
||||
|
||||
void** get_buffer_ptrs_dev() override {
|
||||
return buffers_dev_;
|
||||
}
|
||||
|
||||
void** get_signal_pad_ptrs_dev() override {
|
||||
return signal_pads_dev_;
|
||||
}
|
||||
|
||||
size_t get_buffer_size() override {
|
||||
return buffer_size_;
|
||||
}
|
||||
|
||||
size_t get_signal_pad_size() override {
|
||||
return signal_pad_size;
|
||||
};
|
||||
|
||||
bool has_multicast_support() override {
|
||||
// TODO
|
||||
return false;
|
||||
}
|
||||
|
||||
void* get_multicast_ptr() override {
|
||||
// TODO
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
at::Tensor get_buffer(
|
||||
int rank,
|
||||
c10::IntArrayRef sizes,
|
||||
c10::ScalarType dtype,
|
||||
int64_t storage_offset) {
|
||||
// TODO: deduplicate
|
||||
const size_t numel = std::accumulate(
|
||||
sizes.begin(),
|
||||
sizes.end(),
|
||||
static_cast<size_t>(1),
|
||||
std::multiplies<size_t>());
|
||||
const auto element_size = c10::elementSize(dtype);
|
||||
const auto req_size = (numel + storage_offset) * element_size;
|
||||
TORCH_CHECK(
|
||||
req_size <= buffer_size_,
|
||||
"NVSHMEMSymmetricMemory::get_buffer: the requested size (",
|
||||
req_size,
|
||||
" bytes) exceeds the allocated size (",
|
||||
buffer_size_,
|
||||
" bytes)");
|
||||
auto data_ptr = reinterpret_cast<uint8_t*>(buffers_[rank]) +
|
||||
storage_offset * element_size;
|
||||
auto device = c10::Device(c10::DeviceType::CUDA, device_idx_);
|
||||
auto options = at::TensorOptions().dtype(dtype).device(device);
|
||||
return at::for_blob(data_ptr, sizes)
|
||||
.options(options)
|
||||
.target_device(device)
|
||||
.make_tensor();
|
||||
}
|
||||
|
||||
at::Tensor get_signal_pad(
|
||||
int rank,
|
||||
c10::IntArrayRef sizes,
|
||||
std::optional<c10::ScalarType> dtype,
|
||||
int64_t storage_offset) override {
|
||||
// TODO: deduplicate
|
||||
// If the dtype is unspecified, default it to UInt32, as it
|
||||
// is the most common type for signaling purposes.
|
||||
if (!dtype.has_value()) {
|
||||
dtype = c10::ScalarType::UInt32;
|
||||
}
|
||||
|
||||
// If the shape is unspecified, treat the signal pad as a 1d tensor.
|
||||
const auto element_size = c10::elementSize(*dtype);
|
||||
std::vector<int64_t> shape;
|
||||
if (sizes.size() != 0) {
|
||||
shape = sizes.vec();
|
||||
} else {
|
||||
shape.push_back(signal_pad_size / element_size);
|
||||
}
|
||||
|
||||
const size_t numel = std::accumulate(
|
||||
shape.begin(),
|
||||
shape.end(),
|
||||
static_cast<size_t>(1),
|
||||
std::multiplies<size_t>());
|
||||
const auto req_size = (numel + storage_offset) * element_size;
|
||||
TORCH_CHECK(
|
||||
req_size <= signal_pad_size,
|
||||
"NVSHMEMSymmetricMemory::get_signal_pad: the requested size (",
|
||||
req_size,
|
||||
" bytes) exceeds the allocated size (",
|
||||
signal_pad_size,
|
||||
" bytes)");
|
||||
auto data_ptr = reinterpret_cast<uint8_t*>(signal_pads_[rank]) +
|
||||
storage_offset * element_size;
|
||||
auto device = c10::Device(c10::DeviceType::CUDA, device_idx_);
|
||||
auto options = at::TensorOptions().dtype(*dtype).device(device);
|
||||
return at::for_blob(data_ptr, shape)
|
||||
.options(options)
|
||||
.target_device(device)
|
||||
.make_tensor();
|
||||
}
|
||||
|
||||
void barrier(int channel, size_t timeout_ms) override {
|
||||
// TODO
|
||||
}
|
||||
|
||||
void put_signal(int dst_rank, int channel, size_t timeout_ms) override {
|
||||
// TODO
|
||||
}
|
||||
|
||||
void wait_signal(int src_rank, int channel, size_t timeout_ms) override {
|
||||
// TODO
|
||||
}
|
||||
|
||||
int get_rank() override {
|
||||
return rank_;
|
||||
}
|
||||
|
||||
int get_world_size() override {
|
||||
return world_size_;
|
||||
}
|
||||
|
||||
virtual std::vector<int> get_rank_to_global_rank() override {
|
||||
return rank_to_global_rank_;
|
||||
};
|
||||
|
||||
int* get_rank_to_global_rank_dev() override {
|
||||
return rank_to_global_rank_dev_;
|
||||
};
|
||||
|
||||
private:
|
||||
std::shared_ptr<NVSHMEMAllocation> allocation_;
|
||||
size_t buffer_size_;
|
||||
std::vector<void*> buffers_;
|
||||
std::vector<void*> signal_pads_;
|
||||
int device_idx_;
|
||||
int rank_;
|
||||
int world_size_;
|
||||
void** buffers_dev_;
|
||||
void** signal_pads_dev_;
|
||||
std::string group_name_;
|
||||
|
||||
std::vector<int> rank_to_global_rank_;
|
||||
int* rank_to_global_rank_dev_;
|
||||
};
|
||||
|
||||
class NVSHMEMSymmetricMemoryAllocator : public SymmetricMemoryAllocator {
|
||||
public:
|
||||
void* alloc(
|
||||
size_t size,
|
||||
int device_idx,
|
||||
const std::optional<std::string>& group_name) override {
|
||||
TORCH_CHECK(
|
||||
group_name == std::nullopt,
|
||||
"NVSHMEMSymmetricMemoryAllocator::alloc "
|
||||
"must not be called with a group_name");
|
||||
|
||||
auto group_info = get_group_info("0");
|
||||
auto store = group_info.store;
|
||||
int rank = group_info.rank;
|
||||
int world_size = group_info.world_size;
|
||||
|
||||
nvshmem_extension::initialize_nvshmem_with_store(store, rank, world_size);
|
||||
auto ptr = nvshmem_extension::nvshmem_malloc(size);
|
||||
auto allocation =
|
||||
std::make_shared<NVSHMEMAllocation>(ptr, size, device_idx);
|
||||
// TODO: thread safety
|
||||
allocations_.emplace(ptr, allocation);
|
||||
return ptr;
|
||||
}
|
||||
|
||||
void free(void* ptr) override {
|
||||
// TODO: thread safety
|
||||
ptr_to_symm_mem_.erase(ptr);
|
||||
};
|
||||
|
||||
size_t get_alloc_size(void* ptr) override {
|
||||
auto it = ptr_to_symm_mem_.find(ptr);
|
||||
if (it == ptr_to_symm_mem_.end()) {
|
||||
TORCH_CHECK(
|
||||
false, ptr, " is not allocated with NVSHMEMSymmetricMemoryAllocator");
|
||||
}
|
||||
return it->second->get_buffer_size();
|
||||
};
|
||||
|
||||
c10::intrusive_ptr<SymmetricMemory> rendezvous(
|
||||
void* ptr,
|
||||
const std::optional<std::string>& group_name) override {
|
||||
TORCH_CHECK(group_name.has_value());
|
||||
{
|
||||
auto it = symm_mems_.find(std::make_tuple(ptr, *group_name));
|
||||
if (it != symm_mems_.end()) {
|
||||
return it->second;
|
||||
}
|
||||
}
|
||||
auto it = allocations_.find(ptr);
|
||||
TORCH_CHECK(it != allocations_.end());
|
||||
auto symm_mem =
|
||||
c10::make_intrusive<NVSHMEMSymmetricMemory>(it->second, *group_name);
|
||||
|
||||
symm_mems_[std::make_tuple(ptr, *group_name)] = symm_mem;
|
||||
return symm_mem;
|
||||
};
|
||||
|
||||
bool has_multicast_support(int device_idx) override {
|
||||
// TODO
|
||||
return false;
|
||||
};
|
||||
|
||||
private:
|
||||
std::unordered_map<void*, c10::intrusive_ptr<SymmetricMemory>>
|
||||
ptr_to_symm_mem_;
|
||||
|
||||
std::unordered_map<void*, std::shared_ptr<NVSHMEMAllocation>> allocations_;
|
||||
std::map<std::tuple<void*, std::string>, c10::intrusive_ptr<SymmetricMemory>>
|
||||
symm_mems_;
|
||||
};
|
||||
|
||||
struct RegisterNVSHMEMSymmetricMemoryAllocator {
|
||||
RegisterNVSHMEMSymmetricMemoryAllocator() {
|
||||
// Query backend used for CUDA tensor
|
||||
if (getSymmMemBackendCUDA() == "NVSHMEM") {
|
||||
register_allocator(
|
||||
c10::DeviceType::CUDA,
|
||||
c10::make_intrusive<NVSHMEMSymmetricMemoryAllocator>());
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
static RegisterNVSHMEMSymmetricMemoryAllocator register_allocator_;
|
||||
|
||||
} // namespace symmetric_memory
|
||||
} // namespace c10d
|
@ -6,6 +6,7 @@ using namespace c10d::symmetric_memory;
|
||||
|
||||
static bool is_finalizing_ = false;
|
||||
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-special-member-functions)
|
||||
class AllocatorMap {
|
||||
public:
|
||||
AllocatorMap(const AllocatorMap&) = delete;
|
||||
@ -212,7 +213,9 @@ namespace {
|
||||
|
||||
at::Tensor one_shot_all_reduce_meta(
|
||||
const at::Tensor& input,
|
||||
// NOLINTNEXTLINE(performance-unnecessary-value-param)
|
||||
std::string reduce_op,
|
||||
// NOLINTNEXTLINE(performance-unnecessary-value-param)
|
||||
std::string group_name) {
|
||||
return at::empty_like(input);
|
||||
}
|
||||
@ -220,7 +223,9 @@ at::Tensor one_shot_all_reduce_meta(
|
||||
at::Tensor one_shot_all_reduce_copy_meta(
|
||||
const at::Tensor& symm_buffer,
|
||||
const at::Tensor& local_input,
|
||||
// NOLINTNEXTLINE(performance-unnecessary-value-param)
|
||||
std::string reduce_op,
|
||||
// NOLINTNEXTLINE(performance-unnecessary-value-param)
|
||||
std::string group_name) {
|
||||
return at::empty_like(local_input);
|
||||
}
|
||||
@ -269,6 +274,8 @@ TORCH_LIBRARY_FRAGMENT(symm_mem, m) {
|
||||
"stream_write_value32_(Tensor(a!) input, int offset, int val) -> Tensor(a!)");
|
||||
m.def(
|
||||
"memset32_(Tensor(a!) input, int offset, int val, int count) -> Tensor(a!)");
|
||||
|
||||
m.def("nvshmem_broadcast(Tensor(a!) input, str group_name) -> Tensor(a!)");
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_IMPL(symm_mem, Meta, m) {
|
||||
|
@ -71,6 +71,14 @@ class TORCH_API SymmetricMemory : public c10::intrusive_ptr_target {
|
||||
|
||||
virtual int get_rank() = 0;
|
||||
virtual int get_world_size() = 0;
|
||||
|
||||
virtual std::vector<int> get_rank_to_global_rank() {
|
||||
TORCH_CHECK(false, "NYI");
|
||||
};
|
||||
|
||||
virtual int* get_rank_to_global_rank_dev() {
|
||||
TORCH_CHECK(false, "NYI");
|
||||
};
|
||||
};
|
||||
|
||||
class SymmetricMemoryAllocator : public c10::intrusive_ptr_target {
|
||||
|
126
torch/csrc/distributed/c10d/nvshmem_extension.cu
Normal file
126
torch/csrc/distributed/c10d/nvshmem_extension.cu
Normal file
@ -0,0 +1,126 @@
|
||||
#include <torch/csrc/distributed/c10d/nvshmem_extension.cuh>
|
||||
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
|
||||
#include <torch/csrc/distributed/c10d/CUDASymmetricMemory-inl.h>
|
||||
#include <torch/csrc/distributed/c10d/CUDASymmetricMemoryUtils.hpp>
|
||||
#include <torch/csrc/distributed/c10d/SymmetricMemory.hpp>
|
||||
|
||||
#include <cuda_awbarrier_primitives.h>
|
||||
#include <nvshmem.h>
|
||||
|
||||
namespace c10d::nvshmem_extension {
|
||||
|
||||
using c10d::symmetric_memory::StoreExchange;
|
||||
static StoreExchange storeExchange = StoreExchange("nvshmem_ext");
|
||||
|
||||
// Bootstrap based on user's setting for NCCL
|
||||
// Long term, this may be a bit unclean; short term, it improves UX
|
||||
void maybe_initialize_env_vars() {
|
||||
auto nccl_socket_if_name = c10::utils::get_env("NCCL_SOCKET_IFNAME");
|
||||
auto nccl_hca_list = c10::utils::get_env("NCCL_IB_HCA");
|
||||
auto nccl_ib_gid_index = c10::utils::get_env("NCCL_IB_GID_INDEX");
|
||||
auto nvshmem_socket_if_name =
|
||||
c10::utils::get_env("NVSHMEM_BOOTSTRAP_UID_SOCK_IFNAME");
|
||||
auto nvshmem_hca_list = c10::utils::get_env("NCCL_IB_HCA");
|
||||
auto nvshmem_ib_gid_index = c10::utils::get_env("NVSHMEM_IB_GID_INDEX");
|
||||
|
||||
if (!nvshmem_socket_if_name.has_value() && nccl_socket_if_name.has_value()) {
|
||||
c10::utils::set_env(
|
||||
"NVSHMEM_BOOTSTRAP_UID_SOCK_IFNAME", nccl_socket_if_name->c_str());
|
||||
}
|
||||
if (!nvshmem_hca_list.has_value() && nccl_hca_list.has_value()) {
|
||||
c10::utils::set_env("NVSHMEM_ENABLE_NIC_PE_MAPPING", "1");
|
||||
c10::utils::set_env("NVSHMEM_HCA_LIST", nccl_hca_list->c_str());
|
||||
}
|
||||
if (!nvshmem_ib_gid_index.has_value() && nccl_ib_gid_index.has_value()) {
|
||||
c10::utils::set_env("NVSHMEM_IB_GID_INDEX", nccl_ib_gid_index->c_str());
|
||||
}
|
||||
}
|
||||
|
||||
void initialize_nvshmem_with_store(
|
||||
c10::intrusive_ptr<c10d::Store> store,
|
||||
int rank,
|
||||
int world_size) {
|
||||
static bool is_initialized = false;
|
||||
if (is_initialized) {
|
||||
return;
|
||||
}
|
||||
|
||||
maybe_initialize_env_vars();
|
||||
|
||||
nvshmemx_uniqueid_t unique_id;
|
||||
TORCH_CHECK(
|
||||
nvshmemx_get_uniqueid(&unique_id) == 0, "nvshmemx_get_uniqueid failed");
|
||||
|
||||
// Using an existing store_all_gather due to laziness.
|
||||
// TODO(yifu): should use broadcast
|
||||
auto unique_ids = storeExchange.all_gather(store, rank, world_size, unique_id);
|
||||
|
||||
nvshmemx_init_attr_t attr;
|
||||
nvshmemx_set_attr_uniqueid_args(rank, world_size, &unique_ids[0], &attr);
|
||||
|
||||
TORCH_CHECK(
|
||||
nvshmemx_init_attr(NVSHMEMX_INIT_WITH_UNIQUEID, &attr) == 0,
|
||||
"nvshmemx_init_attr failed");
|
||||
|
||||
is_initialized = true;
|
||||
}
|
||||
|
||||
void* nvshmem_malloc(size_t size) {
|
||||
return ::nvshmem_malloc(size);
|
||||
}
|
||||
|
||||
void* nvshmem_ptr(const void* dest, int pe) {
|
||||
return ::nvshmem_ptr(dest, pe);
|
||||
}
|
||||
|
||||
std::unordered_map<std::string, nvshmem_team_t> group_name_to_team_;
|
||||
|
||||
nvshmem_team_t group_to_team(
|
||||
const std::string& group_name,
|
||||
const std::vector<int>& global_ranks) {
|
||||
auto it = group_name_to_team_.find(group_name);
|
||||
if (it != group_name_to_team_.end()) {
|
||||
return it->second;
|
||||
}
|
||||
TORCH_CHECK(global_ranks.size() > 1);
|
||||
int stride = global_ranks[1] - global_ranks[0];
|
||||
for (size_t r = 1; r < global_ranks.size(); ++r) {
|
||||
TORCH_CHECK(global_ranks[r] - global_ranks[r - 1] == stride);
|
||||
}
|
||||
|
||||
nvshmem_team_t team;
|
||||
TORCH_CHECK(
|
||||
nvshmem_team_split_strided(
|
||||
NVSHMEM_TEAM_WORLD,
|
||||
global_ranks[0],
|
||||
stride,
|
||||
global_ranks.size(),
|
||||
nullptr,
|
||||
0,
|
||||
&team) == 0);
|
||||
group_name_to_team_[group_name] = team;
|
||||
TORCH_CHECK(team != NVSHMEM_TEAM_INVALID);
|
||||
return team;
|
||||
}
|
||||
|
||||
at::Tensor nvshmem_broadcast(at::Tensor& input, const std::string& group_name) {
|
||||
auto input_hdl = c10d::symmetric_memory::rendezvous(input, group_name);
|
||||
int rank = input_hdl->get_rank();
|
||||
int world_size = input_hdl->get_world_size();
|
||||
auto team = group_to_team(group_name, input_hdl->get_rank_to_global_rank());
|
||||
void* buffer_ptr = input_hdl->get_buffer_ptrs()[rank];
|
||||
|
||||
auto stream = at::cuda::getCurrentCUDAStream();
|
||||
nvshmemx_broadcastmem_on_stream(team, buffer_ptr, buffer_ptr, input_hdl->get_buffer_size(), 0, stream);
|
||||
return input;
|
||||
}
|
||||
|
||||
|
||||
} // namespace c10d::nvshmem_extension
|
||||
|
||||
|
||||
TORCH_LIBRARY_IMPL(symm_mem, CUDA, m) {
|
||||
m.impl("nvshmem_broadcast", c10d::nvshmem_extension::nvshmem_broadcast);
|
||||
}
|
20
torch/csrc/distributed/c10d/nvshmem_extension.cuh
Normal file
20
torch/csrc/distributed/c10d/nvshmem_extension.cuh
Normal file
@ -0,0 +1,20 @@
|
||||
#pragma once
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
|
||||
#include <torch/csrc/distributed/c10d/Store.hpp>
|
||||
|
||||
namespace c10d::nvshmem_extension {
|
||||
|
||||
void initialize_nvshmem_with_store(
|
||||
c10::intrusive_ptr<c10d::Store> store,
|
||||
int rank,
|
||||
int world_size);
|
||||
|
||||
void* nvshmem_malloc(size_t size);
|
||||
|
||||
void* nvshmem_ptr(const void* dest, int pe);
|
||||
|
||||
at::Tensor nvshmem_broadcast(at::Tensor& input, const std::string& group_name);
|
||||
|
||||
} // namespace c10d::nvshmem_extension
|
@ -37,26 +37,6 @@ def enable_symm_mem_for_group(group_name: str) -> None:
|
||||
f"symmetric_memory-{global_ranks_str}",
|
||||
c10d._get_process_group_store(group),
|
||||
)
|
||||
# Use one store-based broadcast to bootstrap a file store from the process
|
||||
# and simultaneously verify that all ranks are on the same host.
|
||||
hostname = socket.gethostname()
|
||||
if group.rank() == 0:
|
||||
uid = str(uuid.uuid4())
|
||||
msg = f"{hostname}/{uid}"
|
||||
store.set("init", msg)
|
||||
else:
|
||||
msg = store.get("init").decode("utf-8")
|
||||
tokens = msg.split("/")
|
||||
assert len(tokens) == 2, tokens
|
||||
rank_0_hostname, uid = tokens
|
||||
if hostname != rank_0_hostname:
|
||||
raise RuntimeError(
|
||||
"init_symmetric_memory_for_process_group() failed for "
|
||||
f'group "{group_name}". Rank 0 and rank {group.rank()} '
|
||||
f"are on different hosts ({rank_0_hostname} and {hostname})"
|
||||
)
|
||||
store = torch._C._distributed_c10d.FileStore(f"/tmp/{uid}", group.size())
|
||||
# TODO: check device connectiivity
|
||||
_group_name_to_store[group_name] = store
|
||||
_SymmetricMemory.set_group_info(
|
||||
group_name,
|
||||
|
Reference in New Issue
Block a user