diff --git a/c10/cuda/CUDACachingAllocator.cpp b/c10/cuda/CUDACachingAllocator.cpp index ef0172fd855e..0409e6340595 100644 --- a/c10/cuda/CUDACachingAllocator.cpp +++ b/c10/cuda/CUDACachingAllocator.cpp @@ -1413,6 +1413,9 @@ class DeviceCachingAllocator { stats.num_alloc_retries = 0; stats.num_ooms = 0; + stats.num_sync_all_streams = 0; + stats.num_device_alloc = 0; + stats.num_device_free = 0; reset_accumulated_stat(stats.oversize_allocations); reset_accumulated_stat(stats.oversize_segments); } @@ -2022,6 +2025,7 @@ class DeviceCachingAllocator { update_stat(stats.reserved_bytes[stat_type], mapped_range.size); }); + stats.num_device_alloc++; record_trace( TraceEntry::SEGMENT_MAP, int64_t(mapped_range.ptr), @@ -2442,6 +2446,7 @@ class DeviceCachingAllocator { // p.block came from new, not cudaMalloc. It should not be nullptr here. TORCH_INTERNAL_ASSERT(p.block != nullptr && p.block->ptr != nullptr); + stats.num_device_alloc++; record_trace( TraceEntry::SEGMENT_ALLOC, int64_t(p.block->ptr), @@ -2547,6 +2552,7 @@ class DeviceCachingAllocator { Block* block, const std::shared_ptr& context) { TORCH_INTERNAL_ASSERT(!block->expandable_segment_); + stats.num_device_free++; record_trace( TraceEntry::SEGMENT_FREE, int64_t(block->ptr), @@ -2629,6 +2635,7 @@ class DeviceCachingAllocator { update_stat(stats.reserved_bytes[stat_type], -unmapped.size); }); + stats.num_device_free++; record_trace( TraceEntry::SEGMENT_UNMAP, int64_t(unmapped.ptr), @@ -2672,6 +2679,7 @@ class DeviceCachingAllocator { void synchronize_and_free_events( const std::shared_ptr& context) { // Synchronize on outstanding events and then free associated blocks. + stats.num_sync_all_streams++; // This function syncs, so capture should not be underway. Might as well // make sure capture-deferred end of life events get processed too. diff --git a/c10/cuda/CUDACachingAllocator.h b/c10/cuda/CUDACachingAllocator.h index 906b9650cb85..7f83db3796b4 100644 --- a/c10/cuda/CUDACachingAllocator.h +++ b/c10/cuda/CUDACachingAllocator.h @@ -102,6 +102,17 @@ struct DeviceStats { // COUNT: total number of oversize blocks requiring malloc Stat oversize_segments; + // COUNT: total number of synchronize_and_free_events() calls + int64_t num_sync_all_streams = 0; + + // COUNT: total number of CUDA allocation calls. This includes both cuMemMap + // and cudaMalloc. + int64_t num_device_alloc = 0; + + // COUNT: total number of CUDA free calls. This includes both cuMemUnmap + // and cudaFree. + int64_t num_device_free = 0; + // SIZE: maximum block size that is allowed to be split. int64_t max_split_size = 0; }; diff --git a/test/test_cuda_multigpu.py b/test/test_cuda_multigpu.py index ee3c6a927345..77e8d5693c35 100644 --- a/test/test_cuda_multigpu.py +++ b/test/test_cuda_multigpu.py @@ -156,8 +156,27 @@ class TestCudaMultiGPU(TestCase): last_r_arr[0] = new_r max_r_arr[0] = new_max_r + stat_key_n_sync = "num_sync_all_streams" + stat_key_n_alloc = "num_device_alloc" + stat_key_n_free = "num_device_free" if empty_cache: + num_sync_1 = torch.cuda.memory_stats(device).get(stat_key_n_sync, -1) + self.assertGreaterEqual(num_sync_1, 0) + num_alloc_1 = torch.cuda.memory_stats(device).get(stat_key_n_alloc, -1) + # if current memory usage is greater than zero we must have + # allocated something + self.assertGreaterEqual(num_alloc_1, 0 if new_m == 0 else 1) + num_free_1 = torch.cuda.memory_stats(device).get(stat_key_n_free, -1) + self.assertGreaterEqual(num_free_1, 0) + # empty_cache will enforce the call of release_cached_blocks torch.cuda.empty_cache() + num_sync_2 = torch.cuda.memory_stats(device).get(stat_key_n_sync, -1) + self.assertEqual(num_sync_1 + 1, num_sync_2) + num_alloc_2 = torch.cuda.memory_stats(device).get(stat_key_n_alloc, -1) + self.assertGreaterEqual(num_alloc_2, num_alloc_1) + num_free_2 = torch.cuda.memory_stats(device).get(stat_key_n_free, -1) + self.assertGreaterEqual(num_free_2, num_free_1) + new_r = torch.cuda.memory_reserved(device) new_max_r = torch.cuda.max_memory_reserved(device) self.assertLessEqual(new_r, last_r_arr[0]) diff --git a/torch/csrc/cuda/Module.cpp b/torch/csrc/cuda/Module.cpp index c4942ad9b8ac..0ce243c300c3 100644 --- a/torch/csrc/cuda/Module.cpp +++ b/torch/csrc/cuda/Module.cpp @@ -584,6 +584,9 @@ PyObject* THCPModule_memoryStats(PyObject* _unused, PyObject* arg) { result["num_alloc_retries"] = stats.num_alloc_retries; result["num_ooms"] = stats.num_ooms; result["max_split_size"] = stats.max_split_size; + result["num_sync_all_streams"] = stats.num_sync_all_streams; + result["num_device_alloc"] = stats.num_device_alloc; + result["num_device_free"] = stats.num_device_free; result["allocation"] = statArrayToDict(stats.allocation); result["segment"] = statArrayToDict(stats.segment); result["active"] = statArrayToDict(stats.active);