mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[PyTorch] Rework stat collection in CUDACachingAllocator (#71669)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/71669 This was relatively inefficient. Rather than looping for each type of stat we want to update, we now do one loop covering all the stats. ghstack-source-id: 148013645 Reviewed By: ngimel Differential Revision: D33725458 fbshipit-source-id: 39ef5d65a73d4ef67f259de8c02c7df29487d990 (cherry picked from commit 7ca46689b72ba7611517447a292445571bd02dd7)
This commit is contained in:
committed by
PyTorch MergeBot
parent
ca2ff12ea3
commit
4aade95029
@ -107,12 +107,12 @@ constexpr size_t kMinLargeAlloc =
|
||||
10485760; // allocations between 1 and 10 MiB may use kLargeBuffer
|
||||
constexpr size_t kRoundLarge = 2097152; // round up large allocations to 2 MiB
|
||||
|
||||
typedef std::bitset<static_cast<size_t>(StatType::NUM_TYPES)> StatTypes;
|
||||
using StatTypes = std::array<bool, static_cast<size_t>(StatType::NUM_TYPES)>;
|
||||
|
||||
void update_stat(Stat& stat, int64_t amount) {
|
||||
stat.current += amount;
|
||||
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
|
||||
stat.current >= 0,
|
||||
"Negative tracked stat in CUDA allocator (likely logic error).");
|
||||
|
||||
@ -134,15 +134,23 @@ void reset_peak_stat(Stat& stat) {
|
||||
stat.peak = stat.current;
|
||||
}
|
||||
|
||||
template <typename Func>
|
||||
void for_each_selected_stat_type(const StatTypes& stat_types, Func f) {
|
||||
for (const auto stat_type : c10::irange(stat_types.size())) {
|
||||
if (stat_types[stat_type]) {
|
||||
f(stat_type);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void update_stat_array(
|
||||
StatArray& stat_array,
|
||||
int64_t amount,
|
||||
const StatTypes& stat_types) {
|
||||
for (const auto stat_type : c10::irange(stat_types.size())) {
|
||||
if (stat_types[stat_type]) {
|
||||
update_stat(stat_array[stat_type], amount);
|
||||
}
|
||||
}
|
||||
for_each_selected_stat_type(
|
||||
stat_types, [&stat_array, amount](size_t stat_type) {
|
||||
update_stat(stat_array[stat_type], amount);
|
||||
});
|
||||
}
|
||||
|
||||
struct Block;
|
||||
@ -264,7 +272,7 @@ struct AllocParams {
|
||||
BlockPool* pool;
|
||||
size_t alloc_size;
|
||||
Block* block;
|
||||
StatTypes stat_types;
|
||||
StatTypes stat_types = {false};
|
||||
cudaError_t err;
|
||||
};
|
||||
|
||||
@ -563,25 +571,29 @@ class DeviceCachingAllocator {
|
||||
} else {
|
||||
// A new split inactive block is being created from a previously unsplit
|
||||
// block, size remaining->size bytes.
|
||||
update_stat_array(
|
||||
stats.inactive_split_bytes, remaining->size, params.stat_types);
|
||||
update_stat_array(stats.inactive_split, 1, params.stat_types);
|
||||
for_each_selected_stat_type(params.stat_types, [&](size_t stat_type) {
|
||||
update_stat(stats.inactive_split_bytes[stat_type], remaining->size);
|
||||
update_stat(stats.inactive_split[stat_type], 1);
|
||||
});
|
||||
}
|
||||
} else if (already_split) {
|
||||
// An already-split block is becoming active
|
||||
update_stat_array(
|
||||
stats.inactive_split_bytes, -block->size, params.stat_types);
|
||||
update_stat_array(stats.inactive_split, -1, params.stat_types);
|
||||
for_each_selected_stat_type(params.stat_types, [&](size_t stat_type) {
|
||||
update_stat(stats.inactive_split_bytes[stat_type], -block->size);
|
||||
update_stat(stats.inactive_split[stat_type], -1);
|
||||
});
|
||||
}
|
||||
|
||||
block->allocated = true;
|
||||
bool inserted = active_blocks.insert(block).second;
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(inserted);
|
||||
|
||||
update_stat_array(stats.allocation, 1, params.stat_types);
|
||||
update_stat_array(stats.allocated_bytes, block->size, params.stat_types);
|
||||
update_stat_array(stats.active, 1, params.stat_types);
|
||||
update_stat_array(stats.active_bytes, block->size, params.stat_types);
|
||||
for_each_selected_stat_type(params.stat_types, [&](size_t stat_type) {
|
||||
update_stat(stats.allocation[stat_type], 1);
|
||||
update_stat(stats.allocated_bytes[stat_type], block->size);
|
||||
update_stat(stats.active[stat_type], 1);
|
||||
update_stat(stats.active_bytes[stat_type], block->size);
|
||||
});
|
||||
if (block->size >= CachingAllocatorConfig::max_split_size())
|
||||
update_stat(stats.oversize_allocations, 1);
|
||||
|
||||
@ -605,12 +617,14 @@ class DeviceCachingAllocator {
|
||||
auto orig_block_ptr = block->ptr;
|
||||
auto orig_block_size = block->size;
|
||||
|
||||
StatTypes stat_types;
|
||||
StatTypes stat_types = {false};
|
||||
stat_types[static_cast<size_t>(StatType::AGGREGATE)] = true;
|
||||
stat_types[static_cast<size_t>(get_stat_type_for_pool(*(block->pool)))] =
|
||||
true;
|
||||
update_stat_array(stats.allocation, -1, {stat_types});
|
||||
update_stat_array(stats.allocated_bytes, -block->size, {stat_types});
|
||||
for_each_selected_stat_type(stat_types, [&](size_t stat_type) {
|
||||
update_stat(stats.allocation[stat_type], -1);
|
||||
update_stat(stats.allocated_bytes[stat_type], -block->size);
|
||||
});
|
||||
if (block->size >= CachingAllocatorConfig::max_split_size())
|
||||
update_stat(stats.oversize_allocations, -1);
|
||||
|
||||
@ -916,15 +930,18 @@ class DeviceCachingAllocator {
|
||||
net_change_inactive_split_size += block->size;
|
||||
}
|
||||
|
||||
StatTypes stat_types;
|
||||
StatTypes stat_types = {false};
|
||||
stat_types[static_cast<size_t>(StatType::AGGREGATE)] = true;
|
||||
stat_types[static_cast<size_t>(get_stat_type_for_pool(pool))] = true;
|
||||
update_stat_array(
|
||||
stats.inactive_split, net_change_inactive_split_blocks, stat_types);
|
||||
update_stat_array(
|
||||
stats.inactive_split_bytes, net_change_inactive_split_size, stat_types);
|
||||
update_stat_array(stats.active, -1, stat_types);
|
||||
update_stat_array(stats.active_bytes, -original_block_size, stat_types);
|
||||
for_each_selected_stat_type(stat_types, [&](size_t stat_type) {
|
||||
update_stat(
|
||||
stats.inactive_split[stat_type], net_change_inactive_split_blocks);
|
||||
update_stat(
|
||||
stats.inactive_split_bytes[stat_type],
|
||||
net_change_inactive_split_size);
|
||||
update_stat(stats.active[stat_type], -1);
|
||||
update_stat(stats.active_bytes[stat_type], -original_block_size);
|
||||
});
|
||||
}
|
||||
|
||||
/** combine previously split blocks. returns the size of the subsumed block,
|
||||
@ -1089,8 +1106,10 @@ class DeviceCachingAllocator {
|
||||
|
||||
total_allocated_memory += size;
|
||||
p.block = new Block(p.device(), p.stream(), size, p.pool, (char*)ptr);
|
||||
update_stat_array(stats.segment, 1, p.stat_types);
|
||||
update_stat_array(stats.reserved_bytes, size, p.stat_types);
|
||||
for_each_selected_stat_type(p.stat_types, [&](size_t stat_type) {
|
||||
update_stat(stats.segment[stat_type], 1);
|
||||
update_stat(stats.reserved_bytes[stat_type], size);
|
||||
});
|
||||
if (size >= CachingAllocatorConfig::max_split_size())
|
||||
update_stat(stats.oversize_segments, 1);
|
||||
|
||||
@ -1179,11 +1198,13 @@ class DeviceCachingAllocator {
|
||||
pool->owner_PrivatePool->cudaMalloc_count--;
|
||||
}
|
||||
|
||||
StatTypes stat_types;
|
||||
StatTypes stat_types = {false};
|
||||
stat_types[static_cast<size_t>(StatType::AGGREGATE)] = true;
|
||||
stat_types[static_cast<size_t>(get_stat_type_for_pool(*pool))] = true;
|
||||
update_stat_array(stats.segment, -1, stat_types);
|
||||
update_stat_array(stats.reserved_bytes, -block->size, stat_types);
|
||||
for_each_selected_stat_type(stat_types, [&](size_t stat_type) {
|
||||
update_stat(stats.segment[stat_type], -1);
|
||||
update_stat(stats.reserved_bytes[stat_type], -block->size);
|
||||
});
|
||||
if (block->size >= CachingAllocatorConfig::max_split_size())
|
||||
update_stat(stats.oversize_segments, -1);
|
||||
|
||||
|
Reference in New Issue
Block a user