Implement reference counting for shared IPC CUDA tensors (#16854)

Summary:
This is to fix #16141 and similar issues.

The idea is to track a reference to every shared CUDA Storage and deallocate memory only after a consumer process deallocates received Storage.

ezyang Done with cleanup. Same (insignificantly better) performance as in file-per-share solution, but handles millions of shared tensors easily. Note [ ] documentation in progress.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/16854

Differential Revision: D13994490

Pulled By: VitalyFedyunin

fbshipit-source-id: 565148ec3ac4fafb32d37fde0486b325bed6fbd1
This commit is contained in:
Vitaly Fedyunin
2019-03-25 10:18:29 -07:00
committed by Facebook Github Bot
parent f5ea528687
commit 5653a914f7
15 changed files with 832 additions and 79 deletions

View File

@ -19,6 +19,7 @@ struct C10_API StorageImpl final : public c10::intrusive_ptr_target {
data_ptr_(std::move(data_ptr)),
numel_(numel),
resizable_(resizable),
received_cuda_(false),
allocator_(allocator) {
if (resizable) {
AT_ASSERTM(
@ -210,11 +211,24 @@ struct C10_API StorageImpl final : public c10::intrusive_ptr_target {
resizable_ = false;
}
// This method can be used only after storage construction and cannot be used
// to modify storage status
void set_received_cuda(bool received_cuda) {
received_cuda_ = received_cuda;
}
bool received_cuda() {
return received_cuda_;
}
private:
caffe2::TypeMeta data_type_;
DataPtr data_ptr_;
int64_t numel_;
bool resizable_;
// Identifies that Storage was received from another process and doesn't have
// local to process cuda memory allocation
bool received_cuda_;
Allocator* allocator_;
};
} // namespace c10

View File

@ -16,8 +16,10 @@
#include <vector>
namespace c10 {
namespace cuda {
C10_DEFINE_REGISTRY(FreeCudaMemoryCallbacksRegistry, FreeMemoryCallback);
namespace cuda {
namespace CUDACachingAllocator {
//
@ -47,6 +49,8 @@ namespace CUDACachingAllocator {
// work.
//
namespace {
using stream_set = std::unordered_set<cuda::CUDAStream>;
@ -154,7 +158,7 @@ struct THCCachingAllocator
std::vector<DeviceStats> device_stats;
// lock around all operations
std::mutex mutex;
std::recursive_mutex mutex;
// lock around calls to cudaFree (to prevent deadlocks with NCCL)
std::mutex cuda_free_mutex;
@ -186,7 +190,7 @@ struct THCCachingAllocator
/** allocates a block which is safe to use from the provided stream */
void malloc(void** devPtr, size_t size, cudaStream_t stream)
{
std::lock_guard<std::mutex> lock(mutex);
std::lock_guard<std::recursive_mutex> lock(mutex);
int device;
C10_CUDA_CHECK(cudaGetDevice(&device));
@ -201,14 +205,29 @@ struct THCCachingAllocator
Block search_key(device, stream, size);
auto& pool = get_pool(size);
Block* block = nullptr;
Block* remaining = nullptr;
auto find_free_block = [&]()->Block*{
auto it = pool.lower_bound(&search_key);
if (it != pool.end() && (*it)->device == device &&
(*it)->stream == stream) {
Block* block = *it;
pool.erase(it);
return block;
}
return nullptr;
};
auto it = pool.lower_bound(&search_key);
if (it != pool.end() && (*it)->device == device && (*it)->stream == stream) {
block = *it;
pool.erase(it);
} else {
Block* block = find_free_block();
if (block == nullptr) {
bool freed_memory = false;
for (const auto& name : FreeCudaMemoryCallbacksRegistry()->Keys()) {
freed_memory |=
FreeCudaMemoryCallbacksRegistry()->Create(name)->Execute();
}
if (freed_memory) {
block = find_free_block();
}
}
if (block == nullptr) {
void* ptr;
size_t alloc_size = get_allocation_size(size);
cudaError_t err = cuda_malloc_retry(device, &ptr, alloc_size);
@ -253,8 +272,10 @@ struct THCCachingAllocator
block = new Block(device, stream, alloc_size, &pool, ptr);
}
Block* remaining = nullptr;
AT_ASSERT(block);
if (should_split(block, size)) {
remaining = block;
block = new Block(device, stream, size, &pool, block->ptr);
@ -280,7 +301,7 @@ struct THCCachingAllocator
void free(void* ptr)
{
std::lock_guard<std::mutex> lock(mutex);
std::lock_guard<std::recursive_mutex> lock(mutex);
if (!ptr) {
return;
}
@ -305,14 +326,14 @@ struct THCCachingAllocator
/** returns cached blocks to the system allocator */
void emptyCache()
{
std::lock_guard<std::mutex> lock(mutex);
std::lock_guard<std::recursive_mutex> lock(mutex);
free_blocks(large_blocks, large_blocks.begin(), large_blocks.end());
free_blocks(small_blocks, small_blocks.begin(), small_blocks.end());
}
void* getBaseAllocation(void* ptr, size_t* outSize)
{
std::lock_guard<std::mutex> lock(mutex);
std::lock_guard<std::recursive_mutex> lock(mutex);
Block* block = find_allocated_block(ptr);
if (!block) {
AT_ERROR("invalid device pointer: %p", ptr);
@ -348,14 +369,14 @@ struct THCCachingAllocator
void cacheInfo(int dev_id, size_t* total, size_t* largest)
{
std::lock_guard<std::mutex> lock(mutex);
std::lock_guard<std::recursive_mutex> lock(mutex);
cacheInfoAux(large_blocks, dev_id, total, largest);
cacheInfoAux(small_blocks, dev_id, total, largest);
}
void recordStream(void* ptr, cuda::CUDAStream stream)
{
std::lock_guard<std::mutex> lock(mutex);
std::lock_guard<std::recursive_mutex> lock(mutex);
Block* block = find_allocated_block(ptr);
if (!block) {
AT_ERROR("invalid device pointer: %p", ptr);

View File

@ -4,10 +4,24 @@
#include <c10/cuda/CUDAStream.h>
#include <c10/core/Allocator.h>
#include <c10/cuda/CUDAMacros.h>
#include <c10/util/Registry.h>
#include <mutex>
namespace c10 {
// Caching allocator will execute every registered callback if it unable to find
// block inside of already allocated area.
class C10_CUDA_API FreeMemoryCallback {
public:
virtual ~FreeMemoryCallback() {};
virtual bool Execute() = 0;
};
C10_DECLARE_REGISTRY(FreeCudaMemoryCallbacksRegistry, FreeMemoryCallback);
#define REGISTER_FREE_MEMORY_CALLBACK(name, ...) \
C10_REGISTER_CLASS(FreeCudaMemoryCallbacksRegistry, name, __VA_ARGS__);
namespace cuda {
// TODO: Turn this into an honest to goodness class. I briefly attempted to do

View File

@ -28,57 +28,65 @@ Python 2 can only create subprocesses using ``fork``, and it's not supported
by the CUDA runtime.
Unlike CPU tensors, the sending process is required to keep the original tensor
as long as the receiving process retains a copy of the tensor.
This shouldn't be a problem for sharing model parameters (which stay live
for the entire execution of the model), but passing other
kinds of data should be done with care.
as long as the receiving process retains a copy of the tensor. It is implemented
under the hood but requires users to follow the next best practices.
Here is an example program which handles these requirements correctly:
1. Release memory ASAP in the consumer.
::
import torch
import torch.multiprocessing as mp
## Good
x = queue.get()
# do somethings with x
del x
torch.set_default_tensor_type(torch.cuda.FloatTensor)
::
def sender(q, e):
for i in range(10):
s_sample = [torch.zeros(1), torch.ones(1)]
q.put(s_sample)
e.wait()
del s_sample
e.clear()
## Bad
x = queue.get()
# do somethings with x
# do everything else (producer have to keep x in memory)
if __name__ == "__main__":
ctx = mp.get_context("spawn")
q = ctx.Queue()
e = ctx.Event()
p = ctx.Process(target=sender, args=(q, e))
p.start()
2. Keep producer process running until all consumers exits. This will prevent
the situation when the producer process releasing memory which is still in use
by the consumer.
for i in range(10):
print('=== ITER {} ===".format(i))
r_sample = q.get()
del r_sample
e.set()
::
p.join()
## producer
# send tensors, do something
event.wait()
In the example above, calling `e.wait()`
on sender side ensures tensor `s_sample` doesn't get deleted while
receiver is working on it. The receiver signals when it is done
with the tensor using `e.set()`, being careful to `del` its reference
to the received tensor first. It is INSUFFICIENT to promise never to call
`r_sample` again; while `r_sample` is live, it may be confused with
any subsequent tensors allocated by the source process at the same address.
If a receiver wants to save the data of `r_sample` for future use while
letting the source process deallocate the original, it must
`clone()` it.
::
## consumer
# receive tensors and use them
event.set()
3. Don't pass received tensors.
::
# not going to work
x = queue.get()
queue_2.put(x)
::
# you need to create a process-local copy
x = queue.get()
x_clone = x.clone()
queue_2.put(x_clone)
::
# putting and getting from the same queue in the same process will likely end up with segfault
queue.put(tensor)
x = queue.get()
This behavior is very confusing, and we are tracking a fix for it
at https://github.com/pytorch/pytorch/issues/16141
Sharing strategies
------------------

View File

@ -12,7 +12,7 @@ import torch.multiprocessing as mp
import torch.utils.hooks
from torch.nn import Parameter
from common_utils import (TestCase, run_tests, IS_WINDOWS, NO_MULTIPROCESSING_SPAWN, TEST_WITH_ASAN,
load_tests)
load_tests, slowTest)
from multiprocessing.reduction import ForkingPickler
# load_tests from common_utils is used to automatically filter tests for
@ -56,6 +56,30 @@ def send_tensor(queue, event, tp):
event.wait()
def send_and_delete_tensors(queue, event, tp, count, size=5):
for i in range(count):
t = torch.full([size], i).type(tp)
queue.put(t)
del t
event.wait()
def receive_and_send_sum(queue, out_queue, event, tp, count, size=5):
s = torch.full([size], 0).type(tp)
for i in range(count):
t = queue.get()
s += t
out_queue.put(s)
event.wait()
def receive_and_send(queue, out_queue, event, count):
for i in range(count):
t = queue.get()
out_queue.put(t.clone())
event.wait()
def call_backward():
x = torch.randn(3, 3, requires_grad=True)
x.sum().backward()
@ -150,6 +174,8 @@ class leak_checker(object):
return self
def __exit__(self, *args):
if torch.cuda.is_available():
torch.cuda.ipc_collect()
if args[0] is None:
# Check that the 10th available file-descriptor at the end of the
# test is no more than 4 higher than the 10th available at the
@ -193,6 +219,11 @@ class leak_checker(object):
class TestMultiprocessing(TestCase):
def tearDown(self):
# This will keep tests isolated from each-other
if torch.cuda.is_available():
torch.cuda.ipc_collect()
def _test_sharing(self, ctx=mp, type=torch.FloatTensor, repeat=1):
def test_fill():
x = torch.zeros(5, 5).type(type)
@ -222,6 +253,9 @@ class TestMultiprocessing(TestCase):
t2 = q.get()
self.assertTrue(t1.eq(1).all())
self.assertTrue(id(t1.storage()) == id(t2.storage()))
# We need to delete this tensors to allow producer (child process)
# collect them properly
del t1, t2
e.set()
p.join(1)
self.assertFalse(p.is_alive())
@ -322,10 +356,55 @@ class TestMultiprocessing(TestCase):
@unittest.skipIf(NO_MULTIPROCESSING_SPAWN, "Disabled for environments that \
don't support multiprocessing with spawn start method")
@unittest.skipIf(not TEST_CUDA_IPC, 'CUDA IPC not available')
def test_cuda(self):
def test_cuda_simple(self):
torch.cuda.FloatTensor([1]) # initialize CUDA outside of leak checker
self._test_sharing(mp.get_context('spawn'), torch.cuda.FloatTensor)
@unittest.skipIf(NO_MULTIPROCESSING_SPAWN, "Disabled for environments that \
don't support multiprocessing with spawn start method")
@unittest.skipIf(not TEST_CUDA_IPC, 'CUDA IPC not available')
def test_cuda_memory_allocation(self):
ctx = mp.get_context('spawn')
q = ctx.Queue()
e = ctx.Event()
p = ctx.Process(target=send_and_delete_tensors, args=(q, e, torch.cuda.IntTensor, 5))
p.start()
t = []
for _ in range(5):
t.append(q.get())
self.assertEqual(t[0], torch.full([5], 0))
del t
e.set()
p.join(1)
@slowTest
@unittest.skipIf(NO_MULTIPROCESSING_SPAWN, "Disabled for environments that \
don't support multiprocessing with spawn start method")
@unittest.skipIf(not TEST_CUDA_IPC, 'CUDA IPC not available')
def test_cuda_send_many(self, name=None, size=5, count=100000):
ctx = mp.get_context('spawn')
q1 = ctx.Queue()
q2 = ctx.Queue()
q3 = ctx.Queue()
e1 = ctx.Event()
e2 = ctx.Event()
e3 = ctx.Event()
p1 = ctx.Process(target=send_and_delete_tensors, args=(q1, e1, torch.cuda.LongTensor, count, size))
p2 = ctx.Process(target=receive_and_send, args=(q1, q2, e2, count))
p3 = ctx.Process(target=receive_and_send_sum, args=(q2, q3, e3, torch.cuda.LongTensor, count, size))
p1.start()
p2.start()
p3.start()
result = q3.get()
self.assertEqual(result[0], int(count * (count - 1) / 2))
del result
e1.set()
e2.set()
e3.set()
p1.join(1)
p2.join(1)
p3.join(1)
@unittest.skipIf(NO_MULTIPROCESSING_SPAWN, "Disabled for environments that \
don't support multiprocessing with spawn start method")
@unittest.skipIf(not TEST_CUDA_IPC, 'CUDA IPC not available')
@ -355,6 +434,7 @@ class TestMultiprocessing(TestCase):
self.assertEqual(v, torch.arange(i * 5., (i + 1) * 5).sum())
self.assertEqual(device, i % 2)
self.assertEqual(tensor_size, 5)
# You might think this should be the case, but it's not! After
# data from the CUDA caching allocator goes through IPC, the
# size of the storage is the size of the *cached cudaMalloc for
@ -363,6 +443,15 @@ class TestMultiprocessing(TestCase):
#
# self.assertEqual(storage_size, 5)
# Collect current process (producer) files, make sure nothing holds
# ref to the sent tensors
del _tensor
del tensors
# We need to collect, as CUDA MP implementation holds one shared
# memory 'file' for performance reason
torch.cuda.ipc_collect()
@unittest.skipIf(IS_WINDOWS, 'not applicable to Windows (only fails with fork)')
@unittest.skipIf(not torch.cuda.is_available(), 'CUDA not available')
def test_cuda_bad_call(self):

View File

@ -489,6 +489,7 @@ if (BUILD_PYTHON)
endif()
set(TORCH_PYTHON_SRCS
${TORCH_SRC_DIR}/csrc/CudaIPCTypes.cpp
${TORCH_SRC_DIR}/csrc/DataLoader.cpp
${TORCH_SRC_DIR}/csrc/Device.cpp
${TORCH_SRC_DIR}/csrc/Dtype.cpp

240
torch/csrc/CudaIPCTypes.cpp Normal file
View File

@ -0,0 +1,240 @@
#ifdef USE_CUDA
#include <torch/csrc/CudaIPCTypes.h>
#include <TH/THAllocator.h>
#include <map>
#include <mutex>
#include <random>
#ifdef _MSC_VER
#include <windows.h>
#else
#include <sys/types.h>
#include <unistd.h>
#endif
namespace torch {
namespace {
void warnProducerTerminatedBeforeSharedTensorsReleased() {
static bool warned = false;
if (!warned) {
LOG(WARNING)
<< "Producer process has been terminated before all shared CUDA tensors released. See Note [Sharing CUDA tensors]";
warned = true;
}
}
struct CudaIPCGlobalEntities {
std::mutex ref_counters_mutex_;
std::atomic<int64_t> sync_events_used_;
std::map<std::string, std::shared_ptr<CudaIPCRefCountersFile>>
ref_counters_files_;
std::shared_ptr<CudaIPCRefCountersFile> next_available_ref_counters_file_;
CudaIPCSentDataLimbo CudaIPCSentDataLimbo_;
CudaIPCGlobalEntities() : ref_counters_files_() {}
~CudaIPCGlobalEntities() {
CudaIPCSentDataLimbo_.collect();
safe_clean_current_file();
if (next_available_ref_counters_file_) {
warnProducerTerminatedBeforeSharedTensorsReleased();
}
}
void safe_clean_current_file() {
std::lock_guard<std::mutex> lock(ref_counters_mutex_);
if (next_available_ref_counters_file_ &&
next_available_ref_counters_file_->offsets_in_use() == 0) {
ref_counters_files_.erase(next_available_ref_counters_file_->handle());
next_available_ref_counters_file_.reset();
}
}
};
CudaIPCGlobalEntities cuda_ipc_global_entities;
CudaIPCSentDataLimbo::~CudaIPCSentDataLimbo() {
collect();
if (shared_blocks_.size() > 0) {
warnProducerTerminatedBeforeSharedTensorsReleased();
}
}
bool CudaIPCSentDataLimbo::collect() {
bool freed_memory = false;
std::lock_guard<std::mutex> lock(limbo_mutex_);
std::vector<std::unique_ptr<CudaIPCSentData>> kept_blocks;
for (auto& sd : shared_blocks_) {
if (sd->counter_value() > 0) {
kept_blocks.push_back(std::move(sd));
} else {
freed_memory = true;
sd.reset();
}
}
shared_blocks_ = std::move(kept_blocks);
return freed_memory;
}
void CudaIPCSentDataLimbo::add(std::unique_ptr<CudaIPCSentData> shared_block) {
std::lock_guard<std::mutex> lock(limbo_mutex_);
static bool warned = false;
if (shared_blocks_.size() > CUDA_IPC_WARN_AFTER_X_BLOCKS_IN_LIMBO &&
!warned) {
LOG(WARNING)
<< "Producer process tried to deallocate over "
<< CUDA_IPC_WARN_AFTER_X_BLOCKS_IN_LIMBO
<< " memory blocks referred by consumer processes. Deallocation might be significantly slowed down. "
<< "We assume it will never going to be the case, but if it is, please file but to https://github.com/pytorch/pytorch";
warned = true;
}
shared_blocks_.push_back(std::move(shared_block));
}
void CudaIPCSentDataDelete(void* ptr) {
std::unique_ptr<CudaIPCSentData> sent_data(
static_cast<CudaIPCSentData*>(ptr));
if (sent_data->counter_value() > 0) {
cuda_ipc_global_entities.CudaIPCSentDataLimbo_.add(std::move(sent_data));
}
cuda_ipc_global_entities.CudaIPCSentDataLimbo_.collect();
}
void ReturnRefCounter(const std::string& handle, uint64_t offset /* unused */) {
std::lock_guard<std::mutex> lock(
cuda_ipc_global_entities.ref_counters_mutex_);
cuda_ipc_global_entities.ref_counters_files_[handle]->return_offset(offset);
if (cuda_ipc_global_entities.ref_counters_files_[handle]->offsets_in_use() ==
0 &&
!cuda_ipc_global_entities.ref_counters_files_[handle]->have_offsets()) {
cuda_ipc_global_entities.ref_counters_files_.erase(handle);
}
}
} // namespace
CudaIPCSentData::CudaIPCSentData(
std::string handle,
int64_t offset,
int64_t* counter_ptr,
at::Device device)
: handle_(handle),
offset_(offset),
counter_ptr_(counter_ptr),
original_ptr_(),
device_(device) {
#ifndef __HIP_PLATFORM_HCC__
// CUDA have the unofficial limit on the number of recorded blocking interprocess
// events, to prevent using of all events, we are switching to StreamSync
// before limit reached.
//
// ```python
// import torch
// a = [ torch.cuda.Event(
// enable_timing=False, blocking=True, interprocess=True) for i in range(30000) ]
// [i.record() for i in a]
// ```
//
if (cuda_ipc_global_entities.sync_events_used_.load() < CUDA_IPC_MAXIMUM_EVENTS_TO_USE) {
// TODO: More efficient would be to create event inside of main thread (at
// the moment of the queue.put). The reason this is more efficient is
// because the main thread may have queued extra work on the stream, which
// this event will consequently wait for (uselessly).
cuda_ipc_global_entities.sync_events_used_ ++;
C10_CUDA_CHECK(cudaEventCreateWithFlags(
&event_,
cudaEventDisableTiming | cudaEventInterprocess |
cudaEventBlockingSync));
C10_CUDA_CHECK(cudaEventRecord(
event_, c10::cuda::getCurrentCUDAStream(device.index())));
event_sync_required_ = true;
} else {
auto stream = c10::cuda::getCurrentCUDAStream(device.index());
C10_CUDA_CHECK(cudaStreamSynchronize(stream));
event_sync_required_ = false;
}
#else
// cuIpcGetEventHandle with HIP is not supported, so we have to sync
// stream instead of passing event
auto stream = c10::cuda::getCurrentCUDAStream(device.index());
C10_CUDA_CHECK(cudaStreamSynchronize(stream));
event_sync_required_ = false;
#endif
}
CudaIPCSentData::~CudaIPCSentData() {
ReturnRefCounter(handle_, offset_);
#ifndef __HIP_PLATFORM_HCC__
try {
if (event_sync_required_) {
at::cuda::CUDAGuard device_guard(device_.index());
cudaEventDestroy(event_);
cuda_ipc_global_entities.sync_events_used_ --;
}
} catch (...) { /* No throw */
}
#endif
}
int64_t CudaIPCSentData::counter_value() {
return *counter_ptr_;
}
at::DataPtr GetNewRefCountedSentData(void* data, at::Device device) {
{
std::lock_guard<std::mutex> lock(
cuda_ipc_global_entities.ref_counters_mutex_);
if (!cuda_ipc_global_entities.next_available_ref_counters_file_) {
static std::random_device rd;
std::string ref_counter_handle = "/torch_";
#ifdef _MSC_VER
ref_counter_handle += std::to_string(GetCurrentProcessId());
#else
ref_counter_handle += std::to_string(getpid());
#endif
ref_counter_handle += "_";
ref_counter_handle += std::to_string(rd());
int flags = TH_ALLOCATOR_MAPPED_SHAREDMEM | TH_ALLOCATOR_MAPPED_EXCLUSIVE;
at::DataPtr sptr = THRefcountedMapAllocator::makeDataPtr(
ref_counter_handle.c_str(),
flags,
sizeof(int64_t) * CUDA_IPC_REF_COUNTER_FILE_SIZE,
nullptr);
auto rc = std::make_shared<CudaIPCRefCountersFile>(
ref_counter_handle, CUDA_IPC_REF_COUNTER_FILE_SIZE, std::move(sptr));
cuda_ipc_global_entities.ref_counters_files_[ref_counter_handle] = rc;
cuda_ipc_global_entities.next_available_ref_counters_file_ = rc;
}
}
cuda_ipc_global_entities.next_available_ref_counters_file_->set_counter(1);
auto sent_data = new CudaIPCSentData(
cuda_ipc_global_entities.next_available_ref_counters_file_->handle(),
cuda_ipc_global_entities.next_available_ref_counters_file_->get_offset(),
cuda_ipc_global_entities.next_available_ref_counters_file_->counter_ptr(),
device);
cuda_ipc_global_entities.next_available_ref_counters_file_->rotate_offset();
if (!cuda_ipc_global_entities.next_available_ref_counters_file_
->have_offsets()) {
cuda_ipc_global_entities.next_available_ref_counters_file_.reset();
}
return at::DataPtr(data, sent_data, CudaIPCSentDataDelete, device);
}
bool CudaIPCCollect() {
bool freed_memory = cuda_ipc_global_entities.CudaIPCSentDataLimbo_.collect();
if (cuda_ipc_global_entities.CudaIPCSentDataLimbo_.size() == 0) {
cuda_ipc_global_entities.safe_clean_current_file();
}
return freed_memory;
}
} // namespace torch
namespace c10 {
namespace {
REGISTER_FREE_MEMORY_CALLBACK("cuda_ipc_collect", CudaIPCCollectCallback);
}
} // namespace c10
#endif

146
torch/csrc/CudaIPCTypes.h Normal file
View File

@ -0,0 +1,146 @@
#pragma once
#ifdef USE_CUDA
#include <c10/core/Allocator.h>
#include <c10/cuda/CUDACachingAllocator.h>
#include <c10/cuda/CUDAException.h>
#include <c10/cuda/CUDAGuard.h>
#include <c10/cuda/CUDAStream.h>
#include <c10/util/Logging.h>
#include <cuda_runtime_api.h>
#include <cstddef>
namespace torch {
bool CudaIPCCollect();
struct CudaIPCReceivedData final {
explicit CudaIPCReceivedData(std::shared_ptr<void> shared_ptr)
: shared_ptr_(std::move(shared_ptr)) {}
std::shared_ptr<void> shared_ptr_;
};
struct CudaIPCSentData final {
std::string handle_;
int64_t offset_;
int64_t* counter_ptr_; // Reference counter shared memory block
at::DataPtr original_ptr_; // Original mem allocation
cudaEvent_t event_; // Sync cuEventDestroy
bool event_sync_required_;
at::Device device_;
CudaIPCSentData(
std::string handle,
int64_t offset,
int64_t* counter_ptr,
at::Device device);
~CudaIPCSentData();
int64_t counter_value();
std::string handle() {
return handle_;
}
int64_t offset() {
return offset_;
}
void set_original_ptr(at::DataPtr data_ptr) {
original_ptr_ = std::move(data_ptr);
}
};
at::DataPtr GetNewRefCountedSentData(void* data, at::Device device);
namespace {
constexpr int64_t CUDA_IPC_REF_COUNTER_FILE_SIZE = 10000;
constexpr int64_t CUDA_IPC_WARN_AFTER_X_BLOCKS_IN_LIMBO = 1000;
// This was determined empirically that CUDA (v10.1 and below) have the limit
// on the number of recorded blocking interprocess events. It is around ~22,000.
// And to give us leeway, we picked 1000 as it gives us enough events to share
// tensors effectively.
constexpr int64_t CUDA_IPC_MAXIMUM_EVENTS_TO_USE = 1000;
// All to be deleted data blocks with non zero reference counter goes there
struct CudaIPCSentDataLimbo final {
~CudaIPCSentDataLimbo();
bool collect();
void add(std::unique_ptr<CudaIPCSentData> shared_block);
uint64_t size() {
return shared_blocks_.size();
}
private:
// TODO: Can be changed to FIFO in order to avoid full traverse on every
// collect()
std::vector<std::unique_ptr<CudaIPCSentData>> shared_blocks_;
std::mutex limbo_mutex_;
};
struct CudaIPCRefCountersFile final {
CudaIPCRefCountersFile(
std::string handle,
uint64_t size,
at::DataPtr data_ptr)
: next_offset_(0),
size_(size),
used_slots_(0),
handle_(handle),
refcounted_shared_mem_(std::move(data_ptr)) {}
int64_t* counter_ptr() {
return static_cast<int64_t*>(refcounted_shared_mem_.get()) + next_offset_;
}
void set_counter(uint64_t value) {
*counter_ptr() = value;
}
bool have_offsets() {
return next_offset_ < size_;
}
bool offsets_in_use() {
return used_slots_;
}
int64_t get_offset() {
return next_offset_;
}
void rotate_offset() {
next_offset_++;
used_slots_++;
}
void return_offset(uint64_t offset /* unused */) {
used_slots_--;
}
std::string handle() {
return handle_;
}
private:
uint64_t next_offset_;
uint64_t size_;
uint64_t used_slots_;
std::string handle_;
at::DataPtr refcounted_shared_mem_;
};
} // namespace
} // namespace torch
namespace c10 {
namespace {
class CudaIPCCollectCallback : public FreeMemoryCallback {
public:
~CudaIPCCollectCallback() {};
bool Execute() override {
return torch::CudaIPCCollect();
}
};
} // namespace
} // namespace c10
#endif

View File

@ -16,6 +16,7 @@
#include <torch/csrc/THP.h>
#include <torch/csrc/copy_utils.h>
#include <torch/csrc/DynamicTypes.h>
#include <torch/csrc/CudaIPCTypes.h>
#include <torch/csrc/generic/Storage.cpp>
#include <TH/THGenerateAllTypes.h>

View File

@ -13,7 +13,7 @@
#endif
#include <torch/csrc/cuda/THCP.h>
#include <torch/csrc/CudaIPCTypes.h>
#include <torch/csrc/utils/pybind.h>
#include <torch/csrc/autograd/generated/VariableType.h>
#include <torch/csrc/utils/python_strings.h>
@ -217,6 +217,14 @@ PyObject * THCPModule_cudaSynchronize(PyObject *_unused)
END_HANDLE_TH_ERRORS
}
PyObject * THCPModule_cudaIPCCollect(PyObject *_unused /* unused */)
{
HANDLE_TH_ERRORS
torch::CudaIPCCollect();
Py_RETURN_NONE;
END_HANDLE_TH_ERRORS
}
PyObject * THCPModule_cudaSleep(PyObject *_unused, PyObject *cycles)
{
HANDLE_TH_ERRORS
@ -453,6 +461,7 @@ static struct PyMethodDef _THCPModule_methods[] = {
{"_cuda_initialSeed", (PyCFunction)THCPModule_initialSeed, METH_NOARGS, nullptr},
{"_cuda_cudaHostAllocator", (PyCFunction)THCPModule_cudaHostAllocator, METH_NOARGS, nullptr},
{"_cuda_synchronize", (PyCFunction)THCPModule_cudaSynchronize, METH_NOARGS, nullptr},
{"_cuda_ipc_collect", (PyCFunction)THCPModule_cudaIPCCollect, METH_NOARGS, nullptr},
{"_cuda_sleep", (PyCFunction)THCPModule_cudaSleep, METH_O, nullptr},
{"_cuda_lock_mutex", (PyCFunction)THCPModule_cudaLockMutex, METH_NOARGS, nullptr},
{"_cuda_unlock_mutex", (PyCFunction)THCPModule_cudaUnlockMutex, METH_NOARGS, nullptr},

View File

@ -12,6 +12,7 @@
#include <torch/csrc/cuda/override_macros.h>
#include <torch/csrc/copy_utils.h>
#include <torch/csrc/DynamicTypes.h>
#include <torch/csrc/CudaIPCTypes.h>
#define THC_GENERIC_FILE "torch/csrc/generic/Storage.cpp"
#include <THC/THCGenerateAllTypes.h>

View File

@ -216,13 +216,26 @@ static PyObject * THPStorage_(shareCuda)(THPStorage *self)
{
HANDLE_TH_ERRORS
THWStorage *storage = self->cdata;
if (storage->received_cuda()) {
AT_ERROR(
"Attempted to send CUDA tensor received from another process; this is not currently supported. Consider cloning before sending.");
}
at::DeviceGuard device_guard(storage->device());
THPObjectPtr tuple(PyTuple_New(4));
THPObjectPtr tuple(PyTuple_New(8));
THPObjectPtr device(PyLong_FromLong(storage->device().index()));
THPObjectPtr _handle(Py_None);
Py_INCREF(Py_None);
THPObjectPtr size_bytes(PyLong_FromLong(storage->numel() * sizeof(scalar_t)));
THPObjectPtr _offset_bytes(PyLong_FromLong(0));
THPObjectPtr _ref_counter(Py_None);
Py_INCREF(Py_None);
THPObjectPtr _ref_counter_offset(PyLong_FromLong(0));
THPObjectPtr _event_handle(Py_None);
Py_INCREF(Py_None);
THPObjectPtr _event_sync_required(Py_None);
Py_INCREF(Py_None);
if (THWStorage_(data)(LIBRARY_STATE storage)) {
size_t base_size;
void *base_ptr = c10::cuda::CUDACachingAllocator::getBaseAllocation(THWStorage_(data)(LIBRARY_STATE storage), &base_size);
@ -233,9 +246,33 @@ static PyObject * THPStorage_(shareCuda)(THPStorage *self)
_handle = PyBytes_FromStringAndSize((char *)&handle, CUDA_IPC_HANDLE_SIZE);
_offset_bytes = PyLong_FromSsize_t((Py_ssize_t)offset_bytes);
// Put Storage Data behind new ref counting context
// See Note [CUDA IPC Refcounting implementation explained]
at::DataPtr sent_data_ptr = torch::GetNewRefCountedSentData(storage->data(), storage->device());
auto old_data_ptr = storage->set_data_ptr(std::move(sent_data_ptr));
auto sent_data = static_cast<torch::CudaIPCSentData*>(storage->data_ptr().get_context());
sent_data->set_original_ptr(std::move(old_data_ptr));
_ref_counter = PyBytes_FromString((sent_data->handle()).c_str());
_ref_counter_offset = PyLong_FromLong(sent_data->offset());
cudaIpcEventHandle_t ipc_event_handle;
#ifndef __HIP_PLATFORM_HCC__
if (sent_data->event_sync_required_) {
THCudaCheck(cudaIpcGetEventHandle(&ipc_event_handle, sent_data->event_));
}
#else
// ipc_event_handle unused in storage receiver, we can leave it uninitialized.
#endif
_event_handle = PyBytes_FromStringAndSize((char *)&ipc_event_handle, CUDA_IPC_HANDLE_SIZE);
_event_sync_required = PyBool_FromLong(sent_data->event_sync_required_);
}
if (!tuple || !device || !_handle || !size_bytes || !_offset_bytes) {
if (!tuple || !device || !_handle || !size_bytes || !_offset_bytes || !_event_handle) {
return nullptr;
}
PyTuple_SET_ITEM(tuple.get(), 0, device.release());
@ -248,40 +285,111 @@ static PyObject * THPStorage_(shareCuda)(THPStorage *self)
// as key in shared_cache(multiprocessing/reduction.py).
// Offset in numel cannot uniquely represent a storage.
PyTuple_SET_ITEM(tuple.get(), 3, _offset_bytes.release());
PyTuple_SET_ITEM(tuple.get(), 4, _ref_counter.release());
PyTuple_SET_ITEM(tuple.get(), 5, _ref_counter_offset.release());
PyTuple_SET_ITEM(tuple.get(), 6, _event_handle.release());
PyTuple_SET_ITEM(tuple.get(), 7, _event_sync_required.release());
return tuple.release();
END_HANDLE_TH_ERRORS
}
static PyObject * THPStorage_(releaseIPCCounter)(PyObject *_unused, PyObject *args)
{
HANDLE_TH_ERRORS
THPUtils_assert(PyTuple_GET_SIZE(args) == 2, "tuple of 2 items expected");
PyObject *_ref_counter = PyTuple_GET_ITEM(args, 0);
PyObject *_ref_counter_offset = PyTuple_GET_ITEM(args, 1);
if (!(PyBytes_Check(_ref_counter) &&
THPUtils_checkLong(_ref_counter_offset))) {
THPUtils_invalidArguments(
args,
nullptr,
"_release_ipc_counter in CUDA mode",
1,
"(bytes _ref_counter, int _ref_counter_offset)");
return nullptr;
}
std::string ref_counter_handle = PyBytes_AS_STRING(_ref_counter);
ptrdiff_t ref_counter_offset =
(ptrdiff_t)THPUtils_unpackLong(_ref_counter_offset);
// We don't want to break existing code, so resource deletion is best
// effort basis. Exception expected if producer process terminated
// before consumer released data.
int flags =
TH_ALLOCATOR_MAPPED_SHAREDMEM | TH_ALLOCATOR_MAPPED_NOCREATE;
try {
auto sptr = THRefcountedMapAllocator::makeDataPtr(
ref_counter_handle.c_str(),
flags,
sizeof(int64_t) * torch::CUDA_IPC_REF_COUNTER_FILE_SIZE,
nullptr);
*(static_cast<int64_t*>(sptr.get()) + ref_counter_offset) -= 1;
} catch (c10::Error) {
// Already warned inside of producer process
}
Py_RETURN_NONE;
END_HANDLE_TH_ERRORS
}
static std::string THPStorage_(bytesAsHandleString)(PyObject *handle) {
char* buffer;
Py_ssize_t handle_size;
if (PyBytes_AsStringAndSize(handle, &buffer, &handle_size) == -1) {
return nullptr;
}
THPUtils_assert(
handle_size == CUDA_IPC_HANDLE_SIZE, "incorrect handle size");
return std::string(buffer, handle_size);
}
static PyObject * THPStorage_(newSharedCuda)(PyObject *_unused, PyObject *args)
{
HANDLE_TH_ERRORS
THPUtils_assert(PyTuple_GET_SIZE(args) == 4, "tuple of 4 items expected");
THPUtils_assert(PyTuple_GET_SIZE(args) == 8, "tuple of 8 items expected");
PyObject *_device = PyTuple_GET_ITEM(args, 0);
PyObject *_handle = PyTuple_GET_ITEM(args, 1);
PyObject *_size_bytes = PyTuple_GET_ITEM(args, 2);
PyObject *_offset_bytes = PyTuple_GET_ITEM(args, 3);
if (!(THPUtils_checkLong(_device) && THPUtils_checkLong(_size_bytes)
&& (_handle != Py_None && PyBytes_Check(_handle))
&& THPUtils_checkLong(_offset_bytes))) {
THPUtils_invalidArguments(args, nullptr, "_new_shared in CUDA mode", 1,
"(int device, bytes handle, int storage_size_bytes, int storage_offset_bytes)");
PyObject *_ref_counter = PyTuple_GET_ITEM(args, 4);
PyObject *_ref_counter_offset = PyTuple_GET_ITEM(args, 5);
PyObject *_event_handle = PyTuple_GET_ITEM(args, 6);
PyObject *_event_sync_required = PyTuple_GET_ITEM(args, 7);
if (!(THPUtils_checkLong(_device) && THPUtils_checkLong(_size_bytes) &&
PyBytes_Check(_handle) && PyBytes_Check(_ref_counter) &&
PyBytes_Check(_event_handle) && THPUtils_checkLong(_offset_bytes) &&
THPUtils_checkLong(_ref_counter_offset) && PyBool_Check(_event_sync_required))) {
THPUtils_invalidArguments(
args,
nullptr,
"_new_shared in CUDA mode",
1,
"(int device, bytes handle, int storage_size_bytes, int storage_offset_bytes, bytes _ref_counter, int _ref_counter_offset, bytes event_handle, bool event_sync_required)");
return nullptr;
}
// Storage constructor requires size in numel.
size_t storage_size = (size_t)THPUtils_unpackLong(_size_bytes) / sizeof(scalar_t);
ptrdiff_t storage_offset_bytes = (ptrdiff_t)THPUtils_unpackLong(_offset_bytes);
int64_t device = THPUtils_unpackLong(_device);
at::cuda::CUDAGuard device_guard(device);
char *buffer;
Py_ssize_t handle_size;
if (PyBytes_AsStringAndSize(_handle, &buffer, &handle_size) == -1) {
return nullptr;
#ifndef __HIP_PLATFORM_HCC__
if (PyObject_IsTrue(_event_sync_required)) {
// Ensure that producer prepared all tensor's data
std::string s_ipc_event_handle =
THPStorage_(bytesAsHandleString)(_event_handle);
auto ipc_event_handle = reinterpret_cast<const cudaIpcEventHandle_t*>(
s_ipc_event_handle.c_str());
cudaEvent_t event;
cudaIpcOpenEventHandle(&event, *ipc_event_handle);
AT_CUDA_CHECK(
cudaStreamWaitEvent(c10::cuda::getCurrentCUDAStream(device), event, 0));
}
THPUtils_assert(handle_size == CUDA_IPC_HANDLE_SIZE, "incorrect handle size");
std::string s_handle = std::string(buffer, handle_size);
#else
// Already synchronized inside producer stream
#endif
std::string s_handle = THPStorage_(bytesAsHandleString)(_handle);
std::shared_ptr<void> basePtr = c10::cuda::CUDACachingAllocator::getIpcDevPtr(s_handle);
// Offset the basePtr to reconstruct the real storage
@ -289,11 +397,50 @@ static PyObject * THPStorage_(newSharedCuda)(PyObject *_unused, PyObject *args)
void* devPtr = basePtr.get();
devPtr = (char*)devPtr + storage_offset_bytes;
std::string ref_counter_handle = PyBytes_AS_STRING(_ref_counter);
ptrdiff_t ref_counter_offset = (ptrdiff_t)THPUtils_unpackLong(_ref_counter_offset);
auto c = new torch::CudaIPCReceivedData(std::move(basePtr));
auto sp = std::shared_ptr<void>(
(void*)c, [ref_counter_handle, ref_counter_offset, device](void* ptr) {
delete static_cast<torch::CudaIPCReceivedData*>(ptr);
// Sync default stream to make sure all operations related to the storage is
// finished (otherwise another process may reuse memory and corrupt
// data)
// Ideally all shared memory reference counting could be replaced by
// sending untriggered CUDA event from the producer to consumer and
// using this event as the criteria of memory release. However, CUDA (atm 10.1)
// does not support the creation of untriggered events and performance
// impact of having thousands of shared events is unknown.
// TODO: Instead of cudaStreamSynchronize it is possible to add Stream
// Callback and release counter inside of it (need to check performance impact)
cudaStreamSynchronize(c10::cuda::getCurrentCUDAStream(device));
// We don't want to break existing code, so resource deletion is best
// effort basis. Exception expected if producer process terminated
// before consumer released data.
int flags =
TH_ALLOCATOR_MAPPED_SHAREDMEM | TH_ALLOCATOR_MAPPED_NOCREATE;
try {
auto sptr = THRefcountedMapAllocator::makeDataPtr(
ref_counter_handle.c_str(),
flags,
sizeof(int64_t) * torch::CUDA_IPC_REF_COUNTER_FILE_SIZE,
nullptr);
*(static_cast<int64_t*>(sptr.get()) + ref_counter_offset) -= 1;
} catch (c10::Error) {
// Already warned inside of producer process
}
});
THWStoragePtr base(THWStorage_(newWithDataAndAllocator)(
LIBRARY_STATE
THCIpcDeleter::makeDataPtr(std::move(basePtr), devPtr),
THCIpcDeleter::makeDataPtr(std::move(sp), devPtr),
storage_size, /* allocator */ nullptr));
base->set_resizable(false);
base->set_received_cuda(true);
return THPStorage_(New)(base.release());
END_HANDLE_TH_ERRORS
@ -382,6 +529,7 @@ static PyMethodDef THPStorage_(sharingMethods)[] = {
#ifdef THC_GENERIC_FILE
{"_share_cuda_", (PyCFunction)THPStorage_(shareCuda), METH_NOARGS, nullptr},
{"_new_shared_cuda", (PyCFunction)THPStorage_(newSharedCuda), METH_VARARGS | METH_STATIC, nullptr},
{"_release_ipc_counter", (PyCFunction)THPStorage_(releaseIPCCounter), METH_VARARGS | METH_STATIC, nullptr},
#else
{"_share_fd_", (PyCFunction)THPStorage_(shareFd), METH_NOARGS, nullptr},
{"_new_shared_fd", (PyCFunction)THPStorage_(newSharedFd), METH_VARARGS | METH_STATIC, nullptr},

View File

@ -358,6 +358,19 @@ def synchronize():
return torch._C._cuda_synchronize()
def ipc_collect():
r"""Force collects GPU memory after it has been released by CUDA IPC.
.. note::
Checks if any sent CUDA tensors could be cleaned from the memory. Force
closes shared memory file used for reference counting if there is no
active counters. Useful when the producer process stopped actively sending
tensors and want to release unused memory.
"""
_lazy_init()
return torch._C._cuda_ipc_collect()
def current_stream(device=None):
r"""Returns the currently selected :class:`Stream` for a given device.

View File

@ -0,0 +1,32 @@
# CUDA IPC Refcounting implementation explained
Since shared CUDA memory belongs to the producer process, we need to take special precautions to make sure that it is stays allocated for entire shared tensor life-span.
It could be done manually by syncing on an event:
```python
# Producer
queue.put(tensor)
event.wait()
# Consumer
tensor = queue.get()
safe_to_use_tensor = tensor.clone()
event.set()
```
However, this requires blocking producer process (and gets overcomplicated in case of multiple consumers and handling various race-conditions).
Instead, we implement cross-process reference counting for shared CUDA (and HIP) tensors, which will take care of keeping producers memory allocated for entire tensor's life-span.
Details of implementation follow.
At the moment of sending tensor, we are wrapping DataPtr of the tensor with additional structure CudaIPCSentData. It still points to the same memory, but have other behavior on destruction.
Instead of simply removing the allocated block, it checks if there are any active references to this block (references are stored in shared memory files described by CudaIPCRefCountersFile structure). If such exists, instead of deleting blocks DataPtr it is moved to the global state CudaIPCSentDataLimbo.
Each individual CudaIPCRefCountersFile contains multiple reference counters for multiple tensors. Current implementation sequentially provides next available reference counter by increasing offset.
CudaIPCSentDataLimbo is keeping references to data blocks which are not in use by producer process (i.e., tensor when out of scope), but still in use (or will be in use) by a consumer. It also tries to reduce the number of stored blocks by scanning the limbo list for blocks whose ref count has gone to zero on various events such as CudaCaching allocator haven't found any suitable block for the next allocation, the attempt of any shared block deletion, explicit call of cuda_ipc_collect.
Consumer's side wraps received data into the different structure CudaIPCReceivedData. On destruction, it takes care of decreasing reference count to the received tensor.

View File

@ -87,7 +87,7 @@ def rebuild_tensor(cls, storage, metadata):
def rebuild_cuda_tensor(tensor_cls, tensor_size, tensor_stride, tensor_offset,
storage_cls, storage_device, storage_handle, storage_size_bytes, storage_offset_bytes,
requires_grad):
requires_grad, ref_counter_handle, ref_counter_offset, event_handle, event_sync_required):
# If storage_handle is None, storage points to nullptr.
if storage_handle is None or storage_size_bytes == 0:
storage = storage_cls(0)
@ -99,8 +99,15 @@ def rebuild_cuda_tensor(tensor_cls, tensor_size, tensor_stride, tensor_offset,
storage_device,
storage_handle,
storage_size_bytes,
storage_offset_bytes)
storage_offset_bytes,
ref_counter_handle,
ref_counter_offset,
event_handle,
event_sync_required)
shared_cache[(storage_handle, storage_offset_bytes)] = StorageWeakRef(storage)
else:
# We already ref counting this Storage, but producer needs new ref-counters to be released.
storage_cls._release_ipc_counter(ref_counter_handle, ref_counter_offset)
t = torch._utils._rebuild_tensor(storage, tensor_offset, tensor_size, tensor_stride)
if tensor_cls == torch.nn.parameter.Parameter:
@ -211,11 +218,16 @@ def reduce_tensor(tensor):
# thing.
#
if storage.is_cuda:
(device, handle, storage_size_bytes, storage_offset_bytes) = storage._share_cuda_()
(device,
handle,
storage_size_bytes,
storage_offset_bytes,
ref_counter_handle,
ref_counter_offset,
event_handle,
event_sync_required) = storage._share_cuda_()
tensor_offset = tensor.storage_offset()
shared_cache[handle] = StorageWeakRef(storage)
# _backward_hooks purposely omitted here, see
# Note [Don't serialize hooks]
return (rebuild_cuda_tensor,
@ -228,7 +240,11 @@ def reduce_tensor(tensor):
handle, # identifier which CUDA allocation is the storage in.
storage_size_bytes, # size(in bytes) of the storage
storage_offset_bytes, # offset(in bytes) of the storage in the CUDA allocation
tensor.requires_grad))
tensor.requires_grad,
ref_counter_handle,
ref_counter_offset,
event_handle,
event_sync_required))
# _backward_hooks purposely omitted here, see Note [Don't serialize hooks]
metadata = (tensor.storage_offset(), tensor.size(), tensor.stride(), tensor.requires_grad)