torch/monitor: add pybind (#69567)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/69567

This exposes torch.monitor events and stats via pybind11 to the underlying C++ implementation.

* The registration interface is a tad different since it takes a lambda function in Python where as in C++ it's a full class.
* This has a small amount of changes to the counter interfaces since there's no way to create an initializer list at runtime so they now also take a vector.
* Only double based stats are provided in Python since it's intended more for high level stats where float imprecision shouldn't be an issue. This can be changed down the line if need arises.

```
events = []

def handler(event):
    events.append(event)

handle = register_event_handler(handler)

log_event(Event(type="torch.monitor.TestEvent", timestamp=datetime.now(), metadata={"foo": 1.0}))
```

D32969391 is now included in this diff.
This cleans up the naming for events. type is now name, message is gone, and metadata is renamed data.

Test Plan: buck test //caffe2/test:monitor //caffe2/test/cpp/monitor:monitor

Reviewed By: kiukchung

Differential Revision: D32924141

fbshipit-source-id: 563304c2e3261a4754e40cca39fc64c5a04b43e8
This commit is contained in:
Tristan Rice
2022-01-12 13:33:49 -08:00
committed by Facebook GitHub Bot
parent 90ef54f8ea
commit bfe1abd3b5
14 changed files with 484 additions and 117 deletions

View File

@ -69,6 +69,7 @@ Features described in this documentation are classified by release status:
torch.hub <hub>
torch.jit <jit>
torch.linalg <linalg>
torch.monitor <monitor>
torch.special <special>
torch.overrides
torch.package <package>

40
docs/source/monitor.rst Normal file
View File

@ -0,0 +1,40 @@
torch.monitor
=============
.. warning::
This module is a prototype release, and its interfaces and functionality may
change without warning in future PyTorch releases.
``torch.monitor`` provides an interface for logging events and counters from
PyTorch.
API Reference
-------------
.. automodule:: torch.monitor
.. autoclass:: torch.monitor.Aggregation
:members:
.. autoclass:: torch.monitor.Stat
:members:
.. autoclass:: torch.monitor.IntervalStat
:members:
.. autoclass:: torch.monitor.FixedCountStat
:members:
.. autoclass:: torch.monitor.Event
:members:
.. autoclass:: torch.monitor.PythonEventHandler
:members:
.. autofunction:: torch.monitor.log_event
.. autofunction:: torch.monitor.register_event_handler
.. autofunction:: torch.monitor.unregister_event_handler

View File

@ -10,7 +10,7 @@ using namespace torch::monitor;
TEST(MonitorTest, CounterDouble) {
FixedCountStat<double> a{
"a",
{MEAN, COUNT},
{Aggregation::MEAN, Aggregation::COUNT},
2,
};
a.add(5.0);
@ -19,9 +19,9 @@ TEST(MonitorTest, CounterDouble) {
ASSERT_EQ(a.count(), 0);
auto stats = a.get();
std::unordered_map<Aggregation, double> want = {
{MEAN, 5.5},
{COUNT, 2.0},
std::unordered_map<Aggregation, double, AggregationHash> want = {
{Aggregation::MEAN, 5.5},
{Aggregation::COUNT, 2.0},
};
ASSERT_EQ(stats, want);
}
@ -29,14 +29,14 @@ TEST(MonitorTest, CounterDouble) {
TEST(MonitorTest, CounterInt64Sum) {
FixedCountStat<int64_t> a{
"a",
{SUM},
{Aggregation::SUM},
2,
};
a.add(5);
a.add(6);
auto stats = a.get();
std::unordered_map<Aggregation, int64_t> want = {
{SUM, 11},
std::unordered_map<Aggregation, int64_t, AggregationHash> want = {
{Aggregation::SUM, 11},
};
ASSERT_EQ(stats, want);
}
@ -44,14 +44,14 @@ TEST(MonitorTest, CounterInt64Sum) {
TEST(MonitorTest, CounterInt64Value) {
FixedCountStat<int64_t> a{
"a",
{VALUE},
{Aggregation::VALUE},
2,
};
a.add(5);
a.add(6);
auto stats = a.get();
std::unordered_map<Aggregation, int64_t> want = {
{VALUE, 6},
std::unordered_map<Aggregation, int64_t, AggregationHash> want = {
{Aggregation::VALUE, 6},
};
ASSERT_EQ(stats, want);
}
@ -59,14 +59,14 @@ TEST(MonitorTest, CounterInt64Value) {
TEST(MonitorTest, CounterInt64Mean) {
FixedCountStat<int64_t> a{
"a",
{MEAN},
{Aggregation::MEAN},
2,
};
{
// zero samples case
auto stats = a.get();
std::unordered_map<Aggregation, int64_t> want = {
{MEAN, 0},
std::unordered_map<Aggregation, int64_t, AggregationHash> want = {
{Aggregation::MEAN, 0},
};
ASSERT_EQ(stats, want);
}
@ -76,8 +76,8 @@ TEST(MonitorTest, CounterInt64Mean) {
{
auto stats = a.get();
std::unordered_map<Aggregation, int64_t> want = {
{MEAN, 5},
std::unordered_map<Aggregation, int64_t, AggregationHash> want = {
{Aggregation::MEAN, 5},
};
ASSERT_EQ(stats, want);
}
@ -86,7 +86,7 @@ TEST(MonitorTest, CounterInt64Mean) {
TEST(MonitorTest, CounterInt64Count) {
FixedCountStat<int64_t> a{
"a",
{COUNT},
{Aggregation::COUNT},
2,
};
ASSERT_EQ(a.count(), 0);
@ -96,8 +96,8 @@ TEST(MonitorTest, CounterInt64Count) {
ASSERT_EQ(a.count(), 0);
auto stats = a.get();
std::unordered_map<Aggregation, int64_t> want = {
{COUNT, 2},
std::unordered_map<Aggregation, int64_t, AggregationHash> want = {
{Aggregation::COUNT, 2},
};
ASSERT_EQ(stats, want);
}
@ -105,14 +105,14 @@ TEST(MonitorTest, CounterInt64Count) {
TEST(MonitorTest, CounterInt64MinMax) {
FixedCountStat<int64_t> a{
"a",
{MIN, MAX},
{Aggregation::MIN, Aggregation::MAX},
6,
};
{
auto stats = a.get();
std::unordered_map<Aggregation, int64_t> want = {
{MAX, 0},
{MIN, 0},
std::unordered_map<Aggregation, int64_t, AggregationHash> want = {
{Aggregation::MAX, 0},
{Aggregation::MIN, 0},
};
ASSERT_EQ(stats, want);
}
@ -125,9 +125,9 @@ TEST(MonitorTest, CounterInt64MinMax) {
a.add(2);
{
auto stats = a.get();
std::unordered_map<Aggregation, int64_t> want = {
{MAX, 9},
{MIN, -6},
std::unordered_map<Aggregation, int64_t, AggregationHash> want = {
{Aggregation::MAX, 9},
{Aggregation::MIN, -6},
};
ASSERT_EQ(stats, want);
}
@ -136,7 +136,7 @@ TEST(MonitorTest, CounterInt64MinMax) {
TEST(MonitorTest, CounterInt64WindowSize) {
FixedCountStat<int64_t> a{
"a",
{COUNT, SUM},
{Aggregation::COUNT, Aggregation::SUM},
/*windowSize=*/3,
};
a.add(1);
@ -149,9 +149,9 @@ TEST(MonitorTest, CounterInt64WindowSize) {
ASSERT_EQ(a.count(), 1);
auto stats = a.get();
std::unordered_map<Aggregation, int64_t> want = {
{COUNT, 3},
{SUM, 6},
std::unordered_map<Aggregation, int64_t, AggregationHash> want = {
{Aggregation::COUNT, 3},
{Aggregation::SUM, 6},
};
ASSERT_EQ(stats, want);
}
@ -197,7 +197,7 @@ TEST(MonitorTest, IntervalStat) {
IntervalStat<int64_t> a{
"a",
{COUNT, SUM},
{Aggregation::COUNT, Aggregation::SUM},
std::chrono::milliseconds(1),
};
ASSERT_EQ(guard.handler->events.size(), 0);
@ -218,7 +218,7 @@ TEST(MonitorTest, IntervalStatEvent) {
TestIntervalStat<int64_t> a{
"a",
{COUNT, SUM},
{Aggregation::COUNT, Aggregation::SUM},
std::chrono::milliseconds(1),
};
ASSERT_EQ(guard.handler->events.size(), 0);
@ -236,14 +236,13 @@ TEST(MonitorTest, IntervalStatEvent) {
ASSERT_EQ(guard.handler->events.size(), 1);
Event e = guard.handler->events.at(0);
ASSERT_EQ(e.type, "torch.monitor.Stat");
ASSERT_EQ(e.message, "a");
ASSERT_EQ(e.name, "torch.monitor.Stat");
ASSERT_NE(e.timestamp, std::chrono::system_clock::time_point{});
std::unordered_map<std::string, metadata_value_t> metadata{
std::unordered_map<std::string, data_value_t> data{
{"a.sum", 3L},
{"a.count", 2L},
};
ASSERT_EQ(e.metadata, metadata);
ASSERT_EQ(e.data, data);
}
TEST(MonitorTest, IntervalStatEventDestruction) {
@ -252,7 +251,7 @@ TEST(MonitorTest, IntervalStatEventDestruction) {
{
TestIntervalStat<int64_t> a{
"a",
{COUNT, SUM},
{Aggregation::COUNT, Aggregation::SUM},
std::chrono::hours(10),
};
a.add(1);
@ -262,14 +261,13 @@ TEST(MonitorTest, IntervalStatEventDestruction) {
ASSERT_EQ(guard.handler->events.size(), 1);
Event e = guard.handler->events.at(0);
ASSERT_EQ(e.type, "torch.monitor.Stat");
ASSERT_EQ(e.message, "a");
ASSERT_EQ(e.name, "torch.monitor.Stat");
ASSERT_NE(e.timestamp, std::chrono::system_clock::time_point{});
std::unordered_map<std::string, metadata_value_t> metadata{
std::unordered_map<std::string, data_value_t> data{
{"a.sum", 1L},
{"a.count", 1L},
};
ASSERT_EQ(e.metadata, metadata);
ASSERT_EQ(e.data, data);
}
TEST(MonitorTest, FixedCountStatEvent) {
@ -277,7 +275,7 @@ TEST(MonitorTest, FixedCountStatEvent) {
FixedCountStat<int64_t> a{
"a",
{COUNT, SUM},
{Aggregation::COUNT, Aggregation::SUM},
3,
};
ASSERT_EQ(guard.handler->events.size(), 0);
@ -293,14 +291,13 @@ TEST(MonitorTest, FixedCountStatEvent) {
ASSERT_EQ(guard.handler->events.size(), 1);
Event e = guard.handler->events.at(0);
ASSERT_EQ(e.type, "torch.monitor.Stat");
ASSERT_EQ(e.message, "a");
ASSERT_EQ(e.name, "torch.monitor.Stat");
ASSERT_NE(e.timestamp, std::chrono::system_clock::time_point{});
std::unordered_map<std::string, metadata_value_t> metadata{
std::unordered_map<std::string, data_value_t> data{
{"a.sum", 4L},
{"a.count", 3L},
};
ASSERT_EQ(e.metadata, metadata);
ASSERT_EQ(e.data, data);
}
TEST(MonitorTest, FixedCountStatEventDestruction) {
@ -309,7 +306,7 @@ TEST(MonitorTest, FixedCountStatEventDestruction) {
{
FixedCountStat<int64_t> a{
"a",
{COUNT, SUM},
{Aggregation::COUNT, Aggregation::SUM},
3,
};
ASSERT_EQ(guard.handler->events.size(), 0);
@ -320,12 +317,11 @@ TEST(MonitorTest, FixedCountStatEventDestruction) {
ASSERT_EQ(guard.handler->events.size(), 1);
Event e = guard.handler->events.at(0);
ASSERT_EQ(e.type, "torch.monitor.Stat");
ASSERT_EQ(e.message, "a");
ASSERT_EQ(e.name, "torch.monitor.Stat");
ASSERT_NE(e.timestamp, std::chrono::system_clock::time_point{});
std::unordered_map<std::string, metadata_value_t> metadata{
std::unordered_map<std::string, data_value_t> data{
{"a.sum", 1L},
{"a.count", 1L},
};
ASSERT_EQ(e.metadata, metadata);
ASSERT_EQ(e.data, data);
}

View File

@ -14,13 +14,12 @@ struct AggregatingEventHandler : public EventHandler {
TEST(EventsTest, EventHandler) {
Event e;
e.type = "test";
e.message = "test message";
e.name = "test";
e.timestamp = std::chrono::system_clock::now();
e.metadata["string"] = "asdf";
e.metadata["double"] = 1234.5678;
e.metadata["int"] = 1234L;
e.metadata["bool"] = true;
e.data["string"] = "asdf";
e.data["double"] = 1234.5678;
e.data["int"] = 1234L;
e.data["bool"] = true;
// log to nothing
logEvent(e);

92
test/test_monitor.py Normal file
View File

@ -0,0 +1,92 @@
from torch.testing._internal.common_utils import (
TestCase, run_tests,
)
from datetime import timedelta, datetime
import time
from torch.monitor import (
Aggregation,
FixedCountStat,
IntervalStat,
Event,
log_event,
register_event_handler,
unregister_event_handler,
)
class TestMonitor(TestCase):
def test_interval_stat(self) -> None:
events = []
def handler(event):
events.append(event)
handle = register_event_handler(handler)
s = IntervalStat(
"asdf",
(Aggregation.SUM, Aggregation.COUNT),
timedelta(milliseconds=1),
)
s.add(2)
time.sleep(0.002)
s.add(3)
self.assertEqual(s.name, "asdf")
self.assertGreaterEqual(len(events), 1)
unregister_event_handler(handle)
def test_fixed_count_stat(self) -> None:
s = FixedCountStat(
"asdf",
(Aggregation.SUM, Aggregation.COUNT),
3,
)
s.add(1)
s.add(2)
name = s.name
self.assertEqual(name, "asdf")
self.assertEqual(s.count, 2)
s.add(3)
self.assertEqual(s.count, 0)
self.assertEqual(s.get(), {Aggregation.SUM: 6.0, Aggregation.COUNT: 3})
def test_log_event(self) -> None:
e = Event(
name="torch.monitor.TestEvent",
timestamp=datetime.now(),
data={
"str": "a string",
"float": 1234.0,
"int": 1234,
},
)
self.assertEqual(e.name, "torch.monitor.TestEvent")
self.assertIsNotNone(e.timestamp)
self.assertIsNotNone(e.data)
log_event(e)
def test_event_handler(self) -> None:
events = []
def handler(event: Event) -> None:
events.append(event)
handle = register_event_handler(handler)
e = Event(
name="torch.monitor.TestEvent",
timestamp=datetime.now(),
data={},
)
log_event(e)
self.assertEqual(len(events), 1)
self.assertEqual(events[0], e)
log_event(e)
self.assertEqual(len(events), 2)
unregister_event_handler(handle)
log_event(e)
self.assertEqual(len(events), 2)
if __name__ == '__main__':
run_tests()

View File

@ -856,6 +856,7 @@ libtorch_python_core_sources = [
"torch/csrc/jit/python/python_tree_views.cpp",
"torch/csrc/jit/runtime/static/init.cpp",
"torch/csrc/jit/tensorexpr/tensorexpr_init.cpp",
"torch/csrc/monitor/python_init.cpp",
"torch/csrc/multiprocessing/init.cpp",
"torch/csrc/onnx/init.cpp",
"torch/csrc/serialization.cpp",

50
torch/_C/_monitor.pyi Normal file
View File

@ -0,0 +1,50 @@
# Defined in torch/csrc/monitor/python_init.cpp
from typing import List, Dict, Callable, Union
from enum import Enum
import datetime
class Aggregation(Enum):
VALUE = ...
MEAN = ...
COUNT = ...
SUM = ...
MAX = ...
MIN = ...
class Stat:
name: str
count: int
def add(self, v: float) -> None: ...
def get(self) -> Dict[Aggregation, float]: ...
class IntervalStat(Stat):
def __init__(
self,
name: str,
aggregations: List[Aggregation],
window_size: datetime.timedelta,
) -> None: ...
class FixedCountStat(Stat):
def __init__(
self, name: str, aggregations: List[Aggregation], window_size: int
) -> None: ...
class Event:
name: str
timestamp: datetime.datetime
data: Dict[str, Union[int, float, bool, str]]
def __init__(
self,
name: str,
timestamp: datetime.datetime,
data: Dict[str, Union[int, float, bool, str]],
) -> None: ...
def log_event(e: Event) -> None: ...
class PythonEventHandler: ...
def register_event_handler(handler: Callable[[Event], None]) -> PythonEventHandler: ...
def unregister_event_handler(handle: PythonEventHandler) -> None: ...

View File

@ -61,6 +61,7 @@
#include <torch/csrc/jit/python/python_tracer.h>
#include <torch/csrc/jit/python/init.h>
#include <torch/csrc/jit/python/python_ir.h>
#include <torch/csrc/monitor/python_init.h>
#include <torch/csrc/onnx/init.h>
#include <torch/csrc/utils/init.h>
#include <torch/csrc/utils/crash_handler.h>
@ -832,6 +833,7 @@ PyObject* initModule() {
// init.
torch::onnx::initONNXBindings(module);
torch::jit::initJITBindings(module);
torch::monitor::initMonitorBindings(module);
torch::impl::dispatch::initDispatchBindings(module);
torch::throughput_benchmark::initThroughputBenchmarkBindings(module);
torch::crash_handler::initCrashHandlerBindings(module);

View File

@ -9,22 +9,23 @@ namespace monitor {
const char* aggregationName(Aggregation agg) {
switch (agg) {
case NONE:
case Aggregation::NONE:
return "none";
case VALUE:
case Aggregation::VALUE:
return "value";
case MEAN:
case Aggregation::MEAN:
return "mean";
case COUNT:
case Aggregation::COUNT:
return "count";
case SUM:
case Aggregation::SUM:
return "sum";
case MAX:
case Aggregation::MAX:
return "max";
case MIN:
case Aggregation::MIN:
return "min";
default:
throw std::runtime_error("unknown aggregation: " + std::to_string(agg));
throw std::runtime_error(
"unknown aggregation: " + std::to_string(static_cast<int>(agg)));
}
}

View File

@ -6,6 +6,8 @@
#include <unordered_map>
#include <vector>
#include <c10/macros/Macros.h>
#include <torch/csrc/monitor/events.h>
namespace torch {
@ -15,7 +17,7 @@ constexpr int NUM_AGGREGATIONS = 7;
// Aggregation is the list of possible aggregations for Stats.
// These use bitwise flags so they can be efficiently stored.
enum Aggregation {
enum class C10_API_ENUM Aggregation {
// NONE means no aggregations are set.
NONE = 0,
// VALUE exports the most recently set value.
@ -35,29 +37,36 @@ enum Aggregation {
MIN = 6,
};
struct TORCH_API AggregationHash {
template <typename T>
std::size_t operator()(T t) const {
return static_cast<std::size_t>(t);
}
};
// aggregationName returns the human readable name corresponding to the
// aggregation.
const char* aggregationName(Aggregation agg);
TORCH_API const char* aggregationName(Aggregation agg);
template <typename T>
class Stat;
namespace {
inline std::bitset<NUM_AGGREGATIONS> merge(
std::initializer_list<Aggregation>& list) {
template <typename T>
inline std::bitset<NUM_AGGREGATIONS> merge(T& list) {
std::bitset<NUM_AGGREGATIONS> a;
for (Aggregation b : list) {
a.set(b);
a.set(static_cast<int>(b));
}
return a;
}
} // namespace
namespace detail {
void registerStat(Stat<double>* stat);
void registerStat(Stat<int64_t>* stat);
void unregisterStat(Stat<double>* stat);
void unregisterStat(Stat<int64_t>* stat);
void TORCH_API registerStat(Stat<double>* stat);
void TORCH_API registerStat(Stat<int64_t>* stat);
void TORCH_API unregisterStat(Stat<double>* stat);
void TORCH_API unregisterStat(Stat<int64_t>* stat);
} // namespace detail
// Stat is a base class for stats. These stats are used to compute summary
@ -82,7 +91,7 @@ class Stat {
};
public:
Stat(std::string name, std::initializer_list<Aggregation> aggregations)
Stat(std::string name, std::vector<Aggregation> aggregations)
: name_(std::move(name)), aggregations_(merge(aggregations)) {
detail::registerStat(this);
}
@ -101,19 +110,20 @@ class Stat {
std::lock_guard<std::mutex> guard(mu_);
maybeLogLocked();
if (aggregations_.test(VALUE)) {
if (aggregations_.test(static_cast<int>(Aggregation::VALUE))) {
current_.value = v;
}
if (aggregations_.test(MEAN) || aggregations_.test(SUM)) {
if (aggregations_.test(static_cast<int>(Aggregation::MEAN)) ||
aggregations_.test(static_cast<int>(Aggregation::SUM))) {
current_.sum += v;
}
if (aggregations_.test(MAX)) {
if (aggregations_.test(static_cast<int>(Aggregation::MAX))) {
if (current_.max < v || current_.count == 0) {
current_.max = v;
}
}
if (aggregations_.test(MIN)) {
if (aggregations_.test(static_cast<int>(Aggregation::MIN))) {
if (current_.min > v || current_.count == 0) {
current_.min = v;
}
@ -134,7 +144,7 @@ class Stat {
return current_.count;
}
std::unordered_map<Aggregation, T> get() noexcept {
std::unordered_map<Aggregation, T, AggregationHash> get() noexcept {
std::lock_guard<std::mutex> guard(mu_);
return getLocked();
}
@ -152,48 +162,48 @@ class Stat {
}
Event e;
e.type = "torch.monitor.Stat";
e.message = name_;
e.name = "torch.monitor.Stat";
e.timestamp = std::chrono::system_clock::now();
auto stats = getLocked();
e.metadata.reserve(stats.size());
e.data.reserve(stats.size());
for (auto& kv : stats) {
std::stringstream key;
key << name_;
key << ".";
key << aggregationName(kv.first);
e.metadata[key.str()] = kv.second;
e.data[key.str()] = kv.second;
}
logEvent(e);
}
std::unordered_map<Aggregation, T> getLocked() const noexcept {
std::unordered_map<Aggregation, T> out;
std::unordered_map<Aggregation, T, AggregationHash> getLocked()
const noexcept {
std::unordered_map<Aggregation, T, AggregationHash> out;
out.reserve(aggregations_.count());
if (aggregations_.test(VALUE)) {
out.emplace(VALUE, prev_.value);
if (aggregations_.test(static_cast<int>(Aggregation::VALUE))) {
out.emplace(Aggregation::VALUE, prev_.value);
}
if (aggregations_.test(MEAN)) {
if (aggregations_.test(static_cast<int>(Aggregation::MEAN))) {
if (prev_.count == 0) {
out.emplace(MEAN, 0);
out.emplace(Aggregation::MEAN, 0);
} else {
out.emplace(MEAN, prev_.sum / prev_.count);
out.emplace(Aggregation::MEAN, prev_.sum / prev_.count);
}
}
if (aggregations_.test(COUNT)) {
out.emplace(COUNT, prev_.count);
if (aggregations_.test(static_cast<int>(Aggregation::COUNT))) {
out.emplace(Aggregation::COUNT, prev_.count);
}
if (aggregations_.test(SUM)) {
out.emplace(SUM, prev_.sum);
if (aggregations_.test(static_cast<int>(Aggregation::SUM))) {
out.emplace(Aggregation::SUM, prev_.sum);
}
if (aggregations_.test(MAX)) {
out.emplace(MAX, prev_.max);
if (aggregations_.test(static_cast<int>(Aggregation::MAX))) {
out.emplace(Aggregation::MAX, prev_.max);
}
if (aggregations_.test(MIN)) {
out.emplace(MIN, prev_.min);
if (aggregations_.test(static_cast<int>(Aggregation::MIN))) {
out.emplace(Aggregation::MIN, prev_.min);
}
return out;
@ -219,6 +229,12 @@ class IntervalStat : public Stat<T> {
std::chrono::milliseconds windowSize)
: Stat<T>(std::move(name), aggregations), windowSize_(windowSize) {}
IntervalStat(
std::string name,
std::vector<Aggregation> aggregations,
std::chrono::milliseconds windowSize)
: Stat<T>(std::move(name), aggregations), windowSize_(windowSize) {}
protected:
virtual uint64_t currentWindowId() const {
auto now = std::chrono::steady_clock::now().time_since_epoch();
@ -251,6 +267,12 @@ class FixedCountStat : public Stat<T> {
int64_t windowSize)
: Stat<T>(std::move(name), aggregations), windowSize_(windowSize) {}
FixedCountStat(
std::string name,
std::vector<Aggregation> aggregations,
int64_t windowSize)
: Stat<T>(std::move(name), aggregations), windowSize_(windowSize) {}
private:
void maybeLogLocked() override {
if (Stat<T>::current_.count >= windowSize_) {

View File

@ -5,42 +5,39 @@
#include <unordered_map>
#include <c10/util/variant.h>
#include <c10/macros/Macros.h>
namespace torch {
namespace monitor {
// metadata_value_t is the type for Event metadata values.
using metadata_value_t = c10::variant<std::string, double, int64_t, bool>;
// data_value_t is the type for Event data values.
using data_value_t = c10::variant<std::string, double, int64_t, bool>;
// Event represents a single event that can be logged out to an external
// tracker. This does acquire a lock on logging so should be used relatively
// infrequently to avoid performance issues.
struct Event {
// type is the type of the event. This is a static string that's used to
struct TORCH_API Event {
// name is the name of the event. This is a static string that's used to
// differentiate between event types for programmatic access. The type should
// be in the format of a fully qualified Python-style class name.
// Ex: torch.monitor.MonitorEvent
std::string type;
// message is a human readable name. This is optional for machine intended
// stats.
std::string message;
std::string name;
// timestamp is a timestamp relative to the Unix epoch time.
std::chrono::system_clock::time_point timestamp;
// metadata contains rich information about the event. The contents are event
// data contains rich information about the event. The contents are event
// specific so you should check the type to ensure it's what you expect before
// accessing the metadata.
// accessing the data.
//
// NOTE: these events are not versioned and it's up to the consumer of the
// events to check the fields to ensure backwards compatibility.
std::unordered_map<std::string, metadata_value_t> metadata;
std::unordered_map<std::string, data_value_t> data;
};
inline bool operator==(const Event& lhs, const Event& rhs) {
return lhs.type == rhs.type && lhs.message == rhs.message &&
lhs.timestamp == rhs.timestamp && lhs.metadata == rhs.metadata;
TORCH_API inline bool operator==(const Event& lhs, const Event& rhs) {
return lhs.name == rhs.name && lhs.timestamp == rhs.timestamp &&
lhs.data == rhs.data;
}
// EventHandler represents an abstract event handler that can be registered to
@ -49,7 +46,7 @@ inline bool operator==(const Event& lhs, const Event& rhs) {
//
// NOTE: The handlers should avoid any IO, blocking calls or heavy computation
// as this may block the main thread and cause performance issues.
class EventHandler {
class TORCH_API EventHandler {
public:
virtual ~EventHandler() = default;
@ -60,16 +57,16 @@ class EventHandler {
// logEvent calls each registered event handler with the event. This method can
// be called from concurrently from multiple threads.
void logEvent(const Event& e);
TORCH_API void logEvent(const Event& e);
// registerEventHandler registers an EventHandler so it receives any logged
// events. Typically an EventHandler will be registered during program
// setup and unregistered at the end.
void registerEventHandler(std::shared_ptr<EventHandler> p);
TORCH_API void registerEventHandler(std::shared_ptr<EventHandler> p);
// unregisterEventHandler unregisters the event handler pointed to by the
// shared_ptr.
void unregisterEventHandler(const std::shared_ptr<EventHandler>& p);
TORCH_API void unregisterEventHandler(const std::shared_ptr<EventHandler>& p);
} // namespace monitor
} // namespace torch

View File

@ -0,0 +1,154 @@
#include <utility>
#include <torch/csrc/utils/pybind.h>
#include <torch/csrc/utils/python_arg_parser.h>
#include <torch/csrc/utils/python_numbers.h>
#include <torch/csrc/utils/python_strings.h>
#include <pybind11/chrono.h>
#include <pybind11/functional.h>
#include <pybind11/operators.h>
#include <pybind11/stl.h>
#include <torch/csrc/monitor/counters.h>
#include <torch/csrc/monitor/events.h>
namespace pybind11 {
namespace detail {
template <>
struct type_caster<torch::monitor::data_value_t> {
public:
PYBIND11_TYPE_CASTER(torch::monitor::data_value_t, _("data_value_t"));
// Python -> C++
bool load(handle src, bool) {
PyObject* source = src.ptr();
if (THPUtils_checkLong(source)) {
this->value = THPUtils_unpackLong(source);
} else if (THPUtils_checkDouble(source)) {
this->value = THPUtils_unpackDouble(source);
} else if (THPUtils_checkString(source)) {
this->value = THPUtils_unpackString(source);
} else if (PyBool_Check(source)) {
this->value = THPUtils_unpackBool(source);
} else {
return false;
}
return !PyErr_Occurred();
}
// C++ -> Python
static handle cast(
torch::monitor::data_value_t src,
return_value_policy /* policy */,
handle /* parent */) {
if (c10::holds_alternative<double>(src)) {
return PyFloat_FromDouble(c10::get<double>(src));
} else if (c10::holds_alternative<int64_t>(src)) {
return THPUtils_packInt64(c10::get<int64_t>(src));
} else if (c10::holds_alternative<bool>(src)) {
if (c10::get<bool>(src)) {
Py_RETURN_TRUE;
} else {
Py_RETURN_FALSE;
}
} else if (c10::holds_alternative<std::string>(src)) {
std::string str = c10::get<std::string>(src);
return THPUtils_packString(str);
}
throw std::runtime_error("unknown data_value_t type");
}
};
} // namespace detail
} // namespace pybind11
namespace torch {
namespace monitor {
namespace {
class PythonEventHandler : public EventHandler {
public:
explicit PythonEventHandler(std::function<void(const Event&)> handler)
: handler_(std::move(handler)) {}
void handle(const Event& e) override {
handler_(e);
}
private:
std::function<void(const Event&)> handler_;
};
} // namespace
void initMonitorBindings(PyObject* module) {
auto rootModule = py::handle(module).cast<py::module>();
auto m = rootModule.def_submodule("_monitor");
py::enum_<Aggregation>(m, "Aggregation")
.value("VALUE", Aggregation::NONE)
.value("MEAN", Aggregation::MEAN)
.value("COUNT", Aggregation::COUNT)
.value("SUM", Aggregation::SUM)
.value("MAX", Aggregation::MAX)
.value("MIN", Aggregation::MIN)
.export_values();
py::class_<Stat<double>>(m, "Stat")
.def("add", &Stat<double>::add)
.def("get", &Stat<double>::get)
.def_property_readonly("name", &Stat<double>::name)
.def_property_readonly("count", &Stat<double>::count);
py::class_<IntervalStat<double>, Stat<double>>(m, "IntervalStat")
.def(py::init<
std::string,
std::vector<Aggregation>,
std::chrono::milliseconds>());
py::class_<FixedCountStat<double>, Stat<double>>(m, "FixedCountStat")
.def(py::init<std::string, std::vector<Aggregation>, int64_t>());
py::class_<Event>(m, "Event")
.def(
py::init([](const std::string& name,
std::chrono::system_clock::time_point timestamp,
std::unordered_map<std::string, data_value_t> data) {
Event e;
e.name = name;
e.timestamp = timestamp;
e.data = data;
return e;
}),
py::arg("name"),
py::arg("timestamp"),
py::arg("data"))
.def_readwrite("name", &Event::name)
.def_readwrite("timestamp", &Event::timestamp)
.def_readwrite("data", &Event::data);
m.def("log_event", &logEvent);
py::class_<data_value_t> dataClass(m, "data_value_t");
py::implicitly_convertible<std::string, data_value_t>();
py::implicitly_convertible<double, data_value_t>();
py::implicitly_convertible<int64_t, data_value_t>();
py::implicitly_convertible<bool, data_value_t>();
py::class_<PythonEventHandler, std::shared_ptr<PythonEventHandler>>
eventHandlerClass(m, "PythonEventHandler");
m.def("register_event_handler", [](std::function<void(const Event&)> f) {
auto handler = std::make_shared<PythonEventHandler>(f);
registerEventHandler(handler);
return handler;
});
m.def(
"unregister_event_handler",
[](std::shared_ptr<PythonEventHandler> handler) {
unregisterEventHandler(handler);
});
}
} // namespace monitor
} // namespace torch

View File

@ -0,0 +1,11 @@
#pragma once
#include <torch/csrc/utils/pybind.h>
namespace torch {
namespace monitor {
void initMonitorBindings(PyObject* module);
}
} // namespace torch

View File

@ -0,0 +1 @@
from torch._C._monitor import * # noqa: F403