Add UT for torch.accelerator memory-related API (#155200)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/155200
Approved by: https://github.com/albanD
ghstack dependencies: #138222, #152932
This commit is contained in:
Yu, Guangye
2025-08-08 15:17:59 +00:00
committed by PyTorch MergeBot
parent 84f7e88aef
commit da1f608ca3
3 changed files with 151 additions and 0 deletions

View File

@ -1,5 +1,6 @@
# Owner(s): ["module: tests"]
import gc
import sys
import unittest
@ -156,6 +157,83 @@ class TestAccelerator(TestCase):
):
event1.elapsed_time(event2)
@unittest.skipIf(TEST_MPS, "MPS doesn't support torch.accelerator memory API!")
def test_memory_stats(self):
# Ensure that device allocator is initialized
acc = torch.accelerator.current_accelerator()
tmp = torch.randn(100, device=acc)
del tmp
gc.collect()
self.assertTrue(torch._C._accelerator_isAllocatorInitialized())
torch.accelerator.empty_cache()
pool_type = ["all", "small_pool", "large_pool"]
metric_type = ["peak", "current", "allocated", "freed"]
stats_type = [
"allocated_bytes",
"reserved_bytes",
"active_bytes",
"requested_bytes",
]
mem_stats = torch.accelerator.memory_stats()
expected_stats = [
f"{st}.{pt}.{mt}"
for st in stats_type
for pt in pool_type
for mt in metric_type
]
missing_stats = [stat for stat in expected_stats if stat not in mem_stats]
self.assertEqual(
len(missing_stats),
0,
f"Missing expected memory statistics: {missing_stats}",
)
prev_allocated = torch.accelerator.memory_allocated()
prev_reserved = torch.accelerator.memory_reserved()
prev_max_allocated = torch.accelerator.max_memory_allocated()
prev_max_reserved = torch.accelerator.max_memory_reserved()
self.assertGreaterEqual(prev_allocated, 0)
self.assertGreaterEqual(prev_reserved, 0)
self.assertGreater(prev_max_allocated, 0)
self.assertGreater(prev_max_reserved, 0)
tmp = torch.ones(256, device=acc)
self.assertGreater(torch.accelerator.memory_allocated(), prev_allocated)
self.assertGreaterEqual(torch.accelerator.memory_reserved(), prev_reserved)
del tmp
gc.collect()
torch.accelerator.empty_cache()
torch.accelerator.reset_peak_memory_stats()
self.assertEqual(torch.accelerator.memory_allocated(), prev_allocated)
self.assertEqual(torch.accelerator.memory_reserved(), prev_reserved)
torch.accelerator.reset_accumulated_memory_stats()
prev_max_allocated = torch.accelerator.max_memory_allocated()
prev_max_reserved = torch.accelerator.max_memory_reserved()
# Activate 1kB memory
prev_active_current = torch.accelerator.memory_stats()[
"active_bytes.all.current"
]
tmp = torch.randn(256, device=acc)
# Detect if the current active memory is 1kB
self.assertEqual(
torch.accelerator.memory_stats()["active_bytes.all.current"],
1024 + prev_active_current,
)
self.assertEqual(torch.accelerator.memory_stats()["active_bytes.all.freed"], 0)
del tmp
gc.collect()
torch.accelerator.empty_cache()
self.assertEqual(
torch.accelerator.memory_stats()["active_bytes.all.current"],
prev_active_current,
)
self.assertEqual(
torch.accelerator.memory_stats()["active_bytes.all.freed"], 1024
)
torch.accelerator.reset_peak_memory_stats()
self.assertEqual(torch.accelerator.max_memory_allocated(), prev_max_allocated)
self.assertEqual(torch.accelerator.max_memory_reserved(), prev_max_reserved)
if __name__ == "__main__":
run_tests()

View File

@ -373,6 +373,42 @@ print(t.is_pinned())
torch.cuda.caching_allocator_delete(mem)
self.assertEqual(torch.cuda.memory_allocated(), prev)
def test_memory_stats(self):
gc.collect()
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
torch.cuda.reset_accumulated_memory_stats()
prev_allocated = torch.accelerator.memory_allocated()
prev_reserved = torch.accelerator.memory_reserved()
prev_max_allocated = torch.accelerator.max_memory_allocated()
prev_max_reserved = torch.accelerator.max_memory_reserved()
self.assertEqual(prev_allocated, prev_max_allocated)
self.assertEqual(prev_reserved, prev_max_reserved)
# Activate 1kB memory
prev_active_current = torch.accelerator.memory_stats()[
"active_bytes.all.current"
]
tmp = torch.randn(256, device="cuda")
# Detect if the current active memory is 1kB
self.assertEqual(
torch.accelerator.memory_stats()["active_bytes.all.current"],
1024 + prev_active_current,
)
self.assertEqual(torch.accelerator.memory_stats()["active_bytes.all.freed"], 0)
del tmp
gc.collect()
torch.accelerator.empty_cache()
self.assertEqual(
torch.accelerator.memory_stats()["active_bytes.all.current"],
prev_active_current,
)
self.assertEqual(
torch.accelerator.memory_stats()["active_bytes.all.freed"], 1024
)
torch.accelerator.reset_peak_memory_stats()
self.assertEqual(torch.accelerator.max_memory_allocated(), prev_max_allocated)
self.assertEqual(torch.accelerator.max_memory_reserved(), prev_max_reserved)
def test_check_error(self):
# Assert this call doesn't raise.
torch.cuda.check_error(0)

View File

@ -1,5 +1,6 @@
# Owner(s): ["module: intel"]
import gc
import re
import subprocess
import sys
@ -520,6 +521,42 @@ if __name__ == "__main__":
)
del a
def test_memory_stats(self):
gc.collect()
torch.xpu.empty_cache()
torch.xpu.reset_peak_memory_stats()
torch.xpu.reset_accumulated_memory_stats()
prev_allocated = torch.accelerator.memory_allocated()
prev_reserved = torch.accelerator.memory_reserved()
prev_max_allocated = torch.accelerator.max_memory_allocated()
prev_max_reserved = torch.accelerator.max_memory_reserved()
self.assertEqual(prev_allocated, prev_max_allocated)
self.assertEqual(prev_reserved, prev_max_reserved)
# Activate 1kB memory
prev_active_current = torch.accelerator.memory_stats()[
"active_bytes.all.current"
]
tmp = torch.randn(256, device="xpu")
# Detect if the current active memory is 1kB
self.assertEqual(
torch.accelerator.memory_stats()["active_bytes.all.current"],
1024 + prev_active_current,
)
self.assertEqual(torch.accelerator.memory_stats()["active_bytes.all.freed"], 0)
del tmp
gc.collect()
torch.accelerator.empty_cache()
self.assertEqual(
torch.accelerator.memory_stats()["active_bytes.all.current"],
prev_active_current,
)
self.assertEqual(
torch.accelerator.memory_stats()["active_bytes.all.freed"], 1024
)
torch.accelerator.reset_peak_memory_stats()
self.assertEqual(torch.accelerator.max_memory_allocated(), prev_max_allocated)
self.assertEqual(torch.accelerator.max_memory_reserved(), prev_max_reserved)
@skipXPUIf(
int(torch.version.xpu) < 20250000,
"Test requires SYCL compiler version 2025.0.0 or newer.",