Revert "Add UT for torch.accelerator memory-related API (#155200)"

This reverts commit 4604f0482c2b4a3001b62e5bc5085149a9bb053c.

Reverted https://github.com/pytorch/pytorch/pull/155200 on behalf of https://github.com/jithunnair-amd due to Broke ROCm periodic runs on MI300 e.g. https://github.com/pytorch/pytorch/actions/runs/16764977800/job/47470050573 ([comment](https://github.com/pytorch/pytorch/pull/138222#issuecomment-3164941815))
This commit is contained in:
PyTorch MergeBot
2025-08-07 16:34:36 +00:00
parent 90b78ee50f
commit c4e64467b5
3 changed files with 0 additions and 151 deletions

View File

@ -1,6 +1,5 @@
# Owner(s): ["module: tests"]
import gc
import sys
import unittest
@ -157,83 +156,6 @@ class TestAccelerator(TestCase):
):
event1.elapsed_time(event2)
@unittest.skipIf(TEST_MPS, "MPS doesn't support torch.accelerator memory API!")
def test_memory_stats(self):
# Ensure that device allocator is initialized
acc = torch.accelerator.current_accelerator()
tmp = torch.randn(100, device=acc)
del tmp
gc.collect()
self.assertTrue(torch._C._accelerator_isAllocatorInitialized())
torch.accelerator.empty_cache()
pool_type = ["all", "small_pool", "large_pool"]
metric_type = ["peak", "current", "allocated", "freed"]
stats_type = [
"allocated_bytes",
"reserved_bytes",
"active_bytes",
"requested_bytes",
]
mem_stats = torch.accelerator.memory_stats()
expected_stats = [
f"{st}.{pt}.{mt}"
for st in stats_type
for pt in pool_type
for mt in metric_type
]
missing_stats = [stat for stat in expected_stats if stat not in mem_stats]
self.assertEqual(
len(missing_stats),
0,
f"Missing expected memory statistics: {missing_stats}",
)
prev_allocated = torch.accelerator.memory_allocated()
prev_reserved = torch.accelerator.memory_reserved()
prev_max_allocated = torch.accelerator.max_memory_allocated()
prev_max_reserved = torch.accelerator.max_memory_reserved()
self.assertGreaterEqual(prev_allocated, 0)
self.assertGreaterEqual(prev_reserved, 0)
self.assertGreater(prev_max_allocated, 0)
self.assertGreater(prev_max_reserved, 0)
tmp = torch.ones(256, device=acc)
self.assertGreater(torch.accelerator.memory_allocated(), prev_allocated)
self.assertGreaterEqual(torch.accelerator.memory_reserved(), prev_reserved)
del tmp
gc.collect()
torch.accelerator.empty_cache()
torch.accelerator.reset_peak_memory_stats()
self.assertEqual(torch.accelerator.memory_allocated(), prev_allocated)
self.assertEqual(torch.accelerator.memory_reserved(), prev_reserved)
torch.accelerator.reset_accumulated_memory_stats()
prev_max_allocated = torch.accelerator.max_memory_allocated()
prev_max_reserved = torch.accelerator.max_memory_reserved()
# Activate 1kB memory
prev_active_current = torch.accelerator.memory_stats()[
"active_bytes.all.current"
]
tmp = torch.randn(256, device=acc)
# Detect if the current active memory is 1kB
self.assertEqual(
torch.accelerator.memory_stats()["active_bytes.all.current"],
1024 + prev_active_current,
)
self.assertEqual(torch.accelerator.memory_stats()["active_bytes.all.freed"], 0)
del tmp
gc.collect()
torch.accelerator.empty_cache()
self.assertEqual(
torch.accelerator.memory_stats()["active_bytes.all.current"],
prev_active_current,
)
self.assertEqual(
torch.accelerator.memory_stats()["active_bytes.all.freed"], 1024
)
torch.accelerator.reset_peak_memory_stats()
self.assertEqual(torch.accelerator.max_memory_allocated(), prev_max_allocated)
self.assertEqual(torch.accelerator.max_memory_reserved(), prev_max_reserved)
if __name__ == "__main__":
run_tests()