mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Partially addresses #123062 Ran lintrunner on: - `test/jit` with command: ```bash lintrunner -a --take UFMT --all-files ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/123623 Approved by: https://github.com/ezyang
122 lines
4.2 KiB
Python
122 lines
4.2 KiB
Python
# Owner(s): ["oncall: jit"]
|
|
|
|
import os
|
|
import sys
|
|
|
|
import torch
|
|
|
|
# Make the helper files in test/ importable
|
|
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
|
sys.path.append(pytorch_test_dir)
|
|
from torch.testing._internal.jit_utils import JitTestCase
|
|
|
|
if __name__ == "__main__":
|
|
raise RuntimeError(
|
|
"This test file is not meant to be run directly, use:\n\n"
|
|
"\tpython test/test_jit.py TESTNAME\n\n"
|
|
"instead."
|
|
)
|
|
|
|
|
|
class TestLogging(JitTestCase):
|
|
def test_bump_numeric_counter(self):
|
|
class ModuleThatLogs(torch.jit.ScriptModule):
|
|
@torch.jit.script_method
|
|
def forward(self, x):
|
|
for i in range(x.size(0)):
|
|
x += 1.0
|
|
torch.jit._logging.add_stat_value("foo", 1)
|
|
|
|
if bool(x.sum() > 0.0):
|
|
torch.jit._logging.add_stat_value("positive", 1)
|
|
else:
|
|
torch.jit._logging.add_stat_value("negative", 1)
|
|
return x
|
|
|
|
logger = torch.jit._logging.LockingLogger()
|
|
old_logger = torch.jit._logging.set_logger(logger)
|
|
try:
|
|
mtl = ModuleThatLogs()
|
|
for i in range(5):
|
|
mtl(torch.rand(3, 4, 5))
|
|
|
|
self.assertEqual(logger.get_counter_val("foo"), 15)
|
|
self.assertEqual(logger.get_counter_val("positive"), 5)
|
|
finally:
|
|
torch.jit._logging.set_logger(old_logger)
|
|
|
|
def test_trace_numeric_counter(self):
|
|
def foo(x):
|
|
torch.jit._logging.add_stat_value("foo", 1)
|
|
return x + 1.0
|
|
|
|
traced = torch.jit.trace(foo, torch.rand(3, 4))
|
|
logger = torch.jit._logging.LockingLogger()
|
|
old_logger = torch.jit._logging.set_logger(logger)
|
|
try:
|
|
traced(torch.rand(3, 4))
|
|
|
|
self.assertEqual(logger.get_counter_val("foo"), 1)
|
|
finally:
|
|
torch.jit._logging.set_logger(old_logger)
|
|
|
|
def test_time_measurement_counter(self):
|
|
class ModuleThatTimes(torch.jit.ScriptModule):
|
|
def forward(self, x):
|
|
tp_start = torch.jit._logging.time_point()
|
|
for i in range(30):
|
|
x += 1.0
|
|
tp_end = torch.jit._logging.time_point()
|
|
torch.jit._logging.add_stat_value("mytimer", tp_end - tp_start)
|
|
return x
|
|
|
|
mtm = ModuleThatTimes()
|
|
logger = torch.jit._logging.LockingLogger()
|
|
old_logger = torch.jit._logging.set_logger(logger)
|
|
try:
|
|
mtm(torch.rand(3, 4))
|
|
self.assertGreater(logger.get_counter_val("mytimer"), 0)
|
|
finally:
|
|
torch.jit._logging.set_logger(old_logger)
|
|
|
|
def test_time_measurement_counter_script(self):
|
|
class ModuleThatTimes(torch.jit.ScriptModule):
|
|
@torch.jit.script_method
|
|
def forward(self, x):
|
|
tp_start = torch.jit._logging.time_point()
|
|
for i in range(30):
|
|
x += 1.0
|
|
tp_end = torch.jit._logging.time_point()
|
|
torch.jit._logging.add_stat_value("mytimer", tp_end - tp_start)
|
|
return x
|
|
|
|
mtm = ModuleThatTimes()
|
|
logger = torch.jit._logging.LockingLogger()
|
|
old_logger = torch.jit._logging.set_logger(logger)
|
|
try:
|
|
mtm(torch.rand(3, 4))
|
|
self.assertGreater(logger.get_counter_val("mytimer"), 0)
|
|
finally:
|
|
torch.jit._logging.set_logger(old_logger)
|
|
|
|
def test_counter_aggregation(self):
|
|
def foo(x):
|
|
for i in range(3):
|
|
torch.jit._logging.add_stat_value("foo", 1)
|
|
return x + 1.0
|
|
|
|
traced = torch.jit.trace(foo, torch.rand(3, 4))
|
|
logger = torch.jit._logging.LockingLogger()
|
|
logger.set_aggregation_type("foo", torch.jit._logging.AggregationType.AVG)
|
|
old_logger = torch.jit._logging.set_logger(logger)
|
|
try:
|
|
traced(torch.rand(3, 4))
|
|
|
|
self.assertEqual(logger.get_counter_val("foo"), 1)
|
|
finally:
|
|
torch.jit._logging.set_logger(old_logger)
|
|
|
|
def test_logging_levels_set(self):
|
|
torch._C._jit_set_logging_option("foo")
|
|
self.assertEqual("foo", torch._C._jit_get_logging_option())
|