Bug fixes in profiling allocator (#45993)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/45993

Some bug exposed via updated test and validation code.
Also enabled this test to be run on CI instead of just mobile only test.

Test Plan:
cpu_profiling_allocator_test

Imported from OSS

Reviewed By: dzhulgakov

Differential Revision: D24172599

fbshipit-source-id: da0d2e1d1dec87b476bf39a1c2a2ffa0e4b5df66
This commit is contained in:
Kimish Patel
2020-10-14 22:40:57 -07:00
committed by Facebook GitHub Bot
parent 419dafe791
commit 4aaad88790
3 changed files with 99 additions and 52 deletions

View File

@ -30,6 +30,7 @@ list(APPEND ATen_CPU_TEST_SRCS
${CMAKE_CURRENT_SOURCE_DIR}/math_kernel_test.cpp
${CMAKE_CURRENT_SOURCE_DIR}/memory_overlapping_test.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpu_generator_test.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpu_profiling_allocator_test.cpp
${CMAKE_CURRENT_SOURCE_DIR}/pow_test.cpp
${CMAKE_CURRENT_SOURCE_DIR}/variant_test.cpp
${CMAKE_CURRENT_SOURCE_DIR}/reduce_ops_test.cpp

View File

@ -1,7 +1,9 @@
#include <gtest/gtest.h>
#include <c10/core/CPUAllocator.h>
#include <c10/mobile/CPUProfilingAllocator.h>
#include <ATen/ATen.h>
#include <ATen/Context.h>
at::Tensor run_with_control_flow(
at::Tensor input,
@ -37,14 +39,14 @@ TEST(CPUAllocationPlanTest, with_control_flow) {
// 23, 16, 14, 14
// Flattened shape = 23, 3136
at::Tensor linear_weight = at::rand({32, 3136});
at::Tensor output;
at::Tensor output, ref_output;
std::vector<void*> pointers;
auto valid_allocation_plan = [&]() {
c10::AllocationPlan plan;
{
c10::WithProfileAllocationsGuard profile_guard(&plan);
output = run_with_control_flow(
ref_output = run_with_control_flow(
a, conv_weight, linear_weight, true, pointers);
}
};
@ -55,7 +57,7 @@ TEST(CPUAllocationPlanTest, with_control_flow) {
c10::AllocationPlan plan;
{
c10::WithProfileAllocationsGuard profile_guard(&plan);
output =
ref_output =
run_with_control_flow(a, conv_weight, linear_weight, record_mode, pointers);
}
bool success{true};
@ -84,14 +86,14 @@ TEST(CPUAllocationPlanTest, with_profiling_alloc) {
// 23, 16, 14, 14
// Flattened shape = 23, 3136
at::Tensor linear_weight = at::rand({32, 3136});
at::Tensor output;
at::Tensor output, ref_output;
std::vector<void*> pointers;
auto valid_allocation_plan = [&]() {
c10::AllocationPlan plan;
{
c10::WithProfileAllocationsGuard profile_guard(&plan);
output = run_with_control_flow(
ref_output = run_with_control_flow(
a, conv_weight, linear_weight, false, pointers);
}
};
@ -105,7 +107,7 @@ TEST(CPUAllocationPlanTest, with_profiling_alloc) {
c10::AllocationPlan plan;
{
c10::WithProfileAllocationsGuard profile_guard(&plan);
output = run_with_control_flow(
ref_output = run_with_control_flow(
a,
conv_weight,
linear_weight,
@ -145,11 +147,15 @@ TEST(CPUAllocationPlanTest, with_profiling_alloc) {
// When control flow conditions are same between profiling and evaluation
// profiling allocator should not throw.
ASSERT_NO_THROW(validate_allocation_plan(true, true, false));
ASSERT_TRUE(ref_output.equal(output));
ASSERT_NO_THROW(validate_allocation_plan(false, false, false));
ASSERT_TRUE(ref_output.equal(output));
// Furthermore profiling allocator should return the same pointers
// back for the intermediate tensors
ASSERT_NO_THROW(validate_allocation_plan(true, true, true));
ASSERT_TRUE(ref_output.equal(output));
ASSERT_NO_THROW(validate_allocation_plan(false, false, true));
ASSERT_TRUE(ref_output.equal(output));
// When control flow conditions are different between profiling and evaluation
// profiling allocator should throw.
@ -158,10 +164,13 @@ TEST(CPUAllocationPlanTest, with_profiling_alloc) {
}
int main(int argc, char* argv[]) {
// At the moment caching allocator is only exposed to mobile cpu allocator.
#ifdef C10_MOBILE
// Setting the priority high to make sure no other allocator gets used instead of this.
c10::SetCPUAllocator(c10::GetDefaultMobileCPUAllocator(), /*priority*/ 100);
// Need to disable mkldnn for this test since it allocatred memory
// via raw_allocate inteface which requires context pointer and raw
// pointer to be the same. Tis is not true for mobile allocator.
at::globalContext().setUserEnabledMkldnn(false);
::testing::InitGoogleTest(&argc, argv);
at::manual_seed(42);
return RUN_ALL_TESTS();
#endif /* C10_Mobile */
}

View File

@ -12,28 +12,10 @@ struct MemBlock {
uint64_t start_offset, end_offset;
MemBlock(uint64_t s, uint64_t e) : start_offset(s), end_offset(e) {}
bool operator<(const MemBlock& other) const {
return end_offset <= other.start_offset;
return start_offset < other.start_offset;
}
};
bool validate_allocation_plan(
const std::vector<uint64_t>& allocation_sizes,
const std::vector<uint64_t>& allocation_offsets) {
std::set<MemBlock> allocations;
for (uint64_t i = 0; i < allocation_sizes.size(); ++i) {
// Skip allocations not managed by AllocationPlan
if (allocation_offsets[i] == std::numeric_limits<uint64_t>::max()) {
continue;
}
auto start_offset = allocation_offsets[i];
auto end_offset = allocation_offsets[i] + allocation_sizes[i];
if (!allocations.emplace(start_offset, end_offset).second) {
return false;
}
}
return true;
}
enum class EventType {
Allocate = 0,
Free,
@ -49,6 +31,58 @@ struct MemEvent {
time(t), allocation_id(id), size(s), type(e) {}
};
bool overlaps(const MemBlock& a, const MemBlock& b) {
// two blocks dont overlap if
// |---a--------|--------------b--------|
// strat_a end_a <= start_b end_b
return
!((a.end_offset <= b.start_offset) || (b.end_offset <= a.start_offset));
}
bool validate_allocation_plan(
const std::vector<MemEvent>& alloc_events,
const std::vector<uint64_t>& allocation_offsets) {
std::set<MemBlock> allocations;
for (const auto& event : alloc_events) {
auto alloc_id = event.allocation_id;
// Skip allocations not managed by AllocationPlan
if (allocation_offsets[alloc_id] == std::numeric_limits<uint64_t>::max()) {
continue;
}
auto start_offset = allocation_offsets[alloc_id];
auto end_offset = allocation_offsets[alloc_id] + event.size;
MemBlock mem_block(start_offset, end_offset);
if (event.type == EventType::Allocate) {
auto it = allocations.lower_bound(mem_block);
if (it != allocations.end()) {
auto next_block = *it;
if (overlaps(next_block, mem_block)) {
return false;
}
}
if (it != allocations.begin()) {
auto prev_block = *(--it);
if (overlaps(prev_block, mem_block)) {
return false;
}
}
allocations.emplace(mem_block);
} else if (event.type == EventType::Free) {
auto it = allocations.find(mem_block);
TORCH_CHECK((*it).end_offset == end_offset,
"Enf offset of allocation being freed must match the one recorded.");
TORCH_CHECK(
it != allocations.end(),
"ProfilingAllocator: Allocate event "
"must have preceded deallocate event.");
allocations.erase(it);
} else {
TORCH_CHECK(false, "ProfilingAllocator: Invalid event type.");
}
}
return true;
}
std::vector<MemEvent> create_and_sort_mem_events(
const std::vector<uint64_t>& allocation_sizes,
const std::vector<uint64_t>& allocation_lifetimes) {
@ -106,7 +140,6 @@ std::vector<uint64_t> formulate_greedy_allocation_plan(
// Merging should always be done recursively until no more chunks
// that can be found.
// After last free we should have only one entry left in these maps.
ska::flat_hash_map<uint64_t, uint64_t> allocated_offset_to_size;
std::vector<uint64_t> allocation_offsets(
allocation_sizes.size(), std::numeric_limits<uint64_t>::max());
@ -122,7 +155,6 @@ std::vector<uint64_t> formulate_greedy_allocation_plan(
// allocate a new one.
alloc_offset = max_offset;
max_offset += mem_event.size;
allocated_offset_to_size.emplace(alloc_offset, mem_event.size);
} else {
// If we have found a block of the size we want
// 1. change the block by allocating out of it.
@ -130,7 +162,6 @@ std::vector<uint64_t> formulate_greedy_allocation_plan(
// 1.2 Erase the reverse map entries
// 2. If block still has space left insert the remainder back in map.
// Including reverse map entries.
// 3. Insert the allocated block in allocated_offset_to_size.
alloc_offset = it->second;
new_offset = alloc_offset + mem_event.size;
new_size = it->first - mem_event.size;
@ -138,11 +169,10 @@ std::vector<uint64_t> formulate_greedy_allocation_plan(
free_start_offset_to_size_iter.erase(alloc_offset);
free_end_offset_to_size_iter.erase(alloc_offset + it->first);
if (new_size > 0) {
auto ref_it = free_size_to_offset.emplace(new_offset, new_size).first;
auto ref_it = free_size_to_offset.emplace(new_size, new_offset).first;
free_start_offset_to_size_iter.emplace(new_offset, ref_it);
free_end_offset_to_size_iter.emplace(new_offset + new_size, ref_it);
}
allocated_offset_to_size.emplace(alloc_offset, mem_event.size);
}
allocation_offsets[mem_event.allocation_id] = alloc_offset;
} else {
@ -161,23 +191,30 @@ std::vector<uint64_t> formulate_greedy_allocation_plan(
auto freed_size = mem_event.size;
auto end_offset = freed_offset + freed_size;
// Merge when another free block exist at the end of this block
auto end_it = free_end_offset_to_size_iter.find(end_offset);
if (end_it != free_end_offset_to_size_iter.end()) {
auto size_to_end_offset_iter = end_it->second;
freed_size += size_to_end_offset_iter->first;
free_size_to_offset.erase(size_to_end_offset_iter);
free_end_offset_to_size_iter.erase(end_it);
auto end_it = free_start_offset_to_size_iter.find(end_offset);
if (end_it != free_start_offset_to_size_iter.end()) {
auto merge_block_iter = end_it->second;
auto merge_block_size = merge_block_iter->first;
freed_size += merge_block_size;
free_size_to_offset.erase(merge_block_iter);
free_start_offset_to_size_iter.erase(end_it);
// If the block is being merged then also remove it from
// free_end_offset_to_size_iter
free_end_offset_to_size_iter.erase(end_offset + merge_block_size);
}
// Merge when freed block exist at the end of another free block
auto start_it = free_start_offset_to_size_iter.find(freed_offset);
if (start_it != free_start_offset_to_size_iter.end()) {
auto size_to_start_offset_iter = start_it->second;
freed_size += size_to_start_offset_iter->first;
freed_offset -= size_to_start_offset_iter->first;
free_size_to_offset.erase(size_to_start_offset_iter);
free_start_offset_to_size_iter.erase(start_it);
auto start_it = free_end_offset_to_size_iter.find(freed_offset);
if (start_it != free_end_offset_to_size_iter.end()) {
auto merge_block_iter = start_it->second;
auto merge_block_size = merge_block_iter->first;
freed_size += merge_block_size;
freed_offset -= merge_block_size;
free_size_to_offset.erase(merge_block_iter);
free_end_offset_to_size_iter.erase(start_it);
// If the block is being merged then also remove it from
// free_start_offset_to_size_iter
free_start_offset_to_size_iter.erase(freed_offset);
}
allocated_offset_to_size.erase(freed_offset);
auto freed_block_it =
free_size_to_offset.emplace(freed_size, freed_offset).first;
free_start_offset_to_size_iter.emplace(freed_offset, freed_block_it);
@ -185,8 +222,8 @@ std::vector<uint64_t> formulate_greedy_allocation_plan(
freed_offset + freed_size, freed_block_it);
}
}
TORCH_CHECK(validate_allocation_plan(allocation_sizes, allocation_offsets),
"Allocation plan invaild.");
TORCH_CHECK(validate_allocation_plan(mem_events, allocation_offsets),
"ProfilingAllocator: Allocation plan invaild.");
return allocation_offsets;
}
@ -207,7 +244,7 @@ void AllocationPlanner::record_allocation(
allocation_plan_->allocation_sizes.push_back(size);
allocation_plan_->allocation_lifetimes.push_back(
std::numeric_limits<uint64_t>::max());
allocation_ptr_to_id_.emplace(ptr, allocation_id_);
allocation_ptr_to_id_[ptr] = allocation_id_;
allocation_id_++;
}
@ -244,7 +281,7 @@ bool AllocationPlanner::validate_allocation(
return false;
}
allocation_ptr_to_id_.emplace(ptr, allocation_id_);
allocation_ptr_to_id_[ptr] = allocation_id_;
allocation_id_++;
return true;
}
@ -313,7 +350,7 @@ void* CPUProfilingAllocator::allocate(const size_t bytes) {
void* ptr =
reinterpret_cast<uint8_t*>(blob_) +
plan_->allocation_offsets[allocation_id_];
TORCH_CHECK(allocation_ptr_to_id_.emplace(ptr, allocation_id_).second);
allocation_ptr_to_id_[ptr] = allocation_id_;
allocation_id_++;
return ptr;
}