torch/monitor: merge Interval and FixedCount stats (#72009)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/72009

This simplifies the Stats interface by merging IntervalStat and FixedCountStat into a single Stat w/ a specific window size duration and an optional max samples per window. This allows for the original intention of having comparably sized windows (for statistical purposes) while also having a consistent output bandwidth.

Test Plan:
```
buck test //caffe2/test:monitor //caffe2/test/cpp/monitor:monitor
```

Reviewed By: kiukchung

Differential Revision: D33822956

fbshipit-source-id: a74782492421be613a1a8b14341b6fb2e8eeb8b4
(cherry picked from commit 293b94e0b4646521ffe047e5222c4bba7e688464)
This commit is contained in:
Tristan Rice
2022-01-30 15:17:44 -08:00
committed by PyTorch MergeBot
parent a18cfb790d
commit 6208c2800e
6 changed files with 158 additions and 217 deletions

View File

@ -30,13 +30,6 @@ API Reference
.. autoclass:: torch.monitor.Stat
:members:
.. autoclass:: torch.monitor.IntervalStat
:members: +add, count, name
:special-members: __init__
.. autoclass:: torch.monitor.FixedCountStat
:members: +add, count, name
:special-members: __init__
.. autoclass:: torch.monitor.data_value_t

View File

@ -8,9 +8,10 @@
using namespace torch::monitor;
TEST(MonitorTest, CounterDouble) {
FixedCountStat<double> a{
Stat<double> a{
"a",
{Aggregation::MEAN, Aggregation::COUNT},
std::chrono::milliseconds(100000),
2,
};
a.add(5.0);
@ -27,9 +28,10 @@ TEST(MonitorTest, CounterDouble) {
}
TEST(MonitorTest, CounterInt64Sum) {
FixedCountStat<int64_t> a{
Stat<int64_t> a{
"a",
{Aggregation::SUM},
std::chrono::milliseconds(100000),
2,
};
a.add(5);
@ -42,9 +44,10 @@ TEST(MonitorTest, CounterInt64Sum) {
}
TEST(MonitorTest, CounterInt64Value) {
FixedCountStat<int64_t> a{
Stat<int64_t> a{
"a",
{Aggregation::VALUE},
std::chrono::milliseconds(100000),
2,
};
a.add(5);
@ -57,9 +60,10 @@ TEST(MonitorTest, CounterInt64Value) {
}
TEST(MonitorTest, CounterInt64Mean) {
FixedCountStat<int64_t> a{
Stat<int64_t> a{
"a",
{Aggregation::MEAN},
std::chrono::milliseconds(100000),
2,
};
{
@ -84,9 +88,10 @@ TEST(MonitorTest, CounterInt64Mean) {
}
TEST(MonitorTest, CounterInt64Count) {
FixedCountStat<int64_t> a{
Stat<int64_t> a{
"a",
{Aggregation::COUNT},
std::chrono::milliseconds(100000),
2,
};
ASSERT_EQ(a.count(), 0);
@ -103,9 +108,10 @@ TEST(MonitorTest, CounterInt64Count) {
}
TEST(MonitorTest, CounterInt64MinMax) {
FixedCountStat<int64_t> a{
Stat<int64_t> a{
"a",
{Aggregation::MIN, Aggregation::MAX},
std::chrono::milliseconds(100000),
6,
};
{
@ -134,9 +140,10 @@ TEST(MonitorTest, CounterInt64MinMax) {
}
TEST(MonitorTest, CounterInt64WindowSize) {
FixedCountStat<int64_t> a{
Stat<int64_t> a{
"a",
{Aggregation::COUNT, Aggregation::SUM},
std::chrono::milliseconds(100000),
/*windowSize=*/3,
};
a.add(1);
@ -145,8 +152,34 @@ TEST(MonitorTest, CounterInt64WindowSize) {
a.add(3);
ASSERT_EQ(a.count(), 0);
// after logging max for window, should be zero
a.add(4);
ASSERT_EQ(a.count(), 1);
ASSERT_EQ(a.count(), 0);
auto stats = a.get();
std::unordered_map<Aggregation, int64_t, AggregationHash> want = {
{Aggregation::COUNT, 3},
{Aggregation::SUM, 6},
};
ASSERT_EQ(stats, want);
}
TEST(MonitorTest, CounterInt64WindowSizeHuge) {
Stat<int64_t> a{
"a",
{Aggregation::COUNT, Aggregation::SUM},
std::chrono::hours(24 * 365 * 10), // 10 years
/*windowSize=*/3,
};
a.add(1);
a.add(2);
ASSERT_EQ(a.count(), 2);
a.add(3);
ASSERT_EQ(a.count(), 0);
// after logging max for window, should be zero
a.add(4);
ASSERT_EQ(a.count(), 0);
auto stats = a.get();
std::unordered_map<Aggregation, int64_t, AggregationHash> want = {
@ -157,14 +190,15 @@ TEST(MonitorTest, CounterInt64WindowSize) {
}
template <typename T>
struct TestIntervalStat : public IntervalStat<T> {
uint64_t mockWindowId{0};
struct TestStat : public Stat<T> {
uint64_t mockWindowId{1};
TestIntervalStat(
TestStat(
std::string name,
std::initializer_list<Aggregation> aggregations,
std::chrono::milliseconds windowSize)
: IntervalStat<T>(name, aggregations, windowSize) {}
std::chrono::milliseconds windowSize,
int64_t maxSamples = std::numeric_limits<int64_t>::max())
: Stat<T>(name, aggregations, windowSize, maxSamples) {}
uint64_t currentWindowId() const override {
return mockWindowId;
@ -192,10 +226,10 @@ struct HandlerGuard {
}
};
TEST(MonitorTest, IntervalStat) {
TEST(MonitorTest, Stat) {
HandlerGuard<AggregatingEventHandler> guard;
IntervalStat<int64_t> a{
Stat<int64_t> a{
"a",
{Aggregation::COUNT, Aggregation::SUM},
std::chrono::milliseconds(1),
@ -213,10 +247,10 @@ TEST(MonitorTest, IntervalStat) {
ASSERT_LE(guard.handler->events.size(), 2);
}
TEST(MonitorTest, IntervalStatEvent) {
TEST(MonitorTest, StatEvent) {
HandlerGuard<AggregatingEventHandler> guard;
TestIntervalStat<int64_t> a{
TestStat<int64_t> a{
"a",
{Aggregation::COUNT, Aggregation::SUM},
std::chrono::milliseconds(1),
@ -245,11 +279,11 @@ TEST(MonitorTest, IntervalStatEvent) {
ASSERT_EQ(e.data, data);
}
TEST(MonitorTest, IntervalStatEventDestruction) {
TEST(MonitorTest, StatEventDestruction) {
HandlerGuard<AggregatingEventHandler> guard;
{
TestIntervalStat<int64_t> a{
TestStat<int64_t> a{
"a",
{Aggregation::COUNT, Aggregation::SUM},
std::chrono::hours(10),
@ -269,59 +303,3 @@ TEST(MonitorTest, IntervalStatEventDestruction) {
};
ASSERT_EQ(e.data, data);
}
TEST(MonitorTest, FixedCountStatEvent) {
HandlerGuard<AggregatingEventHandler> guard;
FixedCountStat<int64_t> a{
"a",
{Aggregation::COUNT, Aggregation::SUM},
3,
};
ASSERT_EQ(guard.handler->events.size(), 0);
a.add(1);
ASSERT_EQ(a.count(), 1);
a.add(2);
ASSERT_EQ(a.count(), 2);
ASSERT_EQ(guard.handler->events.size(), 0);
a.add(1);
ASSERT_EQ(a.count(), 0);
ASSERT_EQ(guard.handler->events.size(), 1);
Event e = guard.handler->events.at(0);
ASSERT_EQ(e.name, "torch.monitor.Stat");
ASSERT_NE(e.timestamp, std::chrono::system_clock::time_point{});
std::unordered_map<std::string, data_value_t> data{
{"a.sum", 4L},
{"a.count", 3L},
};
ASSERT_EQ(e.data, data);
}
TEST(MonitorTest, FixedCountStatEventDestruction) {
HandlerGuard<AggregatingEventHandler> guard;
{
FixedCountStat<int64_t> a{
"a",
{Aggregation::COUNT, Aggregation::SUM},
3,
};
ASSERT_EQ(guard.handler->events.size(), 0);
a.add(1);
ASSERT_EQ(a.count(), 1);
ASSERT_EQ(guard.handler->events.size(), 0);
}
ASSERT_EQ(guard.handler->events.size(), 1);
Event e = guard.handler->events.at(0);
ASSERT_EQ(e.name, "torch.monitor.Stat");
ASSERT_NE(e.timestamp, std::chrono::system_clock::time_point{});
std::unordered_map<std::string, data_value_t> data{
{"a.sum", 1L},
{"a.count", 1L},
};
ASSERT_EQ(e.data, data);
}

View File

@ -10,8 +10,6 @@ import time
from torch.monitor import (
Aggregation,
FixedCountStat,
IntervalStat,
Event,
log_event,
register_event_handler,
@ -28,12 +26,11 @@ class TestMonitor(TestCase):
events.append(event)
handle = register_event_handler(handler)
s = IntervalStat(
s = Stat(
"asdf",
(Aggregation.SUM, Aggregation.COUNT),
timedelta(milliseconds=1),
)
self.assertIsInstance(s, Stat)
self.assertEqual(s.name, "asdf")
s.add(2)
@ -48,12 +45,12 @@ class TestMonitor(TestCase):
unregister_event_handler(handle)
def test_fixed_count_stat(self) -> None:
s = FixedCountStat(
s = Stat(
"asdf",
(Aggregation.SUM, Aggregation.COUNT),
timedelta(hours=100),
3,
)
self.assertIsInstance(s, Stat)
s.add(1)
s.add(2)
name = s.name
@ -126,10 +123,11 @@ class TestMonitorTensorboard(TestCase):
with self.create_summary_writer() as w:
handle = register_event_handler(TensorboardEventHandler(w))
s = FixedCountStat(
s = Stat(
"asdf",
(Aggregation.SUM, Aggregation.COUNT),
2,
timedelta(hours=1),
5,
)
for i in range(10):
s.add(i)
@ -150,8 +148,8 @@ class TestMonitorTensorboard(TestCase):
tag: [e.tensor_proto.float_val[0] for e in events] for tag, events in raw_result.items()
}
self.assertEqual(scalars, {
"asdf.sum": [1, 5, 9, 13, 17],
"asdf.count": [2, 2, 2, 2, 2],
"asdf.sum": [10],
"asdf.count": [5],
})

View File

@ -15,22 +15,13 @@ class Aggregation(Enum):
class Stat:
name: str
count: int
def __init__(
self, name: str, aggregations: List[Aggregation], window_size: int,
max_samples: int = -1,
) -> None: ...
def add(self, v: float) -> None: ...
def get(self) -> Dict[Aggregation, float]: ...
class IntervalStat(Stat):
def __init__(
self,
name: str,
aggregations: List[Aggregation],
window_size: datetime.timedelta,
) -> None: ...
class FixedCountStat(Stat):
def __init__(
self, name: str, aggregations: List[Aggregation], window_size: int
) -> None: ...
class Event:
name: str
timestamp: datetime.datetime

View File

@ -69,10 +69,19 @@ void TORCH_API unregisterStat(Stat<double>* stat);
void TORCH_API unregisterStat(Stat<int64_t>* stat);
} // namespace detail
// Stat is a base class for stats. These stats are used to compute summary
// statistics in a performant way over repeating intervals. When the window
// closes the stats are logged via the event handlers as a `torch.monitor.Stat`
// event.
// Stat is used to compute summary statistics in a performant way over fixed
// intervals. Stat logs the statistics as an Event once every `windowSize`
// duration. When the window closes the stats are logged via the event handlers
// as a `torch.monitor.Stat` event.
//
// `windowSize` should be set to something relatively high to avoid a huge
// number of events being logged. Ex: 60s. Stat uses millisecond precision.
//
// If maxSamples is set, the stat will cap the number of samples per window by
// discarding `add` calls once `maxSamples` adds have occurred. If it's not set,
// all `add` calls during the window will be included.
// This is an optional field to make aggregations more directly comparable
// across windows when the number of samples might vary.
//
// Stats support double and int64_t data types depending on what needs to be
// logged and needs to be templatized with one of them.
@ -91,8 +100,27 @@ class Stat {
};
public:
Stat(std::string name, std::vector<Aggregation> aggregations)
: name_(std::move(name)), aggregations_(merge(aggregations)) {
Stat(
std::string name,
std::initializer_list<Aggregation> aggregations,
std::chrono::milliseconds windowSize,
int64_t maxSamples = std::numeric_limits<int64_t>::max())
: name_(std::move(name)),
aggregations_(merge(aggregations)),
windowSize_(windowSize),
maxSamples_(maxSamples) {
detail::registerStat(this);
}
Stat(
std::string name,
std::vector<Aggregation> aggregations,
std::chrono::milliseconds windowSize,
int64_t maxSamples = std::numeric_limits<int64_t>::max())
: name_(std::move(name)),
aggregations_(merge(aggregations)),
windowSize_(windowSize),
maxSamples_(maxSamples) {
detail::registerStat(this);
}
@ -110,6 +138,10 @@ class Stat {
std::lock_guard<std::mutex> guard(mu_);
maybeLogLocked();
if (alreadyLogged()) {
return;
}
if (aggregations_.test(static_cast<int>(Aggregation::VALUE))) {
current_.value = v;
}
@ -150,7 +182,29 @@ class Stat {
}
protected:
virtual void maybeLogLocked() = 0;
virtual uint64_t currentWindowId() const {
std::chrono::milliseconds now =
std::chrono::duration_cast<std::chrono::milliseconds>(
std::chrono::steady_clock::now().time_since_epoch());
// always returns a currentWindowId of at least 1 to avoid 0 window issues
return (now / windowSize_) + 1;
}
private:
bool alreadyLogged() {
return lastLoggedWindowId_ == currentWindowId();
}
void maybeLogLocked() {
auto windowId = currentWindowId();
bool shouldLog = windowId_ != windowId || current_.count >= maxSamples_;
if (shouldLog && !alreadyLogged()) {
logLocked();
lastLoggedWindowId_ = windowId_;
windowId_ = windowId;
}
}
void logLocked() {
prev_ = current_;
@ -215,72 +269,11 @@ class Stat {
std::mutex mu_;
Values current_;
Values prev_;
};
// IntervalStat is a Stat that logs the stat once every `windowSize` duration.
// This should be set to something relatively high to avoid a huge number of
// events being logged. Ex: 60s.
template <typename T>
class IntervalStat : public Stat<T> {
public:
IntervalStat(
std::string name,
std::initializer_list<Aggregation> aggregations,
std::chrono::milliseconds windowSize)
: Stat<T>(std::move(name), aggregations), windowSize_(windowSize) {}
IntervalStat(
std::string name,
std::vector<Aggregation> aggregations,
std::chrono::milliseconds windowSize)
: Stat<T>(std::move(name), aggregations), windowSize_(windowSize) {}
protected:
virtual uint64_t currentWindowId() const {
auto now = std::chrono::steady_clock::now().time_since_epoch();
return now / windowSize_;
}
private:
void maybeLogLocked() override {
auto windowId = currentWindowId();
if (windowId_ != windowId) {
Stat<T>::logLocked();
windowId_ = windowId;
}
}
uint64_t windowId_{0};
uint64_t lastLoggedWindowId_{0};
const std::chrono::milliseconds windowSize_;
};
// FixedCountStat is a Stat that logs the stat every `windowSize` number of add
// calls. For high performance stats this window size should be fairly large to
// ensure that the event logging frequency is in the range of 1s to 60s under
// normal usage. Core stats should error on the side of less frequent.
template <typename T>
class FixedCountStat : public Stat<T> {
public:
FixedCountStat(
std::string name,
std::initializer_list<Aggregation> aggregations,
int64_t windowSize)
: Stat<T>(std::move(name), aggregations), windowSize_(windowSize) {}
FixedCountStat(
std::string name,
std::vector<Aggregation> aggregations,
int64_t windowSize)
: Stat<T>(std::move(name), aggregations), windowSize_(windowSize) {}
private:
void maybeLogLocked() override {
if (Stat<T>::current_.count >= windowSize_) {
Stat<T>::logLocked();
}
}
const int64_t windowSize_;
const int64_t maxSamples_;
};
} // namespace monitor
} // namespace torch

View File

@ -133,8 +133,37 @@ void initMonitorBindings(PyObject* module) {
m,
"Stat",
R"DOC(
Parent class for all aggregating stat implementations.
Stat is used to compute summary statistics in a performant way over
fixed intervals. Stat logs the statistics as an Event once every
``window_size`` duration. When the window closes the stats are logged
via the event handlers as a ``torch.monitor.Stat`` event.
``window_size`` should be set to something relatively high to avoid a
huge number of events being logged. Ex: 60s. Stat uses millisecond
precision.
If ``max_samples`` is set, the stat will cap the number of samples per
window by discarding `add` calls once ``max_samples`` adds have
occurred. If it's not set, all ``add`` calls during the window will be
included. This is an optional field to make aggregations more directly
comparable across windows when the number of samples might vary.
When the Stat is destructed it will log any remaining data even if the
window hasn't elapsed.
)DOC")
.def(
py::init<
std::string,
std::vector<Aggregation>,
std::chrono::milliseconds,
int64_t>(),
py::arg("name"),
py::arg("aggregations"),
py::arg("window_size"),
py::arg("max_samples") = std::numeric_limits<int64_t>::max(),
R"DOC(
Constructs the ``Stat``.
)DOC")
.def(
"add",
&Stat<double>::add,
@ -165,47 +194,6 @@ void initMonitorBindings(PyObject* module) {
once the event has been logged.
)DOC");
py::class_<IntervalStat<double>, Stat<double>>(
m,
"IntervalStat",
R"DOC(
IntervalStat is a Stat that logs once every ``window_size`` duration. This
should be set to something relatively high to avoid a huge number of
events being logged. Ex: 60s.
The stat will be logged as an event on the next ``add`` call after the
window ends.
)DOC")
.def(
py::init<
std::string,
std::vector<Aggregation>,
std::chrono::milliseconds>(),
py::arg("name"),
py::arg("aggregations"),
py::arg("window_size"),
R"DOC(
Constructs the ``IntervalStat``.
)DOC");
py::class_<FixedCountStat<double>, Stat<double>>(
m,
"FixedCountStat",
R"DOC(
FixedCountStat is a Stat that logs every ``window_size`` number of
``add`` calls. For high performance stats this window size should be
fairly large to ensure that the event logging frequency is in the range
of 1s to 60s under normal usage. Core stats should error on the side of
logging less frequently.
)DOC")
.def(
py::init<std::string, std::vector<Aggregation>, int64_t>(),
py::arg("name"),
py::arg("aggregations"),
py::arg("window_size"),
R"DOC(
Constructs the ``FixedCountStat``.
)DOC");
py::class_<Event>(
m,
"Event",