Free all blocks with outstanding events on OOM-retry (#19222)

Summary:
The caching allocator tries to free all blocks on an out-of-memory
error. Previously, it did not free blocks that still had outstanding
stream uses. This change synchronizes on the outstanding events and
frees those blocks.

See #19219
Pull Request resolved: https://github.com/pytorch/pytorch/pull/19222

Differential Revision: D14925071

Pulled By: colesbury

fbshipit-source-id: a2e9fe957ec11b00ea8e6c0468436c519667c558
This commit is contained in:
Sam Gross
2019-04-15 11:13:33 -07:00
committed by Facebook Github Bot
parent 86619b8ba6
commit 7caad0ed33
2 changed files with 51 additions and 0 deletions

View File

@ -328,6 +328,7 @@ struct THCCachingAllocator
void emptyCache()
{
std::lock_guard<std::recursive_mutex> lock(mutex);
synchronize_and_free_events(nullopt);
free_blocks(large_blocks, large_blocks.begin(), large_blocks.end());
free_blocks(small_blocks, small_blocks.begin(), small_blocks.end());
}
@ -478,6 +479,10 @@ struct THCCachingAllocator
void free_cached_blocks(int device)
{
// First ensure that all blocks that can't currently be allocated due to
// outstanding events are returned to the pool.
synchronize_and_free_events(device);
// Free all non-split cached blocks on device
Block lower_bound(device, nullptr, 0);
Block upper_bound(device + 1, nullptr, 0);
@ -511,6 +516,32 @@ struct THCCachingAllocator
}
}
void synchronize_and_free_events(optional<int> device) {
// Synchronize on outstanding events and then free associated blocks.
// Limited to blocks on the given device if specified.
auto remaining_events = decltype(cuda_events)();
for (auto& e : cuda_events) {
cudaEvent_t event = e.first;
Block* block = e.second;
if (device.has_value() && block->device != *device) {
remaining_events.push_back(e);
continue;
}
C10_CUDA_CHECK(cudaEventSynchronize(event));
C10_CUDA_CHECK(cudaEventDestroy(event));
block->event_count--;
if (block->event_count == 0) {
free_block(block);
}
}
std::swap(cuda_events, remaining_events);
}
Block* find_allocated_block(void *ptr) {
auto it = allocated_blocks.find(ptr);
if (it == allocated_blocks.end()) {

View File

@ -2049,6 +2049,26 @@ class TestCuda(TestCase):
self.assertEqual(gpu_tensor1[0], 1)
self.assertEqual(gpu_tensor0[0], 2)
def test_caching_allocator_record_stream_oom(self):
"""allocations delayed by a record_stream call should still be freed on
an out-of-memory in cuda_malloc_retry. see issue #19219"""
stream = torch.cuda.Stream()
with torch.cuda.stream(stream):
y = torch.zeros(40 * 1024 * 1024, device='cuda')
for _ in range(100):
x = torch.empty(40 * 1024 * 1024, device='cuda')
with torch.cuda.stream(stream):
y += x
# delays re-use of `x` until after all operations in `stream`
x.record_stream(stream)
del x
# we've made a mess by allocating up to the device capacity. free any
# cached blocks in case it affects future tests.
torch.cuda.empty_cache()
def test_reduction_gpu_memory_accessing(self):
x = torch.ones(512, 8, dtype=torch.float32, device='cuda')
torch.sum(x, 0)