[PyTorch][CUDA Caching Allocator] Export sync-stream-and-free-HBM counter in memory_stats for performance debugging (#120050)

Differential Revision: D53734057

Pull Request resolved: https://github.com/pytorch/pytorch/pull/120050
Approved by: https://github.com/xw285cornell
This commit is contained in:
Levy Zhao
2024-02-27 04:34:53 +00:00
committed by PyTorch MergeBot
parent a1c641f118
commit b6139b1e57
4 changed files with 41 additions and 0 deletions

View File

@ -1413,6 +1413,9 @@ class DeviceCachingAllocator {
stats.num_alloc_retries = 0; stats.num_alloc_retries = 0;
stats.num_ooms = 0; stats.num_ooms = 0;
stats.num_sync_all_streams = 0;
stats.num_device_alloc = 0;
stats.num_device_free = 0;
reset_accumulated_stat(stats.oversize_allocations); reset_accumulated_stat(stats.oversize_allocations);
reset_accumulated_stat(stats.oversize_segments); reset_accumulated_stat(stats.oversize_segments);
} }
@ -2022,6 +2025,7 @@ class DeviceCachingAllocator {
update_stat(stats.reserved_bytes[stat_type], mapped_range.size); update_stat(stats.reserved_bytes[stat_type], mapped_range.size);
}); });
stats.num_device_alloc++;
record_trace( record_trace(
TraceEntry::SEGMENT_MAP, TraceEntry::SEGMENT_MAP,
int64_t(mapped_range.ptr), int64_t(mapped_range.ptr),
@ -2442,6 +2446,7 @@ class DeviceCachingAllocator {
// p.block came from new, not cudaMalloc. It should not be nullptr here. // p.block came from new, not cudaMalloc. It should not be nullptr here.
TORCH_INTERNAL_ASSERT(p.block != nullptr && p.block->ptr != nullptr); TORCH_INTERNAL_ASSERT(p.block != nullptr && p.block->ptr != nullptr);
stats.num_device_alloc++;
record_trace( record_trace(
TraceEntry::SEGMENT_ALLOC, TraceEntry::SEGMENT_ALLOC,
int64_t(p.block->ptr), int64_t(p.block->ptr),
@ -2547,6 +2552,7 @@ class DeviceCachingAllocator {
Block* block, Block* block,
const std::shared_ptr<GatheredContext>& context) { const std::shared_ptr<GatheredContext>& context) {
TORCH_INTERNAL_ASSERT(!block->expandable_segment_); TORCH_INTERNAL_ASSERT(!block->expandable_segment_);
stats.num_device_free++;
record_trace( record_trace(
TraceEntry::SEGMENT_FREE, TraceEntry::SEGMENT_FREE,
int64_t(block->ptr), int64_t(block->ptr),
@ -2629,6 +2635,7 @@ class DeviceCachingAllocator {
update_stat(stats.reserved_bytes[stat_type], -unmapped.size); update_stat(stats.reserved_bytes[stat_type], -unmapped.size);
}); });
stats.num_device_free++;
record_trace( record_trace(
TraceEntry::SEGMENT_UNMAP, TraceEntry::SEGMENT_UNMAP,
int64_t(unmapped.ptr), int64_t(unmapped.ptr),
@ -2672,6 +2679,7 @@ class DeviceCachingAllocator {
void synchronize_and_free_events( void synchronize_and_free_events(
const std::shared_ptr<GatheredContext>& context) { const std::shared_ptr<GatheredContext>& context) {
// Synchronize on outstanding events and then free associated blocks. // Synchronize on outstanding events and then free associated blocks.
stats.num_sync_all_streams++;
// This function syncs, so capture should not be underway. Might as well // This function syncs, so capture should not be underway. Might as well
// make sure capture-deferred end of life events get processed too. // make sure capture-deferred end of life events get processed too.

View File

@ -102,6 +102,17 @@ struct DeviceStats {
// COUNT: total number of oversize blocks requiring malloc // COUNT: total number of oversize blocks requiring malloc
Stat oversize_segments; Stat oversize_segments;
// COUNT: total number of synchronize_and_free_events() calls
int64_t num_sync_all_streams = 0;
// COUNT: total number of CUDA allocation calls. This includes both cuMemMap
// and cudaMalloc.
int64_t num_device_alloc = 0;
// COUNT: total number of CUDA free calls. This includes both cuMemUnmap
// and cudaFree.
int64_t num_device_free = 0;
// SIZE: maximum block size that is allowed to be split. // SIZE: maximum block size that is allowed to be split.
int64_t max_split_size = 0; int64_t max_split_size = 0;
}; };

View File

@ -156,8 +156,27 @@ class TestCudaMultiGPU(TestCase):
last_r_arr[0] = new_r last_r_arr[0] = new_r
max_r_arr[0] = new_max_r max_r_arr[0] = new_max_r
stat_key_n_sync = "num_sync_all_streams"
stat_key_n_alloc = "num_device_alloc"
stat_key_n_free = "num_device_free"
if empty_cache: if empty_cache:
num_sync_1 = torch.cuda.memory_stats(device).get(stat_key_n_sync, -1)
self.assertGreaterEqual(num_sync_1, 0)
num_alloc_1 = torch.cuda.memory_stats(device).get(stat_key_n_alloc, -1)
# if current memory usage is greater than zero we must have
# allocated something
self.assertGreaterEqual(num_alloc_1, 0 if new_m == 0 else 1)
num_free_1 = torch.cuda.memory_stats(device).get(stat_key_n_free, -1)
self.assertGreaterEqual(num_free_1, 0)
# empty_cache will enforce the call of release_cached_blocks
torch.cuda.empty_cache() torch.cuda.empty_cache()
num_sync_2 = torch.cuda.memory_stats(device).get(stat_key_n_sync, -1)
self.assertEqual(num_sync_1 + 1, num_sync_2)
num_alloc_2 = torch.cuda.memory_stats(device).get(stat_key_n_alloc, -1)
self.assertGreaterEqual(num_alloc_2, num_alloc_1)
num_free_2 = torch.cuda.memory_stats(device).get(stat_key_n_free, -1)
self.assertGreaterEqual(num_free_2, num_free_1)
new_r = torch.cuda.memory_reserved(device) new_r = torch.cuda.memory_reserved(device)
new_max_r = torch.cuda.max_memory_reserved(device) new_max_r = torch.cuda.max_memory_reserved(device)
self.assertLessEqual(new_r, last_r_arr[0]) self.assertLessEqual(new_r, last_r_arr[0])

View File

@ -584,6 +584,9 @@ PyObject* THCPModule_memoryStats(PyObject* _unused, PyObject* arg) {
result["num_alloc_retries"] = stats.num_alloc_retries; result["num_alloc_retries"] = stats.num_alloc_retries;
result["num_ooms"] = stats.num_ooms; result["num_ooms"] = stats.num_ooms;
result["max_split_size"] = stats.max_split_size; result["max_split_size"] = stats.max_split_size;
result["num_sync_all_streams"] = stats.num_sync_all_streams;
result["num_device_alloc"] = stats.num_device_alloc;
result["num_device_free"] = stats.num_device_free;
result["allocation"] = statArrayToDict(stats.allocation); result["allocation"] = statArrayToDict(stats.allocation);
result["segment"] = statArrayToDict(stats.segment); result["segment"] = statArrayToDict(stats.segment);
result["active"] = statArrayToDict(stats.active); result["active"] = statArrayToDict(stats.active);