mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-10-20 21:14:14 +08:00 
			
		
		
		
	Add UT for torch.accelerator memory-related API (#155200)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/155200 Approved by: https://github.com/albanD ghstack dependencies: #138222, #152932
This commit is contained in:
		
				
					committed by
					
						 PyTorch MergeBot
						PyTorch MergeBot
					
				
			
			
				
	
			
			
			
						parent
						
							15f1173e5d
						
					
				
				
					commit
					4604f0482c
				
			| @ -1,5 +1,6 @@ | ||||
| # Owner(s): ["module: tests"] | ||||
|  | ||||
| import gc | ||||
| import sys | ||||
| import unittest | ||||
|  | ||||
| @ -156,6 +157,83 @@ class TestAccelerator(TestCase): | ||||
|         ): | ||||
|             event1.elapsed_time(event2) | ||||
|  | ||||
|     @unittest.skipIf(TEST_MPS, "MPS doesn't support torch.accelerator memory API!") | ||||
|     def test_memory_stats(self): | ||||
|         # Ensure that device allocator is initialized | ||||
|         acc = torch.accelerator.current_accelerator() | ||||
|         tmp = torch.randn(100, device=acc) | ||||
|         del tmp | ||||
|         gc.collect() | ||||
|         self.assertTrue(torch._C._accelerator_isAllocatorInitialized()) | ||||
|         torch.accelerator.empty_cache() | ||||
|  | ||||
|         pool_type = ["all", "small_pool", "large_pool"] | ||||
|         metric_type = ["peak", "current", "allocated", "freed"] | ||||
|         stats_type = [ | ||||
|             "allocated_bytes", | ||||
|             "reserved_bytes", | ||||
|             "active_bytes", | ||||
|             "requested_bytes", | ||||
|         ] | ||||
|         mem_stats = torch.accelerator.memory_stats() | ||||
|         expected_stats = [ | ||||
|             f"{st}.{pt}.{mt}" | ||||
|             for st in stats_type | ||||
|             for pt in pool_type | ||||
|             for mt in metric_type | ||||
|         ] | ||||
|         missing_stats = [stat for stat in expected_stats if stat not in mem_stats] | ||||
|         self.assertEqual( | ||||
|             len(missing_stats), | ||||
|             0, | ||||
|             f"Missing expected memory statistics: {missing_stats}", | ||||
|         ) | ||||
|  | ||||
|         prev_allocated = torch.accelerator.memory_allocated() | ||||
|         prev_reserved = torch.accelerator.memory_reserved() | ||||
|         prev_max_allocated = torch.accelerator.max_memory_allocated() | ||||
|         prev_max_reserved = torch.accelerator.max_memory_reserved() | ||||
|         self.assertGreaterEqual(prev_allocated, 0) | ||||
|         self.assertGreaterEqual(prev_reserved, 0) | ||||
|         self.assertGreater(prev_max_allocated, 0) | ||||
|         self.assertGreater(prev_max_reserved, 0) | ||||
|         tmp = torch.ones(256, device=acc) | ||||
|         self.assertGreater(torch.accelerator.memory_allocated(), prev_allocated) | ||||
|         self.assertGreaterEqual(torch.accelerator.memory_reserved(), prev_reserved) | ||||
|         del tmp | ||||
|         gc.collect() | ||||
|         torch.accelerator.empty_cache() | ||||
|         torch.accelerator.reset_peak_memory_stats() | ||||
|         self.assertEqual(torch.accelerator.memory_allocated(), prev_allocated) | ||||
|         self.assertEqual(torch.accelerator.memory_reserved(), prev_reserved) | ||||
|         torch.accelerator.reset_accumulated_memory_stats() | ||||
|         prev_max_allocated = torch.accelerator.max_memory_allocated() | ||||
|         prev_max_reserved = torch.accelerator.max_memory_reserved() | ||||
|         # Activate 1kB memory | ||||
|         prev_active_current = torch.accelerator.memory_stats()[ | ||||
|             "active_bytes.all.current" | ||||
|         ] | ||||
|         tmp = torch.randn(256, device=acc) | ||||
|         # Detect if the current active memory is 1kB | ||||
|         self.assertEqual( | ||||
|             torch.accelerator.memory_stats()["active_bytes.all.current"], | ||||
|             1024 + prev_active_current, | ||||
|         ) | ||||
|         self.assertEqual(torch.accelerator.memory_stats()["active_bytes.all.freed"], 0) | ||||
|         del tmp | ||||
|         gc.collect() | ||||
|         torch.accelerator.empty_cache() | ||||
|         self.assertEqual( | ||||
|             torch.accelerator.memory_stats()["active_bytes.all.current"], | ||||
|             prev_active_current, | ||||
|         ) | ||||
|         self.assertEqual( | ||||
|             torch.accelerator.memory_stats()["active_bytes.all.freed"], 1024 | ||||
|         ) | ||||
|         torch.accelerator.reset_peak_memory_stats() | ||||
|         self.assertEqual(torch.accelerator.max_memory_allocated(), prev_max_allocated) | ||||
|         self.assertEqual(torch.accelerator.max_memory_reserved(), prev_max_reserved) | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     run_tests() | ||||
|  | ||||
| @ -373,6 +373,42 @@ print(t.is_pinned()) | ||||
|                 torch.cuda.caching_allocator_delete(mem) | ||||
|                 self.assertEqual(torch.cuda.memory_allocated(), prev) | ||||
|  | ||||
|     def test_memory_stats(self): | ||||
|         gc.collect() | ||||
|         torch.cuda.empty_cache() | ||||
|         torch.cuda.reset_peak_memory_stats() | ||||
|         torch.cuda.reset_accumulated_memory_stats() | ||||
|         prev_allocated = torch.accelerator.memory_allocated() | ||||
|         prev_reserved = torch.accelerator.memory_reserved() | ||||
|         prev_max_allocated = torch.accelerator.max_memory_allocated() | ||||
|         prev_max_reserved = torch.accelerator.max_memory_reserved() | ||||
|         self.assertEqual(prev_allocated, prev_max_allocated) | ||||
|         self.assertEqual(prev_reserved, prev_max_reserved) | ||||
|         # Activate 1kB memory | ||||
|         prev_active_current = torch.accelerator.memory_stats()[ | ||||
|             "active_bytes.all.current" | ||||
|         ] | ||||
|         tmp = torch.randn(256, device="cuda") | ||||
|         # Detect if the current active memory is 1kB | ||||
|         self.assertEqual( | ||||
|             torch.accelerator.memory_stats()["active_bytes.all.current"], | ||||
|             1024 + prev_active_current, | ||||
|         ) | ||||
|         self.assertEqual(torch.accelerator.memory_stats()["active_bytes.all.freed"], 0) | ||||
|         del tmp | ||||
|         gc.collect() | ||||
|         torch.accelerator.empty_cache() | ||||
|         self.assertEqual( | ||||
|             torch.accelerator.memory_stats()["active_bytes.all.current"], | ||||
|             prev_active_current, | ||||
|         ) | ||||
|         self.assertEqual( | ||||
|             torch.accelerator.memory_stats()["active_bytes.all.freed"], 1024 | ||||
|         ) | ||||
|         torch.accelerator.reset_peak_memory_stats() | ||||
|         self.assertEqual(torch.accelerator.max_memory_allocated(), prev_max_allocated) | ||||
|         self.assertEqual(torch.accelerator.max_memory_reserved(), prev_max_reserved) | ||||
|  | ||||
|     def test_check_error(self): | ||||
|         # Assert this call doesn't raise. | ||||
|         torch.cuda.check_error(0) | ||||
|  | ||||
| @ -1,5 +1,6 @@ | ||||
| # Owner(s): ["module: intel"] | ||||
|  | ||||
| import gc | ||||
| import re | ||||
| import subprocess | ||||
| import sys | ||||
| @ -520,6 +521,42 @@ if __name__ == "__main__": | ||||
|         ) | ||||
|         del a | ||||
|  | ||||
|     def test_memory_stats(self): | ||||
|         gc.collect() | ||||
|         torch.xpu.empty_cache() | ||||
|         torch.xpu.reset_peak_memory_stats() | ||||
|         torch.xpu.reset_accumulated_memory_stats() | ||||
|         prev_allocated = torch.accelerator.memory_allocated() | ||||
|         prev_reserved = torch.accelerator.memory_reserved() | ||||
|         prev_max_allocated = torch.accelerator.max_memory_allocated() | ||||
|         prev_max_reserved = torch.accelerator.max_memory_reserved() | ||||
|         self.assertEqual(prev_allocated, prev_max_allocated) | ||||
|         self.assertEqual(prev_reserved, prev_max_reserved) | ||||
|         # Activate 1kB memory | ||||
|         prev_active_current = torch.accelerator.memory_stats()[ | ||||
|             "active_bytes.all.current" | ||||
|         ] | ||||
|         tmp = torch.randn(256, device="xpu") | ||||
|         # Detect if the current active memory is 1kB | ||||
|         self.assertEqual( | ||||
|             torch.accelerator.memory_stats()["active_bytes.all.current"], | ||||
|             1024 + prev_active_current, | ||||
|         ) | ||||
|         self.assertEqual(torch.accelerator.memory_stats()["active_bytes.all.freed"], 0) | ||||
|         del tmp | ||||
|         gc.collect() | ||||
|         torch.accelerator.empty_cache() | ||||
|         self.assertEqual( | ||||
|             torch.accelerator.memory_stats()["active_bytes.all.current"], | ||||
|             prev_active_current, | ||||
|         ) | ||||
|         self.assertEqual( | ||||
|             torch.accelerator.memory_stats()["active_bytes.all.freed"], 1024 | ||||
|         ) | ||||
|         torch.accelerator.reset_peak_memory_stats() | ||||
|         self.assertEqual(torch.accelerator.max_memory_allocated(), prev_max_allocated) | ||||
|         self.assertEqual(torch.accelerator.max_memory_reserved(), prev_max_reserved) | ||||
|  | ||||
|     @skipXPUIf( | ||||
|         int(torch.version.xpu) < 20250000, | ||||
|         "Test requires SYCL compiler version 2025.0.0 or newer.", | ||||
|  | ||||
		Reference in New Issue
	
	Block a user