mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Deprecate overlapped functions in CUDAAllocatorConfig (#165289)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165289 Approved by: https://github.com/albanD ghstack dependencies: #165288
This commit is contained in:
committed by
PyTorch MergeBot
parent
4888ed440e
commit
a1114beed2
@ -183,11 +183,6 @@ struct CUDACachingHostAllocatorImpl
|
||||
return true;
|
||||
}
|
||||
|
||||
bool pinned_use_background_threads() override {
|
||||
return c10::cuda::CUDACachingAllocator::CUDAAllocatorConfig::
|
||||
pinned_use_background_threads();
|
||||
}
|
||||
|
||||
EventPool::Event create_event_internal(DeviceIndex idx) {
|
||||
// Leak the event pool to avoid shutdown issue.
|
||||
static auto* event_pool = new EventPool();
|
||||
|
@ -3,6 +3,7 @@
|
||||
#include <c10/core/AllocatorConfig.h>
|
||||
#include <c10/cuda/CUDAException.h>
|
||||
#include <c10/cuda/CUDAMacros.h>
|
||||
#include <c10/util/Deprecated.h>
|
||||
#include <c10/util/Exception.h>
|
||||
#include <c10/util/env.h>
|
||||
|
||||
@ -17,9 +18,14 @@ enum class Expandable_Segments_Handle_Type : int {
|
||||
// Environment config parser
|
||||
class C10_CUDA_API CUDAAllocatorConfig {
|
||||
public:
|
||||
C10_DEPRECATED_MESSAGE(
|
||||
"c10::cuda::CUDACachingAllocator::CUDAAllocatorConfig::max_split_size() is deprecated. Please use c10::CachingAllocator::AcceleratorAllocatorConfig::max_split_size() instead.")
|
||||
static size_t max_split_size() {
|
||||
return c10::CachingAllocator::AcceleratorAllocatorConfig::max_split_size();
|
||||
}
|
||||
|
||||
C10_DEPRECATED_MESSAGE(
|
||||
"c10::cuda::CUDACachingAllocator::CUDAAllocatorConfig::garbage_collection_threshold() is deprecated. Please use c10::CachingAllocator::AcceleratorAllocatorConfig::garbage_collection_threshold() instead.")
|
||||
static double garbage_collection_threshold() {
|
||||
return c10::CachingAllocator::AcceleratorAllocatorConfig::
|
||||
garbage_collection_threshold();
|
||||
@ -64,6 +70,8 @@ class C10_CUDA_API CUDAAllocatorConfig {
|
||||
return instance().m_pinned_num_register_threads;
|
||||
}
|
||||
|
||||
C10_DEPRECATED_MESSAGE(
|
||||
"c10::cuda::CUDACachingAllocator::CUDAAllocatorConfig::pinned_use_background_threads() is deprecated. Please use c10::CachingAllocator::AcceleratorAllocatorConfig::pinned_use_background_threads() instead.")
|
||||
static bool pinned_use_background_threads() {
|
||||
return c10::CachingAllocator::AcceleratorAllocatorConfig::
|
||||
pinned_use_background_threads();
|
||||
@ -80,11 +88,15 @@ class C10_CUDA_API CUDAAllocatorConfig {
|
||||
return 128;
|
||||
}
|
||||
|
||||
C10_DEPRECATED_MESSAGE(
|
||||
"c10::cuda::CUDACachingAllocator::CUDAAllocatorConfig::roundup_power2_divisions() is deprecated. Please use c10::CachingAllocator::AcceleratorAllocatorConfig::roundup_power2_divisions() instead.")
|
||||
static size_t roundup_power2_divisions(size_t size) {
|
||||
return c10::CachingAllocator::AcceleratorAllocatorConfig::
|
||||
roundup_power2_divisions(size);
|
||||
}
|
||||
|
||||
C10_DEPRECATED_MESSAGE(
|
||||
"c10::cuda::CUDACachingAllocator::CUDAAllocatorConfig::roundup_power2_divisions() is deprecated. Please use c10::CachingAllocator::AcceleratorAllocatorConfig::roundup_power2_divisions() instead.")
|
||||
static std::vector<size_t> roundup_power2_divisions() {
|
||||
return c10::CachingAllocator::AcceleratorAllocatorConfig::
|
||||
roundup_power2_divisions();
|
||||
@ -95,6 +107,8 @@ class C10_CUDA_API CUDAAllocatorConfig {
|
||||
max_non_split_rounding_size();
|
||||
}
|
||||
|
||||
C10_DEPRECATED_MESSAGE(
|
||||
"c10::cuda::CUDACachingAllocator::CUDAAllocatorConfig::last_allocator_settings() is deprecated. Please use c10::CachingAllocator::AcceleratorAllocatorConfig::last_allocator_settings() instead.")
|
||||
static std::string last_allocator_settings() {
|
||||
return c10::CachingAllocator::getAllocatorSettings();
|
||||
}
|
||||
|
@ -1270,7 +1270,7 @@ class DeviceCachingAllocator {
|
||||
large_blocks(/*small=*/false),
|
||||
small_blocks(/*small=*/true) {
|
||||
stats.max_split_size =
|
||||
static_cast<int64_t>(CUDAAllocatorConfig::max_split_size());
|
||||
static_cast<int64_t>(AcceleratorAllocatorConfig::max_split_size());
|
||||
context_recorder_.store(nullptr);
|
||||
}
|
||||
|
||||
@ -1405,7 +1405,8 @@ class DeviceCachingAllocator {
|
||||
// Do garbage collection if the flag is set.
|
||||
if (C10_UNLIKELY(
|
||||
set_fraction &&
|
||||
CUDAAllocatorConfig::garbage_collection_threshold() > 0.0)) {
|
||||
AcceleratorAllocatorConfig::garbage_collection_threshold() >
|
||||
0.0)) {
|
||||
garbage_collect_cached_blocks(context);
|
||||
}
|
||||
// Attempt allocate
|
||||
@ -1657,7 +1658,7 @@ class DeviceCachingAllocator {
|
||||
stats.active_bytes[stat_type].increase(block->size);
|
||||
stats.requested_bytes[stat_type].increase(block->requested_size);
|
||||
});
|
||||
if (block->size >= CUDAAllocatorConfig::max_split_size())
|
||||
if (block->size >= AcceleratorAllocatorConfig::max_split_size())
|
||||
stats.oversize_allocations.increase(1);
|
||||
|
||||
auto allocated_bytes_gauge =
|
||||
@ -1926,7 +1927,7 @@ class DeviceCachingAllocator {
|
||||
block->pool->owner_MempoolId(),
|
||||
context ? context : block->context_when_allocated);
|
||||
|
||||
if (block->size >= CUDAAllocatorConfig::max_split_size())
|
||||
if (block->size >= AcceleratorAllocatorConfig::max_split_size())
|
||||
stats.oversize_allocations.decrease(1);
|
||||
|
||||
// If the block has been used on more than one stream, handle accordingly.
|
||||
@ -2499,7 +2500,8 @@ class DeviceCachingAllocator {
|
||||
if (size < kMinBlockSize) {
|
||||
return kMinBlockSize;
|
||||
} else {
|
||||
auto divisions = CUDAAllocatorConfig::roundup_power2_divisions(size);
|
||||
auto divisions =
|
||||
AcceleratorAllocatorConfig::roundup_power2_divisions(size);
|
||||
if (divisions > 1 && size > (kMinBlockSize * divisions)) {
|
||||
return roundup_power2_next_division(size, divisions);
|
||||
} else {
|
||||
@ -2993,7 +2995,7 @@ class DeviceCachingAllocator {
|
||||
if (block->pool->is_small || CUDAAllocatorConfig::expandable_segments()) {
|
||||
return remaining >= kMinBlockSize;
|
||||
} else {
|
||||
return (size < CUDAAllocatorConfig::max_split_size()) &&
|
||||
return (size < AcceleratorAllocatorConfig::max_split_size()) &&
|
||||
(remaining > kSmallSize);
|
||||
}
|
||||
}
|
||||
@ -3013,7 +3015,7 @@ class DeviceCachingAllocator {
|
||||
|
||||
if (C10_UNLIKELY(
|
||||
set_fraction &&
|
||||
CUDAAllocatorConfig::garbage_collection_threshold() > 0.0)) {
|
||||
AcceleratorAllocatorConfig::garbage_collection_threshold() > 0.0)) {
|
||||
// Track block reuse interval only when garbage collection is enabled.
|
||||
++pool.get_free_blocks_call_count;
|
||||
}
|
||||
@ -3055,13 +3057,13 @@ class DeviceCachingAllocator {
|
||||
}
|
||||
|
||||
// Do not return an oversized block for a large request
|
||||
if ((p.size() < CUDAAllocatorConfig::max_split_size()) &&
|
||||
((*it)->size >= CUDAAllocatorConfig::max_split_size()))
|
||||
if ((p.size() < AcceleratorAllocatorConfig::max_split_size()) &&
|
||||
((*it)->size >= AcceleratorAllocatorConfig::max_split_size()))
|
||||
return false;
|
||||
// Allow oversized block size to be rounded up but within a limit
|
||||
if ((p.size() >= CUDAAllocatorConfig::max_split_size()) &&
|
||||
if ((p.size() >= AcceleratorAllocatorConfig::max_split_size()) &&
|
||||
((*it)->size >=
|
||||
p.size() + CUDAAllocatorConfig::max_non_split_rounding_size()))
|
||||
p.size() + AcceleratorAllocatorConfig::max_non_split_rounding_size()))
|
||||
return false;
|
||||
p.block = *it;
|
||||
pool.blocks.erase(it);
|
||||
@ -3084,7 +3086,7 @@ class DeviceCachingAllocator {
|
||||
// therefore should be of less overheads.
|
||||
|
||||
size_t gc_threshold = static_cast<size_t>(
|
||||
CUDAAllocatorConfig::garbage_collection_threshold() *
|
||||
AcceleratorAllocatorConfig::garbage_collection_threshold() *
|
||||
static_cast<double>(allowed_memory_maximum));
|
||||
// No need to trigger GC yet
|
||||
if (total_allocated_memory <= gc_threshold) {
|
||||
@ -3232,7 +3234,7 @@ class DeviceCachingAllocator {
|
||||
stats.segment[stat_type].increase(1);
|
||||
stats.reserved_bytes[stat_type].increase(size);
|
||||
});
|
||||
if (size >= CUDAAllocatorConfig::max_split_size())
|
||||
if (size >= AcceleratorAllocatorConfig::max_split_size())
|
||||
stats.oversize_segments.increase(1);
|
||||
auto reserved_bytes_gauge =
|
||||
STATIC_GAUGE(pytorch.CUDACachingAllocator.reserved_bytes);
|
||||
@ -3261,7 +3263,7 @@ class DeviceCachingAllocator {
|
||||
bool release_available_cached_blocks(
|
||||
const AllocParams& p,
|
||||
const std::shared_ptr<GatheredContext>& context) {
|
||||
if (CUDAAllocatorConfig::max_split_size() ==
|
||||
if (AcceleratorAllocatorConfig::max_split_size() ==
|
||||
std::numeric_limits<size_t>::max())
|
||||
return false;
|
||||
BlockPool& pool = *p.pool;
|
||||
@ -3269,8 +3271,8 @@ class DeviceCachingAllocator {
|
||||
// because of std::unique_ptr, block cannot be trivially copied
|
||||
// Use constructor for search key.
|
||||
Block key(p.search_key.device, p.search_key.stream, p.search_key.size);
|
||||
key.size = (key.size < CUDAAllocatorConfig::max_split_size())
|
||||
? CUDAAllocatorConfig::max_split_size()
|
||||
key.size = (key.size < AcceleratorAllocatorConfig::max_split_size())
|
||||
? AcceleratorAllocatorConfig::max_split_size()
|
||||
: key.size;
|
||||
auto it = pool.blocks.lower_bound(&key);
|
||||
if (it == pool.blocks.end() || (*it)->stream != p.stream() ||
|
||||
@ -3283,7 +3285,7 @@ class DeviceCachingAllocator {
|
||||
--it; // Back up one item. Now on the largest block for the correct
|
||||
// stream
|
||||
while ((totalReleased < key.size) &&
|
||||
((*it)->size >= CUDAAllocatorConfig::max_split_size()) &&
|
||||
((*it)->size >= AcceleratorAllocatorConfig::max_split_size()) &&
|
||||
((*it)->stream == p.stream())) {
|
||||
auto cur = it;
|
||||
bool is_first = cur == pool.blocks.begin();
|
||||
@ -3408,7 +3410,7 @@ class DeviceCachingAllocator {
|
||||
stats.reserved_bytes[static_cast<int64_t>(StatType::AGGREGATE)]
|
||||
.current);
|
||||
|
||||
if (block->size >= CUDAAllocatorConfig::max_split_size())
|
||||
if (block->size >= AcceleratorAllocatorConfig::max_split_size())
|
||||
stats.oversize_segments.decrease(1);
|
||||
pool->blocks.erase(block);
|
||||
delete block;
|
||||
@ -4059,8 +4061,8 @@ class NativeCachingAllocator : public CUDAAllocator {
|
||||
|
||||
auto& md = result.config_metadata;
|
||||
md.garbage_collection_threshold =
|
||||
CUDAAllocatorConfig::garbage_collection_threshold();
|
||||
md.max_split_size = CUDAAllocatorConfig::max_split_size();
|
||||
AcceleratorAllocatorConfig::garbage_collection_threshold();
|
||||
md.max_split_size = AcceleratorAllocatorConfig::max_split_size();
|
||||
md.pinned_num_register_threads =
|
||||
CUDAAllocatorConfig::pinned_num_register_threads();
|
||||
md.expandable_segments = CUDAAllocatorConfig::expandable_segments();
|
||||
@ -4068,11 +4070,12 @@ class NativeCachingAllocator : public CUDAAllocator {
|
||||
CUDAAllocatorConfig::release_lock_on_cudamalloc();
|
||||
md.pinned_use_host_register =
|
||||
CUDAAllocatorConfig::pinned_use_cuda_host_register();
|
||||
md.last_allocator_settings = CUDAAllocatorConfig::last_allocator_settings();
|
||||
md.last_allocator_settings =
|
||||
AcceleratorAllocatorConfig::last_allocator_settings();
|
||||
md.graph_capture_record_stream_reuse =
|
||||
CUDAAllocatorConfig::graph_capture_record_stream_reuse();
|
||||
md.roundup_power2_divisions =
|
||||
CUDAAllocatorConfig::roundup_power2_divisions();
|
||||
AcceleratorAllocatorConfig::roundup_power2_divisions();
|
||||
|
||||
return result;
|
||||
}
|
||||
|
Reference in New Issue
Block a user