mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-10-20 21:14:14 +08:00 
			
		
		
		
	Partially addresses #123062 Ran lintrunner on: - `test/jit` with command: ```bash lintrunner -a --take UFMT --all-files ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/123623 Approved by: https://github.com/ezyang
		
			
				
	
	
		
			122 lines
		
	
	
		
			4.2 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			122 lines
		
	
	
		
			4.2 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| # Owner(s): ["oncall: jit"]
 | |
| 
 | |
| import os
 | |
| import sys
 | |
| 
 | |
| import torch
 | |
| 
 | |
| # Make the helper files in test/ importable
 | |
| pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
 | |
| sys.path.append(pytorch_test_dir)
 | |
| from torch.testing._internal.jit_utils import JitTestCase
 | |
| 
 | |
| if __name__ == "__main__":
 | |
|     raise RuntimeError(
 | |
|         "This test file is not meant to be run directly, use:\n\n"
 | |
|         "\tpython test/test_jit.py TESTNAME\n\n"
 | |
|         "instead."
 | |
|     )
 | |
| 
 | |
| 
 | |
| class TestLogging(JitTestCase):
 | |
|     def test_bump_numeric_counter(self):
 | |
|         class ModuleThatLogs(torch.jit.ScriptModule):
 | |
|             @torch.jit.script_method
 | |
|             def forward(self, x):
 | |
|                 for i in range(x.size(0)):
 | |
|                     x += 1.0
 | |
|                     torch.jit._logging.add_stat_value("foo", 1)
 | |
| 
 | |
|                 if bool(x.sum() > 0.0):
 | |
|                     torch.jit._logging.add_stat_value("positive", 1)
 | |
|                 else:
 | |
|                     torch.jit._logging.add_stat_value("negative", 1)
 | |
|                 return x
 | |
| 
 | |
|         logger = torch.jit._logging.LockingLogger()
 | |
|         old_logger = torch.jit._logging.set_logger(logger)
 | |
|         try:
 | |
|             mtl = ModuleThatLogs()
 | |
|             for i in range(5):
 | |
|                 mtl(torch.rand(3, 4, 5))
 | |
| 
 | |
|             self.assertEqual(logger.get_counter_val("foo"), 15)
 | |
|             self.assertEqual(logger.get_counter_val("positive"), 5)
 | |
|         finally:
 | |
|             torch.jit._logging.set_logger(old_logger)
 | |
| 
 | |
|     def test_trace_numeric_counter(self):
 | |
|         def foo(x):
 | |
|             torch.jit._logging.add_stat_value("foo", 1)
 | |
|             return x + 1.0
 | |
| 
 | |
|         traced = torch.jit.trace(foo, torch.rand(3, 4))
 | |
|         logger = torch.jit._logging.LockingLogger()
 | |
|         old_logger = torch.jit._logging.set_logger(logger)
 | |
|         try:
 | |
|             traced(torch.rand(3, 4))
 | |
| 
 | |
|             self.assertEqual(logger.get_counter_val("foo"), 1)
 | |
|         finally:
 | |
|             torch.jit._logging.set_logger(old_logger)
 | |
| 
 | |
|     def test_time_measurement_counter(self):
 | |
|         class ModuleThatTimes(torch.jit.ScriptModule):
 | |
|             def forward(self, x):
 | |
|                 tp_start = torch.jit._logging.time_point()
 | |
|                 for i in range(30):
 | |
|                     x += 1.0
 | |
|                 tp_end = torch.jit._logging.time_point()
 | |
|                 torch.jit._logging.add_stat_value("mytimer", tp_end - tp_start)
 | |
|                 return x
 | |
| 
 | |
|         mtm = ModuleThatTimes()
 | |
|         logger = torch.jit._logging.LockingLogger()
 | |
|         old_logger = torch.jit._logging.set_logger(logger)
 | |
|         try:
 | |
|             mtm(torch.rand(3, 4))
 | |
|             self.assertGreater(logger.get_counter_val("mytimer"), 0)
 | |
|         finally:
 | |
|             torch.jit._logging.set_logger(old_logger)
 | |
| 
 | |
|     def test_time_measurement_counter_script(self):
 | |
|         class ModuleThatTimes(torch.jit.ScriptModule):
 | |
|             @torch.jit.script_method
 | |
|             def forward(self, x):
 | |
|                 tp_start = torch.jit._logging.time_point()
 | |
|                 for i in range(30):
 | |
|                     x += 1.0
 | |
|                 tp_end = torch.jit._logging.time_point()
 | |
|                 torch.jit._logging.add_stat_value("mytimer", tp_end - tp_start)
 | |
|                 return x
 | |
| 
 | |
|         mtm = ModuleThatTimes()
 | |
|         logger = torch.jit._logging.LockingLogger()
 | |
|         old_logger = torch.jit._logging.set_logger(logger)
 | |
|         try:
 | |
|             mtm(torch.rand(3, 4))
 | |
|             self.assertGreater(logger.get_counter_val("mytimer"), 0)
 | |
|         finally:
 | |
|             torch.jit._logging.set_logger(old_logger)
 | |
| 
 | |
|     def test_counter_aggregation(self):
 | |
|         def foo(x):
 | |
|             for i in range(3):
 | |
|                 torch.jit._logging.add_stat_value("foo", 1)
 | |
|             return x + 1.0
 | |
| 
 | |
|         traced = torch.jit.trace(foo, torch.rand(3, 4))
 | |
|         logger = torch.jit._logging.LockingLogger()
 | |
|         logger.set_aggregation_type("foo", torch.jit._logging.AggregationType.AVG)
 | |
|         old_logger = torch.jit._logging.set_logger(logger)
 | |
|         try:
 | |
|             traced(torch.rand(3, 4))
 | |
| 
 | |
|             self.assertEqual(logger.get_counter_val("foo"), 1)
 | |
|         finally:
 | |
|             torch.jit._logging.set_logger(old_logger)
 | |
| 
 | |
|     def test_logging_levels_set(self):
 | |
|         torch._C._jit_set_logging_option("foo")
 | |
|         self.assertEqual("foo", torch._C._jit_get_logging_option())
 |