mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Revert "Introduce a prototype for SymmetricMemory (#128582)"
This reverts commit 7a39755da28d5a109bf0c37f72b364d3a83137b1. Reverted https://github.com/pytorch/pytorch/pull/128582 on behalf of https://github.com/fbgheith due to breaking internal builds ([comment](https://github.com/pytorch/pytorch/pull/128582#issuecomment-2176685232))
This commit is contained in:
@ -68,7 +68,6 @@ include_patterns = [
|
||||
'aten/src/ATen/native/cudnn/*.cpp',
|
||||
'c10/**/*.h',
|
||||
'c10/**/*.cpp',
|
||||
'distributed/c10d/*SymmetricMemory.*',
|
||||
'torch/csrc/**/*.h',
|
||||
'torch/csrc/**/*.hpp',
|
||||
'torch/csrc/**/*.cpp',
|
||||
|
@ -744,7 +744,6 @@ cc_library(
|
||||
"torch/csrc/cuda/python_nccl.cpp",
|
||||
"torch/csrc/cuda/nccl.cpp",
|
||||
"torch/csrc/distributed/c10d/intra_node_comm.cu",
|
||||
"torch/csrc/distributed/c10d/CUDASymmetricMemory.cu",
|
||||
"torch/csrc/distributed/c10d/Utils.cu",
|
||||
"torch/csrc/distributed/c10d/quantization/quantization_gpu.cu",
|
||||
],
|
||||
|
@ -501,7 +501,6 @@ libtorch_distributed_base_sources = [
|
||||
"torch/csrc/distributed/c10d/ProcessGroupMPI.cpp",
|
||||
"torch/csrc/distributed/c10d/ProcessGroupWrapper.cpp",
|
||||
"torch/csrc/distributed/c10d/Store.cpp",
|
||||
"torch/csrc/distributed/c10d/SymmetricMemory.cpp",
|
||||
"torch/csrc/distributed/c10d/TCPStore.cpp",
|
||||
"torch/csrc/distributed/c10d/TCPStoreBackend.cpp",
|
||||
"torch/csrc/distributed/c10d/TCPStoreLibUvBackend.cpp",
|
||||
@ -685,7 +684,6 @@ libtorch_cuda_distributed_extra_sources = [
|
||||
"torch/csrc/distributed/c10d/UCCUtils.cpp",
|
||||
"torch/csrc/distributed/c10d/intra_node_comm.cpp",
|
||||
"torch/csrc/distributed/c10d/intra_node_comm.cu",
|
||||
"torch/csrc/distributed/c10d/CUDASymmetricMemory.cu",
|
||||
"torch/csrc/distributed/c10d/Utils.cu",
|
||||
"torch/csrc/distributed/rpc/tensorpipe_cuda.cpp",
|
||||
"torch/csrc/distributed/c10d/quantization/quantization_gpu.cu",
|
||||
|
@ -26,9 +26,6 @@
|
||||
_(cuMemSetAccess) \
|
||||
_(cuMemUnmap) \
|
||||
_(cuMemCreate) \
|
||||
_(cuMemGetAllocationGranularity) \
|
||||
_(cuMemExportToShareableHandle) \
|
||||
_(cuMemImportFromShareableHandle) \
|
||||
_(cuGetErrorString)
|
||||
|
||||
#define C10_NVML_DRIVER_API(_) \
|
||||
|
@ -560,7 +560,6 @@ if(USE_CUDA)
|
||||
append_filelist("libtorch_cuda_distributed_extra_sources" Caffe2_GPU_SRCS)
|
||||
set_source_files_properties(
|
||||
${TORCH_SRC_DIR}/csrc/distributed/c10d/intra_node_comm.cpp
|
||||
${TORCH_SRC_DIR}/csrc/distributed/c10d/CUDASymmetricMemory.cu
|
||||
PROPERTIES COMPILE_FLAGS "-DPYTORCH_C10_DRIVER_API_SUPPORTED=1"
|
||||
)
|
||||
endif()
|
||||
|
@ -1,156 +0,0 @@
|
||||
# Owner(s): ["module: c10d"]
|
||||
|
||||
import torch
|
||||
|
||||
import torch.distributed as dist
|
||||
from torch._C._distributed_c10d import _SymmetricMemory
|
||||
from torch.distributed.distributed_c10d import _get_process_group_store
|
||||
|
||||
from torch.testing._internal.common_distributed import (
|
||||
MultiProcessTestCase,
|
||||
skip_if_lt_x_gpu,
|
||||
)
|
||||
from torch.testing._internal.common_utils import (
|
||||
instantiate_parametrized_tests,
|
||||
run_tests,
|
||||
skip_but_pass_in_sandcastle_if,
|
||||
skipIfRocm,
|
||||
)
|
||||
|
||||
|
||||
def requires_cuda_p2p_access():
|
||||
cuda_p2p_access_available = (
|
||||
torch.cuda.is_available() and torch.cuda.device_count() >= 2
|
||||
)
|
||||
num_devices = torch.cuda.device_count()
|
||||
for i in range(num_devices - 1):
|
||||
for j in range(i + 1, num_devices):
|
||||
if not torch.cuda.can_device_access_peer(i, j):
|
||||
cuda_p2p_access_available = False
|
||||
break
|
||||
if not cuda_p2p_access_available:
|
||||
break
|
||||
|
||||
return skip_but_pass_in_sandcastle_if(
|
||||
not cuda_p2p_access_available,
|
||||
"cuda p2p access is not available",
|
||||
)
|
||||
|
||||
|
||||
@instantiate_parametrized_tests
|
||||
@requires_cuda_p2p_access()
|
||||
class SymmetricMemoryTest(MultiProcessTestCase):
|
||||
def setUp(self) -> None:
|
||||
super().setUp()
|
||||
self._spawn_processes()
|
||||
|
||||
@property
|
||||
def world_size(self) -> int:
|
||||
return 2
|
||||
|
||||
@property
|
||||
def device(self) -> torch.device:
|
||||
return torch.device(f"cuda:{self.rank}")
|
||||
|
||||
def _init_process(self):
|
||||
torch.cuda.set_device(self.device)
|
||||
store = dist.FileStore(self.file_name, self.world_size)
|
||||
dist.init_process_group(
|
||||
backend="nccl",
|
||||
world_size=self.world_size,
|
||||
rank=self.rank,
|
||||
store=store,
|
||||
)
|
||||
_SymmetricMemory.set_group_info(
|
||||
"0",
|
||||
self.rank,
|
||||
self.world_size,
|
||||
_get_process_group_store(dist.GroupMember.WORLD),
|
||||
)
|
||||
|
||||
def _verify_symmetric_memory(self, symm_mem):
|
||||
self.assertEqual(symm_mem.world_size, 2)
|
||||
|
||||
buf = symm_mem.get_buffer(0, (64, 64), torch.float32)
|
||||
if symm_mem.rank == 0:
|
||||
symm_mem.wait_signal(src_rank=1)
|
||||
self.assertTrue(buf.eq(42).all())
|
||||
else:
|
||||
buf.fill_(42)
|
||||
symm_mem.put_signal(dst_rank=0)
|
||||
|
||||
symm_mem.barrier()
|
||||
|
||||
if symm_mem.rank == 0:
|
||||
symm_mem.barrier()
|
||||
self.assertTrue(buf.eq(43).all())
|
||||
else:
|
||||
buf.fill_(43)
|
||||
symm_mem.barrier()
|
||||
|
||||
symm_mem.barrier()
|
||||
|
||||
@skipIfRocm
|
||||
@skip_if_lt_x_gpu(2)
|
||||
def test_empty_strided_p2p(self) -> None:
|
||||
self._init_process()
|
||||
|
||||
shape = (64, 64)
|
||||
stride = (64, 1)
|
||||
dtype = torch.float32
|
||||
device = self.device
|
||||
group_name = "0"
|
||||
alloc_args = (shape, stride, dtype, device, group_name)
|
||||
|
||||
t = torch.empty(shape, dtype=dtype, device=device)
|
||||
with self.assertRaises(RuntimeError):
|
||||
_SymmetricMemory.rendezvous(t)
|
||||
|
||||
t = _SymmetricMemory.empty_strided_p2p(*alloc_args)
|
||||
symm_mem = _SymmetricMemory.rendezvous(t)
|
||||
|
||||
del t
|
||||
self._verify_symmetric_memory(symm_mem)
|
||||
|
||||
@skipIfRocm
|
||||
@skip_if_lt_x_gpu(2)
|
||||
def test_empty_strided_p2p_persistent(self) -> None:
|
||||
self._init_process()
|
||||
|
||||
shape = (64, 64)
|
||||
stride = (64, 1)
|
||||
dtype = torch.float32
|
||||
device = self.device
|
||||
alloc_id = 42 # Persistent allocation
|
||||
group_name = "0"
|
||||
alloc_args = (shape, stride, dtype, device, group_name, alloc_id)
|
||||
|
||||
t = _SymmetricMemory.empty_strided_p2p(*alloc_args)
|
||||
data_ptr = t.data_ptr()
|
||||
|
||||
# Verify that persistent allocation would fail if there's an active
|
||||
# allocation with the same alloc_id.
|
||||
with self.assertRaises(RuntimeError):
|
||||
_SymmetricMemory.empty_strided_p2p(*alloc_args)
|
||||
|
||||
# Verify that persistent allocation would succeed in lieu of activate
|
||||
# allocations with the same alloc_id, and the returned tensor would
|
||||
# have the same data pointer.
|
||||
del t
|
||||
t = _SymmetricMemory.empty_strided_p2p(*alloc_args)
|
||||
self.assertEqual(t.data_ptr(), data_ptr)
|
||||
|
||||
# Verify that get_symmetric_memory would fail if called before
|
||||
# rendezvous.
|
||||
with self.assertRaises(RuntimeError):
|
||||
_SymmetricMemory.get_symmetric_memory(t)
|
||||
|
||||
symm_mem_0 = _SymmetricMemory.rendezvous(t)
|
||||
symm_mem_1 = _SymmetricMemory.get_symmetric_memory(t)
|
||||
self.assertEqual(id(symm_mem_0), id(symm_mem_1))
|
||||
|
||||
self._verify_symmetric_memory(symm_mem_0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
@ -637,33 +637,3 @@ class ProcessGroupCudaP2P(Backend):
|
||||
storage_offset: Optional[int] = 0,
|
||||
) -> torch.Tensor: ...
|
||||
def _shutdown(self) -> None: ...
|
||||
|
||||
class _SymmetricMemory:
|
||||
@staticmethod
|
||||
def set_group_info(
|
||||
group_name: str, rank: int, world_size: int, store: Store
|
||||
) -> None: ...
|
||||
@staticmethod
|
||||
def empty_strided_p2p(
|
||||
size: torch.types._size,
|
||||
stride: torch.types._size,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
group_name: str,
|
||||
) -> torch.Tensor: ...
|
||||
@property
|
||||
def rank(self) -> int: ...
|
||||
@property
|
||||
def world_size(self) -> int: ...
|
||||
@staticmethod
|
||||
def rendezvous(tensor: torch.Tensor) -> _SymmetricMemory: ...
|
||||
def get_buffer(
|
||||
self,
|
||||
rank: int,
|
||||
sizes: torch.Size,
|
||||
dtype: torch.dtype,
|
||||
storage_offset: Optional[int] = 0,
|
||||
) -> torch.Tensor: ...
|
||||
def barrier(self, channel: int = 0) -> None: ...
|
||||
def put_signal(self, dst_rank: int, channel: int = 0) -> None: ...
|
||||
def wait_signal(self, src_rank: int, channel: int = 0) -> None: ...
|
||||
|
@ -1,539 +0,0 @@
|
||||
#include <torch/csrc/distributed/c10d/CUDASymmetricMemory.cuh>
|
||||
|
||||
#include <ATen/ceil_div.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDACachingAllocator.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
|
||||
#if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED)
|
||||
#include <c10/cuda/driver_api.h>
|
||||
#endif
|
||||
|
||||
#include <sys/syscall.h>
|
||||
#include <unistd.h>
|
||||
|
||||
namespace {
|
||||
|
||||
constexpr size_t signal_pad_size = 2048;
|
||||
const std::string store_comm_prefix = "CUDASymmetricMemory";
|
||||
|
||||
static size_t store_comm_seq_id = 0;
|
||||
|
||||
template <typename T>
|
||||
std::vector<T> store_all_gather(
|
||||
const c10::intrusive_ptr<c10d::Store>& store,
|
||||
int rank,
|
||||
int world_size,
|
||||
T val) {
|
||||
static_assert(std::is_trivially_copyable_v<T>);
|
||||
|
||||
std::vector<std::string> peer_keys;
|
||||
for (int r = 0; r < world_size; ++r) {
|
||||
std::ostringstream oss;
|
||||
oss << store_comm_prefix << "/" << store_comm_seq_id << "/" << r;
|
||||
peer_keys.push_back(oss.str());
|
||||
}
|
||||
++store_comm_seq_id;
|
||||
|
||||
{
|
||||
std::vector<uint8_t> payload(
|
||||
reinterpret_cast<uint8_t*>(&val),
|
||||
reinterpret_cast<uint8_t*>(&val) + sizeof(T));
|
||||
store->set(peer_keys[rank], payload);
|
||||
}
|
||||
|
||||
std::vector<T> peer_vals;
|
||||
for (int r = 0; r < world_size; ++r) {
|
||||
if (r == rank) {
|
||||
peer_vals.push_back(val);
|
||||
continue;
|
||||
}
|
||||
store->wait({peer_keys[r]});
|
||||
auto payload = store->get(peer_keys[r]);
|
||||
TORCH_CHECK(payload.size() == sizeof(T));
|
||||
T peer_val{};
|
||||
std::memcpy(&peer_val, payload.data(), sizeof(T));
|
||||
peer_vals.push_back(peer_val);
|
||||
}
|
||||
return peer_vals;
|
||||
}
|
||||
|
||||
void store_barrier(
|
||||
const c10::intrusive_ptr<c10d::Store>& store,
|
||||
int rank,
|
||||
int world_size) {
|
||||
store_all_gather(store, rank, world_size, 0);
|
||||
}
|
||||
|
||||
int import_remote_fd(int pid, int fd) {
|
||||
#if defined(SYS_pidfd_open) and defined(SYS_pidfd_getfd)
|
||||
int pidfd = syscall(SYS_pidfd_open, pid, 0);
|
||||
return syscall(SYS_pidfd_getfd, pidfd, fd, 0);
|
||||
#else
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"CUDASymmetricMemory requires pidfd_open ",
|
||||
"and pidfd_getfd support");
|
||||
#endif
|
||||
}
|
||||
|
||||
void map_block(
|
||||
void** ptr,
|
||||
c10d::symmetric_memory::HandleType handle,
|
||||
size_t size,
|
||||
int device_idx) {
|
||||
#if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED)
|
||||
auto driver_api = c10::cuda::DriverAPI::get();
|
||||
auto dev_ptr = reinterpret_cast<CUdeviceptr*>(ptr);
|
||||
C10_CUDA_DRIVER_CHECK(
|
||||
driver_api->cuMemAddressReserve_(dev_ptr, size, 0ULL, 0, 0ULL));
|
||||
C10_CUDA_DRIVER_CHECK(driver_api->cuMemMap_(*dev_ptr, size, 0, handle, 0ULL));
|
||||
|
||||
CUmemAccessDesc desc;
|
||||
desc.location.type = CU_MEM_LOCATION_TYPE_DEVICE;
|
||||
// NOLINTNEXTLINE(bugprone-signed-char-misuse)
|
||||
desc.location.id = static_cast<int>(device_idx);
|
||||
desc.flags = CU_MEM_ACCESS_FLAGS_PROT_READWRITE;
|
||||
C10_CUDA_DRIVER_CHECK(driver_api->cuMemSetAccess_(*dev_ptr, size, &desc, 1));
|
||||
#else
|
||||
TORCH_CHECK(
|
||||
false, "CUDASymmetricMemory requires PYTORCH_C10_DRIVER_API_SUPPORTED");
|
||||
#endif
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
namespace c10d {
|
||||
namespace symmetric_memory {
|
||||
|
||||
CUDASymmetricMemory::CUDASymmetricMemory(
|
||||
std::vector<HandleType> handles,
|
||||
size_t block_size,
|
||||
std::vector<void*> buffers,
|
||||
std::vector<void*> signal_pads,
|
||||
size_t buffer_size,
|
||||
int local_device_idx,
|
||||
int rank,
|
||||
int world_size)
|
||||
: handles_(std::move(handles)),
|
||||
block_size_(block_size),
|
||||
buffers_(std::move(buffers)),
|
||||
signal_pads_(std::move(signal_pads)),
|
||||
buffer_size_(buffer_size),
|
||||
local_device_idx_(local_device_idx),
|
||||
rank_(rank),
|
||||
world_size_(world_size) {
|
||||
const size_t arr_size = sizeof(void*) * world_size_;
|
||||
buffers_dev_ = reinterpret_cast<void**>(
|
||||
c10::cuda::CUDACachingAllocator::raw_alloc(arr_size));
|
||||
signal_pads_dev_ = reinterpret_cast<void**>(
|
||||
c10::cuda::CUDACachingAllocator::raw_alloc(arr_size));
|
||||
|
||||
c10::cuda::CUDAGuard guard(local_device_idx);
|
||||
AT_CUDA_CHECK(cudaMemcpy(
|
||||
buffers_dev_, buffers_.data(), arr_size, cudaMemcpyHostToDevice));
|
||||
AT_CUDA_CHECK(cudaMemcpy(
|
||||
signal_pads_dev_, signal_pads_.data(), arr_size, cudaMemcpyHostToDevice));
|
||||
}
|
||||
|
||||
CUDASymmetricMemory::~CUDASymmetricMemory() {
|
||||
#if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED)
|
||||
c10::cuda::CUDAGuard guard(local_device_idx_);
|
||||
C10_CUDA_CHECK(cudaDeviceSynchronize());
|
||||
|
||||
auto driver_api = c10::cuda::DriverAPI::get();
|
||||
for (int r = 0; r < world_size_; ++r) {
|
||||
C10_CUDA_DRIVER_CHECK(driver_api->cuMemUnmap_(
|
||||
reinterpret_cast<CUdeviceptr>(buffers_[r]), block_size_));
|
||||
C10_CUDA_DRIVER_CHECK(driver_api->cuMemRelease_(handles_[r]));
|
||||
}
|
||||
c10::cuda::CUDACachingAllocator::raw_delete(buffers_dev_);
|
||||
c10::cuda::CUDACachingAllocator::raw_delete(signal_pads_dev_);
|
||||
#else
|
||||
TORCH_CHECK(
|
||||
false, "CUDASymmetricMemory requires PYTORCH_C10_DRIVER_API_SUPPORTED");
|
||||
#endif
|
||||
}
|
||||
|
||||
std::vector<void*> CUDASymmetricMemory::get_buffer_ptrs() {
|
||||
return buffers_;
|
||||
}
|
||||
|
||||
std::vector<void*> CUDASymmetricMemory::get_signal_pad_ptrs() {
|
||||
return signal_pads_;
|
||||
}
|
||||
|
||||
void** CUDASymmetricMemory::get_buffer_ptrs_dev() {
|
||||
return buffers_dev_;
|
||||
}
|
||||
|
||||
void** CUDASymmetricMemory::get_signal_pad_ptrs_dev() {
|
||||
return signal_pads_dev_;
|
||||
}
|
||||
|
||||
size_t CUDASymmetricMemory::get_buffer_size() {
|
||||
return buffer_size_;
|
||||
}
|
||||
|
||||
size_t CUDASymmetricMemory::get_signal_pad_size() {
|
||||
return signal_pad_size;
|
||||
}
|
||||
|
||||
at::Tensor CUDASymmetricMemory::get_buffer(
|
||||
int rank,
|
||||
c10::IntArrayRef sizes,
|
||||
c10::ScalarType dtype,
|
||||
int64_t storage_offset) {
|
||||
const auto numel =
|
||||
std::accumulate(sizes.begin(), sizes.end(), 1, std::multiplies<int>());
|
||||
const auto element_size = c10::elementSize(dtype);
|
||||
const auto req_size = (numel + storage_offset) * element_size;
|
||||
TORCH_CHECK(
|
||||
req_size <= buffer_size_,
|
||||
"CUDASymmetricMemory::get_buffer: the requested size (",
|
||||
req_size,
|
||||
" bytes) exceeds the allocated size (",
|
||||
buffer_size_,
|
||||
" bytes)");
|
||||
auto device = c10::Device(c10::DeviceType::CUDA, local_device_idx_);
|
||||
auto options = at::TensorOptions().dtype(dtype).device(device);
|
||||
return at::for_blob(buffers_[rank], sizes)
|
||||
.storage_offset(storage_offset)
|
||||
.options(options)
|
||||
.target_device(device)
|
||||
.make_tensor();
|
||||
}
|
||||
|
||||
void check_channel(int channel, int world_size) {
|
||||
TORCH_CHECK(
|
||||
channel >= 0,
|
||||
"channel for barrier(), put_signal() and wait_signal() ",
|
||||
"must be greater than 0 (got ",
|
||||
channel,
|
||||
")");
|
||||
const size_t num_channels = signal_pad_size / sizeof(uint32_t) * world_size;
|
||||
TORCH_CHECK(
|
||||
static_cast<size_t>(channel) < num_channels,
|
||||
"The maximum supported channel for barrier(), put_signal() and wait_signal() is ",
|
||||
num_channels - 1,
|
||||
" (got ",
|
||||
channel,
|
||||
")");
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void release_signal(uint32_t* addr) {
|
||||
#if defined(USE_ROCM) || (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800))
|
||||
CUDA_KERNEL_ASSERT(false);
|
||||
#else
|
||||
volatile uint32_t* signal = addr;
|
||||
uint32_t val;
|
||||
do {
|
||||
val = *signal;
|
||||
} while (val != 0 || atomicCAS_system(addr, 0, 1) != 0);
|
||||
#endif
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void acquire_signal(uint32_t* addr) {
|
||||
#if defined(USE_ROCM) || (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800))
|
||||
CUDA_KERNEL_ASSERT(false);
|
||||
#else
|
||||
volatile uint32_t* signal = addr;
|
||||
uint32_t val;
|
||||
do {
|
||||
val = *signal;
|
||||
} while (val != 1 || atomicCAS_system(addr, 1, 0) != 1);
|
||||
#endif
|
||||
}
|
||||
|
||||
static __global__ void barrier_kernel(
|
||||
uint32_t** signal_pads,
|
||||
int channel,
|
||||
int rank,
|
||||
int world_size) {
|
||||
if (threadIdx.x < world_size) {
|
||||
auto target_rank = threadIdx.x;
|
||||
release_signal(signal_pads[target_rank] + world_size * channel + rank);
|
||||
acquire_signal(signal_pads[rank] + world_size * channel + target_rank);
|
||||
}
|
||||
}
|
||||
|
||||
void CUDASymmetricMemory::barrier(int channel) {
|
||||
check_channel(channel, world_size_);
|
||||
c10::cuda::CUDAGuard guard(local_device_idx_);
|
||||
barrier_kernel<<<1, C10_WARP_SIZE, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
reinterpret_cast<uint32_t**>(signal_pads_dev_),
|
||||
channel,
|
||||
rank_,
|
||||
world_size_);
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
}
|
||||
|
||||
static __global__ void put_signal_kernel(
|
||||
uint32_t** signal_pads,
|
||||
int dst_rank,
|
||||
int channel,
|
||||
int rank,
|
||||
int world_size) {
|
||||
if (threadIdx.x == 0) {
|
||||
release_signal(signal_pads[dst_rank] + world_size * channel + rank);
|
||||
}
|
||||
}
|
||||
|
||||
void CUDASymmetricMemory::put_signal(int dst_rank, int channel) {
|
||||
check_channel(channel, world_size_);
|
||||
c10::cuda::CUDAGuard guard(local_device_idx_);
|
||||
put_signal_kernel<<<1, C10_WARP_SIZE, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
reinterpret_cast<uint32_t**>(signal_pads_dev_),
|
||||
dst_rank,
|
||||
channel,
|
||||
rank_,
|
||||
world_size_);
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
}
|
||||
|
||||
static __global__ void wait_signal_kernel(
|
||||
uint32_t** signal_pads,
|
||||
int src_rank,
|
||||
int channel,
|
||||
int rank,
|
||||
int world_size) {
|
||||
if (threadIdx.x == 0) {
|
||||
acquire_signal(signal_pads[rank] + world_size * channel + src_rank);
|
||||
}
|
||||
__threadfence_system();
|
||||
}
|
||||
|
||||
void CUDASymmetricMemory::wait_signal(int src_rank, int channel) {
|
||||
check_channel(channel, world_size_);
|
||||
c10::cuda::CUDAGuard guard(local_device_idx_);
|
||||
wait_signal_kernel<<<1, C10_WARP_SIZE, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
reinterpret_cast<uint32_t**>(signal_pads_dev_),
|
||||
src_rank,
|
||||
channel,
|
||||
rank_,
|
||||
world_size_);
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
}
|
||||
|
||||
int CUDASymmetricMemory::get_rank() {
|
||||
return rank_;
|
||||
}
|
||||
|
||||
int CUDASymmetricMemory::get_world_size() {
|
||||
return world_size_;
|
||||
}
|
||||
|
||||
void* CUDASymmetricMemoryAllocator::alloc(
|
||||
size_t size,
|
||||
int device_idx,
|
||||
const std::string& group_name) {
|
||||
#if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED)
|
||||
auto driver_api = c10::cuda::DriverAPI::get();
|
||||
|
||||
CUmemAllocationProp prop = {};
|
||||
prop.type = CU_MEM_ALLOCATION_TYPE_PINNED;
|
||||
prop.location.type = CU_MEM_LOCATION_TYPE_DEVICE;
|
||||
// NOLINTNEXTLINE(bugprone-signed-char-misuse)
|
||||
prop.location.id = device_idx;
|
||||
prop.requestedHandleTypes = CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR;
|
||||
|
||||
size_t signal_pad_offset = at::round_up(size, 16UL);
|
||||
size_t block_size = signal_pad_offset + signal_pad_size;
|
||||
|
||||
size_t granularity;
|
||||
C10_CUDA_DRIVER_CHECK(driver_api->cuMemGetAllocationGranularity_(
|
||||
&granularity, &prop, CU_MEM_ALLOC_GRANULARITY_RECOMMENDED));
|
||||
block_size = at::round_up(block_size, granularity);
|
||||
|
||||
HandleType handle;
|
||||
C10_CUDA_DRIVER_CHECK(
|
||||
driver_api->cuMemCreate_(&handle, block_size, &prop, 0));
|
||||
|
||||
void* ptr = nullptr;
|
||||
map_block(&ptr, handle, block_size, device_idx);
|
||||
|
||||
c10::cuda::CUDAGuard guard(device_idx);
|
||||
AT_CUDA_CHECK(cudaMemset(ptr, 0, block_size));
|
||||
|
||||
auto block = c10::make_intrusive<Block>(
|
||||
handle, device_idx, block_size, size, signal_pad_offset, group_name);
|
||||
{
|
||||
std::unique_lock lock(mutex_);
|
||||
ptr_to_block_.emplace(ptr, std::move(block));
|
||||
}
|
||||
return ptr;
|
||||
#else
|
||||
TORCH_CHECK(
|
||||
false, "CUDASymmetricMemory requires PYTORCH_C10_DRIVER_API_SUPPORTED");
|
||||
#endif
|
||||
}
|
||||
|
||||
void CUDASymmetricMemoryAllocator::free(void* ptr) {
|
||||
#if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED)
|
||||
auto block = find_block(ptr);
|
||||
if (block == nullptr) {
|
||||
return;
|
||||
}
|
||||
// Initializing CUDASymmetricMemory with an allocation transfers its
|
||||
// ownership to the CUDASymmetricMemory object.
|
||||
if (block->symm_mem == nullptr) {
|
||||
auto driver_api = c10::cuda::DriverAPI::get();
|
||||
C10_CUDA_DRIVER_CHECK(driver_api->cuMemUnmap_(
|
||||
reinterpret_cast<CUdeviceptr>(ptr), block->block_size));
|
||||
C10_CUDA_DRIVER_CHECK(driver_api->cuMemRelease_(block->handle));
|
||||
}
|
||||
{
|
||||
std::unique_lock lock(mutex_);
|
||||
ptr_to_block_.erase(ptr);
|
||||
}
|
||||
#else
|
||||
TORCH_CHECK(
|
||||
false, "CUDASymmetricMemory requires PYTORCH_C10_DRIVER_API_SUPPORTED");
|
||||
#endif
|
||||
}
|
||||
|
||||
size_t CUDASymmetricMemoryAllocator::get_alloc_size(void* ptr) {
|
||||
auto block = find_block(ptr);
|
||||
TORCH_CHECK(
|
||||
block != nullptr,
|
||||
"CUDASymmetricMemoryAllocator::get_alloc_size: input must be allocated ",
|
||||
"via CUDASymmetricMemoryAllocator::alloc");
|
||||
return block->buffer_size;
|
||||
}
|
||||
|
||||
struct RendezvousRequest {
|
||||
int device_idx;
|
||||
int block_fd;
|
||||
int pid;
|
||||
size_t block_size;
|
||||
size_t buffer_size;
|
||||
size_t signal_pad_offset;
|
||||
};
|
||||
|
||||
void validate_rendezvous_requests(
|
||||
const std::vector<RendezvousRequest> reqs,
|
||||
int world_size) {
|
||||
TORCH_CHECK(reqs.size() == (size_t)world_size);
|
||||
|
||||
std::unordered_set<int> device_indices;
|
||||
device_indices.reserve(world_size);
|
||||
for (auto req : reqs) {
|
||||
device_indices.insert(req.device_idx);
|
||||
}
|
||||
if (device_indices.size() < (size_t)world_size) {
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"CUDASymmetricMemoryAllocator::rendezvous: ",
|
||||
"detected allocations from overlapping devices ",
|
||||
"from different ranks.");
|
||||
}
|
||||
|
||||
for (int r = 1; r < world_size; ++r) {
|
||||
TORCH_CHECK(reqs[r].block_size == reqs[0].block_size);
|
||||
TORCH_CHECK(reqs[r].buffer_size == reqs[0].buffer_size);
|
||||
TORCH_CHECK(reqs[r].signal_pad_offset == reqs[0].signal_pad_offset);
|
||||
}
|
||||
}
|
||||
|
||||
c10::intrusive_ptr<SymmetricMemory> CUDASymmetricMemoryAllocator::rendezvous(
|
||||
void* ptr) {
|
||||
#if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED)
|
||||
auto block = find_block(ptr);
|
||||
TORCH_CHECK(
|
||||
block != nullptr,
|
||||
"CUDASymmetricMemoryAllocator::rendezvous: input must be allocated ",
|
||||
"via CUDASymmetricMemoryAllocator::alloc");
|
||||
|
||||
if (block->symm_mem != nullptr) {
|
||||
return block->symm_mem;
|
||||
}
|
||||
|
||||
auto group_info = get_group_info(block->group_name);
|
||||
auto driver_api = c10::cuda::DriverAPI::get();
|
||||
int block_fd;
|
||||
C10_CUDA_DRIVER_CHECK(driver_api->cuMemExportToShareableHandle_(
|
||||
&block_fd, block->handle, CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR, 0));
|
||||
|
||||
auto local_req = RendezvousRequest{
|
||||
.device_idx = block->device_idx,
|
||||
.block_fd = block_fd,
|
||||
.pid = getpid(),
|
||||
.block_size = block->block_size,
|
||||
.buffer_size = block->buffer_size,
|
||||
.signal_pad_offset = block->signal_pad_offset};
|
||||
auto reqs = store_all_gather(
|
||||
group_info.store, group_info.rank, group_info.world_size, local_req);
|
||||
validate_rendezvous_requests(reqs, group_info.world_size);
|
||||
|
||||
std::vector<HandleType> handles(group_info.world_size);
|
||||
std::vector<void*> buffers(group_info.world_size, nullptr);
|
||||
std::vector<void*> signal_pads(group_info.world_size, nullptr);
|
||||
for (int r = 0; r < group_info.world_size; ++r) {
|
||||
if (r == group_info.rank) {
|
||||
handles[r] = block->handle;
|
||||
buffers[r] = ptr;
|
||||
signal_pads[r] = (void*)((uintptr_t)ptr + block->signal_pad_offset);
|
||||
continue;
|
||||
}
|
||||
int imported_fd = import_remote_fd(reqs[r].pid, reqs[r].block_fd);
|
||||
C10_CUDA_DRIVER_CHECK(driver_api->cuMemImportFromShareableHandle_(
|
||||
&handles[r],
|
||||
(void*)(uintptr_t)imported_fd,
|
||||
CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR));
|
||||
map_block(&buffers[r], handles[r], block->block_size, block->device_idx);
|
||||
signal_pads[r] = (void*)((uintptr_t)buffers[r] + block->signal_pad_offset);
|
||||
close(imported_fd);
|
||||
}
|
||||
store_barrier(group_info.store, group_info.rank, group_info.world_size);
|
||||
close(block_fd);
|
||||
|
||||
// Initializing CUDASymmetricMemory with an allocation transfers its
|
||||
// ownership to the CUDASymmetricMemory object. So that outstanding
|
||||
// references to the CUDASymmetricMemory object can keep the allocation
|
||||
// alive.
|
||||
block->symm_mem = c10::make_intrusive<CUDASymmetricMemory>(
|
||||
std::move(handles),
|
||||
block->block_size,
|
||||
std::move(buffers),
|
||||
std::move(signal_pads),
|
||||
block->buffer_size,
|
||||
block->device_idx,
|
||||
group_info.rank,
|
||||
group_info.world_size);
|
||||
return block->symm_mem;
|
||||
#else
|
||||
TORCH_CHECK(
|
||||
false, "CUDASymmetricMemory requires PYTORCH_C10_DRIVER_API_SUPPORTED");
|
||||
#endif
|
||||
}
|
||||
|
||||
bool CUDASymmetricMemoryAllocator::is_rendezvous_completed(void* ptr) {
|
||||
auto block = find_block(ptr);
|
||||
TORCH_CHECK(
|
||||
block != nullptr,
|
||||
"CUDASymmetricMemoryAllocator::is_rendezvous_completed: input must be allocated ",
|
||||
"via CUDASymmetricMemoryAllocator::alloc");
|
||||
return block->symm_mem != nullptr;
|
||||
}
|
||||
|
||||
c10::intrusive_ptr<Block> CUDASymmetricMemoryAllocator::find_block(void* ptr) {
|
||||
std::shared_lock lock(mutex_);
|
||||
auto it = ptr_to_block_.find(ptr);
|
||||
if (it == ptr_to_block_.end()) {
|
||||
return nullptr;
|
||||
}
|
||||
return it->second;
|
||||
}
|
||||
|
||||
struct RegisterCUDASymmetricMemoryAllocator {
|
||||
RegisterCUDASymmetricMemoryAllocator() {
|
||||
register_allocator(
|
||||
c10::DeviceType::CUDA,
|
||||
c10::make_intrusive<CUDASymmetricMemoryAllocator>());
|
||||
}
|
||||
};
|
||||
|
||||
static RegisterCUDASymmetricMemoryAllocator register_allocator_;
|
||||
|
||||
} // namespace symmetric_memory
|
||||
} // namespace c10d
|
@ -1,109 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <torch/csrc/distributed/c10d/Store.hpp>
|
||||
#include <torch/csrc/distributed/c10d/SymmetricMemory.hpp>
|
||||
|
||||
namespace c10d {
|
||||
namespace symmetric_memory {
|
||||
|
||||
#if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED)
|
||||
using HandleType = CUmemGenericAllocationHandle;
|
||||
#else
|
||||
using HandleType = void*;
|
||||
#endif
|
||||
|
||||
class CUDASymmetricMemory : public SymmetricMemory {
|
||||
public:
|
||||
CUDASymmetricMemory(
|
||||
std::vector<HandleType> handles,
|
||||
size_t block_size,
|
||||
std::vector<void*> buffers,
|
||||
std::vector<void*> signal_pads,
|
||||
size_t buffer_size,
|
||||
int local_device_idx,
|
||||
int rank,
|
||||
int world_size);
|
||||
|
||||
~CUDASymmetricMemory() override;
|
||||
|
||||
std::vector<void*> get_buffer_ptrs() override;
|
||||
std::vector<void*> get_signal_pad_ptrs() override;
|
||||
void** get_buffer_ptrs_dev() override;
|
||||
void** get_signal_pad_ptrs_dev() override;
|
||||
size_t get_buffer_size() override;
|
||||
size_t get_signal_pad_size() override;
|
||||
|
||||
at::Tensor get_buffer(
|
||||
int rank,
|
||||
c10::IntArrayRef sizes,
|
||||
c10::ScalarType dtype,
|
||||
int64_t storage_offset) override;
|
||||
|
||||
void barrier(int channel) override;
|
||||
void put_signal(int dst_rank, int channel) override;
|
||||
void wait_signal(int src_rank, int channel) override;
|
||||
|
||||
int get_rank() override;
|
||||
int get_world_size() override;
|
||||
|
||||
private:
|
||||
std::vector<HandleType> handles_;
|
||||
size_t block_size_;
|
||||
std::vector<void*> buffers_;
|
||||
std::vector<void*> signal_pads_;
|
||||
size_t buffer_size_;
|
||||
int local_device_idx_;
|
||||
int rank_;
|
||||
int world_size_;
|
||||
void** buffers_dev_;
|
||||
void** signal_pads_dev_;
|
||||
std::optional<std::function<void(void)>> finalizer_;
|
||||
};
|
||||
|
||||
struct Block : public c10::intrusive_ptr_target {
|
||||
HandleType handle;
|
||||
int device_idx;
|
||||
size_t block_size;
|
||||
size_t buffer_size;
|
||||
size_t signal_pad_offset;
|
||||
std::string group_name;
|
||||
c10::intrusive_ptr<CUDASymmetricMemory> symm_mem = nullptr;
|
||||
|
||||
Block(
|
||||
HandleType handle,
|
||||
int device_idx,
|
||||
size_t block_size,
|
||||
size_t buffer_size,
|
||||
size_t signal_pad_offset,
|
||||
const std::string& group_name)
|
||||
: handle(handle),
|
||||
device_idx(device_idx),
|
||||
block_size(block_size),
|
||||
buffer_size(buffer_size),
|
||||
signal_pad_offset(signal_pad_offset),
|
||||
group_name(group_name),
|
||||
symm_mem(nullptr) {}
|
||||
};
|
||||
|
||||
class CUDASymmetricMemoryAllocator : public SymmetricMemoryAllocator {
|
||||
public:
|
||||
void* alloc(
|
||||
size_t size,
|
||||
int device_idx,
|
||||
const std::string& group_name) override;
|
||||
|
||||
void free(void *ptr) override;
|
||||
size_t get_alloc_size(void* ptr) override;
|
||||
c10::intrusive_ptr<SymmetricMemory> rendezvous(void* ptr) override;
|
||||
bool is_rendezvous_completed(void* ptr) override;
|
||||
|
||||
private:
|
||||
c10::intrusive_ptr<Block> find_block(void* ptr);
|
||||
|
||||
std::shared_mutex mutex_;
|
||||
std::unordered_map<void*, c10::intrusive_ptr<Block>> ptr_to_block_;
|
||||
};
|
||||
|
||||
} // namespace symmetric_memory
|
||||
} // namespace c10d
|
@ -10,7 +10,6 @@ constexpr auto kProcessGroupCudaP2PDefaultTimeout =
|
||||
|
||||
namespace c10d {
|
||||
|
||||
// NOTE: this class will be be removed soon in favor of SymmetricMemory
|
||||
class TORCH_API ProcessGroupCudaP2P : public Backend {
|
||||
public:
|
||||
struct Options : Backend::Options {
|
||||
|
@ -1,189 +0,0 @@
|
||||
#include <torch/csrc/distributed/c10d/SymmetricMemory.hpp>
|
||||
|
||||
namespace {
|
||||
|
||||
using namespace c10d::symmetric_memory;
|
||||
|
||||
class AllocatorMap {
|
||||
public:
|
||||
static AllocatorMap& get() {
|
||||
static AllocatorMap instance;
|
||||
return instance;
|
||||
}
|
||||
|
||||
void register_allocator(
|
||||
c10::DeviceType device_type,
|
||||
c10::intrusive_ptr<SymmetricMemoryAllocator> allocator) {
|
||||
map_[device_type] = std::move(allocator);
|
||||
}
|
||||
|
||||
c10::intrusive_ptr<SymmetricMemoryAllocator> get_allocator(
|
||||
c10::DeviceType device_type) {
|
||||
auto it = map_.find(device_type);
|
||||
TORCH_CHECK(
|
||||
it != map_.end(),
|
||||
"SymmetricMemory does not support device type ",
|
||||
device_type);
|
||||
return it->second;
|
||||
}
|
||||
|
||||
~AllocatorMap() {
|
||||
for (auto& it : map_) {
|
||||
it.second.release();
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
AllocatorMap() = default;
|
||||
AllocatorMap(const AllocatorMap&) = delete;
|
||||
AllocatorMap& operator=(const AllocatorMap&) = delete;
|
||||
|
||||
std::unordered_map<
|
||||
c10::DeviceType,
|
||||
c10::intrusive_ptr<SymmetricMemoryAllocator>>
|
||||
map_;
|
||||
};
|
||||
|
||||
static std::unordered_map<std::string, GroupInfo> group_info_map{};
|
||||
|
||||
// Data structures for tracking persistent allocations
|
||||
static std::unordered_map<uint64_t, void*> alloc_id_to_dev_ptr{};
|
||||
static std::unordered_map<uint64_t, c10::weak_intrusive_ptr<c10::StorageImpl>>
|
||||
alloc_id_to_storage{};
|
||||
|
||||
static at::Tensor empty_strided_p2p_persistent(
|
||||
c10::IntArrayRef size,
|
||||
c10::IntArrayRef stride,
|
||||
c10::ScalarType dtype,
|
||||
c10::Device device,
|
||||
const std::string& group_name,
|
||||
uint64_t alloc_id) {
|
||||
// Make the allocation fails if a previous allocation with the same alloc_id
|
||||
// is still active.
|
||||
auto storage = alloc_id_to_storage.find(alloc_id);
|
||||
if (storage != alloc_id_to_storage.end() && storage->second.use_count() > 0) {
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"SymmetricMemory::empty_strided_p2p_persistent: ",
|
||||
"can not allocate with alloc_id == ",
|
||||
alloc_id,
|
||||
" because a previous allocation with the same alloc_id "
|
||||
"is still active.");
|
||||
}
|
||||
|
||||
const size_t numel =
|
||||
std::accumulate(size.begin(), size.end(), 1, std::multiplies<int>());
|
||||
const size_t element_size = c10::elementSize(dtype);
|
||||
const size_t alloc_size = numel * element_size;
|
||||
|
||||
auto allocator = get_allocator(device.type());
|
||||
void* dev_ptr = nullptr;
|
||||
if (alloc_id_to_dev_ptr.find(alloc_id) != alloc_id_to_dev_ptr.end()) {
|
||||
dev_ptr = alloc_id_to_dev_ptr[alloc_id];
|
||||
TORCH_CHECK(
|
||||
alloc_size == allocator->get_alloc_size(dev_ptr),
|
||||
"SymmetricMemory::empty_strided_p2p_persistent: ",
|
||||
"requested allocation size (",
|
||||
alloc_size,
|
||||
") is different from the size of a previous allocation ",
|
||||
"with the same alloc_id ",
|
||||
allocator->get_alloc_size(dev_ptr));
|
||||
} else {
|
||||
dev_ptr = allocator->alloc(alloc_size, device.index(), group_name);
|
||||
alloc_id_to_dev_ptr[alloc_id] = dev_ptr;
|
||||
}
|
||||
|
||||
auto options = at::TensorOptions().dtype(dtype).device(device);
|
||||
auto allocated = at::from_blob(dev_ptr, size, stride, options);
|
||||
|
||||
// Track the allocation's activeness
|
||||
alloc_id_to_storage.erase(alloc_id);
|
||||
alloc_id_to_storage.emplace(
|
||||
alloc_id, allocated.storage().getWeakStorageImpl());
|
||||
return allocated;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
namespace c10d {
|
||||
namespace symmetric_memory {
|
||||
|
||||
void register_allocator(
|
||||
c10::DeviceType device_type,
|
||||
c10::intrusive_ptr<SymmetricMemoryAllocator> allocator) {
|
||||
return AllocatorMap::get().register_allocator(
|
||||
device_type, std::move(allocator));
|
||||
}
|
||||
|
||||
c10::intrusive_ptr<SymmetricMemoryAllocator> get_allocator(
|
||||
c10::DeviceType device_type) {
|
||||
return AllocatorMap::get().get_allocator(device_type);
|
||||
}
|
||||
|
||||
void set_group_info(
|
||||
const std::string& group_name,
|
||||
int rank,
|
||||
int world_size,
|
||||
c10::intrusive_ptr<Store> store) {
|
||||
TORCH_CHECK(group_info_map.find(group_name) == group_info_map.end());
|
||||
GroupInfo group_info;
|
||||
group_info.rank = rank;
|
||||
group_info.world_size = world_size;
|
||||
group_info.store = std::move(store);
|
||||
group_info_map.emplace(group_name, std::move(group_info));
|
||||
}
|
||||
|
||||
const GroupInfo& get_group_info(const std::string& group_name) {
|
||||
TORCH_CHECK(
|
||||
group_info_map.find(group_name) != group_info_map.end(),
|
||||
"get_group_info: no group info associated with the group name ",
|
||||
group_name);
|
||||
return group_info_map[group_name];
|
||||
}
|
||||
|
||||
at::Tensor empty_strided_p2p(
|
||||
c10::IntArrayRef size,
|
||||
c10::IntArrayRef stride,
|
||||
c10::ScalarType dtype,
|
||||
c10::Device device,
|
||||
const std::string& group_name,
|
||||
std::optional<uint64_t> alloc_id) {
|
||||
if (alloc_id.has_value()) {
|
||||
return empty_strided_p2p_persistent(
|
||||
size, stride, dtype, device, group_name, *alloc_id);
|
||||
}
|
||||
const size_t numel =
|
||||
std::accumulate(size.begin(), size.end(), 1, std::multiplies<int>());
|
||||
const size_t element_size = c10::elementSize(dtype);
|
||||
const size_t alloc_size = numel * element_size;
|
||||
|
||||
auto allocator = get_allocator(device.type());
|
||||
void* dev_ptr = allocator->alloc(alloc_size, device.index(), group_name);
|
||||
|
||||
auto options = at::TensorOptions().dtype(dtype).device(device);
|
||||
return at::from_blob(
|
||||
dev_ptr,
|
||||
size,
|
||||
stride,
|
||||
[allocator = std::move(allocator)](void* ptr) { allocator->free(ptr); },
|
||||
options);
|
||||
}
|
||||
|
||||
TORCH_API c10::intrusive_ptr<SymmetricMemory> rendezvous(
|
||||
const at::Tensor& tensor) {
|
||||
auto allocator = get_allocator(tensor.device().type());
|
||||
return allocator->rendezvous(tensor.data_ptr());
|
||||
}
|
||||
|
||||
c10::intrusive_ptr<SymmetricMemory> get_symmetric_memory(
|
||||
const at::Tensor& tensor) {
|
||||
auto allocator = get_allocator(tensor.device().type());
|
||||
TORCH_CHECK(
|
||||
allocator->is_rendezvous_completed(tensor.data_ptr()),
|
||||
"SymmetricMemory: must invoke rendezvous on a tensor ",
|
||||
"before calling get_symmetric_memory on it");
|
||||
return allocator->rendezvous(tensor.data_ptr());
|
||||
}
|
||||
|
||||
} // namespace symmetric_memory
|
||||
} // namespace c10d
|
@ -1,152 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <torch/csrc/distributed/c10d/Store.hpp>
|
||||
|
||||
namespace c10d {
|
||||
namespace symmetric_memory {
|
||||
|
||||
// SymmetricMemory represents symmetric allocations across a group of devices.
|
||||
// The allocations represented by a SymmetricMemory object are accessible by
|
||||
// all devices in the group. The class can be used for op-level custom
|
||||
// communication patterns (via the get_buffer APIs and the synchronization
|
||||
// primitives), as well as custom communication kernels (via the buffer and
|
||||
// signal_pad device pointers).
|
||||
//
|
||||
// To acquire a SymmetricMemory object, each rank first allocates
|
||||
// identical-sized memory via SymmetricMemoryAllocator::alloc(), then invokes
|
||||
// SymmetricMemoryAllocator::rendezvous() on the memory to establish the
|
||||
// association across peer buffers. The rendezvous is a one-time process, and
|
||||
// the mapping between a local memory memory and the associated SymmetricMemory
|
||||
// object is unique.
|
||||
//
|
||||
// NOTE [symmetric memory signal pad]
|
||||
// Signal pads are P2P-accessible memory regions designated for
|
||||
// synchronization. SymmetricMemory offers built-in synchronization primitives
|
||||
// such as barriers, put_signal, and wait_signal, which are all based on signal
|
||||
// pads. Users may utilize signal pads for their own synchronization logic,
|
||||
// provided that the signal pads remain zero-filled following successful
|
||||
// synchronization.
|
||||
//
|
||||
// NOTE [symmetric memory synchronization channel]
|
||||
// Synchronization channels allow users to use a single SymmetricMemory object
|
||||
// to perform isolated synchronizations on different streams. For example,
|
||||
// consider the case in which two barriers are issued on two streams for
|
||||
// different purposes. Without the concept of channels, we cannot guarantee the
|
||||
// correctness of the barriers since signals issued from barrier on stream A
|
||||
// can be received by the barrier on stream B. By specifying different channels
|
||||
// for these two barriers, they can operate correctly in parallel.
|
||||
class TORCH_API SymmetricMemory : public c10::intrusive_ptr_target {
|
||||
public:
|
||||
virtual ~SymmetricMemory() {}
|
||||
|
||||
virtual std::vector<void*> get_buffer_ptrs() = 0;
|
||||
virtual std::vector<void*> get_signal_pad_ptrs() = 0;
|
||||
|
||||
// get_buffer_ptrs_dev() and get_signal_pad_ptrs_dev() each return a pointer
|
||||
// to a device array of size world_size, containing buffer pointers and
|
||||
// signal pad pointers, respectively.
|
||||
virtual void** get_buffer_ptrs_dev() = 0;
|
||||
virtual void** get_signal_pad_ptrs_dev() = 0;
|
||||
virtual size_t get_buffer_size() = 0;
|
||||
virtual size_t get_signal_pad_size() = 0;
|
||||
|
||||
virtual at::Tensor get_buffer(
|
||||
int rank,
|
||||
c10::IntArrayRef sizes,
|
||||
c10::ScalarType dtype,
|
||||
int64_t storage_offset) = 0;
|
||||
|
||||
virtual void barrier(int channel) = 0;
|
||||
virtual void put_signal(int dst_rank, int channel) = 0;
|
||||
virtual void wait_signal(int src_rank, int channel) = 0;
|
||||
|
||||
virtual int get_rank() = 0;
|
||||
virtual int get_world_size() = 0;
|
||||
};
|
||||
|
||||
class SymmetricMemoryAllocator : public c10::intrusive_ptr_target {
|
||||
public:
|
||||
virtual ~SymmetricMemoryAllocator(){};
|
||||
|
||||
virtual void* alloc(
|
||||
size_t size,
|
||||
int device_idx,
|
||||
const std::string& group_name) = 0;
|
||||
|
||||
virtual void free(void* ptr) = 0;
|
||||
virtual size_t get_alloc_size(void* ptr) = 0;
|
||||
virtual c10::intrusive_ptr<SymmetricMemory> rendezvous(void* ptr) = 0;
|
||||
virtual bool is_rendezvous_completed(void* ptr) = 0;
|
||||
};
|
||||
|
||||
C10_EXPORT void register_allocator(
|
||||
c10::DeviceType device_type,
|
||||
c10::intrusive_ptr<SymmetricMemoryAllocator> allocator);
|
||||
|
||||
C10_EXPORT c10::intrusive_ptr<SymmetricMemoryAllocator> get_allocator(
|
||||
c10::DeviceType device_type);
|
||||
|
||||
// Set a store for rendezvousing symmetric allocations on a group of devices
|
||||
// identified by `group_name`. The concept of groups is logical; users can
|
||||
// utilize predefined groups (e.g., a group of device identified by a
|
||||
// ProcessGroup) or create custom ones. Note that a SymmetricMemoryAllocator
|
||||
// backends might employ a more efficient communication channel for the actual
|
||||
// rendezvous process and only use the store for bootstrapping purposes.
|
||||
TORCH_API void set_group_info(
|
||||
const std::string& group_name,
|
||||
int rank,
|
||||
int world_size,
|
||||
c10::intrusive_ptr<Store> store);
|
||||
|
||||
struct GroupInfo {
|
||||
int rank;
|
||||
int world_size;
|
||||
c10::intrusive_ptr<c10d::Store> store;
|
||||
};
|
||||
|
||||
C10_EXPORT const GroupInfo& get_group_info(const std::string& group_name);
|
||||
|
||||
// Identical to empty_strided, but allows symmetric memory access to be
|
||||
// established for the allocated tensor via SymmetricMemory::rendezvous(). This
|
||||
// function itself is not a collective operation. It invokes
|
||||
// SymmetricMemoryAllocator::alloc() for the requested device under the hood.
|
||||
//
|
||||
// NOTE [symmetric memory persistent allocation]
|
||||
// If an `alloc_id` is supplied, empty_strided_p2p will perform persistent
|
||||
// allocation. This makes the function cache allocated memory and ensure that
|
||||
// invocations with the same `alloc_id` receive tensors backed by the same
|
||||
// memory address. For safety, if a previous persistent allocation is still
|
||||
// active (i.e., the storage of the returned tensor is still alive), persistent
|
||||
// allocations with the same `alloc_id` will fail. This determinism coupled
|
||||
// with memory planning of communication buffers (e.g., by Inductor) allows
|
||||
// communication algorithms to reliably reuse previously established remote
|
||||
// memory access.
|
||||
TORCH_API at::Tensor empty_strided_p2p(
|
||||
c10::IntArrayRef size,
|
||||
c10::IntArrayRef stride,
|
||||
c10::ScalarType dtype,
|
||||
c10::Device device,
|
||||
const std::string& group_name,
|
||||
std::optional<uint64_t> alloc_id);
|
||||
|
||||
// Establishes symmetric memory access on tensors allocated via
|
||||
// empty_strided_p2p() and empty_strided_p2p_persistent(). rendezvous() is a
|
||||
// one-time process, and the mapping between a local memory region and the
|
||||
// associated SymmetricMemory object is unique. Subsequent calls to
|
||||
// rendezvous() with the same tensor, or tensors allocated with
|
||||
// empty_strided_p2p_persistent() using the same alloc_id, will receive the
|
||||
// cached SymmetricMemory object.
|
||||
//
|
||||
// The function has a collective semantic and must be invoked simultaneously
|
||||
// from all rendezvous participants.
|
||||
TORCH_API c10::intrusive_ptr<SymmetricMemory> rendezvous(
|
||||
const at::Tensor& tensor);
|
||||
|
||||
// Returns the SymmetricMemory object associated with the tensor. It can only
|
||||
// be invoked after rendezvous() but does not need to be invoked collectively.
|
||||
TORCH_API c10::intrusive_ptr<SymmetricMemory> get_symmetric_memory(
|
||||
const at::Tensor& tensor);
|
||||
|
||||
} // namespace symmetric_memory
|
||||
} // namespace c10d
|
@ -41,7 +41,6 @@
|
||||
#include <fmt/format.h>
|
||||
#include <pybind11/chrono.h>
|
||||
#include <torch/csrc/distributed/c10d/PrefixStore.hpp>
|
||||
#include <torch/csrc/distributed/c10d/SymmetricMemory.hpp>
|
||||
|
||||
#include <torch/csrc/distributed/c10d/comm.hpp>
|
||||
#include <torch/csrc/distributed/c10d/debug.h>
|
||||
@ -976,44 +975,6 @@ This class does not support ``__members__`` property.)");
|
||||
"global_ranks_in_group",
|
||||
&::c10d::DistributedBackendOptions::global_ranks_in_group);
|
||||
|
||||
using SymmetricMemory = ::c10d::symmetric_memory::SymmetricMemory;
|
||||
py::class_<SymmetricMemory, c10::intrusive_ptr<SymmetricMemory>>(
|
||||
module, "_SymmetricMemory")
|
||||
.def_static("set_group_info", &::c10d::symmetric_memory::set_group_info)
|
||||
.def_static(
|
||||
"empty_strided_p2p",
|
||||
::c10d::symmetric_memory::empty_strided_p2p,
|
||||
py::arg("size"),
|
||||
py::arg("stride"),
|
||||
py::arg("dtype"),
|
||||
py::arg("device"),
|
||||
py::arg("group_name"),
|
||||
py::arg("alloc_id") = py::none())
|
||||
.def_static("rendezvous", &::c10d::symmetric_memory::rendezvous)
|
||||
.def_static(
|
||||
"get_symmetric_memory",
|
||||
&::c10d::symmetric_memory::get_symmetric_memory)
|
||||
.def_property_readonly("rank", &SymmetricMemory::get_rank)
|
||||
.def_property_readonly("world_size", &SymmetricMemory::get_world_size)
|
||||
.def(
|
||||
"get_buffer",
|
||||
&SymmetricMemory::get_buffer,
|
||||
py::arg("rank"),
|
||||
py::arg("sizes"),
|
||||
py::arg("dtype"),
|
||||
py::arg("storage_offset") = 0)
|
||||
.def("barrier", &SymmetricMemory::barrier, py::arg("channel") = 0)
|
||||
.def(
|
||||
"put_signal",
|
||||
&SymmetricMemory::put_signal,
|
||||
py::arg("dst_rank"),
|
||||
py::arg("channel") = 0)
|
||||
.def(
|
||||
"wait_signal",
|
||||
&SymmetricMemory::wait_signal,
|
||||
py::arg("src_rank"),
|
||||
py::arg("channel") = 0);
|
||||
|
||||
auto store =
|
||||
py::class_<::c10d::Store, c10::intrusive_ptr<::c10d::Store>, PythonStore>(
|
||||
module,
|
||||
|
@ -218,8 +218,23 @@ IntraNodeComm::~IntraNodeComm() {
|
||||
if (!isInitialized_) {
|
||||
return;
|
||||
}
|
||||
auto allocator = get_allocator(c10::DeviceType::CUDA);
|
||||
allocator->free(symmetricMemoryPtr_);
|
||||
// Intentionally releasing resources without synchronizing devices. The
|
||||
// teardown logic is safe for propoerly sync'd user program. We don't want
|
||||
// improperly sync'd user program to hang here.
|
||||
for (size_t r = 0; r < worldSize_; ++r) {
|
||||
if (r == rank_) {
|
||||
continue;
|
||||
}
|
||||
AT_CUDA_CHECK(cudaIpcCloseMemHandle(p2pStates_[r]));
|
||||
AT_CUDA_CHECK(cudaIpcCloseMemHandle(buffers_[r]));
|
||||
}
|
||||
AT_CUDA_CHECK(cudaFree(p2pStates_[rank_]));
|
||||
AT_CUDA_CHECK(cudaFree(buffers_[rank_]));
|
||||
if (topoInfo_ != nullptr) {
|
||||
AT_CUDA_CHECK(cudaFree(topoInfo_));
|
||||
}
|
||||
AT_CUDA_CHECK(cudaFree(p2pStatesDev_));
|
||||
AT_CUDA_CHECK(cudaFree(buffersDev_));
|
||||
}
|
||||
|
||||
bool IntraNodeComm::isEnabled() {
|
||||
@ -329,19 +344,83 @@ bool IntraNodeComm::rendezvous() {
|
||||
// Detect topology
|
||||
Topology topology = detectTopology(nvlMesh, worldSize_);
|
||||
|
||||
set_group_info("IntraNodeComm", rank_, worldSize_, store_);
|
||||
auto allocator = get_allocator(c10::DeviceType::CUDA);
|
||||
symmetricMemoryPtr_ =
|
||||
allocator->alloc(bufferSize_, deviceIdx, "IntraNodeComm");
|
||||
symmetricMemory_ = allocator->rendezvous(symmetricMemoryPtr_);
|
||||
TORCH_CHECK(symmetricMemory_->get_signal_pad_size() >= kP2pStateSize);
|
||||
// Initialize p2p state
|
||||
auto p2pState = initP2pState();
|
||||
|
||||
// Allocate buffer
|
||||
void* buffer = nullptr;
|
||||
AT_CUDA_CHECK(cudaMalloc(&buffer, bufferSize_));
|
||||
|
||||
// Second handshake: exchange topology and CUDA IPC handles
|
||||
struct IpcInfo {
|
||||
NvlMesh nvlMesh;
|
||||
Topology topology;
|
||||
cudaIpcMemHandle_t p2pStateHandle, bufferHandle;
|
||||
};
|
||||
|
||||
// Make p2p state and buffer available for IPC
|
||||
cudaIpcMemHandle_t p2pStateHandle, bufferHandle;
|
||||
AT_CUDA_CHECK(cudaIpcGetMemHandle(&p2pStateHandle, p2pState));
|
||||
AT_CUDA_CHECK(cudaIpcGetMemHandle(&bufferHandle, buffer));
|
||||
|
||||
IpcInfo ipcInfo{
|
||||
.nvlMesh = nvlMesh,
|
||||
.topology = topology,
|
||||
.p2pStateHandle = p2pStateHandle,
|
||||
.bufferHandle = bufferHandle};
|
||||
|
||||
auto peerIpcInfos =
|
||||
storeAllGather(store_, "handshake-1", rank_, worldSize_, ipcInfo);
|
||||
|
||||
for (const auto& info : peerIpcInfos) {
|
||||
if (!isSame(info.nvlMesh, peerIpcInfos.front().nvlMesh) ||
|
||||
info.topology != peerIpcInfos.front().topology) {
|
||||
LOG(WARNING) << "Aborting IntraNodeComm::rendezvous because some "
|
||||
"participants are observing different topologies ("
|
||||
<< int(info.topology) << " and " << int(topology) << ")";
|
||||
AT_CUDA_CHECK(cudaFree(p2pState));
|
||||
AT_CUDA_CHECK(cudaFree(buffer));
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
std::array<void*, kMaxDevices> p2pStates = {}, buffers = {};
|
||||
for (size_t r = 0; r < peerIpcInfos.size(); ++r) {
|
||||
if (r == rank_) {
|
||||
p2pStates[r] = p2pState;
|
||||
buffers[r] = buffer;
|
||||
} else {
|
||||
AT_CUDA_CHECK(cudaIpcOpenMemHandle(
|
||||
&p2pStates[r],
|
||||
peerIpcInfos[r].p2pStateHandle,
|
||||
cudaIpcMemLazyEnablePeerAccess));
|
||||
AT_CUDA_CHECK(cudaIpcOpenMemHandle(
|
||||
&buffers[r],
|
||||
peerIpcInfos[r].bufferHandle,
|
||||
cudaIpcMemLazyEnablePeerAccess));
|
||||
}
|
||||
}
|
||||
void* p2pStatesDev = nullptr;
|
||||
AT_CUDA_CHECK(cudaMalloc(&p2pStatesDev, sizeof(p2pStates)));
|
||||
AT_CUDA_CHECK(cudaMemcpy(
|
||||
p2pStatesDev,
|
||||
p2pStates.data(),
|
||||
sizeof(p2pStates),
|
||||
cudaMemcpyHostToDevice));
|
||||
|
||||
void* buffersDev = nullptr;
|
||||
AT_CUDA_CHECK(cudaMalloc(&buffersDev, sizeof(buffers)));
|
||||
AT_CUDA_CHECK(cudaMemcpy(
|
||||
buffersDev, buffers.data(), sizeof(buffers), cudaMemcpyHostToDevice));
|
||||
|
||||
void* topoInfo = initTopoInfo(topology, nvlMesh, rank_);
|
||||
|
||||
isInitialized_ = true;
|
||||
topology_ = topology;
|
||||
p2pStatesDev_ = symmetricMemory_->get_signal_pad_ptrs_dev();
|
||||
buffersDev_ = symmetricMemory_->get_buffer_ptrs_dev();
|
||||
std::copy(p2pStates.begin(), p2pStates.end(), p2pStates_.begin());
|
||||
std::copy(buffers.begin(), buffers.end(), buffers_.begin());
|
||||
p2pStatesDev_ = p2pStatesDev;
|
||||
buffersDev_ = buffersDev;
|
||||
topoInfo_ = topoInfo;
|
||||
return true;
|
||||
#endif
|
||||
|
@ -132,8 +132,6 @@ struct P2pState {
|
||||
uint32_t signals1[kMaxAllReduceBlocks][kMaxDevices];
|
||||
};
|
||||
|
||||
static_assert(sizeof(P2pState) <= kP2pStateSize);
|
||||
|
||||
template <uint32_t kWorldSize, bool kAligned>
|
||||
static __global__ void oneShotAllReduceKernel(
|
||||
at::BFloat16* input,
|
||||
@ -524,7 +522,7 @@ at::Tensor IntraNodeComm::oneShotAllReduce(
|
||||
const bool fuseInputCopy = isAligned && blocks.x < kMaxAllReduceBlocks;
|
||||
if (!fuseInputCopy) {
|
||||
AT_CUDA_CHECK(cudaMemcpyAsync(
|
||||
symmetricMemory_->get_buffer_ptrs_dev()[rank_],
|
||||
buffers_[rank_],
|
||||
input.data_ptr(),
|
||||
input.numel() * input.element_size(),
|
||||
cudaMemcpyDeviceToDevice,
|
||||
@ -584,7 +582,7 @@ at::Tensor IntraNodeComm::twoShotAllReduce(
|
||||
|
||||
at::cuda::OptionalCUDAGuard guard(input.get_device());
|
||||
AT_CUDA_CHECK(cudaMemcpyAsync(
|
||||
symmetricMemory_->get_buffer_ptrs_dev()[rank_],
|
||||
buffers_[rank_],
|
||||
input.data_ptr(),
|
||||
input.numel() * input.element_size(),
|
||||
cudaMemcpyDeviceToDevice,
|
||||
@ -634,7 +632,7 @@ at::Tensor IntraNodeComm::hybridCubeMeshAllReduce(
|
||||
|
||||
at::cuda::OptionalCUDAGuard guard(input.get_device());
|
||||
AT_CUDA_CHECK(cudaMemcpyAsync(
|
||||
symmetricMemory_->get_buffer_ptrs_dev()[rank_],
|
||||
buffers_[rank_],
|
||||
input.data_ptr(),
|
||||
input.numel() * input.element_size(),
|
||||
cudaMemcpyDeviceToDevice,
|
||||
@ -757,7 +755,15 @@ at::Tensor IntraNodeComm::getBuffer(
|
||||
const std::vector<int64_t>& sizes,
|
||||
c10::ScalarType dtype,
|
||||
int64_t storageOffset) {
|
||||
return symmetricMemory_->get_buffer(rank, sizes, dtype, storageOffset);
|
||||
const auto numel = std::accumulate(sizes.begin(), sizes.end(), 0);
|
||||
const auto elementSize = c10::elementSize(dtype);
|
||||
TORCH_CHECK((numel + storageOffset) * elementSize <= bufferSize_);
|
||||
auto options = at::TensorOptions().dtype(dtype).device(
|
||||
at::kCUDA, at::cuda::current_device());
|
||||
return at::for_blob(buffers_[rank], sizes)
|
||||
.storage_offset(storageOffset)
|
||||
.options(options)
|
||||
.make_tensor();
|
||||
}
|
||||
|
||||
} // namespace intra_node_comm
|
||||
|
@ -4,16 +4,12 @@
|
||||
#include <ATen/cuda/CUDAEvent.h>
|
||||
#include <c10/cuda/CUDAStream.h>
|
||||
#include <torch/csrc/distributed/c10d/Store.hpp>
|
||||
#include <torch/csrc/distributed/c10d/SymmetricMemory.hpp>
|
||||
#include <torch/csrc/distributed/c10d/Work.hpp>
|
||||
|
||||
namespace c10d::intra_node_comm {
|
||||
|
||||
using namespace c10d::symmetric_memory;
|
||||
|
||||
constexpr size_t kMaxDevices = 8;
|
||||
constexpr size_t kDefaultBufferSize = 10ull * 1024 * 1024;
|
||||
constexpr size_t kP2pStateSize = 2048;
|
||||
|
||||
using NvlMesh = std::array<std::array<size_t, kMaxDevices>, kMaxDevices>;
|
||||
using HybridCubeMesh = std::array<std::array<int, 4>, kMaxDevices>;
|
||||
@ -31,7 +27,6 @@ enum class AllReduceAlgo : uint8_t {
|
||||
HCM = 3
|
||||
};
|
||||
|
||||
// NOTE: this class will be be removed soon in favor of SymmetricMemory
|
||||
class TORCH_API IntraNodeComm : public c10::intrusive_ptr_target {
|
||||
public:
|
||||
IntraNodeComm(
|
||||
@ -102,8 +97,8 @@ class TORCH_API IntraNodeComm : public c10::intrusive_ptr_target {
|
||||
*/
|
||||
bool isInitialized_ = false;
|
||||
Topology topology_ = Topology::UNKNOWN;
|
||||
void* symmetricMemoryPtr_ = nullptr;
|
||||
c10::intrusive_ptr<SymmetricMemory> symmetricMemory_ = nullptr;
|
||||
std::array<void*, kMaxDevices> p2pStates_{};
|
||||
std::array<void*, kMaxDevices> buffers_{};
|
||||
void* p2pStatesDev_{};
|
||||
void* buffersDev_{};
|
||||
void* topoInfo_{};
|
||||
|
Reference in New Issue
Block a user