mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
- We silently run skipped tests and then raise a skip message with the error message (if any) - Instead of raising expectedFailure, we raise a skip message with the error message (if any) We log the skip messages in CI, so this will let us read the logs and do some basic triaging of the failure messages. Test Plan: - existing tests. I hope that there are no tests that cause each other to fail. Pull Request resolved: https://github.com/pytorch/pytorch/pull/117401 Approved by: https://github.com/voznesenskym ghstack dependencies: #117391, #117400
161 lines
4.5 KiB
Python
161 lines
4.5 KiB
Python
# Owner(s): ["oncall: r2p"]
|
|
|
|
from torch.testing._internal.common_utils import (
|
|
TestCase, run_tests, skipIfTorchDynamo,
|
|
)
|
|
|
|
from datetime import timedelta, datetime
|
|
import tempfile
|
|
import time
|
|
|
|
from torch.monitor import (
|
|
Aggregation,
|
|
Event,
|
|
log_event,
|
|
register_event_handler,
|
|
unregister_event_handler,
|
|
Stat,
|
|
TensorboardEventHandler,
|
|
)
|
|
|
|
class TestMonitor(TestCase):
|
|
def test_interval_stat(self) -> None:
|
|
events = []
|
|
|
|
def handler(event):
|
|
events.append(event)
|
|
|
|
handle = register_event_handler(handler)
|
|
s = Stat(
|
|
"asdf",
|
|
(Aggregation.SUM, Aggregation.COUNT),
|
|
timedelta(milliseconds=1),
|
|
)
|
|
self.assertEqual(s.name, "asdf")
|
|
|
|
s.add(2)
|
|
for _ in range(100):
|
|
# NOTE: different platforms sleep may be inaccurate so we loop
|
|
# instead (i.e. win)
|
|
time.sleep(1 / 1000) # ms
|
|
s.add(3)
|
|
if len(events) >= 1:
|
|
break
|
|
self.assertGreaterEqual(len(events), 1)
|
|
unregister_event_handler(handle)
|
|
|
|
def test_fixed_count_stat(self) -> None:
|
|
s = Stat(
|
|
"asdf",
|
|
(Aggregation.SUM, Aggregation.COUNT),
|
|
timedelta(hours=100),
|
|
3,
|
|
)
|
|
s.add(1)
|
|
s.add(2)
|
|
name = s.name
|
|
self.assertEqual(name, "asdf")
|
|
self.assertEqual(s.count, 2)
|
|
s.add(3)
|
|
self.assertEqual(s.count, 0)
|
|
self.assertEqual(s.get(), {Aggregation.SUM: 6.0, Aggregation.COUNT: 3})
|
|
|
|
def test_log_event(self) -> None:
|
|
e = Event(
|
|
name="torch.monitor.TestEvent",
|
|
timestamp=datetime.now(),
|
|
data={
|
|
"str": "a string",
|
|
"float": 1234.0,
|
|
"int": 1234,
|
|
},
|
|
)
|
|
self.assertEqual(e.name, "torch.monitor.TestEvent")
|
|
self.assertIsNotNone(e.timestamp)
|
|
self.assertIsNotNone(e.data)
|
|
log_event(e)
|
|
|
|
@skipIfTorchDynamo("Really weird error")
|
|
def test_event_handler(self) -> None:
|
|
events = []
|
|
|
|
def handler(event: Event) -> None:
|
|
events.append(event)
|
|
|
|
handle = register_event_handler(handler)
|
|
e = Event(
|
|
name="torch.monitor.TestEvent",
|
|
timestamp=datetime.now(),
|
|
data={},
|
|
)
|
|
log_event(e)
|
|
self.assertEqual(len(events), 1)
|
|
self.assertEqual(events[0], e)
|
|
log_event(e)
|
|
self.assertEqual(len(events), 2)
|
|
|
|
unregister_event_handler(handle)
|
|
log_event(e)
|
|
self.assertEqual(len(events), 2)
|
|
|
|
|
|
@skipIfTorchDynamo("Really weird error")
|
|
class TestMonitorTensorboard(TestCase):
|
|
def setUp(self):
|
|
global SummaryWriter, event_multiplexer
|
|
try:
|
|
from torch.utils.tensorboard import SummaryWriter
|
|
from tensorboard.backend.event_processing import (
|
|
plugin_event_multiplexer as event_multiplexer,
|
|
)
|
|
except ImportError:
|
|
return self.skipTest("Skip the test since TensorBoard is not installed")
|
|
self.temp_dirs = []
|
|
|
|
def create_summary_writer(self):
|
|
temp_dir = tempfile.TemporaryDirectory() # noqa: P201
|
|
self.temp_dirs.append(temp_dir)
|
|
return SummaryWriter(temp_dir.name)
|
|
|
|
def tearDown(self):
|
|
# Remove directories created by SummaryWriter
|
|
for temp_dir in self.temp_dirs:
|
|
temp_dir.cleanup()
|
|
|
|
def test_event_handler(self):
|
|
with self.create_summary_writer() as w:
|
|
handle = register_event_handler(TensorboardEventHandler(w))
|
|
|
|
s = Stat(
|
|
"asdf",
|
|
(Aggregation.SUM, Aggregation.COUNT),
|
|
timedelta(hours=1),
|
|
5,
|
|
)
|
|
for i in range(10):
|
|
s.add(i)
|
|
self.assertEqual(s.count, 0)
|
|
|
|
unregister_event_handler(handle)
|
|
|
|
mul = event_multiplexer.EventMultiplexer()
|
|
mul.AddRunsFromDirectory(self.temp_dirs[-1].name)
|
|
mul.Reload()
|
|
scalar_dict = mul.PluginRunToTagToContent("scalars")
|
|
raw_result = {
|
|
tag: mul.Tensors(run, tag)
|
|
for run, run_dict in scalar_dict.items()
|
|
for tag in run_dict
|
|
}
|
|
scalars = {
|
|
tag: [e.tensor_proto.float_val[0] for e in events] for tag, events in raw_result.items()
|
|
}
|
|
self.assertEqual(scalars, {
|
|
"asdf.sum": [10],
|
|
"asdf.count": [5],
|
|
})
|
|
|
|
|
|
if __name__ == '__main__':
|
|
run_tests()
|