mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Improve caching allocator for Pascal and newer GPUs. (#17120)
Summary: ``` NVIDIA changed the CUDA allocation behavior on Pascal GPUs. The page size increased from 1MB to 2MB and allocations larger than 1MB are now always page-aligned. Previously, allocations larger than 1MB were aligned to 128KB boundaries. This interacted poorly with the caching allocator. The remaining memory in a page could only be filled by small cudaMalloc calls, but the caching allocator never cudaMalloc's a chunk smaller than 1MB. This behavior could also cause a large discrepancy between the memory usage reported by nvidia-smi and the memory usage reported by PyTorch, because nvidia-smi counts a partially used page as "full", while PyTorch only counts the actual memory requested. This PR makes a few changes to the caching allocator to better support Pascal and Volta GPUs: - All cudaMalloc calls are now multiples of 2MB (the page size) - Requests between 1-10MB allocate (and split) a 20MB block to reduce wasted space due to rounding - Small requests are now packed into 2MB blocks (instead of 1MB) This improves Mask R-CNN memory usage by 10-20% in internal tests on Volta GPUs. Maxwell performance seems to be largely unchanged, but it's possible that some use cases suffer slightly. ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/17120 Differential Revision: D14301536 Pulled By: colesbury fbshipit-source-id: a8282315ea8f7b8ca149b5066fdeaecd0d404edf
This commit is contained in:
committed by
Facebook Github Bot
parent
8420a2025b
commit
079093a662
@ -30,9 +30,11 @@ namespace CUDACachingAllocator {
|
||||
// split. If no block is found, the allocator will delegate to cudaMalloc.
|
||||
// - If the cudaMalloc fails, the allocator will free all cached blocks that
|
||||
// are not split and retry the allocation.
|
||||
// - Large (>1MB) and small allocation requests are handled separately. Large
|
||||
// allocation requests can be filled by a cudaMalloc call of the exact size.
|
||||
// Small requests will allocate and split a 1MB buffer, if necessary.
|
||||
// - Large (>1MB) and small allocations are stored in separate pools.
|
||||
// Small requests are packed into 2MB buffers. Large requests will use the
|
||||
// smallest available free block or allocate a new block using cudaMalloc.
|
||||
// To reduce fragmentation, requests between 1MB and 10MB will allocate and
|
||||
// split a 20MB block, if no free block of sufficient size is available.
|
||||
//
|
||||
// With this allocator, allocations and frees should logically be considered
|
||||
// "usages" of the memory segment associated with streams, just like kernel
|
||||
@ -49,9 +51,12 @@ namespace {
|
||||
|
||||
using stream_set = std::unordered_set<cuda::CUDAStream>;
|
||||
|
||||
const size_t kRoundSmall = 512; // round up small allocs to 512 bytes
|
||||
const size_t kRoundLarge = 131072; // round up large allocs to 128 KiB
|
||||
const size_t kSmallAlloc = 1048576; // largest "small" allocation is 1 MiB
|
||||
constexpr size_t kMinBlockSize = 512; // all sizes are rounded to at least 512 bytes
|
||||
constexpr size_t kSmallSize = 1048576; // largest "small" allocation is 1 MiB
|
||||
constexpr size_t kSmallBuffer = 2097152; // "small" allocations are packed in 2 MiB blocks
|
||||
constexpr size_t kLargeBuffer = 20971520; // "large" allocations may be packed in 20 MiB blocks
|
||||
constexpr size_t kMinLargeAlloc = 10485760; // allocations between 1 and 10 MiB may use kLargeBuffer
|
||||
constexpr size_t kRoundLarge = 2097152; // round up large allocs to 2 MiB
|
||||
|
||||
struct DeviceStats {
|
||||
uint64_t amount_allocated; // total amount allocated in bytes
|
||||
@ -82,20 +87,30 @@ struct DeviceStats {
|
||||
}
|
||||
};
|
||||
|
||||
struct Block;
|
||||
typedef bool (*Comparison)(const Block*, const Block*);
|
||||
typedef std::set<Block*, Comparison> BlockPool;
|
||||
|
||||
struct Block {
|
||||
int device; // gpu
|
||||
cudaStream_t stream; // allocation stream
|
||||
stream_set stream_uses; // streams on which the block was used
|
||||
size_t size; // block size in bytes
|
||||
char* ptr; // memory address
|
||||
BlockPool* pool; // owning memory pool
|
||||
void* ptr; // memory address
|
||||
bool allocated; // in-use flag
|
||||
Block* prev; // prev block if split from a larger allocation
|
||||
Block* next; // next block if split from a larger allocation
|
||||
int event_count; // number of outstanding CUDA events
|
||||
|
||||
Block(int device, cudaStream_t stream, size_t size, char* ptr=NULL) :
|
||||
device(device), stream(stream), stream_uses(), size(size), ptr(ptr),
|
||||
allocated(0), prev(NULL), next(NULL), event_count(0) { }
|
||||
Block(int device, cudaStream_t stream, size_t size, BlockPool* pool, void* ptr) :
|
||||
device(device), stream(stream), stream_uses(), size(size), pool(pool),
|
||||
ptr(ptr), allocated(0), prev(nullptr), next(nullptr), event_count(0) { }
|
||||
|
||||
// constructor for search key
|
||||
Block(int device, cudaStream_t stream, size_t size) :
|
||||
device(device), stream(stream), stream_uses(), size(size), pool(nullptr),
|
||||
ptr(nullptr), allocated(0), prev(nullptr), next(nullptr), event_count(0) { }
|
||||
};
|
||||
|
||||
static bool BlockComparator(const Block* a, const Block* b)
|
||||
@ -135,9 +150,6 @@ static std::string format_size(uint64_t size) {
|
||||
|
||||
struct THCCachingAllocator
|
||||
{
|
||||
typedef bool (*Comparison)(const Block*, const Block*);
|
||||
typedef std::set<Block*, Comparison> FreeBlocks;
|
||||
|
||||
// device statistics
|
||||
std::vector<DeviceStats> device_stats;
|
||||
|
||||
@ -148,10 +160,10 @@ struct THCCachingAllocator
|
||||
std::mutex cuda_free_mutex;
|
||||
|
||||
// cached blocks larger than 1 MB
|
||||
FreeBlocks large_blocks;
|
||||
BlockPool large_blocks;
|
||||
|
||||
// cached blocks 1 MB or smaller
|
||||
FreeBlocks small_blocks;
|
||||
BlockPool small_blocks;
|
||||
|
||||
// allocated blocks by device pointer
|
||||
std::unordered_map<void*, Block*> allocated_blocks;
|
||||
@ -183,23 +195,22 @@ struct THCCachingAllocator
|
||||
process_events();
|
||||
|
||||
size = round_size(size);
|
||||
bool small = size <= kSmallAlloc;
|
||||
|
||||
DeviceStats &stats = get_stats_for_device(device);
|
||||
|
||||
Block search_key(device, stream, size);
|
||||
auto& free_blocks = small ? small_blocks : large_blocks;
|
||||
auto& pool = get_pool(size);
|
||||
|
||||
Block* block = NULL;
|
||||
Block* remaining = NULL;
|
||||
Block* block = nullptr;
|
||||
Block* remaining = nullptr;
|
||||
|
||||
auto it = free_blocks.lower_bound(&search_key);
|
||||
if (it != free_blocks.end() && (*it)->device == device && (*it)->stream == stream) {
|
||||
auto it = pool.lower_bound(&search_key);
|
||||
if (it != pool.end() && (*it)->device == device && (*it)->stream == stream) {
|
||||
block = *it;
|
||||
free_blocks.erase(it);
|
||||
pool.erase(it);
|
||||
} else {
|
||||
void* ptr;
|
||||
size_t alloc_size = small ? kSmallAlloc : size;
|
||||
size_t alloc_size = get_allocation_size(size);
|
||||
cudaError_t err = cuda_malloc_retry(device, &ptr, alloc_size);
|
||||
if (err != cudaSuccess) {
|
||||
if (err == cudaErrorMemoryAllocation) {
|
||||
@ -239,13 +250,14 @@ struct THCCachingAllocator
|
||||
}
|
||||
}
|
||||
stats.increaseCached(alloc_size);
|
||||
block = new Block(device, stream, alloc_size, (char*)ptr);
|
||||
block = new Block(device, stream, alloc_size, &pool, ptr);
|
||||
}
|
||||
|
||||
if (block->size - size >= (small ? kRoundSmall : kSmallAlloc + 1)) {
|
||||
AT_ASSERT(block);
|
||||
if (should_split(block, size)) {
|
||||
remaining = block;
|
||||
|
||||
block = new Block(device, stream, size, block->ptr);
|
||||
block = new Block(device, stream, size, &pool, block->ptr);
|
||||
block->prev = remaining->prev;
|
||||
if (block->prev) {
|
||||
block->prev->next = block;
|
||||
@ -253,15 +265,15 @@ struct THCCachingAllocator
|
||||
block->next = remaining;
|
||||
|
||||
remaining->prev = block;
|
||||
remaining->ptr += size;
|
||||
remaining->ptr = static_cast<char*>(remaining->ptr) + size;
|
||||
remaining->size -= size;
|
||||
free_blocks.insert(remaining);
|
||||
pool.insert(remaining);
|
||||
}
|
||||
|
||||
block->allocated = true;
|
||||
allocated_blocks[block->ptr] = block;
|
||||
|
||||
*devPtr = (void*)block->ptr;
|
||||
*devPtr = block->ptr;
|
||||
|
||||
stats.increaseAllocated(block->size);
|
||||
}
|
||||
@ -320,8 +332,8 @@ struct THCCachingAllocator
|
||||
return basePtr;
|
||||
}
|
||||
|
||||
// Accumulates sizes of all memory blocks for given device in given free list
|
||||
void cacheInfoAux(FreeBlocks& blocks, int dev_id, size_t* total, size_t* largest)
|
||||
// Accumulates sizes of all memory blocks for given device in given pool
|
||||
void cacheInfoAux(BlockPool& blocks, int dev_id, size_t* total, size_t* largest)
|
||||
{
|
||||
Block search_key(dev_id, 0, 0);
|
||||
auto it = blocks.lower_bound(&search_key);
|
||||
@ -356,19 +368,18 @@ struct THCCachingAllocator
|
||||
block->stream_uses.insert(stream);
|
||||
}
|
||||
|
||||
/** moves a block into the free block list */
|
||||
/** moves a block into a pool of cached free blocks */
|
||||
void free_block(Block* block)
|
||||
{
|
||||
AT_ASSERT(!block->allocated && block->event_count == 0);
|
||||
bool small = block->size <= kSmallAlloc;
|
||||
auto& free_blocks = small ? small_blocks : large_blocks;
|
||||
try_merge_blocks(block, block->prev, free_blocks);
|
||||
try_merge_blocks(block, block->next, free_blocks);
|
||||
free_blocks.insert(block);
|
||||
auto& pool = *block->pool;
|
||||
try_merge_blocks(block, block->prev, pool);
|
||||
try_merge_blocks(block, block->next, pool);
|
||||
pool.insert(block);
|
||||
}
|
||||
|
||||
/** combine previously split blocks */
|
||||
void try_merge_blocks(Block* dst, Block* src, FreeBlocks& free_blocks)
|
||||
void try_merge_blocks(Block* dst, Block* src, BlockPool& pool)
|
||||
{
|
||||
if (!src || src->allocated || src->event_count > 0) {
|
||||
return;
|
||||
@ -386,20 +397,45 @@ struct THCCachingAllocator
|
||||
}
|
||||
}
|
||||
dst->size += src->size;
|
||||
free_blocks.erase(src);
|
||||
pool.erase(src);
|
||||
delete src;
|
||||
}
|
||||
|
||||
size_t round_size(size_t size)
|
||||
{
|
||||
if (size < kRoundSmall) {
|
||||
size = kRoundSmall;
|
||||
} else if (size < kSmallAlloc) {
|
||||
size += kRoundSmall - 1 - (size - 1) % kRoundSmall;
|
||||
BlockPool& get_pool(size_t size) {
|
||||
if (size <= kSmallSize) {
|
||||
return small_blocks;
|
||||
} else {
|
||||
size += kRoundLarge - 1 - (size - 1) % kRoundLarge;
|
||||
return large_blocks;
|
||||
}
|
||||
}
|
||||
|
||||
bool should_split(Block* block, size_t size) {
|
||||
size_t remaining = block->size - size;
|
||||
if (block->pool == &small_blocks) {
|
||||
return remaining >= kMinBlockSize;
|
||||
} else if (block->pool == &large_blocks) {
|
||||
return remaining > kSmallSize;
|
||||
} else {
|
||||
AT_ERROR("should_split: invalid pool");
|
||||
}
|
||||
}
|
||||
|
||||
size_t round_size(size_t size) {
|
||||
if (size < kMinBlockSize) {
|
||||
return kMinBlockSize;
|
||||
} else {
|
||||
return kMinBlockSize * ((size + kMinBlockSize - 1) / kMinBlockSize);
|
||||
}
|
||||
}
|
||||
|
||||
size_t get_allocation_size(size_t size) {
|
||||
if (size <= kSmallSize) {
|
||||
return kSmallBuffer;
|
||||
} else if (size < kMinLargeAlloc) {
|
||||
return kLargeBuffer;
|
||||
} else {
|
||||
return kRoundLarge * ((size + kRoundLarge - 1) / kRoundLarge);
|
||||
}
|
||||
return size;
|
||||
}
|
||||
|
||||
cudaError_t cuda_malloc_retry(int device, void** devPtr, size_t size)
|
||||
@ -421,8 +457,8 @@ struct THCCachingAllocator
|
||||
void free_cached_blocks(int device)
|
||||
{
|
||||
// Free all non-split cached blocks on device
|
||||
Block lower_bound(device, NULL, 0);
|
||||
Block upper_bound(device + 1, NULL, 0);
|
||||
Block lower_bound(device, nullptr, 0);
|
||||
Block upper_bound(device + 1, nullptr, 0);
|
||||
|
||||
free_blocks(
|
||||
large_blocks,
|
||||
@ -434,7 +470,7 @@ struct THCCachingAllocator
|
||||
small_blocks.lower_bound(&upper_bound));
|
||||
}
|
||||
|
||||
void free_blocks(FreeBlocks& blocks, FreeBlocks::iterator it, FreeBlocks::iterator end)
|
||||
void free_blocks(BlockPool& blocks, BlockPool::iterator it, BlockPool::iterator end)
|
||||
{
|
||||
// Frees all non-split blocks between `it` and `end`
|
||||
std::lock_guard<std::mutex> lock(cuda_free_mutex);
|
||||
@ -456,7 +492,7 @@ struct THCCachingAllocator
|
||||
Block* find_allocated_block(void *ptr) {
|
||||
auto it = allocated_blocks.find(ptr);
|
||||
if (it == allocated_blocks.end()) {
|
||||
return NULL;
|
||||
return nullptr;
|
||||
}
|
||||
return it->second;
|
||||
}
|
||||
|
Reference in New Issue
Block a user