mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
torch/monitor: add pybind (#69567)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/69567 This exposes torch.monitor events and stats via pybind11 to the underlying C++ implementation. * The registration interface is a tad different since it takes a lambda function in Python where as in C++ it's a full class. * This has a small amount of changes to the counter interfaces since there's no way to create an initializer list at runtime so they now also take a vector. * Only double based stats are provided in Python since it's intended more for high level stats where float imprecision shouldn't be an issue. This can be changed down the line if need arises. ``` events = [] def handler(event): events.append(event) handle = register_event_handler(handler) log_event(Event(type="torch.monitor.TestEvent", timestamp=datetime.now(), metadata={"foo": 1.0})) ``` D32969391 is now included in this diff. This cleans up the naming for events. type is now name, message is gone, and metadata is renamed data. Test Plan: buck test //caffe2/test:monitor //caffe2/test/cpp/monitor:monitor Reviewed By: kiukchung Differential Revision: D32924141 fbshipit-source-id: 563304c2e3261a4754e40cca39fc64c5a04b43e8
This commit is contained in:
committed by
Facebook GitHub Bot
parent
90ef54f8ea
commit
bfe1abd3b5
@ -69,6 +69,7 @@ Features described in this documentation are classified by release status:
|
||||
torch.hub <hub>
|
||||
torch.jit <jit>
|
||||
torch.linalg <linalg>
|
||||
torch.monitor <monitor>
|
||||
torch.special <special>
|
||||
torch.overrides
|
||||
torch.package <package>
|
||||
|
40
docs/source/monitor.rst
Normal file
40
docs/source/monitor.rst
Normal file
@ -0,0 +1,40 @@
|
||||
torch.monitor
|
||||
=============
|
||||
|
||||
.. warning::
|
||||
|
||||
This module is a prototype release, and its interfaces and functionality may
|
||||
change without warning in future PyTorch releases.
|
||||
|
||||
``torch.monitor`` provides an interface for logging events and counters from
|
||||
PyTorch.
|
||||
|
||||
|
||||
API Reference
|
||||
-------------
|
||||
|
||||
.. automodule:: torch.monitor
|
||||
|
||||
.. autoclass:: torch.monitor.Aggregation
|
||||
:members:
|
||||
|
||||
.. autoclass:: torch.monitor.Stat
|
||||
:members:
|
||||
|
||||
.. autoclass:: torch.monitor.IntervalStat
|
||||
:members:
|
||||
|
||||
.. autoclass:: torch.monitor.FixedCountStat
|
||||
:members:
|
||||
|
||||
.. autoclass:: torch.monitor.Event
|
||||
:members:
|
||||
|
||||
.. autoclass:: torch.monitor.PythonEventHandler
|
||||
:members:
|
||||
|
||||
.. autofunction:: torch.monitor.log_event
|
||||
|
||||
.. autofunction:: torch.monitor.register_event_handler
|
||||
|
||||
.. autofunction:: torch.monitor.unregister_event_handler
|
@ -10,7 +10,7 @@ using namespace torch::monitor;
|
||||
TEST(MonitorTest, CounterDouble) {
|
||||
FixedCountStat<double> a{
|
||||
"a",
|
||||
{MEAN, COUNT},
|
||||
{Aggregation::MEAN, Aggregation::COUNT},
|
||||
2,
|
||||
};
|
||||
a.add(5.0);
|
||||
@ -19,9 +19,9 @@ TEST(MonitorTest, CounterDouble) {
|
||||
ASSERT_EQ(a.count(), 0);
|
||||
|
||||
auto stats = a.get();
|
||||
std::unordered_map<Aggregation, double> want = {
|
||||
{MEAN, 5.5},
|
||||
{COUNT, 2.0},
|
||||
std::unordered_map<Aggregation, double, AggregationHash> want = {
|
||||
{Aggregation::MEAN, 5.5},
|
||||
{Aggregation::COUNT, 2.0},
|
||||
};
|
||||
ASSERT_EQ(stats, want);
|
||||
}
|
||||
@ -29,14 +29,14 @@ TEST(MonitorTest, CounterDouble) {
|
||||
TEST(MonitorTest, CounterInt64Sum) {
|
||||
FixedCountStat<int64_t> a{
|
||||
"a",
|
||||
{SUM},
|
||||
{Aggregation::SUM},
|
||||
2,
|
||||
};
|
||||
a.add(5);
|
||||
a.add(6);
|
||||
auto stats = a.get();
|
||||
std::unordered_map<Aggregation, int64_t> want = {
|
||||
{SUM, 11},
|
||||
std::unordered_map<Aggregation, int64_t, AggregationHash> want = {
|
||||
{Aggregation::SUM, 11},
|
||||
};
|
||||
ASSERT_EQ(stats, want);
|
||||
}
|
||||
@ -44,14 +44,14 @@ TEST(MonitorTest, CounterInt64Sum) {
|
||||
TEST(MonitorTest, CounterInt64Value) {
|
||||
FixedCountStat<int64_t> a{
|
||||
"a",
|
||||
{VALUE},
|
||||
{Aggregation::VALUE},
|
||||
2,
|
||||
};
|
||||
a.add(5);
|
||||
a.add(6);
|
||||
auto stats = a.get();
|
||||
std::unordered_map<Aggregation, int64_t> want = {
|
||||
{VALUE, 6},
|
||||
std::unordered_map<Aggregation, int64_t, AggregationHash> want = {
|
||||
{Aggregation::VALUE, 6},
|
||||
};
|
||||
ASSERT_EQ(stats, want);
|
||||
}
|
||||
@ -59,14 +59,14 @@ TEST(MonitorTest, CounterInt64Value) {
|
||||
TEST(MonitorTest, CounterInt64Mean) {
|
||||
FixedCountStat<int64_t> a{
|
||||
"a",
|
||||
{MEAN},
|
||||
{Aggregation::MEAN},
|
||||
2,
|
||||
};
|
||||
{
|
||||
// zero samples case
|
||||
auto stats = a.get();
|
||||
std::unordered_map<Aggregation, int64_t> want = {
|
||||
{MEAN, 0},
|
||||
std::unordered_map<Aggregation, int64_t, AggregationHash> want = {
|
||||
{Aggregation::MEAN, 0},
|
||||
};
|
||||
ASSERT_EQ(stats, want);
|
||||
}
|
||||
@ -76,8 +76,8 @@ TEST(MonitorTest, CounterInt64Mean) {
|
||||
|
||||
{
|
||||
auto stats = a.get();
|
||||
std::unordered_map<Aggregation, int64_t> want = {
|
||||
{MEAN, 5},
|
||||
std::unordered_map<Aggregation, int64_t, AggregationHash> want = {
|
||||
{Aggregation::MEAN, 5},
|
||||
};
|
||||
ASSERT_EQ(stats, want);
|
||||
}
|
||||
@ -86,7 +86,7 @@ TEST(MonitorTest, CounterInt64Mean) {
|
||||
TEST(MonitorTest, CounterInt64Count) {
|
||||
FixedCountStat<int64_t> a{
|
||||
"a",
|
||||
{COUNT},
|
||||
{Aggregation::COUNT},
|
||||
2,
|
||||
};
|
||||
ASSERT_EQ(a.count(), 0);
|
||||
@ -96,8 +96,8 @@ TEST(MonitorTest, CounterInt64Count) {
|
||||
ASSERT_EQ(a.count(), 0);
|
||||
|
||||
auto stats = a.get();
|
||||
std::unordered_map<Aggregation, int64_t> want = {
|
||||
{COUNT, 2},
|
||||
std::unordered_map<Aggregation, int64_t, AggregationHash> want = {
|
||||
{Aggregation::COUNT, 2},
|
||||
};
|
||||
ASSERT_EQ(stats, want);
|
||||
}
|
||||
@ -105,14 +105,14 @@ TEST(MonitorTest, CounterInt64Count) {
|
||||
TEST(MonitorTest, CounterInt64MinMax) {
|
||||
FixedCountStat<int64_t> a{
|
||||
"a",
|
||||
{MIN, MAX},
|
||||
{Aggregation::MIN, Aggregation::MAX},
|
||||
6,
|
||||
};
|
||||
{
|
||||
auto stats = a.get();
|
||||
std::unordered_map<Aggregation, int64_t> want = {
|
||||
{MAX, 0},
|
||||
{MIN, 0},
|
||||
std::unordered_map<Aggregation, int64_t, AggregationHash> want = {
|
||||
{Aggregation::MAX, 0},
|
||||
{Aggregation::MIN, 0},
|
||||
};
|
||||
ASSERT_EQ(stats, want);
|
||||
}
|
||||
@ -125,9 +125,9 @@ TEST(MonitorTest, CounterInt64MinMax) {
|
||||
a.add(2);
|
||||
{
|
||||
auto stats = a.get();
|
||||
std::unordered_map<Aggregation, int64_t> want = {
|
||||
{MAX, 9},
|
||||
{MIN, -6},
|
||||
std::unordered_map<Aggregation, int64_t, AggregationHash> want = {
|
||||
{Aggregation::MAX, 9},
|
||||
{Aggregation::MIN, -6},
|
||||
};
|
||||
ASSERT_EQ(stats, want);
|
||||
}
|
||||
@ -136,7 +136,7 @@ TEST(MonitorTest, CounterInt64MinMax) {
|
||||
TEST(MonitorTest, CounterInt64WindowSize) {
|
||||
FixedCountStat<int64_t> a{
|
||||
"a",
|
||||
{COUNT, SUM},
|
||||
{Aggregation::COUNT, Aggregation::SUM},
|
||||
/*windowSize=*/3,
|
||||
};
|
||||
a.add(1);
|
||||
@ -149,9 +149,9 @@ TEST(MonitorTest, CounterInt64WindowSize) {
|
||||
ASSERT_EQ(a.count(), 1);
|
||||
|
||||
auto stats = a.get();
|
||||
std::unordered_map<Aggregation, int64_t> want = {
|
||||
{COUNT, 3},
|
||||
{SUM, 6},
|
||||
std::unordered_map<Aggregation, int64_t, AggregationHash> want = {
|
||||
{Aggregation::COUNT, 3},
|
||||
{Aggregation::SUM, 6},
|
||||
};
|
||||
ASSERT_EQ(stats, want);
|
||||
}
|
||||
@ -197,7 +197,7 @@ TEST(MonitorTest, IntervalStat) {
|
||||
|
||||
IntervalStat<int64_t> a{
|
||||
"a",
|
||||
{COUNT, SUM},
|
||||
{Aggregation::COUNT, Aggregation::SUM},
|
||||
std::chrono::milliseconds(1),
|
||||
};
|
||||
ASSERT_EQ(guard.handler->events.size(), 0);
|
||||
@ -218,7 +218,7 @@ TEST(MonitorTest, IntervalStatEvent) {
|
||||
|
||||
TestIntervalStat<int64_t> a{
|
||||
"a",
|
||||
{COUNT, SUM},
|
||||
{Aggregation::COUNT, Aggregation::SUM},
|
||||
std::chrono::milliseconds(1),
|
||||
};
|
||||
ASSERT_EQ(guard.handler->events.size(), 0);
|
||||
@ -236,14 +236,13 @@ TEST(MonitorTest, IntervalStatEvent) {
|
||||
|
||||
ASSERT_EQ(guard.handler->events.size(), 1);
|
||||
Event e = guard.handler->events.at(0);
|
||||
ASSERT_EQ(e.type, "torch.monitor.Stat");
|
||||
ASSERT_EQ(e.message, "a");
|
||||
ASSERT_EQ(e.name, "torch.monitor.Stat");
|
||||
ASSERT_NE(e.timestamp, std::chrono::system_clock::time_point{});
|
||||
std::unordered_map<std::string, metadata_value_t> metadata{
|
||||
std::unordered_map<std::string, data_value_t> data{
|
||||
{"a.sum", 3L},
|
||||
{"a.count", 2L},
|
||||
};
|
||||
ASSERT_EQ(e.metadata, metadata);
|
||||
ASSERT_EQ(e.data, data);
|
||||
}
|
||||
|
||||
TEST(MonitorTest, IntervalStatEventDestruction) {
|
||||
@ -252,7 +251,7 @@ TEST(MonitorTest, IntervalStatEventDestruction) {
|
||||
{
|
||||
TestIntervalStat<int64_t> a{
|
||||
"a",
|
||||
{COUNT, SUM},
|
||||
{Aggregation::COUNT, Aggregation::SUM},
|
||||
std::chrono::hours(10),
|
||||
};
|
||||
a.add(1);
|
||||
@ -262,14 +261,13 @@ TEST(MonitorTest, IntervalStatEventDestruction) {
|
||||
ASSERT_EQ(guard.handler->events.size(), 1);
|
||||
|
||||
Event e = guard.handler->events.at(0);
|
||||
ASSERT_EQ(e.type, "torch.monitor.Stat");
|
||||
ASSERT_EQ(e.message, "a");
|
||||
ASSERT_EQ(e.name, "torch.monitor.Stat");
|
||||
ASSERT_NE(e.timestamp, std::chrono::system_clock::time_point{});
|
||||
std::unordered_map<std::string, metadata_value_t> metadata{
|
||||
std::unordered_map<std::string, data_value_t> data{
|
||||
{"a.sum", 1L},
|
||||
{"a.count", 1L},
|
||||
};
|
||||
ASSERT_EQ(e.metadata, metadata);
|
||||
ASSERT_EQ(e.data, data);
|
||||
}
|
||||
|
||||
TEST(MonitorTest, FixedCountStatEvent) {
|
||||
@ -277,7 +275,7 @@ TEST(MonitorTest, FixedCountStatEvent) {
|
||||
|
||||
FixedCountStat<int64_t> a{
|
||||
"a",
|
||||
{COUNT, SUM},
|
||||
{Aggregation::COUNT, Aggregation::SUM},
|
||||
3,
|
||||
};
|
||||
ASSERT_EQ(guard.handler->events.size(), 0);
|
||||
@ -293,14 +291,13 @@ TEST(MonitorTest, FixedCountStatEvent) {
|
||||
ASSERT_EQ(guard.handler->events.size(), 1);
|
||||
|
||||
Event e = guard.handler->events.at(0);
|
||||
ASSERT_EQ(e.type, "torch.monitor.Stat");
|
||||
ASSERT_EQ(e.message, "a");
|
||||
ASSERT_EQ(e.name, "torch.monitor.Stat");
|
||||
ASSERT_NE(e.timestamp, std::chrono::system_clock::time_point{});
|
||||
std::unordered_map<std::string, metadata_value_t> metadata{
|
||||
std::unordered_map<std::string, data_value_t> data{
|
||||
{"a.sum", 4L},
|
||||
{"a.count", 3L},
|
||||
};
|
||||
ASSERT_EQ(e.metadata, metadata);
|
||||
ASSERT_EQ(e.data, data);
|
||||
}
|
||||
|
||||
TEST(MonitorTest, FixedCountStatEventDestruction) {
|
||||
@ -309,7 +306,7 @@ TEST(MonitorTest, FixedCountStatEventDestruction) {
|
||||
{
|
||||
FixedCountStat<int64_t> a{
|
||||
"a",
|
||||
{COUNT, SUM},
|
||||
{Aggregation::COUNT, Aggregation::SUM},
|
||||
3,
|
||||
};
|
||||
ASSERT_EQ(guard.handler->events.size(), 0);
|
||||
@ -320,12 +317,11 @@ TEST(MonitorTest, FixedCountStatEventDestruction) {
|
||||
ASSERT_EQ(guard.handler->events.size(), 1);
|
||||
|
||||
Event e = guard.handler->events.at(0);
|
||||
ASSERT_EQ(e.type, "torch.monitor.Stat");
|
||||
ASSERT_EQ(e.message, "a");
|
||||
ASSERT_EQ(e.name, "torch.monitor.Stat");
|
||||
ASSERT_NE(e.timestamp, std::chrono::system_clock::time_point{});
|
||||
std::unordered_map<std::string, metadata_value_t> metadata{
|
||||
std::unordered_map<std::string, data_value_t> data{
|
||||
{"a.sum", 1L},
|
||||
{"a.count", 1L},
|
||||
};
|
||||
ASSERT_EQ(e.metadata, metadata);
|
||||
ASSERT_EQ(e.data, data);
|
||||
}
|
||||
|
@ -14,13 +14,12 @@ struct AggregatingEventHandler : public EventHandler {
|
||||
|
||||
TEST(EventsTest, EventHandler) {
|
||||
Event e;
|
||||
e.type = "test";
|
||||
e.message = "test message";
|
||||
e.name = "test";
|
||||
e.timestamp = std::chrono::system_clock::now();
|
||||
e.metadata["string"] = "asdf";
|
||||
e.metadata["double"] = 1234.5678;
|
||||
e.metadata["int"] = 1234L;
|
||||
e.metadata["bool"] = true;
|
||||
e.data["string"] = "asdf";
|
||||
e.data["double"] = 1234.5678;
|
||||
e.data["int"] = 1234L;
|
||||
e.data["bool"] = true;
|
||||
|
||||
// log to nothing
|
||||
logEvent(e);
|
||||
|
92
test/test_monitor.py
Normal file
92
test/test_monitor.py
Normal file
@ -0,0 +1,92 @@
|
||||
from torch.testing._internal.common_utils import (
|
||||
TestCase, run_tests,
|
||||
)
|
||||
|
||||
from datetime import timedelta, datetime
|
||||
import time
|
||||
|
||||
from torch.monitor import (
|
||||
Aggregation,
|
||||
FixedCountStat,
|
||||
IntervalStat,
|
||||
Event,
|
||||
log_event,
|
||||
register_event_handler,
|
||||
unregister_event_handler,
|
||||
)
|
||||
|
||||
class TestMonitor(TestCase):
|
||||
def test_interval_stat(self) -> None:
|
||||
events = []
|
||||
|
||||
def handler(event):
|
||||
events.append(event)
|
||||
|
||||
handle = register_event_handler(handler)
|
||||
s = IntervalStat(
|
||||
"asdf",
|
||||
(Aggregation.SUM, Aggregation.COUNT),
|
||||
timedelta(milliseconds=1),
|
||||
)
|
||||
s.add(2)
|
||||
time.sleep(0.002)
|
||||
s.add(3)
|
||||
self.assertEqual(s.name, "asdf")
|
||||
self.assertGreaterEqual(len(events), 1)
|
||||
unregister_event_handler(handle)
|
||||
|
||||
def test_fixed_count_stat(self) -> None:
|
||||
s = FixedCountStat(
|
||||
"asdf",
|
||||
(Aggregation.SUM, Aggregation.COUNT),
|
||||
3,
|
||||
)
|
||||
s.add(1)
|
||||
s.add(2)
|
||||
name = s.name
|
||||
self.assertEqual(name, "asdf")
|
||||
self.assertEqual(s.count, 2)
|
||||
s.add(3)
|
||||
self.assertEqual(s.count, 0)
|
||||
self.assertEqual(s.get(), {Aggregation.SUM: 6.0, Aggregation.COUNT: 3})
|
||||
|
||||
def test_log_event(self) -> None:
|
||||
e = Event(
|
||||
name="torch.monitor.TestEvent",
|
||||
timestamp=datetime.now(),
|
||||
data={
|
||||
"str": "a string",
|
||||
"float": 1234.0,
|
||||
"int": 1234,
|
||||
},
|
||||
)
|
||||
self.assertEqual(e.name, "torch.monitor.TestEvent")
|
||||
self.assertIsNotNone(e.timestamp)
|
||||
self.assertIsNotNone(e.data)
|
||||
log_event(e)
|
||||
|
||||
def test_event_handler(self) -> None:
|
||||
events = []
|
||||
|
||||
def handler(event: Event) -> None:
|
||||
events.append(event)
|
||||
|
||||
handle = register_event_handler(handler)
|
||||
e = Event(
|
||||
name="torch.monitor.TestEvent",
|
||||
timestamp=datetime.now(),
|
||||
data={},
|
||||
)
|
||||
log_event(e)
|
||||
self.assertEqual(len(events), 1)
|
||||
self.assertEqual(events[0], e)
|
||||
log_event(e)
|
||||
self.assertEqual(len(events), 2)
|
||||
|
||||
unregister_event_handler(handle)
|
||||
log_event(e)
|
||||
self.assertEqual(len(events), 2)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
run_tests()
|
@ -856,6 +856,7 @@ libtorch_python_core_sources = [
|
||||
"torch/csrc/jit/python/python_tree_views.cpp",
|
||||
"torch/csrc/jit/runtime/static/init.cpp",
|
||||
"torch/csrc/jit/tensorexpr/tensorexpr_init.cpp",
|
||||
"torch/csrc/monitor/python_init.cpp",
|
||||
"torch/csrc/multiprocessing/init.cpp",
|
||||
"torch/csrc/onnx/init.cpp",
|
||||
"torch/csrc/serialization.cpp",
|
||||
|
50
torch/_C/_monitor.pyi
Normal file
50
torch/_C/_monitor.pyi
Normal file
@ -0,0 +1,50 @@
|
||||
# Defined in torch/csrc/monitor/python_init.cpp
|
||||
|
||||
from typing import List, Dict, Callable, Union
|
||||
from enum import Enum
|
||||
import datetime
|
||||
|
||||
class Aggregation(Enum):
|
||||
VALUE = ...
|
||||
MEAN = ...
|
||||
COUNT = ...
|
||||
SUM = ...
|
||||
MAX = ...
|
||||
MIN = ...
|
||||
|
||||
class Stat:
|
||||
name: str
|
||||
count: int
|
||||
def add(self, v: float) -> None: ...
|
||||
def get(self) -> Dict[Aggregation, float]: ...
|
||||
|
||||
class IntervalStat(Stat):
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
aggregations: List[Aggregation],
|
||||
window_size: datetime.timedelta,
|
||||
) -> None: ...
|
||||
|
||||
class FixedCountStat(Stat):
|
||||
def __init__(
|
||||
self, name: str, aggregations: List[Aggregation], window_size: int
|
||||
) -> None: ...
|
||||
|
||||
class Event:
|
||||
name: str
|
||||
timestamp: datetime.datetime
|
||||
data: Dict[str, Union[int, float, bool, str]]
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
timestamp: datetime.datetime,
|
||||
data: Dict[str, Union[int, float, bool, str]],
|
||||
) -> None: ...
|
||||
|
||||
def log_event(e: Event) -> None: ...
|
||||
|
||||
class PythonEventHandler: ...
|
||||
|
||||
def register_event_handler(handler: Callable[[Event], None]) -> PythonEventHandler: ...
|
||||
def unregister_event_handler(handle: PythonEventHandler) -> None: ...
|
@ -61,6 +61,7 @@
|
||||
#include <torch/csrc/jit/python/python_tracer.h>
|
||||
#include <torch/csrc/jit/python/init.h>
|
||||
#include <torch/csrc/jit/python/python_ir.h>
|
||||
#include <torch/csrc/monitor/python_init.h>
|
||||
#include <torch/csrc/onnx/init.h>
|
||||
#include <torch/csrc/utils/init.h>
|
||||
#include <torch/csrc/utils/crash_handler.h>
|
||||
@ -832,6 +833,7 @@ PyObject* initModule() {
|
||||
// init.
|
||||
torch::onnx::initONNXBindings(module);
|
||||
torch::jit::initJITBindings(module);
|
||||
torch::monitor::initMonitorBindings(module);
|
||||
torch::impl::dispatch::initDispatchBindings(module);
|
||||
torch::throughput_benchmark::initThroughputBenchmarkBindings(module);
|
||||
torch::crash_handler::initCrashHandlerBindings(module);
|
||||
|
@ -9,22 +9,23 @@ namespace monitor {
|
||||
|
||||
const char* aggregationName(Aggregation agg) {
|
||||
switch (agg) {
|
||||
case NONE:
|
||||
case Aggregation::NONE:
|
||||
return "none";
|
||||
case VALUE:
|
||||
case Aggregation::VALUE:
|
||||
return "value";
|
||||
case MEAN:
|
||||
case Aggregation::MEAN:
|
||||
return "mean";
|
||||
case COUNT:
|
||||
case Aggregation::COUNT:
|
||||
return "count";
|
||||
case SUM:
|
||||
case Aggregation::SUM:
|
||||
return "sum";
|
||||
case MAX:
|
||||
case Aggregation::MAX:
|
||||
return "max";
|
||||
case MIN:
|
||||
case Aggregation::MIN:
|
||||
return "min";
|
||||
default:
|
||||
throw std::runtime_error("unknown aggregation: " + std::to_string(agg));
|
||||
throw std::runtime_error(
|
||||
"unknown aggregation: " + std::to_string(static_cast<int>(agg)));
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -6,6 +6,8 @@
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
#include <c10/macros/Macros.h>
|
||||
|
||||
#include <torch/csrc/monitor/events.h>
|
||||
|
||||
namespace torch {
|
||||
@ -15,7 +17,7 @@ constexpr int NUM_AGGREGATIONS = 7;
|
||||
|
||||
// Aggregation is the list of possible aggregations for Stats.
|
||||
// These use bitwise flags so they can be efficiently stored.
|
||||
enum Aggregation {
|
||||
enum class C10_API_ENUM Aggregation {
|
||||
// NONE means no aggregations are set.
|
||||
NONE = 0,
|
||||
// VALUE exports the most recently set value.
|
||||
@ -35,29 +37,36 @@ enum Aggregation {
|
||||
MIN = 6,
|
||||
};
|
||||
|
||||
struct TORCH_API AggregationHash {
|
||||
template <typename T>
|
||||
std::size_t operator()(T t) const {
|
||||
return static_cast<std::size_t>(t);
|
||||
}
|
||||
};
|
||||
|
||||
// aggregationName returns the human readable name corresponding to the
|
||||
// aggregation.
|
||||
const char* aggregationName(Aggregation agg);
|
||||
TORCH_API const char* aggregationName(Aggregation agg);
|
||||
|
||||
template <typename T>
|
||||
class Stat;
|
||||
|
||||
namespace {
|
||||
inline std::bitset<NUM_AGGREGATIONS> merge(
|
||||
std::initializer_list<Aggregation>& list) {
|
||||
template <typename T>
|
||||
inline std::bitset<NUM_AGGREGATIONS> merge(T& list) {
|
||||
std::bitset<NUM_AGGREGATIONS> a;
|
||||
for (Aggregation b : list) {
|
||||
a.set(b);
|
||||
a.set(static_cast<int>(b));
|
||||
}
|
||||
return a;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
namespace detail {
|
||||
void registerStat(Stat<double>* stat);
|
||||
void registerStat(Stat<int64_t>* stat);
|
||||
void unregisterStat(Stat<double>* stat);
|
||||
void unregisterStat(Stat<int64_t>* stat);
|
||||
void TORCH_API registerStat(Stat<double>* stat);
|
||||
void TORCH_API registerStat(Stat<int64_t>* stat);
|
||||
void TORCH_API unregisterStat(Stat<double>* stat);
|
||||
void TORCH_API unregisterStat(Stat<int64_t>* stat);
|
||||
} // namespace detail
|
||||
|
||||
// Stat is a base class for stats. These stats are used to compute summary
|
||||
@ -82,7 +91,7 @@ class Stat {
|
||||
};
|
||||
|
||||
public:
|
||||
Stat(std::string name, std::initializer_list<Aggregation> aggregations)
|
||||
Stat(std::string name, std::vector<Aggregation> aggregations)
|
||||
: name_(std::move(name)), aggregations_(merge(aggregations)) {
|
||||
detail::registerStat(this);
|
||||
}
|
||||
@ -101,19 +110,20 @@ class Stat {
|
||||
std::lock_guard<std::mutex> guard(mu_);
|
||||
maybeLogLocked();
|
||||
|
||||
if (aggregations_.test(VALUE)) {
|
||||
if (aggregations_.test(static_cast<int>(Aggregation::VALUE))) {
|
||||
current_.value = v;
|
||||
}
|
||||
if (aggregations_.test(MEAN) || aggregations_.test(SUM)) {
|
||||
if (aggregations_.test(static_cast<int>(Aggregation::MEAN)) ||
|
||||
aggregations_.test(static_cast<int>(Aggregation::SUM))) {
|
||||
current_.sum += v;
|
||||
}
|
||||
|
||||
if (aggregations_.test(MAX)) {
|
||||
if (aggregations_.test(static_cast<int>(Aggregation::MAX))) {
|
||||
if (current_.max < v || current_.count == 0) {
|
||||
current_.max = v;
|
||||
}
|
||||
}
|
||||
if (aggregations_.test(MIN)) {
|
||||
if (aggregations_.test(static_cast<int>(Aggregation::MIN))) {
|
||||
if (current_.min > v || current_.count == 0) {
|
||||
current_.min = v;
|
||||
}
|
||||
@ -134,7 +144,7 @@ class Stat {
|
||||
return current_.count;
|
||||
}
|
||||
|
||||
std::unordered_map<Aggregation, T> get() noexcept {
|
||||
std::unordered_map<Aggregation, T, AggregationHash> get() noexcept {
|
||||
std::lock_guard<std::mutex> guard(mu_);
|
||||
return getLocked();
|
||||
}
|
||||
@ -152,48 +162,48 @@ class Stat {
|
||||
}
|
||||
|
||||
Event e;
|
||||
e.type = "torch.monitor.Stat";
|
||||
e.message = name_;
|
||||
e.name = "torch.monitor.Stat";
|
||||
e.timestamp = std::chrono::system_clock::now();
|
||||
|
||||
auto stats = getLocked();
|
||||
e.metadata.reserve(stats.size());
|
||||
e.data.reserve(stats.size());
|
||||
for (auto& kv : stats) {
|
||||
std::stringstream key;
|
||||
key << name_;
|
||||
key << ".";
|
||||
key << aggregationName(kv.first);
|
||||
e.metadata[key.str()] = kv.second;
|
||||
e.data[key.str()] = kv.second;
|
||||
}
|
||||
|
||||
logEvent(e);
|
||||
}
|
||||
|
||||
std::unordered_map<Aggregation, T> getLocked() const noexcept {
|
||||
std::unordered_map<Aggregation, T> out;
|
||||
std::unordered_map<Aggregation, T, AggregationHash> getLocked()
|
||||
const noexcept {
|
||||
std::unordered_map<Aggregation, T, AggregationHash> out;
|
||||
out.reserve(aggregations_.count());
|
||||
|
||||
if (aggregations_.test(VALUE)) {
|
||||
out.emplace(VALUE, prev_.value);
|
||||
if (aggregations_.test(static_cast<int>(Aggregation::VALUE))) {
|
||||
out.emplace(Aggregation::VALUE, prev_.value);
|
||||
}
|
||||
if (aggregations_.test(MEAN)) {
|
||||
if (aggregations_.test(static_cast<int>(Aggregation::MEAN))) {
|
||||
if (prev_.count == 0) {
|
||||
out.emplace(MEAN, 0);
|
||||
out.emplace(Aggregation::MEAN, 0);
|
||||
} else {
|
||||
out.emplace(MEAN, prev_.sum / prev_.count);
|
||||
out.emplace(Aggregation::MEAN, prev_.sum / prev_.count);
|
||||
}
|
||||
}
|
||||
if (aggregations_.test(COUNT)) {
|
||||
out.emplace(COUNT, prev_.count);
|
||||
if (aggregations_.test(static_cast<int>(Aggregation::COUNT))) {
|
||||
out.emplace(Aggregation::COUNT, prev_.count);
|
||||
}
|
||||
if (aggregations_.test(SUM)) {
|
||||
out.emplace(SUM, prev_.sum);
|
||||
if (aggregations_.test(static_cast<int>(Aggregation::SUM))) {
|
||||
out.emplace(Aggregation::SUM, prev_.sum);
|
||||
}
|
||||
if (aggregations_.test(MAX)) {
|
||||
out.emplace(MAX, prev_.max);
|
||||
if (aggregations_.test(static_cast<int>(Aggregation::MAX))) {
|
||||
out.emplace(Aggregation::MAX, prev_.max);
|
||||
}
|
||||
if (aggregations_.test(MIN)) {
|
||||
out.emplace(MIN, prev_.min);
|
||||
if (aggregations_.test(static_cast<int>(Aggregation::MIN))) {
|
||||
out.emplace(Aggregation::MIN, prev_.min);
|
||||
}
|
||||
|
||||
return out;
|
||||
@ -219,6 +229,12 @@ class IntervalStat : public Stat<T> {
|
||||
std::chrono::milliseconds windowSize)
|
||||
: Stat<T>(std::move(name), aggregations), windowSize_(windowSize) {}
|
||||
|
||||
IntervalStat(
|
||||
std::string name,
|
||||
std::vector<Aggregation> aggregations,
|
||||
std::chrono::milliseconds windowSize)
|
||||
: Stat<T>(std::move(name), aggregations), windowSize_(windowSize) {}
|
||||
|
||||
protected:
|
||||
virtual uint64_t currentWindowId() const {
|
||||
auto now = std::chrono::steady_clock::now().time_since_epoch();
|
||||
@ -251,6 +267,12 @@ class FixedCountStat : public Stat<T> {
|
||||
int64_t windowSize)
|
||||
: Stat<T>(std::move(name), aggregations), windowSize_(windowSize) {}
|
||||
|
||||
FixedCountStat(
|
||||
std::string name,
|
||||
std::vector<Aggregation> aggregations,
|
||||
int64_t windowSize)
|
||||
: Stat<T>(std::move(name), aggregations), windowSize_(windowSize) {}
|
||||
|
||||
private:
|
||||
void maybeLogLocked() override {
|
||||
if (Stat<T>::current_.count >= windowSize_) {
|
||||
|
@ -5,42 +5,39 @@
|
||||
#include <unordered_map>
|
||||
|
||||
#include <c10/util/variant.h>
|
||||
#include <c10/macros/Macros.h>
|
||||
|
||||
namespace torch {
|
||||
namespace monitor {
|
||||
|
||||
// metadata_value_t is the type for Event metadata values.
|
||||
using metadata_value_t = c10::variant<std::string, double, int64_t, bool>;
|
||||
// data_value_t is the type for Event data values.
|
||||
using data_value_t = c10::variant<std::string, double, int64_t, bool>;
|
||||
|
||||
// Event represents a single event that can be logged out to an external
|
||||
// tracker. This does acquire a lock on logging so should be used relatively
|
||||
// infrequently to avoid performance issues.
|
||||
struct Event {
|
||||
// type is the type of the event. This is a static string that's used to
|
||||
struct TORCH_API Event {
|
||||
// name is the name of the event. This is a static string that's used to
|
||||
// differentiate between event types for programmatic access. The type should
|
||||
// be in the format of a fully qualified Python-style class name.
|
||||
// Ex: torch.monitor.MonitorEvent
|
||||
std::string type;
|
||||
|
||||
// message is a human readable name. This is optional for machine intended
|
||||
// stats.
|
||||
std::string message;
|
||||
std::string name;
|
||||
|
||||
// timestamp is a timestamp relative to the Unix epoch time.
|
||||
std::chrono::system_clock::time_point timestamp;
|
||||
|
||||
// metadata contains rich information about the event. The contents are event
|
||||
// data contains rich information about the event. The contents are event
|
||||
// specific so you should check the type to ensure it's what you expect before
|
||||
// accessing the metadata.
|
||||
// accessing the data.
|
||||
//
|
||||
// NOTE: these events are not versioned and it's up to the consumer of the
|
||||
// events to check the fields to ensure backwards compatibility.
|
||||
std::unordered_map<std::string, metadata_value_t> metadata;
|
||||
std::unordered_map<std::string, data_value_t> data;
|
||||
};
|
||||
|
||||
inline bool operator==(const Event& lhs, const Event& rhs) {
|
||||
return lhs.type == rhs.type && lhs.message == rhs.message &&
|
||||
lhs.timestamp == rhs.timestamp && lhs.metadata == rhs.metadata;
|
||||
TORCH_API inline bool operator==(const Event& lhs, const Event& rhs) {
|
||||
return lhs.name == rhs.name && lhs.timestamp == rhs.timestamp &&
|
||||
lhs.data == rhs.data;
|
||||
}
|
||||
|
||||
// EventHandler represents an abstract event handler that can be registered to
|
||||
@ -49,7 +46,7 @@ inline bool operator==(const Event& lhs, const Event& rhs) {
|
||||
//
|
||||
// NOTE: The handlers should avoid any IO, blocking calls or heavy computation
|
||||
// as this may block the main thread and cause performance issues.
|
||||
class EventHandler {
|
||||
class TORCH_API EventHandler {
|
||||
public:
|
||||
virtual ~EventHandler() = default;
|
||||
|
||||
@ -60,16 +57,16 @@ class EventHandler {
|
||||
|
||||
// logEvent calls each registered event handler with the event. This method can
|
||||
// be called from concurrently from multiple threads.
|
||||
void logEvent(const Event& e);
|
||||
TORCH_API void logEvent(const Event& e);
|
||||
|
||||
// registerEventHandler registers an EventHandler so it receives any logged
|
||||
// events. Typically an EventHandler will be registered during program
|
||||
// setup and unregistered at the end.
|
||||
void registerEventHandler(std::shared_ptr<EventHandler> p);
|
||||
TORCH_API void registerEventHandler(std::shared_ptr<EventHandler> p);
|
||||
|
||||
// unregisterEventHandler unregisters the event handler pointed to by the
|
||||
// shared_ptr.
|
||||
void unregisterEventHandler(const std::shared_ptr<EventHandler>& p);
|
||||
TORCH_API void unregisterEventHandler(const std::shared_ptr<EventHandler>& p);
|
||||
|
||||
} // namespace monitor
|
||||
} // namespace torch
|
||||
|
154
torch/csrc/monitor/python_init.cpp
Normal file
154
torch/csrc/monitor/python_init.cpp
Normal file
@ -0,0 +1,154 @@
|
||||
#include <utility>
|
||||
|
||||
#include <torch/csrc/utils/pybind.h>
|
||||
#include <torch/csrc/utils/python_arg_parser.h>
|
||||
#include <torch/csrc/utils/python_numbers.h>
|
||||
#include <torch/csrc/utils/python_strings.h>
|
||||
|
||||
#include <pybind11/chrono.h>
|
||||
#include <pybind11/functional.h>
|
||||
#include <pybind11/operators.h>
|
||||
#include <pybind11/stl.h>
|
||||
|
||||
#include <torch/csrc/monitor/counters.h>
|
||||
#include <torch/csrc/monitor/events.h>
|
||||
|
||||
namespace pybind11 {
|
||||
namespace detail {
|
||||
template <>
|
||||
struct type_caster<torch::monitor::data_value_t> {
|
||||
public:
|
||||
PYBIND11_TYPE_CASTER(torch::monitor::data_value_t, _("data_value_t"));
|
||||
|
||||
// Python -> C++
|
||||
bool load(handle src, bool) {
|
||||
PyObject* source = src.ptr();
|
||||
if (THPUtils_checkLong(source)) {
|
||||
this->value = THPUtils_unpackLong(source);
|
||||
} else if (THPUtils_checkDouble(source)) {
|
||||
this->value = THPUtils_unpackDouble(source);
|
||||
} else if (THPUtils_checkString(source)) {
|
||||
this->value = THPUtils_unpackString(source);
|
||||
} else if (PyBool_Check(source)) {
|
||||
this->value = THPUtils_unpackBool(source);
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
return !PyErr_Occurred();
|
||||
}
|
||||
|
||||
// C++ -> Python
|
||||
static handle cast(
|
||||
torch::monitor::data_value_t src,
|
||||
return_value_policy /* policy */,
|
||||
handle /* parent */) {
|
||||
if (c10::holds_alternative<double>(src)) {
|
||||
return PyFloat_FromDouble(c10::get<double>(src));
|
||||
} else if (c10::holds_alternative<int64_t>(src)) {
|
||||
return THPUtils_packInt64(c10::get<int64_t>(src));
|
||||
} else if (c10::holds_alternative<bool>(src)) {
|
||||
if (c10::get<bool>(src)) {
|
||||
Py_RETURN_TRUE;
|
||||
} else {
|
||||
Py_RETURN_FALSE;
|
||||
}
|
||||
} else if (c10::holds_alternative<std::string>(src)) {
|
||||
std::string str = c10::get<std::string>(src);
|
||||
return THPUtils_packString(str);
|
||||
}
|
||||
throw std::runtime_error("unknown data_value_t type");
|
||||
}
|
||||
};
|
||||
} // namespace detail
|
||||
} // namespace pybind11
|
||||
|
||||
namespace torch {
|
||||
namespace monitor {
|
||||
|
||||
namespace {
|
||||
class PythonEventHandler : public EventHandler {
|
||||
public:
|
||||
explicit PythonEventHandler(std::function<void(const Event&)> handler)
|
||||
: handler_(std::move(handler)) {}
|
||||
|
||||
void handle(const Event& e) override {
|
||||
handler_(e);
|
||||
}
|
||||
|
||||
private:
|
||||
std::function<void(const Event&)> handler_;
|
||||
};
|
||||
} // namespace
|
||||
|
||||
void initMonitorBindings(PyObject* module) {
|
||||
auto rootModule = py::handle(module).cast<py::module>();
|
||||
|
||||
auto m = rootModule.def_submodule("_monitor");
|
||||
|
||||
py::enum_<Aggregation>(m, "Aggregation")
|
||||
.value("VALUE", Aggregation::NONE)
|
||||
.value("MEAN", Aggregation::MEAN)
|
||||
.value("COUNT", Aggregation::COUNT)
|
||||
.value("SUM", Aggregation::SUM)
|
||||
.value("MAX", Aggregation::MAX)
|
||||
.value("MIN", Aggregation::MIN)
|
||||
.export_values();
|
||||
|
||||
py::class_<Stat<double>>(m, "Stat")
|
||||
.def("add", &Stat<double>::add)
|
||||
.def("get", &Stat<double>::get)
|
||||
.def_property_readonly("name", &Stat<double>::name)
|
||||
.def_property_readonly("count", &Stat<double>::count);
|
||||
|
||||
py::class_<IntervalStat<double>, Stat<double>>(m, "IntervalStat")
|
||||
.def(py::init<
|
||||
std::string,
|
||||
std::vector<Aggregation>,
|
||||
std::chrono::milliseconds>());
|
||||
|
||||
py::class_<FixedCountStat<double>, Stat<double>>(m, "FixedCountStat")
|
||||
.def(py::init<std::string, std::vector<Aggregation>, int64_t>());
|
||||
|
||||
py::class_<Event>(m, "Event")
|
||||
.def(
|
||||
py::init([](const std::string& name,
|
||||
std::chrono::system_clock::time_point timestamp,
|
||||
std::unordered_map<std::string, data_value_t> data) {
|
||||
Event e;
|
||||
e.name = name;
|
||||
e.timestamp = timestamp;
|
||||
e.data = data;
|
||||
return e;
|
||||
}),
|
||||
py::arg("name"),
|
||||
py::arg("timestamp"),
|
||||
py::arg("data"))
|
||||
.def_readwrite("name", &Event::name)
|
||||
.def_readwrite("timestamp", &Event::timestamp)
|
||||
.def_readwrite("data", &Event::data);
|
||||
|
||||
m.def("log_event", &logEvent);
|
||||
|
||||
py::class_<data_value_t> dataClass(m, "data_value_t");
|
||||
|
||||
py::implicitly_convertible<std::string, data_value_t>();
|
||||
py::implicitly_convertible<double, data_value_t>();
|
||||
py::implicitly_convertible<int64_t, data_value_t>();
|
||||
py::implicitly_convertible<bool, data_value_t>();
|
||||
|
||||
py::class_<PythonEventHandler, std::shared_ptr<PythonEventHandler>>
|
||||
eventHandlerClass(m, "PythonEventHandler");
|
||||
m.def("register_event_handler", [](std::function<void(const Event&)> f) {
|
||||
auto handler = std::make_shared<PythonEventHandler>(f);
|
||||
registerEventHandler(handler);
|
||||
return handler;
|
||||
});
|
||||
m.def(
|
||||
"unregister_event_handler",
|
||||
[](std::shared_ptr<PythonEventHandler> handler) {
|
||||
unregisterEventHandler(handler);
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace monitor
|
||||
} // namespace torch
|
11
torch/csrc/monitor/python_init.h
Normal file
11
torch/csrc/monitor/python_init.h
Normal file
@ -0,0 +1,11 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/csrc/utils/pybind.h>
|
||||
|
||||
namespace torch {
|
||||
namespace monitor {
|
||||
|
||||
void initMonitorBindings(PyObject* module);
|
||||
|
||||
}
|
||||
} // namespace torch
|
1
torch/monitor/__init__.py
Normal file
1
torch/monitor/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
from torch._C._monitor import * # noqa: F403
|
Reference in New Issue
Block a user