Fix xpu memory stats error (#135818)

# Motivation
fix https://github.com/pytorch/pytorch/issues/135726
After merging two free blocks, I made a stupid mistake of ignoring the correct size to decrease the active memory size, which should be the original block size instead of the merged block size.

# Additional Context
Add a UT to guard this scenario.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135818
Approved by: https://github.com/EikanWang
This commit is contained in:
Yu, Guangye
2024-09-12 16:48:28 +00:00
committed by PyTorch MergeBot
parent 1c04cbfba6
commit e6b68359d7
2 changed files with 13 additions and 2 deletions

View File

@ -168,6 +168,8 @@ class DeviceCachingAllocator {
!block->allocated && block->event_count == 0 &&
block->stream_uses.empty());
size_t original_block_size = block->size;
size_t requested_size = block->requested_size;
auto& pool = *block->pool;
const std::array<Block*, 2> merge_candidates = {block->prev, block->next};
for (Block* merge_candidate : merge_candidates) {
@ -180,8 +182,8 @@ class DeviceCachingAllocator {
StatTypes stat_types = get_stat_types_for_pool(pool);
for_each_selected_stat_type(stat_types, [&](size_t stat_type) {
stats.active_bytes[stat_type].decrease(block->size);
stats.requested_bytes[stat_type].decrease(block->requested_size);
stats.active_bytes[stat_type].decrease(original_block_size);
stats.requested_bytes[stat_type].decrease(requested_size);
});
}

View File

@ -390,6 +390,15 @@ print(torch.xpu.device_count())
self.assertEqual(torch.xpu.memory_allocated(), prev)
torch.xpu.empty_cache()
self.assertEqual(torch.xpu.memory_reserved(), 0)
torch.xpu.reset_accumulated_memory_stats()
# Activate 1kB memory
a = torch.randn(256, device="xpu")
# Detect if the current active memory is 1kB
self.assertEqual(torch.xpu.memory_stats()["active_bytes.all.current"], 1024)
self.assertEqual(torch.xpu.memory_stats()["active_bytes.all.freed"], 0)
del a
self.assertEqual(torch.xpu.memory_stats()["active_bytes.all.current"], 0)
self.assertEqual(torch.xpu.memory_stats()["active_bytes.all.freed"], 1024)
@unittest.skipIf(not TEST_MULTIXPU, "only one GPU detected")
def test_device_memory_allocated(self):