mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Revert "Add UT for torch.accelerator memory-related API (#155200)"
This reverts commit 4604f0482c2b4a3001b62e5bc5085149a9bb053c. Reverted https://github.com/pytorch/pytorch/pull/155200 on behalf of https://github.com/jithunnair-amd due to Broke ROCm periodic runs on MI300 e.g. https://github.com/pytorch/pytorch/actions/runs/16764977800/job/47470050573 ([comment](https://github.com/pytorch/pytorch/pull/138222#issuecomment-3164941815))
This commit is contained in:
@ -1,6 +1,5 @@
|
||||
# Owner(s): ["module: tests"]
|
||||
|
||||
import gc
|
||||
import sys
|
||||
import unittest
|
||||
|
||||
@ -157,83 +156,6 @@ class TestAccelerator(TestCase):
|
||||
):
|
||||
event1.elapsed_time(event2)
|
||||
|
||||
@unittest.skipIf(TEST_MPS, "MPS doesn't support torch.accelerator memory API!")
|
||||
def test_memory_stats(self):
|
||||
# Ensure that device allocator is initialized
|
||||
acc = torch.accelerator.current_accelerator()
|
||||
tmp = torch.randn(100, device=acc)
|
||||
del tmp
|
||||
gc.collect()
|
||||
self.assertTrue(torch._C._accelerator_isAllocatorInitialized())
|
||||
torch.accelerator.empty_cache()
|
||||
|
||||
pool_type = ["all", "small_pool", "large_pool"]
|
||||
metric_type = ["peak", "current", "allocated", "freed"]
|
||||
stats_type = [
|
||||
"allocated_bytes",
|
||||
"reserved_bytes",
|
||||
"active_bytes",
|
||||
"requested_bytes",
|
||||
]
|
||||
mem_stats = torch.accelerator.memory_stats()
|
||||
expected_stats = [
|
||||
f"{st}.{pt}.{mt}"
|
||||
for st in stats_type
|
||||
for pt in pool_type
|
||||
for mt in metric_type
|
||||
]
|
||||
missing_stats = [stat for stat in expected_stats if stat not in mem_stats]
|
||||
self.assertEqual(
|
||||
len(missing_stats),
|
||||
0,
|
||||
f"Missing expected memory statistics: {missing_stats}",
|
||||
)
|
||||
|
||||
prev_allocated = torch.accelerator.memory_allocated()
|
||||
prev_reserved = torch.accelerator.memory_reserved()
|
||||
prev_max_allocated = torch.accelerator.max_memory_allocated()
|
||||
prev_max_reserved = torch.accelerator.max_memory_reserved()
|
||||
self.assertGreaterEqual(prev_allocated, 0)
|
||||
self.assertGreaterEqual(prev_reserved, 0)
|
||||
self.assertGreater(prev_max_allocated, 0)
|
||||
self.assertGreater(prev_max_reserved, 0)
|
||||
tmp = torch.ones(256, device=acc)
|
||||
self.assertGreater(torch.accelerator.memory_allocated(), prev_allocated)
|
||||
self.assertGreaterEqual(torch.accelerator.memory_reserved(), prev_reserved)
|
||||
del tmp
|
||||
gc.collect()
|
||||
torch.accelerator.empty_cache()
|
||||
torch.accelerator.reset_peak_memory_stats()
|
||||
self.assertEqual(torch.accelerator.memory_allocated(), prev_allocated)
|
||||
self.assertEqual(torch.accelerator.memory_reserved(), prev_reserved)
|
||||
torch.accelerator.reset_accumulated_memory_stats()
|
||||
prev_max_allocated = torch.accelerator.max_memory_allocated()
|
||||
prev_max_reserved = torch.accelerator.max_memory_reserved()
|
||||
# Activate 1kB memory
|
||||
prev_active_current = torch.accelerator.memory_stats()[
|
||||
"active_bytes.all.current"
|
||||
]
|
||||
tmp = torch.randn(256, device=acc)
|
||||
# Detect if the current active memory is 1kB
|
||||
self.assertEqual(
|
||||
torch.accelerator.memory_stats()["active_bytes.all.current"],
|
||||
1024 + prev_active_current,
|
||||
)
|
||||
self.assertEqual(torch.accelerator.memory_stats()["active_bytes.all.freed"], 0)
|
||||
del tmp
|
||||
gc.collect()
|
||||
torch.accelerator.empty_cache()
|
||||
self.assertEqual(
|
||||
torch.accelerator.memory_stats()["active_bytes.all.current"],
|
||||
prev_active_current,
|
||||
)
|
||||
self.assertEqual(
|
||||
torch.accelerator.memory_stats()["active_bytes.all.freed"], 1024
|
||||
)
|
||||
torch.accelerator.reset_peak_memory_stats()
|
||||
self.assertEqual(torch.accelerator.max_memory_allocated(), prev_max_allocated)
|
||||
self.assertEqual(torch.accelerator.max_memory_reserved(), prev_max_reserved)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
Reference in New Issue
Block a user