mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
This PR is part of a series attempting to re-submit https://github.com/pytorch/pytorch/pull/134592 as smaller PRs. In jit tests: - Add and use a common raise_on_run_directly method for when a user runs a test file directly which should not be run this way. Print the file which the user should have run. - Raise a RuntimeError on tests which have been disabled (not run) Pull Request resolved: https://github.com/pytorch/pytorch/pull/154725 Approved by: https://github.com/clee2000
122 lines
4.2 KiB
Python
122 lines
4.2 KiB
Python
# Owner(s): ["oncall: jit"]
|
|
# ruff: noqa: F841
|
|
|
|
import os
|
|
import sys
|
|
|
|
import torch
|
|
|
|
|
|
# Make the helper files in test/ importable
|
|
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
|
sys.path.append(pytorch_test_dir)
|
|
from torch.testing._internal.common_utils import raise_on_run_directly
|
|
from torch.testing._internal.jit_utils import JitTestCase
|
|
|
|
|
|
class TestLogging(JitTestCase):
|
|
def test_bump_numeric_counter(self):
|
|
class ModuleThatLogs(torch.jit.ScriptModule):
|
|
@torch.jit.script_method
|
|
def forward(self, x):
|
|
for i in range(x.size(0)):
|
|
x += 1.0
|
|
torch.jit._logging.add_stat_value("foo", 1)
|
|
|
|
if bool(x.sum() > 0.0):
|
|
torch.jit._logging.add_stat_value("positive", 1)
|
|
else:
|
|
torch.jit._logging.add_stat_value("negative", 1)
|
|
return x
|
|
|
|
logger = torch.jit._logging.LockingLogger()
|
|
old_logger = torch.jit._logging.set_logger(logger)
|
|
try:
|
|
mtl = ModuleThatLogs()
|
|
for i in range(5):
|
|
mtl(torch.rand(3, 4, 5))
|
|
|
|
self.assertEqual(logger.get_counter_val("foo"), 15)
|
|
self.assertEqual(logger.get_counter_val("positive"), 5)
|
|
finally:
|
|
torch.jit._logging.set_logger(old_logger)
|
|
|
|
def test_trace_numeric_counter(self):
|
|
def foo(x):
|
|
torch.jit._logging.add_stat_value("foo", 1)
|
|
return x + 1.0
|
|
|
|
traced = torch.jit.trace(foo, torch.rand(3, 4))
|
|
logger = torch.jit._logging.LockingLogger()
|
|
old_logger = torch.jit._logging.set_logger(logger)
|
|
try:
|
|
traced(torch.rand(3, 4))
|
|
|
|
self.assertEqual(logger.get_counter_val("foo"), 1)
|
|
finally:
|
|
torch.jit._logging.set_logger(old_logger)
|
|
|
|
def test_time_measurement_counter(self):
|
|
class ModuleThatTimes(torch.jit.ScriptModule):
|
|
def forward(self, x):
|
|
tp_start = torch.jit._logging.time_point()
|
|
for i in range(30):
|
|
x += 1.0
|
|
tp_end = torch.jit._logging.time_point()
|
|
torch.jit._logging.add_stat_value("mytimer", tp_end - tp_start)
|
|
return x
|
|
|
|
mtm = ModuleThatTimes()
|
|
logger = torch.jit._logging.LockingLogger()
|
|
old_logger = torch.jit._logging.set_logger(logger)
|
|
try:
|
|
mtm(torch.rand(3, 4))
|
|
self.assertGreater(logger.get_counter_val("mytimer"), 0)
|
|
finally:
|
|
torch.jit._logging.set_logger(old_logger)
|
|
|
|
def test_time_measurement_counter_script(self):
|
|
class ModuleThatTimes(torch.jit.ScriptModule):
|
|
@torch.jit.script_method
|
|
def forward(self, x):
|
|
tp_start = torch.jit._logging.time_point()
|
|
for i in range(30):
|
|
x += 1.0
|
|
tp_end = torch.jit._logging.time_point()
|
|
torch.jit._logging.add_stat_value("mytimer", tp_end - tp_start)
|
|
return x
|
|
|
|
mtm = ModuleThatTimes()
|
|
logger = torch.jit._logging.LockingLogger()
|
|
old_logger = torch.jit._logging.set_logger(logger)
|
|
try:
|
|
mtm(torch.rand(3, 4))
|
|
self.assertGreater(logger.get_counter_val("mytimer"), 0)
|
|
finally:
|
|
torch.jit._logging.set_logger(old_logger)
|
|
|
|
def test_counter_aggregation(self):
|
|
def foo(x):
|
|
for i in range(3):
|
|
torch.jit._logging.add_stat_value("foo", 1)
|
|
return x + 1.0
|
|
|
|
traced = torch.jit.trace(foo, torch.rand(3, 4))
|
|
logger = torch.jit._logging.LockingLogger()
|
|
logger.set_aggregation_type("foo", torch.jit._logging.AggregationType.AVG)
|
|
old_logger = torch.jit._logging.set_logger(logger)
|
|
try:
|
|
traced(torch.rand(3, 4))
|
|
|
|
self.assertEqual(logger.get_counter_val("foo"), 1)
|
|
finally:
|
|
torch.jit._logging.set_logger(old_logger)
|
|
|
|
def test_logging_levels_set(self):
|
|
torch._C._jit_set_logging_option("foo")
|
|
self.assertEqual("foo", torch._C._jit_get_logging_option())
|
|
|
|
|
|
if __name__ == "__main__":
|
|
raise_on_run_directly("test/test_jit.py")
|