[PyTorch] Rework stat collection in CUDACachingAllocator (#71669)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/71669

This was relatively inefficient. Rather than looping for each type of stat we want to update, we now do one loop covering all the stats.
ghstack-source-id: 148013645

Reviewed By: ngimel

Differential Revision: D33725458

fbshipit-source-id: 39ef5d65a73d4ef67f259de8c02c7df29487d990
(cherry picked from commit 7ca46689b72ba7611517447a292445571bd02dd7)
This commit is contained in:
Scott Wolchok
2022-02-01 09:13:08 -08:00
committed by PyTorch MergeBot
parent ca2ff12ea3
commit 4aade95029

View File

@ -107,12 +107,12 @@ constexpr size_t kMinLargeAlloc =
10485760; // allocations between 1 and 10 MiB may use kLargeBuffer
constexpr size_t kRoundLarge = 2097152; // round up large allocations to 2 MiB
typedef std::bitset<static_cast<size_t>(StatType::NUM_TYPES)> StatTypes;
using StatTypes = std::array<bool, static_cast<size_t>(StatType::NUM_TYPES)>;
void update_stat(Stat& stat, int64_t amount) {
stat.current += amount;
TORCH_INTERNAL_ASSERT(
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
stat.current >= 0,
"Negative tracked stat in CUDA allocator (likely logic error).");
@ -134,15 +134,23 @@ void reset_peak_stat(Stat& stat) {
stat.peak = stat.current;
}
template <typename Func>
void for_each_selected_stat_type(const StatTypes& stat_types, Func f) {
for (const auto stat_type : c10::irange(stat_types.size())) {
if (stat_types[stat_type]) {
f(stat_type);
}
}
}
void update_stat_array(
StatArray& stat_array,
int64_t amount,
const StatTypes& stat_types) {
for (const auto stat_type : c10::irange(stat_types.size())) {
if (stat_types[stat_type]) {
update_stat(stat_array[stat_type], amount);
}
}
for_each_selected_stat_type(
stat_types, [&stat_array, amount](size_t stat_type) {
update_stat(stat_array[stat_type], amount);
});
}
struct Block;
@ -264,7 +272,7 @@ struct AllocParams {
BlockPool* pool;
size_t alloc_size;
Block* block;
StatTypes stat_types;
StatTypes stat_types = {false};
cudaError_t err;
};
@ -563,25 +571,29 @@ class DeviceCachingAllocator {
} else {
// A new split inactive block is being created from a previously unsplit
// block, size remaining->size bytes.
update_stat_array(
stats.inactive_split_bytes, remaining->size, params.stat_types);
update_stat_array(stats.inactive_split, 1, params.stat_types);
for_each_selected_stat_type(params.stat_types, [&](size_t stat_type) {
update_stat(stats.inactive_split_bytes[stat_type], remaining->size);
update_stat(stats.inactive_split[stat_type], 1);
});
}
} else if (already_split) {
// An already-split block is becoming active
update_stat_array(
stats.inactive_split_bytes, -block->size, params.stat_types);
update_stat_array(stats.inactive_split, -1, params.stat_types);
for_each_selected_stat_type(params.stat_types, [&](size_t stat_type) {
update_stat(stats.inactive_split_bytes[stat_type], -block->size);
update_stat(stats.inactive_split[stat_type], -1);
});
}
block->allocated = true;
bool inserted = active_blocks.insert(block).second;
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(inserted);
update_stat_array(stats.allocation, 1, params.stat_types);
update_stat_array(stats.allocated_bytes, block->size, params.stat_types);
update_stat_array(stats.active, 1, params.stat_types);
update_stat_array(stats.active_bytes, block->size, params.stat_types);
for_each_selected_stat_type(params.stat_types, [&](size_t stat_type) {
update_stat(stats.allocation[stat_type], 1);
update_stat(stats.allocated_bytes[stat_type], block->size);
update_stat(stats.active[stat_type], 1);
update_stat(stats.active_bytes[stat_type], block->size);
});
if (block->size >= CachingAllocatorConfig::max_split_size())
update_stat(stats.oversize_allocations, 1);
@ -605,12 +617,14 @@ class DeviceCachingAllocator {
auto orig_block_ptr = block->ptr;
auto orig_block_size = block->size;
StatTypes stat_types;
StatTypes stat_types = {false};
stat_types[static_cast<size_t>(StatType::AGGREGATE)] = true;
stat_types[static_cast<size_t>(get_stat_type_for_pool(*(block->pool)))] =
true;
update_stat_array(stats.allocation, -1, {stat_types});
update_stat_array(stats.allocated_bytes, -block->size, {stat_types});
for_each_selected_stat_type(stat_types, [&](size_t stat_type) {
update_stat(stats.allocation[stat_type], -1);
update_stat(stats.allocated_bytes[stat_type], -block->size);
});
if (block->size >= CachingAllocatorConfig::max_split_size())
update_stat(stats.oversize_allocations, -1);
@ -916,15 +930,18 @@ class DeviceCachingAllocator {
net_change_inactive_split_size += block->size;
}
StatTypes stat_types;
StatTypes stat_types = {false};
stat_types[static_cast<size_t>(StatType::AGGREGATE)] = true;
stat_types[static_cast<size_t>(get_stat_type_for_pool(pool))] = true;
update_stat_array(
stats.inactive_split, net_change_inactive_split_blocks, stat_types);
update_stat_array(
stats.inactive_split_bytes, net_change_inactive_split_size, stat_types);
update_stat_array(stats.active, -1, stat_types);
update_stat_array(stats.active_bytes, -original_block_size, stat_types);
for_each_selected_stat_type(stat_types, [&](size_t stat_type) {
update_stat(
stats.inactive_split[stat_type], net_change_inactive_split_blocks);
update_stat(
stats.inactive_split_bytes[stat_type],
net_change_inactive_split_size);
update_stat(stats.active[stat_type], -1);
update_stat(stats.active_bytes[stat_type], -original_block_size);
});
}
/** combine previously split blocks. returns the size of the subsumed block,
@ -1089,8 +1106,10 @@ class DeviceCachingAllocator {
total_allocated_memory += size;
p.block = new Block(p.device(), p.stream(), size, p.pool, (char*)ptr);
update_stat_array(stats.segment, 1, p.stat_types);
update_stat_array(stats.reserved_bytes, size, p.stat_types);
for_each_selected_stat_type(p.stat_types, [&](size_t stat_type) {
update_stat(stats.segment[stat_type], 1);
update_stat(stats.reserved_bytes[stat_type], size);
});
if (size >= CachingAllocatorConfig::max_split_size())
update_stat(stats.oversize_segments, 1);
@ -1179,11 +1198,13 @@ class DeviceCachingAllocator {
pool->owner_PrivatePool->cudaMalloc_count--;
}
StatTypes stat_types;
StatTypes stat_types = {false};
stat_types[static_cast<size_t>(StatType::AGGREGATE)] = true;
stat_types[static_cast<size_t>(get_stat_type_for_pool(*pool))] = true;
update_stat_array(stats.segment, -1, stat_types);
update_stat_array(stats.reserved_bytes, -block->size, stat_types);
for_each_selected_stat_type(stat_types, [&](size_t stat_type) {
update_stat(stats.segment[stat_type], -1);
update_stat(stats.reserved_bytes[stat_type], -block->size);
});
if (block->size >= CachingAllocatorConfig::max_split_size())
update_stat(stats.oversize_segments, -1);