[PyTorch][CUDA Caching Allocator] Export sync-stream-and-free-HBM counter in memory_stats for performance debugging (#120050)

Differential Revision: D53734057

Pull Request resolved: https://github.com/pytorch/pytorch/pull/120050
Approved by: https://github.com/xw285cornell
This commit is contained in:
Levy Zhao
2024-02-27 04:34:53 +00:00
committed by PyTorch MergeBot
parent a1c641f118
commit b6139b1e57
4 changed files with 41 additions and 0 deletions

View File

@ -1413,6 +1413,9 @@ class DeviceCachingAllocator {
stats.num_alloc_retries = 0;
stats.num_ooms = 0;
stats.num_sync_all_streams = 0;
stats.num_device_alloc = 0;
stats.num_device_free = 0;
reset_accumulated_stat(stats.oversize_allocations);
reset_accumulated_stat(stats.oversize_segments);
}
@ -2022,6 +2025,7 @@ class DeviceCachingAllocator {
update_stat(stats.reserved_bytes[stat_type], mapped_range.size);
});
stats.num_device_alloc++;
record_trace(
TraceEntry::SEGMENT_MAP,
int64_t(mapped_range.ptr),
@ -2442,6 +2446,7 @@ class DeviceCachingAllocator {
// p.block came from new, not cudaMalloc. It should not be nullptr here.
TORCH_INTERNAL_ASSERT(p.block != nullptr && p.block->ptr != nullptr);
stats.num_device_alloc++;
record_trace(
TraceEntry::SEGMENT_ALLOC,
int64_t(p.block->ptr),
@ -2547,6 +2552,7 @@ class DeviceCachingAllocator {
Block* block,
const std::shared_ptr<GatheredContext>& context) {
TORCH_INTERNAL_ASSERT(!block->expandable_segment_);
stats.num_device_free++;
record_trace(
TraceEntry::SEGMENT_FREE,
int64_t(block->ptr),
@ -2629,6 +2635,7 @@ class DeviceCachingAllocator {
update_stat(stats.reserved_bytes[stat_type], -unmapped.size);
});
stats.num_device_free++;
record_trace(
TraceEntry::SEGMENT_UNMAP,
int64_t(unmapped.ptr),
@ -2672,6 +2679,7 @@ class DeviceCachingAllocator {
void synchronize_and_free_events(
const std::shared_ptr<GatheredContext>& context) {
// Synchronize on outstanding events and then free associated blocks.
stats.num_sync_all_streams++;
// This function syncs, so capture should not be underway. Might as well
// make sure capture-deferred end of life events get processed too.

View File

@ -102,6 +102,17 @@ struct DeviceStats {
// COUNT: total number of oversize blocks requiring malloc
Stat oversize_segments;
// COUNT: total number of synchronize_and_free_events() calls
int64_t num_sync_all_streams = 0;
// COUNT: total number of CUDA allocation calls. This includes both cuMemMap
// and cudaMalloc.
int64_t num_device_alloc = 0;
// COUNT: total number of CUDA free calls. This includes both cuMemUnmap
// and cudaFree.
int64_t num_device_free = 0;
// SIZE: maximum block size that is allowed to be split.
int64_t max_split_size = 0;
};

View File

@ -156,8 +156,27 @@ class TestCudaMultiGPU(TestCase):
last_r_arr[0] = new_r
max_r_arr[0] = new_max_r
stat_key_n_sync = "num_sync_all_streams"
stat_key_n_alloc = "num_device_alloc"
stat_key_n_free = "num_device_free"
if empty_cache:
num_sync_1 = torch.cuda.memory_stats(device).get(stat_key_n_sync, -1)
self.assertGreaterEqual(num_sync_1, 0)
num_alloc_1 = torch.cuda.memory_stats(device).get(stat_key_n_alloc, -1)
# if current memory usage is greater than zero we must have
# allocated something
self.assertGreaterEqual(num_alloc_1, 0 if new_m == 0 else 1)
num_free_1 = torch.cuda.memory_stats(device).get(stat_key_n_free, -1)
self.assertGreaterEqual(num_free_1, 0)
# empty_cache will enforce the call of release_cached_blocks
torch.cuda.empty_cache()
num_sync_2 = torch.cuda.memory_stats(device).get(stat_key_n_sync, -1)
self.assertEqual(num_sync_1 + 1, num_sync_2)
num_alloc_2 = torch.cuda.memory_stats(device).get(stat_key_n_alloc, -1)
self.assertGreaterEqual(num_alloc_2, num_alloc_1)
num_free_2 = torch.cuda.memory_stats(device).get(stat_key_n_free, -1)
self.assertGreaterEqual(num_free_2, num_free_1)
new_r = torch.cuda.memory_reserved(device)
new_max_r = torch.cuda.max_memory_reserved(device)
self.assertLessEqual(new_r, last_r_arr[0])

View File

@ -584,6 +584,9 @@ PyObject* THCPModule_memoryStats(PyObject* _unused, PyObject* arg) {
result["num_alloc_retries"] = stats.num_alloc_retries;
result["num_ooms"] = stats.num_ooms;
result["max_split_size"] = stats.max_split_size;
result["num_sync_all_streams"] = stats.num_sync_all_streams;
result["num_device_alloc"] = stats.num_device_alloc;
result["num_device_free"] = stats.num_device_free;
result["allocation"] = statArrayToDict(stats.allocation);
result["segment"] = statArrayToDict(stats.segment);
result["active"] = statArrayToDict(stats.active);