Revert "Support IPC for Expandable Segments (#130890)"

This reverts commit 32c2f84e349ad6e34b8559d3f1f9c27020ae702f.

Reverted https://github.com/pytorch/pytorch/pull/130890 on behalf of https://github.com/zdevito due to variable shadowing broke internal tests ([comment](https://github.com/pytorch/pytorch/pull/130890#issuecomment-2245456085))
This commit is contained in:
PyTorch MergeBot
2024-07-23 14:46:27 +00:00
parent f064dac588
commit 1e86387871
4 changed files with 62 additions and 239 deletions

View File

@ -16,7 +16,6 @@
#if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED)
#include <c10/cuda/driver_api.h>
#include <sys/syscall.h>
#include <sys/types.h>
#include <unistd.h>
#endif
@ -124,11 +123,6 @@ constexpr size_t kMinLargeAlloc =
10485760; // allocations between 1 and 10 MiB may use kLargeBuffer
constexpr size_t kRoundLarge = 2097152; // round up large allocations to 2 MiB
enum ShareableHandleType : char {
SHAREABLE_CUDA_MALLOC = 'c',
SHAREABLE_CUDA_EXPANDABLE_SEGMENT = 'e'
};
namespace {
using stream_set = ska::flat_hash_set<cuda::CUDAStream>;
@ -388,7 +382,7 @@ Instead these mapping have to be done manually. The allocator now has an
struct ExpandableSegment {
ExpandableSegment(
c10::DeviceIndex device,
std::optional<cudaStream_t> stream,
cudaStream_t stream,
size_t address_space_size,
size_t segment_size,
std::vector<c10::DeviceIndex> peers)
@ -426,7 +420,6 @@ struct ExpandableSegment {
CUmemGenericAllocationHandle handle = 0;
CUmemAllocationProp prop = {};
prop.type = CU_MEM_ALLOCATION_TYPE_PINNED;
prop.requestedHandleTypes = CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR;
prop.location.type = CU_MEM_LOCATION_TYPE_DEVICE;
// NOLINTNEXTLINE(bugprone-signed-char-misuse)
prop.location.id = static_cast<int>(device_);
@ -436,13 +429,13 @@ struct ExpandableSegment {
for (auto j : c10::irange(begin, i)) {
auto h = handles_.at(j).value();
handles_.at(j) = std::nullopt;
C10_CUDA_DRIVER_CHECK(DriverAPI::get()->cuMemRelease_(h.handle));
C10_CUDA_DRIVER_CHECK(DriverAPI::get()->cuMemRelease_(h));
}
trimHandles();
return rangeFromHandles(begin, begin);
}
C10_CUDA_DRIVER_CHECK(status);
handles_.at(i) = Handle{handle, std::nullopt};
handles_.at(i) = handle;
}
mapAndSetAccess(begin, end);
return rangeFromHandles(begin, end);
@ -461,94 +454,10 @@ struct ExpandableSegment {
return rangeFromHandles(begin, end);
}
// Setup IPC sharing for range.
// Returns the (larger) range that was actually shared.
// Serializes data to std::ostream that can be passed to the
// other process, and then restored as an exapandable segment
// via ExpandableSegment::fromShared(istream);
SegmentRange share(SegmentRange range, std::ostream& buf) {
auto begin = segmentLeft(range.ptr);
auto end = segmentRight(range.ptr + range.size);
ShareHeader header{getpid(), segment_size_, end - begin};
buf.write((const char*)&header, sizeof(ShareHeader));
for (auto i : c10::irange(begin, end)) {
auto& handle = handles_.at(i).value();
if (!handle.fd) {
int fd = 0;
C10_CUDA_DRIVER_CHECK(DriverAPI::get()->cuMemExportToShareableHandle_(
&fd, handle.handle, CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR, 0));
handle.fd = fd;
}
int fd = *handle.fd;
buf.write((const char*)&fd, sizeof(int));
}
return rangeFromHandles(begin, end);
}
static std::unique_ptr<ExpandableSegment> fromShared(
c10::DeviceIndex device,
std::vector<c10::DeviceIndex> peers,
std::istream& buf) {
ShareHeader header{};
buf.read((char*)&header, sizeof(ShareHeader));
auto segment = std::make_unique<ExpandableSegment>(
device,
std::nullopt,
header.num_handles * header.segment_size,
header.segment_size,
std::move(peers));
// older build setups (e.g. multiwheels) do not have this syscall, added 2020
// but the kernel on the system might still support it.
#ifndef SYS_pidfd_open
#define SYS_pidfd_open 434
#endif
#ifndef SYS_pidfd_getfd
#define SYS_pidfd_getfd 438
#endif
auto pidfd = syscall(SYS_pidfd_open, header.pid, 0);
TORCH_CHECK(
pidfd != -1 || errno != ENOSYS,
"The kernel on this machine does not support the pidfd_open syscall needed to use IPC for CUDA tensors when expandable_segments:True is set. "
"Consider using expandable_segments:False via torch.cuda.memory._set_allocator_settings('expandable_segments:False') for this allocation.");
TORCH_CHECK(pidfd != -1, "pidfd_open:", std::strerror(errno));
for (auto i : c10::irange(header.num_handles)) {
(void)i;
int fd = 0;
buf.read((char*)&fd, sizeof(int));
auto myfd = syscall(SYS_pidfd_getfd, pidfd, fd, 0);
if (myfd == -1) {
auto err = errno;
close((int)pidfd);
for (auto& h : segment->handles_) {
C10_CUDA_DRIVER_CHECK(
DriverAPI::get()->cuMemRelease_(h.value().handle));
h = std::nullopt;
}
TORCH_CHECK(
err != ENOSYS,
"The kernel on this machine does not support the pidfd_getfd syscall needed to use IPC for CUDA tensors when expandable_segments:True is set. "
"Consider using expandable_segments:False via torch.cuda.memory._set_allocator_settings('expandable_segments:False') for this allocation.");
TORCH_CHECK(false, "pidfd_getfd: ", std::strerror(err));
}
CUmemGenericAllocationHandle handle = 0;
C10_CUDA_DRIVER_CHECK(DriverAPI::get()->cuMemImportFromShareableHandle_(
&handle,
// NOLINTNEXTLINE(performance-no-int-to-ptr)
(void*)(uintptr_t)myfd,
CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR));
close((int)myfd);
segment->handles_.emplace_back(Handle{handle, std::nullopt});
}
close((int)pidfd);
segment->mapAndSetAccess(0, header.num_handles);
return segment;
}
char* ptr() const {
// NOLINTNEXTLINE(performance-no-int-to-ptr)
return reinterpret_cast<char*>(ptr_);
}
size_t size() const {
return max_handles_ * segment_size_;
}
@ -583,7 +492,7 @@ struct ExpandableSegment {
ptr_ + i * segment_size_,
segment_size_,
0,
handles_.at(i).value().handle,
handles_.at(i).value(),
0ULL));
}
setAccess(device_, begin, end);
@ -600,21 +509,13 @@ struct ExpandableSegment {
// cannot call c10::cuda::stream_synchronize because
// it might grab the GIL which can lead to a deadlock
// Locking order must be GIL -> Allocator Lock
if (stream_) {
C10_CUDA_CHECK(cudaStreamSynchronize(*stream_));
} else {
cuda::CUDAGuard device_guard(device_);
C10_CUDA_CHECK(cudaDeviceSynchronize());
}
C10_CUDA_CHECK(cudaStreamSynchronize(stream_));
for (auto i : c10::irange(begin, end)) {
Handle h = handles_.at(i).value();
CUmemGenericAllocationHandle h = handles_.at(i).value();
handles_.at(i) = std::nullopt;
C10_CUDA_DRIVER_CHECK(DriverAPI::get()->cuMemUnmap_(
ptr_ + segment_size_ * i, segment_size_));
if (h.fd) {
close(*h.fd);
}
C10_CUDA_DRIVER_CHECK(DriverAPI::get()->cuMemRelease_(h.handle));
C10_CUDA_DRIVER_CHECK(DriverAPI::get()->cuMemRelease_(h));
}
trimHandles();
}
@ -650,20 +551,11 @@ struct ExpandableSegment {
ptr() + segment_size_ * begin, segment_size_ * (end - begin));
}
c10::DeviceIndex device_;
std::optional<cudaStream_t> stream_;
cudaStream_t stream_;
CUdeviceptr ptr_{};
size_t segment_size_;
size_t max_handles_;
struct Handle {
CUmemGenericAllocationHandle handle;
std::optional<int> fd;
};
struct ShareHeader {
pid_t pid;
size_t segment_size;
size_t num_handles;
};
std::vector<std::optional<Handle>> handles_;
std::vector<std::optional<CUmemGenericAllocationHandle>> handles_;
// devices on which this memory should be mapped in addition
// to the device where the physical memory lives (device_).
std::vector<c10::DeviceIndex> peers_;
@ -672,7 +564,7 @@ struct ExpandableSegment {
struct ExpandableSegment {
ExpandableSegment(
c10::DeviceIndex device,
std::optional<cudaStream_t> stream,
cudaStream_t stream,
size_t address_space_size,
size_t segment_size,
std::vector<c10::DeviceIndex> peers) {
@ -684,15 +576,6 @@ struct ExpandableSegment {
SegmentRange unmap(SegmentRange range) {
return SegmentRange(nullptr, 0);
}
SegmentRange share(SegmentRange range, std::ostream& ss) {
return SegmentRange(nullptr, 0);
}
static std::unique_ptr<ExpandableSegment> fromShared(
c10::DeviceIndex device,
std::vector<c10::DeviceIndex> peers,
std::istream& buf) {
return {};
}
char* ptr() const {
return nullptr;
}
@ -1545,26 +1428,14 @@ class DeviceCachingAllocator {
}
ShareableHandle shareIpcHandle(Block* block) {
std::lock_guard<std::recursive_mutex> lock(mutex);
std::ostringstream ss;
ptrdiff_t offset = 0;
if (!block->expandable_segment_) {
ss.put(SHAREABLE_CUDA_MALLOC);
Block* base_block = block;
while (base_block->prev) {
base_block = base_block->prev;
}
offset = (char*)block->ptr - (char*)base_block->ptr;
cudaIpcMemHandle_t handle;
C10_CUDA_CHECK(cudaIpcGetMemHandle(&handle, base_block->ptr));
ss.write((char*)&handle, CUDA_IPC_HANDLE_SIZE);
} else {
ss.put(SHAREABLE_CUDA_EXPANDABLE_SEGMENT);
auto full_range = block->expandable_segment_->share(
SegmentRange(block->ptr, block->size), ss);
offset = (char*)block->ptr - (char*)full_range.ptr;
}
return ShareableHandle{offset, ss.str()};
size_t outSize = 0;
void* base = getBaseAllocation(block, &outSize);
auto offset = (char*)block->ptr - (char*)base;
cudaIpcMemHandle_t handle;
C10_CUDA_CHECK(cudaIpcGetMemHandle(&handle, base));
return ShareableHandle{
offset,
std::string((char*)&handle, (char*)&handle + CUDA_IPC_HANDLE_SIZE)};
}
void recordStream(Block* block, cuda::CUDAStream stream) {
@ -2105,7 +1976,6 @@ class DeviceCachingAllocator {
}
void addPeerAccess(c10::DeviceIndex dev_to_access) {
std::lock_guard<std::recursive_mutex> lock(mutex);
if (std::find(
devices_with_peer_access_.begin(),
devices_with_peer_access_.end(),
@ -2117,10 +1987,6 @@ class DeviceCachingAllocator {
es->addPeer(dev_to_access);
}
}
std::vector<c10::DeviceIndex> peers() const {
std::lock_guard<std::recursive_mutex> lock(mutex);
return devices_with_peer_access_;
}
bool hasAllocatedExpandableSegments() const {
return !expandable_segments_.empty();
@ -3531,13 +3397,6 @@ class NativeCachingAllocator : public CUDAAllocator {
C10_CUDA_CHECK(err);
}
device_allocator[dev_to_access]->addPeerAccess(dev);
std::lock_guard<std::mutex> lock(IpcMutex);
for (auto& entry : ipcMemHandle_to_devptr) {
if (entry.second->device_ == dev_to_access &&
entry.second->expandable_segment_) {
entry.second->expandable_segment_->addPeer(dev);
}
}
}
cudaError_t memcpyAsync(
@ -3564,103 +3423,58 @@ class NativeCachingAllocator : public CUDAAllocator {
this->free(ptr);
}
// In CUDA IPC, sender sends a tensor to receiver via shareIPCHandle,
// getIpcDevPtr is called by the receiving process to map the CUDA memory from
// the sending process into its own address space.
// When allocated with cudaMalloc we use the cudaIPCMemHandle_t APIs.
// These APIs only allow sharing a big memory block associated with a
// In CUDA IPC, sender sends a tensor to receiver, getIpcDevPtr
// is called by the receiving process to map the CUDA memory from the sending
// process into its own address space.
//
// CUDA IPC only allows sharing a big memory block associated with a
// cudaIpcMemHandle_t and it can be opened only **once** per context per
// process. There can be multiple types of storage in the same IPC mem block,
// so we must cache the device ptr to construct typed storage as it comes.
// When using cuMemCreate, via expandable segments, we use c
// MemExportToShareableHandle
// create a file descriptor that can be sent t
// the other process to
// ort the object. Then we recreate part of the ex
// andable segment necessary to
// the allocation.
// ipcMemHandle_to_devptr caches the mapping from shareable handle to
// This process' memory mapping information for that share to ensure we do not
// create it twice. When the shared_ptr is no longer in use we clean up the
// cache.
//
// ipcMemHandle_to_devptr maps a cudaIpcMemHandle_t to a device pointer in the
// process that can be used to access the memory block in the sender process.
// It only saves a weak_ptr of the device pointer in the map, the shared_ptr
// will be used to reconstruct all storages in this CudaMalloc allocation. And
// it will deleted in cudaIpcCloseMemHandle when its reference count is 0.
//
std::mutex IpcMutex;
struct MemHandleCacheEntry {
MemHandleCacheEntry(
c10::DeviceIndex device,
std::string& handle,
const DeviceCachingAllocator& allocator)
: device_(device), cuda_ipc_ptr_(nullptr) {
std::istringstream ss(handle);
auto type = ss.get();
if (type == SHAREABLE_CUDA_MALLOC) {
cudaIpcMemHandle_t handle;
ss.read((char*)&handle, CUDA_IPC_HANDLE_SIZE);
C10_CUDA_CHECK(cudaIpcOpenMemHandle(
&cuda_ipc_ptr_, handle, cudaIpcMemLazyEnablePeerAccess));
} else if (type == SHAREABLE_CUDA_EXPANDABLE_SEGMENT) {
expandable_segment_ =
ExpandableSegment::fromShared(device, allocator.peers(), ss);
} else {
TORCH_INTERNAL_ASSERT(
false, "unexpected or illformed shareable handle type");
}
}
MemHandleCacheEntry(const MemHandleCacheEntry&) = delete;
MemHandleCacheEntry& operator=(const MemHandleCacheEntry&) = delete;
~MemHandleCacheEntry() {
if (cuda_ipc_ptr_) {
cuda::CUDAGuard device_guard(device_);
C10_CUDA_CHECK(cudaIpcCloseMemHandle(cuda_ipc_ptr_));
}
}
void* ptr() {
if (cuda_ipc_ptr_) {
return cuda_ipc_ptr_;
} else {
return expandable_segment_->ptr();
}
}
c10::DeviceIndex device_;
std::unique_ptr<ExpandableSegment> expandable_segment_;
void* cuda_ipc_ptr_; // nullptr if expandable_segment_ is not null
std::weak_ptr<void> wp_;
};
ska::flat_hash_map<std::string, std::unique_ptr<MemHandleCacheEntry>>
ipcMemHandle_to_devptr;
ska::flat_hash_map<std::string, std::weak_ptr<void>> ipcMemHandle_to_devptr;
std::shared_ptr<void> getIpcDevPtr(std::string handle) override {
std::lock_guard<std::mutex> lock(IpcMutex);
auto iter = ipcMemHandle_to_devptr.find(handle);
if (iter != ipcMemHandle_to_devptr.end()) {
auto devptr = iter->second->wp_.lock();
// the weak_ptr should always be valid because we delete the entry from
// the cache when the shared_ptr is destructed, so we should never get
// here.
TORCH_INTERNAL_ASSERT(devptr, "entry in cache has missing shared_ptr");
return devptr;
auto devptr = iter->second.lock();
if (devptr)
return devptr;
}
// This ipcMemHandle hasn't been opened, or already expired, open it to
// enable IPC access to that mem block.
void* dev = nullptr;
auto ipc_handle =
reinterpret_cast<const cudaIpcMemHandle_t*>(handle.c_str());
C10_CUDA_CHECK(cudaIpcOpenMemHandle(
&dev, *ipc_handle, cudaIpcMemLazyEnablePeerAccess));
// devPtr has to be deleted in same device when created.
c10::DeviceIndex curr_device = 0;
C10_CUDA_CHECK(c10::cuda::GetDevice(&curr_device));
auto inserted = ipcMemHandle_to_devptr.insert(
iter,
{handle,
std::make_unique<MemHandleCacheEntry>(
curr_device, handle, *device_allocator[curr_device])});
auto sp = std::shared_ptr<void>(
inserted->second->ptr(), [handle, this](void* ptr) {
auto sp =
std::shared_ptr<void>(dev, [handle, curr_device, this](void* ptr) {
cuda::CUDAGuard device_guard(curr_device);
std::lock_guard<std::mutex> deleter_lock(IpcMutex);
C10_CUDA_CHECK(cudaIpcCloseMemHandle(ptr));
ipcMemHandle_to_devptr.erase(handle);
});
inserted->second->wp_ = sp;
std::weak_ptr<void> wp = sp;
// To eliminate an additional search, we can use insert().
// It doesn't overwrite when key already exists(ptr expired).
// But in the deleter for sp we erased the entry,
// this should be safe to do now.
ipcMemHandle_to_devptr.insert(iter, {handle, wp});
return sp;
}
std::string name() override {
return "native";
}
@ -3749,5 +3563,7 @@ struct BackendStaticInitializer {
std::atomic<CUDAAllocator*> allocator;
BackendStaticInitializer backend_static_initializer;
} // namespace cuda::CUDACachingAllocator
} // namespace c10

View File

@ -98,6 +98,9 @@ TEST_CUDA_IPC = (
TEST_MULTIGPU = TEST_CUDA_IPC and torch.cuda.device_count() > 1
if TEST_CUDA_IPC:
torch.cuda.memory._set_allocator_settings("expandable_segments:False")
if not NO_MULTIPROCESSING_SPAWN:
# We want to use `spawn` if able because some of our tests check that the
# data loader terminiates gracefully. To prevent hanging in the testing

View File

@ -48,6 +48,9 @@ TEST_CUDA_IPC = (
TEST_MULTIGPU = TEST_CUDA_IPC and torch.cuda.device_count() > 1
if TEST_CUDA_IPC:
torch.cuda.memory._set_allocator_settings("expandable_segments:False")
class SubProcess(mp.Process):
def __init__(self, tensor):

View File

@ -421,6 +421,7 @@ static std::string THPStorage_bytesAsHandleString(PyObject* handle) {
if (PyBytes_AsStringAndSize(handle, &buffer, &handle_size) == -1) {
TORCH_CHECK(handle_size == CUDA_IPC_HANDLE_SIZE, "incorrect handle");
}
TORCH_CHECK(handle_size == CUDA_IPC_HANDLE_SIZE, "incorrect handle size");
return std::string(buffer, handle_size);
END_HANDLE_TH_ERRORS_RET("")
}