Revert D23752058: [pytorch][PR] Don't split oversize cached blocks

Test Plan: revert-hammer

Differential Revision:
D23752058 (67dcd62310)

Original commit changeset: ccb7c13e3cf8

fbshipit-source-id: 12ae9702135ea510e9714ed97fb75ca3b9f97c27
This commit is contained in:
Natalia Gimelshein
2021-04-14 09:22:57 -07:00
committed by Facebook GitHub Bot
parent e7e164f9e6
commit f94c95a2dd
6 changed files with 35 additions and 234 deletions

View File

@ -18,7 +18,6 @@
#include <unordered_map> #include <unordered_map>
#include <unordered_set> #include <unordered_set>
#include <vector> #include <vector>
#include <regex>
namespace c10 { namespace c10 {
@ -35,18 +34,13 @@ namespace CUDACachingAllocator {
// - The allocator attempts to find the smallest cached block that will fit the // - The allocator attempts to find the smallest cached block that will fit the
// requested size. If the block is larger than the requested size, it may be // requested size. If the block is larger than the requested size, it may be
// split. If no block is found, the allocator will delegate to cudaMalloc. // split. If no block is found, the allocator will delegate to cudaMalloc.
// - If the cudaMalloc fails, the allocator will attempt to free one cached // - If the cudaMalloc fails, the allocator will free all cached blocks that
// block of sufficient size that is not split and retry the allocation. // are not split and retry the allocation.
// If this also fails, the allocator will attempt to free all cached blocks
// that are not split and retry the allocation.
// - Large (>1MB) and small allocations are stored in separate pools. // - Large (>1MB) and small allocations are stored in separate pools.
// Small requests are packed into 2MB buffers. Large requests will use the // Small requests are packed into 2MB buffers. Large requests will use the
// smallest available free block or allocate a new block using cudaMalloc. // smallest available free block or allocate a new block using cudaMalloc.
// - To reduce fragmentation, requests between 1MB and 10MB will allocate and // To reduce fragmentation, requests between 1MB and 10MB will allocate and
// split a 20MB block, if no free block of sufficient size is available. // split a 20MB block, if no free block of sufficient size is available.
// - To further reduce fragmentation, blocks >= 200MB are not allowed to be
// split. These oversize cached blocks will still satisfy requests within
// 20MB of the oversize cached block size.
// //
// With this allocator, allocations and frees should logically be considered // With this allocator, allocations and frees should logically be considered
// "usages" of the memory segment associated with streams, just like kernel // "usages" of the memory segment associated with streams, just like kernel
@ -211,9 +205,9 @@ struct AllocParams {
block(nullptr), block(nullptr),
err(cudaSuccess) {} err(cudaSuccess) {}
int device() const { return search_key.device; } int device() { return search_key.device; }
cudaStream_t stream() const { return search_key.stream; } cudaStream_t stream() { return search_key.stream; }
size_t size() const { return search_key.size; } size_t size() { return search_key.size; }
Block search_key; Block search_key;
BlockPool* pool; BlockPool* pool;
@ -269,63 +263,6 @@ cudaError_t cudaMallocMaybeCapturing(void** p, size_t size) {
} // namespace } // namespace
class CachingAllocatorConfig {
public:
static size_t max_split_size() { return instance().m_max_split_size; }
private:
static std::once_flag s_flag;
static CachingAllocatorConfig* s_instance;
static CachingAllocatorConfig &instance() {
std::call_once(s_flag, &CachingAllocatorConfig::init);
return *s_instance;
}
static void init() {
s_instance = new CachingAllocatorConfig();
s_instance->parseArgs();
}
CachingAllocatorConfig()
: m_max_split_size(std::numeric_limits<size_t>::max())
{ }
size_t m_max_split_size;
void parseArgs() {
const char *val = getenv("PYTORCH_CUDA_ALLOC_CONF");
if (val != NULL) {
const std::string config(val);
std::regex exp("[\\s,]+");
std::sregex_token_iterator it(config.begin(), config.end(), exp, -1);
std::sregex_token_iterator end;
std::vector<std::string> options(it, end);
for (auto option : options) {
std::regex exp2("[:]+");
std::sregex_token_iterator it2(option.begin(), option.end(), exp2, -1);
std::sregex_token_iterator end2;
std::vector<std::string> kv(it2, end2);
if (kv.size() >= 2) {
/* Maximum split size in MB. Limited to large size blocks */
if (kv[0].compare("max_split_size_mb") == 0) {
size_t val2 = stoi(kv[1]);
TORCH_CHECK(val2 > kLargeBuffer/(1024*1024), "CachingAllocator option max_split_size_mb too small, must be >= ",
kLargeBuffer/(1024*1024), "");
val2 = std::max(val2, kLargeBuffer/(1024*1024));
val2 = std::min(val2, (std::numeric_limits<size_t>::max() / (1024*1024)));
m_max_split_size = val2 * 1024 * 1024;
}
else {
TORCH_CHECK(false, "Unrecognized CachingAllocator option: ", kv[0]);
}
}
}
}
}
};
CachingAllocatorConfig *CachingAllocatorConfig::s_instance;
std::once_flag CachingAllocatorConfig::s_flag;
class DeviceCachingAllocator { class DeviceCachingAllocator {
private: private:
@ -377,10 +314,7 @@ class DeviceCachingAllocator {
DeviceCachingAllocator() : DeviceCachingAllocator() :
large_blocks(BlockComparator, /*is_small=*/false), large_blocks(BlockComparator, /*is_small=*/false),
small_blocks(BlockComparator, /*is_small=*/true) small_blocks(BlockComparator, /*is_small=*/true) {}
{
stats.max_split_size = CachingAllocatorConfig::max_split_size();
}
// All public methods (except the above) acquire the allocator mutex. // All public methods (except the above) acquire the allocator mutex.
// Thus, do not call a public method from another public method. // Thus, do not call a public method from another public method.
@ -406,10 +340,8 @@ class DeviceCachingAllocator {
|| (trigger_free_memory_callbacks(params) && get_free_block(params)) || (trigger_free_memory_callbacks(params) && get_free_block(params))
// Attempt allocate // Attempt allocate
|| alloc_block(params, false) || alloc_block(params, false)
// Free enough available cached blocks to satisfy alloc and retry alloc.
|| (release_available_cached_blocks(params) && alloc_block(params, false))
// Free all non-split cached blocks and retry alloc. // Free all non-split cached blocks and retry alloc.
|| (release_cached_blocks() && alloc_block(params, true)); || (free_cached_blocks() && alloc_block(params, true));
if (!block_found) { if (!block_found) {
// For any error code other than cudaErrorMemoryAllocation, // For any error code other than cudaErrorMemoryAllocation,
@ -454,10 +386,7 @@ class DeviceCachingAllocator {
format_size(device_free), " free; ", format_size(device_free), " free; ",
allowed_info, allowed_info,
format_size(stats.reserved_bytes[static_cast<size_t>(StatType::AGGREGATE)].current), format_size(stats.reserved_bytes[static_cast<size_t>(StatType::AGGREGATE)].current),
" reserved in total by PyTorch)", " reserved in total by PyTorch)");
" If reserved memory is >> allocated memory try setting max_split_size_mb to avoid"
" fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF",
"");
} }
TORCH_INTERNAL_ASSERT(params.err == cudaSuccess && TORCH_INTERNAL_ASSERT(params.err == cudaSuccess &&
@ -509,8 +438,6 @@ class DeviceCachingAllocator {
update_stat_array(stats.allocated_bytes, block->size, params.stat_types); update_stat_array(stats.allocated_bytes, block->size, params.stat_types);
update_stat_array(stats.active, 1, params.stat_types); update_stat_array(stats.active, 1, params.stat_types);
update_stat_array(stats.active_bytes, block->size, params.stat_types); update_stat_array(stats.active_bytes, block->size, params.stat_types);
if (block->size >= CachingAllocatorConfig::max_split_size())
update_stat(stats.oversize_allocations, 1);
return block; return block;
} }
@ -529,8 +456,6 @@ class DeviceCachingAllocator {
stat_types[static_cast<size_t>(get_stat_type_for_pool(*(block->pool)))] = true; stat_types[static_cast<size_t>(get_stat_type_for_pool(*(block->pool)))] = true;
update_stat_array(stats.allocation, -1, {stat_types}); update_stat_array(stats.allocation, -1, {stat_types});
update_stat_array(stats.allocated_bytes, -block->size, {stat_types}); update_stat_array(stats.allocated_bytes, -block->size, {stat_types});
if (block->size >= CachingAllocatorConfig::max_split_size())
update_stat(stats.oversize_allocations, -1);
if (!block->stream_uses.empty()) { if (!block->stream_uses.empty()) {
insert_events(block); insert_events(block);
@ -578,7 +503,7 @@ class DeviceCachingAllocator {
/** returns cached blocks to the system allocator **/ /** returns cached blocks to the system allocator **/
void emptyCache() { void emptyCache() {
std::lock_guard<std::recursive_mutex> lock(mutex); std::lock_guard<std::recursive_mutex> lock(mutex);
release_cached_blocks(); free_cached_blocks();
} }
/** Retrieves info (total size + largest block) of the memory cache **/ /** Retrieves info (total size + largest block) of the memory cache **/
@ -621,8 +546,6 @@ class DeviceCachingAllocator {
stats.num_alloc_retries = 0; stats.num_alloc_retries = 0;
stats.num_ooms = 0; stats.num_ooms = 0;
reset_accumulated_stat(stats.oversize_allocations);
reset_accumulated_stat(stats.oversize_segments);
} }
/** Resets the historical peak stats for the device **/ /** Resets the historical peak stats for the device **/
@ -639,8 +562,6 @@ class DeviceCachingAllocator {
reset_peak_stat(stats.active_bytes[statType]); reset_peak_stat(stats.active_bytes[statType]);
reset_peak_stat(stats.inactive_split_bytes[statType]); reset_peak_stat(stats.inactive_split_bytes[statType]);
} }
reset_peak_stat(stats.oversize_allocations);
reset_peak_stat(stats.oversize_segments);
} }
/** Dump a complete snapshot of the memory held by the allocator. Potentially VERY expensive. **/ /** Dump a complete snapshot of the memory held by the allocator. Potentially VERY expensive. **/
@ -871,11 +792,9 @@ class DeviceCachingAllocator {
bool should_split(const Block* block, size_t size) { bool should_split(const Block* block, size_t size) {
size_t remaining = block->size - size; size_t remaining = block->size - size;
if (block->pool->is_small) { return (block->pool->is_small) ?
return remaining >= kMinBlockSize; (remaining >= kMinBlockSize) :
} else { (remaining > kSmallSize);
return (size < CachingAllocatorConfig::max_split_size()) && (remaining > kSmallSize);
}
} }
static size_t get_allocation_size(size_t size) { static size_t get_allocation_size(size_t size) {
@ -893,12 +812,6 @@ class DeviceCachingAllocator {
auto it = pool.blocks.lower_bound(&p.search_key); auto it = pool.blocks.lower_bound(&p.search_key);
if (it == pool.blocks.end() || (*it)->stream != p.stream()) if (it == pool.blocks.end() || (*it)->stream != p.stream())
return false; return false;
// Do not return an oversized block for a large request
if ((p.size() < CachingAllocatorConfig::max_split_size()) && ((*it)->size >= CachingAllocatorConfig::max_split_size()))
return false;
// Allow oversized block size to be rounded up but within a limit
if ((p.size() >= CachingAllocatorConfig::max_split_size()) && ((*it)->size >= p.size() + kLargeBuffer))
return false;
p.block = *it; p.block = *it;
pool.blocks.erase(it); pool.blocks.erase(it);
return true; return true;
@ -955,67 +868,27 @@ class DeviceCachingAllocator {
p.block = new Block(p.device(), p.stream(), size, p.pool, (char*)ptr); p.block = new Block(p.device(), p.stream(), size, p.pool, (char*)ptr);
update_stat_array(stats.segment, 1, p.stat_types); update_stat_array(stats.segment, 1, p.stat_types);
update_stat_array(stats.reserved_bytes, size, p.stat_types); update_stat_array(stats.reserved_bytes, size, p.stat_types);
if (size >= CachingAllocatorConfig::max_split_size())
update_stat(stats.oversize_segments, 1);
// p.block came from new, not cudaMalloc. It should not be nullptr here. // p.block came from new, not cudaMalloc. It should not be nullptr here.
TORCH_INTERNAL_ASSERT(p.block != nullptr && p.block->ptr != nullptr); TORCH_INTERNAL_ASSERT(p.block != nullptr && p.block->ptr != nullptr);
return true; return true;
} }
/** Free one or more oversize blocks to the system allocator. But only enough **/ bool free_cached_blocks()
/** to satisfy the target size **/
bool release_available_cached_blocks(const AllocParams& p)
{
if (CachingAllocatorConfig::max_split_size() == std::numeric_limits<size_t>::max())
return false;
BlockPool& pool = *p.pool;
Block key = p.search_key;
key.size = (key.size < CachingAllocatorConfig::max_split_size()) ? CachingAllocatorConfig::max_split_size() : key.size;
auto it = pool.blocks.lower_bound(&key);
if (it == pool.blocks.end() || (*it)->stream != p.stream()) {
// No single block is large enough; free multiple oversize blocks, starting with the largest
if (it == pool.blocks.begin())
return false;
size_t totalReleased = 0;
--it; // Back up one item. Now on the largest block for the correct stream
while ((totalReleased < key.size) && ((*it)->size >= CachingAllocatorConfig::max_split_size())
&& ((*it)->stream == p.stream())) {
auto cur = it;
totalReleased += (*it)->size;
if (it != pool.blocks.begin()) {
--it;
release_block(*cur);
}
else {
release_block(*cur);
break;
}
}
if (totalReleased < key.size)
return false;
}
else {
release_block(*it);
}
return true;
}
bool release_cached_blocks()
{ {
// First ensure that all blocks that can't currently be allocated due to // First ensure that all blocks that can't currently be allocated due to
// outstanding events are returned to the pool. // outstanding events are returned to the pool.
synchronize_and_free_events(); synchronize_and_free_events();
// Free all non-split cached blocks to system allocator // Free all non-split cached blocks
release_blocks(large_blocks); free_blocks(large_blocks);
release_blocks(small_blocks); free_blocks(small_blocks);
for (auto it = graph_pools_freeable.begin(); it != graph_pools_freeable.end(); ) { for (auto it = graph_pools_freeable.begin(); it != graph_pools_freeable.end(); ) {
// See notifyCaptureDestroy for the strategy here. // See notifyCaptureDestroy for the strategy here.
TORCH_INTERNAL_ASSERT(it->second->use_count == 0); TORCH_INTERNAL_ASSERT(it->second->use_count == 0);
release_blocks(it->second->small_blocks); free_blocks(it->second->small_blocks);
release_blocks(it->second->large_blocks); free_blocks(it->second->large_blocks);
if (it->second->cudaMalloc_count == 0) { if (it->second->cudaMalloc_count == 0) {
auto erase_count = graph_pools.erase(it->first); auto erase_count = graph_pools.erase(it->first);
TORCH_INTERNAL_ASSERT(erase_count == 1); TORCH_INTERNAL_ASSERT(erase_count == 1);
@ -1028,39 +901,35 @@ class DeviceCachingAllocator {
return true; return true;
} }
void release_block(Block *block) void free_blocks(BlockPool& pool)
{
C10_CUDA_CHECK(cudaFree((void*)block->ptr));
total_allocated_memory -= block->size;
StatTypes stat_types;
stat_types[static_cast<size_t>(StatType::AGGREGATE)] = true;
stat_types[static_cast<size_t>(get_stat_type_for_pool(*(block->pool)))] = true;
update_stat_array(stats.segment, -1, stat_types);
update_stat_array(stats.reserved_bytes, -block->size, stat_types);
if (block->size >= CachingAllocatorConfig::max_split_size())
update_stat(stats.oversize_segments, -1);
block->pool->blocks.erase(block);
delete block;
}
void release_blocks(BlockPool& pool)
{ {
// Frees all non-split blocks // Frees all non-split blocks
auto it = pool.blocks.begin(); auto it = pool.blocks.begin();
while (it != pool.blocks.end()) { while (it != pool.blocks.end()) {
Block* block = *it; Block* block = *it;
if (!block->prev && !block->next) { if (!block->prev && !block->next) {
release_block(block); C10_CUDA_CHECK(cudaFree((void*)block->ptr));
total_allocated_memory -= block->size;
if (pool.owner_PrivatePool) { if (pool.owner_PrivatePool) {
// The cudaFreed block belonged to a CUDA graph's PrivatePool. // The cudaFreed block belonged to a CUDA graph's PrivatePool.
TORCH_INTERNAL_ASSERT(pool.owner_PrivatePool->cudaMalloc_count > 0); TORCH_INTERNAL_ASSERT(pool.owner_PrivatePool->cudaMalloc_count > 0);
pool.owner_PrivatePool->cudaMalloc_count--; pool.owner_PrivatePool->cudaMalloc_count--;
} }
StatTypes stat_types;
stat_types[static_cast<size_t>(StatType::AGGREGATE)] = true;
stat_types[static_cast<size_t>(get_stat_type_for_pool(pool))] = true;
update_stat_array(stats.segment, -1, stat_types);
update_stat_array(stats.reserved_bytes, -block->size, stat_types);
auto cur = it;
++it;
pool.blocks.erase(cur);
delete block;
} else {
++it;
} }
++it;
} }
} }

View File

@ -85,15 +85,6 @@ struct DeviceStats {
// COUNT: total number of OOMs (i.e. failed calls to CUDA after cache flush) // COUNT: total number of OOMs (i.e. failed calls to CUDA after cache flush)
int64_t num_ooms = 0; int64_t num_ooms = 0;
// COUNT: total number of oversize blocks allocated from pool
Stat oversize_allocations;
// COUNT: total number of oversize blocks requiring malloc
Stat oversize_segments;
// SIZE: maximum block size that is allowed to be split.
int64_t max_split_size = 0;
}; };
// Struct containing info of an allocation block (i.e. a fractional part of a cudaMalloc).. // Struct containing info of an allocation block (i.e. a fractional part of a cudaMalloc)..

View File

@ -260,21 +260,6 @@ Use of a caching allocator can interfere with memory checking tools such as
``cuda-memcheck``. To debug memory errors using ``cuda-memcheck``, set ``cuda-memcheck``. To debug memory errors using ``cuda-memcheck``, set
``PYTORCH_NO_CUDA_MEMORY_CACHING=1`` in your environment to disable caching. ``PYTORCH_NO_CUDA_MEMORY_CACHING=1`` in your environment to disable caching.
The behavior of caching allocator can be controlled via environment variable
``PYTORCH_CUDA_ALLOC_CONF``.
The format is ``PYTORCH_CUDA_ALLOC_CONF=<option>:<value>,<option2><value2>...``
Available options:
* ``max_split_size_mb`` prevents the allocator from splitting blocks larger
than this size (in MB). This can help prevent fragmentation and may allow
some borderline workloads to complete without running out of memory.
Performance cost can range from 'zero' to 'substatial' depending on
allocation patterns. Default value is unlimited, i.e. all blocks can be
split. The :meth:`~torch.cuda.memory_stats` and
:meth:`~torch.cuda.memory_summary` methods are useful for tuning. This
option should be used as a last resort for a workload that is aborting
due to 'out of memory' and showing a large amount of inactive split blocks.
.. _cufft-plan-cache: .. _cufft-plan-cache:
cuFFT plan cache cuFFT plan cache

View File

@ -356,7 +356,6 @@ PyObject * THCPModule_memoryStats(PyObject *_unused, PyObject *arg)
py::dict result; py::dict result;
result["num_alloc_retries"] = stats.num_alloc_retries; result["num_alloc_retries"] = stats.num_alloc_retries;
result["num_ooms"] = stats.num_ooms; result["num_ooms"] = stats.num_ooms;
result["max_split_size"] = stats.max_split_size;
result["allocation"] = statArrayToDict(stats.allocation); result["allocation"] = statArrayToDict(stats.allocation);
result["segment"] = statArrayToDict(stats.segment); result["segment"] = statArrayToDict(stats.segment);
result["active"] = statArrayToDict(stats.active); result["active"] = statArrayToDict(stats.active);
@ -365,8 +364,6 @@ PyObject * THCPModule_memoryStats(PyObject *_unused, PyObject *arg)
result["reserved_bytes"] = statArrayToDict(stats.reserved_bytes); result["reserved_bytes"] = statArrayToDict(stats.reserved_bytes);
result["active_bytes"] = statArrayToDict(stats.active_bytes); result["active_bytes"] = statArrayToDict(stats.active_bytes);
result["inactive_split_bytes"] = statArrayToDict(stats.inactive_split_bytes); result["inactive_split_bytes"] = statArrayToDict(stats.inactive_split_bytes);
result["oversize_allocations"] = statToDict(stats.oversize_allocations);
result["oversize_segments"] = statToDict(stats.oversize_segments);
return result.release().ptr(); return result.release().ptr();
END_HANDLE_TH_ERRORS END_HANDLE_TH_ERRORS

View File

@ -164,17 +164,6 @@ def memory_stats(device: Union[Device, int] = None) -> Dict[str, Any]:
result in a cache flush and retry. result in a cache flush and retry.
- ``"num_ooms"``: number of out-of-memory errors thrown. - ``"num_ooms"``: number of out-of-memory errors thrown.
The caching allocator can be configured via ENV to not split blocks larger than a
defined size (see Memory Management section of the Cuda Semantics documentation).
This helps avoid memory framentation but may have a performance
penalty. Additional outputs to assist with tuning and evaluating impact:
- ``"max_split_size"``: blocks above this size will not be split.
- ``"oversize_allocations.{current,peak,allocated,freed}"``:
number of over-size allocation requests received by the memory allocator.
- ``"oversize_segments.{current,peak,allocated,freed}"``:
number of over-size reserved segments from ``cudaMalloc()``.
Args: Args:
device (torch.device or int, optional): selected device. Returns device (torch.device or int, optional): selected device. Returns
statistics for the current device, given by :func:`~torch.cuda.current_device`, statistics for the current device, given by :func:`~torch.cuda.current_device`,
@ -501,29 +490,6 @@ def memory_summary(device: Union[Device, int] = None, abbreviated: bool = False)
formatter(freed, freed_prefval)), formatter(freed, freed_prefval)),
) )
metrics_to_display = [
("oversize_allocations", "Oversize allocations", _format_count),
("oversize_segments", "Oversize GPU segments", _format_count),
]
for metric_key, metric_name, formatter in metrics_to_display:
lines.append("-" * 75)
prefix = metric_key + "."
current = stats[prefix + "current"]
peak = stats[prefix + "peak"]
allocated = stats[prefix + "allocated"]
freed = stats[prefix + "freed"]
lines.append(" {:<21} | {} | {} | {} | {} ".format(
metric_name,
formatter(current, current),
formatter(peak, peak),
formatter(allocated, allocated),
formatter(freed, freed)),
)
lines.append("=" * 75) lines.append("=" * 75)
fmt_dict = {"_": "", "device": device} fmt_dict = {"_": "", "device": device}

View File

@ -34,7 +34,6 @@ SystemEnv = namedtuple('SystemEnv', [
'hip_compiled_version', 'hip_compiled_version',
'hip_runtime_version', 'hip_runtime_version',
'miopen_runtime_version', 'miopen_runtime_version',
'caching_allocator_config',
]) ])
@ -273,11 +272,6 @@ def get_pip_packages(run_lambda):
return 'pip3', out3 return 'pip3', out3
def get_cachingallocator_config():
ca_config = os.environ.get('PYTORCH_CUDA_ALLOC_CONF', '')
return ca_config
def get_env_info(): def get_env_info():
run_lambda = run run_lambda = run
pip_version, pip_list_output = get_pip_packages(run_lambda) pip_version, pip_list_output = get_pip_packages(run_lambda)
@ -319,7 +313,6 @@ def get_env_info():
gcc_version=get_gcc_version(run_lambda), gcc_version=get_gcc_version(run_lambda),
clang_version=get_clang_version(run_lambda), clang_version=get_clang_version(run_lambda),
cmake_version=get_cmake_version(run_lambda), cmake_version=get_cmake_version(run_lambda),
caching_allocator_config=get_cachingallocator_config(),
) )
env_info_fmt = """ env_info_fmt = """