mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-10-25 08:11:06 +08:00 
			
		
		
		
	Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/53617 I'm trying to make `pytest test/*.py` work--right now, it fails during test collection. This removes a few of the easier to fix pytest collection problems one way or another. I have two remaining problems which is that the default dtype is trashed on entry to test_torch.py and test_cuda.py, I'll try to fix those in a follow up. Signed-off-by: Edward Z. Yang <ezyang@fb.com> Test Plan: Imported from OSS Reviewed By: mruberry Differential Revision: D26918377 Pulled By: ezyang fbshipit-source-id: 42069786882657e1e3ee974acb3ec48115f16210
		
			
				
	
	
		
			69 lines
		
	
	
		
			1.9 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			69 lines
		
	
	
		
			1.9 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| import collections
 | |
| import unittest
 | |
| 
 | |
| import torch
 | |
| from torch.testing._internal.common_utils import (
 | |
|     TestCase, run_tests, TEST_WITH_ASAN)
 | |
| 
 | |
| try:
 | |
|     import psutil
 | |
|     HAS_PSUTIL = True
 | |
| except ImportError:
 | |
|     HAS_PSUTIL = False
 | |
| 
 | |
| device = torch.device('cpu')
 | |
| 
 | |
| 
 | |
| class Network(torch.nn.Module):
 | |
|     maxp1 = torch.nn.MaxPool2d(1, 1)
 | |
| 
 | |
|     def forward(self, x):
 | |
|         return self.maxp1(x)
 | |
| 
 | |
| 
 | |
| @unittest.skipIf(not HAS_PSUTIL, "Requires psutil to run")
 | |
| @unittest.skipIf(TEST_WITH_ASAN, "Cannot test with ASAN")
 | |
| class TestOpenMP_ParallelFor(TestCase):
 | |
|     batch = 20
 | |
|     channels = 1
 | |
|     side_dim = 80
 | |
|     x = torch.randn([batch, channels, side_dim, side_dim], device=device)
 | |
|     model = Network()
 | |
| 
 | |
|     def func(self, runs):
 | |
|         p = psutil.Process()
 | |
|         # warm up for 5 runs, then things should be stable for the last 5
 | |
|         last_rss = collections.deque(maxlen=5)
 | |
|         for n in range(10):
 | |
|             for i in range(runs):
 | |
|                 self.model(self.x)
 | |
|             last_rss.append(p.memory_info().rss)
 | |
|         return last_rss
 | |
| 
 | |
|     def func_rss(self, runs):
 | |
|         last_rss = list(self.func(runs))
 | |
|         # Check that the sequence is not strictly increasing
 | |
|         is_increasing = True
 | |
|         for idx in range(len(last_rss)):
 | |
|             if idx == 0:
 | |
|                 continue
 | |
|             is_increasing = is_increasing and (last_rss[idx] > last_rss[idx - 1])
 | |
|         self.assertTrue(not is_increasing,
 | |
|                         msg='memory usage is increasing, {}'.format(str(last_rss)))
 | |
| 
 | |
|     def test_one_thread(self):
 | |
|         """Make sure there is no memory leak with one thread: issue gh-32284
 | |
|         """
 | |
|         torch.set_num_threads(1)
 | |
|         self.func_rss(300)
 | |
| 
 | |
|     def test_n_threads(self):
 | |
|         """Make sure there is no memory leak with many threads
 | |
|         """
 | |
|         ncores = min(5, psutil.cpu_count(logical=False))
 | |
|         torch.set_num_threads(ncores)
 | |
|         self.func_rss(300)
 | |
| 
 | |
| if __name__ == '__main__':
 | |
|     run_tests()
 |