mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
torch/monitor: merge Interval and FixedCount stats (#72009)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/72009 This simplifies the Stats interface by merging IntervalStat and FixedCountStat into a single Stat w/ a specific window size duration and an optional max samples per window. This allows for the original intention of having comparably sized windows (for statistical purposes) while also having a consistent output bandwidth. Test Plan: ``` buck test //caffe2/test:monitor //caffe2/test/cpp/monitor:monitor ``` Reviewed By: kiukchung Differential Revision: D33822956 fbshipit-source-id: a74782492421be613a1a8b14341b6fb2e8eeb8b4 (cherry picked from commit 293b94e0b4646521ffe047e5222c4bba7e688464)
This commit is contained in:
committed by
PyTorch MergeBot
parent
a18cfb790d
commit
6208c2800e
@ -30,13 +30,6 @@ API Reference
|
|||||||
|
|
||||||
.. autoclass:: torch.monitor.Stat
|
.. autoclass:: torch.monitor.Stat
|
||||||
:members:
|
:members:
|
||||||
|
|
||||||
.. autoclass:: torch.monitor.IntervalStat
|
|
||||||
:members: +add, count, name
|
|
||||||
:special-members: __init__
|
|
||||||
|
|
||||||
.. autoclass:: torch.monitor.FixedCountStat
|
|
||||||
:members: +add, count, name
|
|
||||||
:special-members: __init__
|
:special-members: __init__
|
||||||
|
|
||||||
.. autoclass:: torch.monitor.data_value_t
|
.. autoclass:: torch.monitor.data_value_t
|
||||||
|
@ -8,9 +8,10 @@
|
|||||||
using namespace torch::monitor;
|
using namespace torch::monitor;
|
||||||
|
|
||||||
TEST(MonitorTest, CounterDouble) {
|
TEST(MonitorTest, CounterDouble) {
|
||||||
FixedCountStat<double> a{
|
Stat<double> a{
|
||||||
"a",
|
"a",
|
||||||
{Aggregation::MEAN, Aggregation::COUNT},
|
{Aggregation::MEAN, Aggregation::COUNT},
|
||||||
|
std::chrono::milliseconds(100000),
|
||||||
2,
|
2,
|
||||||
};
|
};
|
||||||
a.add(5.0);
|
a.add(5.0);
|
||||||
@ -27,9 +28,10 @@ TEST(MonitorTest, CounterDouble) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
TEST(MonitorTest, CounterInt64Sum) {
|
TEST(MonitorTest, CounterInt64Sum) {
|
||||||
FixedCountStat<int64_t> a{
|
Stat<int64_t> a{
|
||||||
"a",
|
"a",
|
||||||
{Aggregation::SUM},
|
{Aggregation::SUM},
|
||||||
|
std::chrono::milliseconds(100000),
|
||||||
2,
|
2,
|
||||||
};
|
};
|
||||||
a.add(5);
|
a.add(5);
|
||||||
@ -42,9 +44,10 @@ TEST(MonitorTest, CounterInt64Sum) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
TEST(MonitorTest, CounterInt64Value) {
|
TEST(MonitorTest, CounterInt64Value) {
|
||||||
FixedCountStat<int64_t> a{
|
Stat<int64_t> a{
|
||||||
"a",
|
"a",
|
||||||
{Aggregation::VALUE},
|
{Aggregation::VALUE},
|
||||||
|
std::chrono::milliseconds(100000),
|
||||||
2,
|
2,
|
||||||
};
|
};
|
||||||
a.add(5);
|
a.add(5);
|
||||||
@ -57,9 +60,10 @@ TEST(MonitorTest, CounterInt64Value) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
TEST(MonitorTest, CounterInt64Mean) {
|
TEST(MonitorTest, CounterInt64Mean) {
|
||||||
FixedCountStat<int64_t> a{
|
Stat<int64_t> a{
|
||||||
"a",
|
"a",
|
||||||
{Aggregation::MEAN},
|
{Aggregation::MEAN},
|
||||||
|
std::chrono::milliseconds(100000),
|
||||||
2,
|
2,
|
||||||
};
|
};
|
||||||
{
|
{
|
||||||
@ -84,9 +88,10 @@ TEST(MonitorTest, CounterInt64Mean) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
TEST(MonitorTest, CounterInt64Count) {
|
TEST(MonitorTest, CounterInt64Count) {
|
||||||
FixedCountStat<int64_t> a{
|
Stat<int64_t> a{
|
||||||
"a",
|
"a",
|
||||||
{Aggregation::COUNT},
|
{Aggregation::COUNT},
|
||||||
|
std::chrono::milliseconds(100000),
|
||||||
2,
|
2,
|
||||||
};
|
};
|
||||||
ASSERT_EQ(a.count(), 0);
|
ASSERT_EQ(a.count(), 0);
|
||||||
@ -103,9 +108,10 @@ TEST(MonitorTest, CounterInt64Count) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
TEST(MonitorTest, CounterInt64MinMax) {
|
TEST(MonitorTest, CounterInt64MinMax) {
|
||||||
FixedCountStat<int64_t> a{
|
Stat<int64_t> a{
|
||||||
"a",
|
"a",
|
||||||
{Aggregation::MIN, Aggregation::MAX},
|
{Aggregation::MIN, Aggregation::MAX},
|
||||||
|
std::chrono::milliseconds(100000),
|
||||||
6,
|
6,
|
||||||
};
|
};
|
||||||
{
|
{
|
||||||
@ -134,9 +140,10 @@ TEST(MonitorTest, CounterInt64MinMax) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
TEST(MonitorTest, CounterInt64WindowSize) {
|
TEST(MonitorTest, CounterInt64WindowSize) {
|
||||||
FixedCountStat<int64_t> a{
|
Stat<int64_t> a{
|
||||||
"a",
|
"a",
|
||||||
{Aggregation::COUNT, Aggregation::SUM},
|
{Aggregation::COUNT, Aggregation::SUM},
|
||||||
|
std::chrono::milliseconds(100000),
|
||||||
/*windowSize=*/3,
|
/*windowSize=*/3,
|
||||||
};
|
};
|
||||||
a.add(1);
|
a.add(1);
|
||||||
@ -145,8 +152,34 @@ TEST(MonitorTest, CounterInt64WindowSize) {
|
|||||||
a.add(3);
|
a.add(3);
|
||||||
ASSERT_EQ(a.count(), 0);
|
ASSERT_EQ(a.count(), 0);
|
||||||
|
|
||||||
|
// after logging max for window, should be zero
|
||||||
a.add(4);
|
a.add(4);
|
||||||
ASSERT_EQ(a.count(), 1);
|
ASSERT_EQ(a.count(), 0);
|
||||||
|
|
||||||
|
auto stats = a.get();
|
||||||
|
std::unordered_map<Aggregation, int64_t, AggregationHash> want = {
|
||||||
|
{Aggregation::COUNT, 3},
|
||||||
|
{Aggregation::SUM, 6},
|
||||||
|
};
|
||||||
|
ASSERT_EQ(stats, want);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(MonitorTest, CounterInt64WindowSizeHuge) {
|
||||||
|
Stat<int64_t> a{
|
||||||
|
"a",
|
||||||
|
{Aggregation::COUNT, Aggregation::SUM},
|
||||||
|
std::chrono::hours(24 * 365 * 10), // 10 years
|
||||||
|
/*windowSize=*/3,
|
||||||
|
};
|
||||||
|
a.add(1);
|
||||||
|
a.add(2);
|
||||||
|
ASSERT_EQ(a.count(), 2);
|
||||||
|
a.add(3);
|
||||||
|
ASSERT_EQ(a.count(), 0);
|
||||||
|
|
||||||
|
// after logging max for window, should be zero
|
||||||
|
a.add(4);
|
||||||
|
ASSERT_EQ(a.count(), 0);
|
||||||
|
|
||||||
auto stats = a.get();
|
auto stats = a.get();
|
||||||
std::unordered_map<Aggregation, int64_t, AggregationHash> want = {
|
std::unordered_map<Aggregation, int64_t, AggregationHash> want = {
|
||||||
@ -157,14 +190,15 @@ TEST(MonitorTest, CounterInt64WindowSize) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
struct TestIntervalStat : public IntervalStat<T> {
|
struct TestStat : public Stat<T> {
|
||||||
uint64_t mockWindowId{0};
|
uint64_t mockWindowId{1};
|
||||||
|
|
||||||
TestIntervalStat(
|
TestStat(
|
||||||
std::string name,
|
std::string name,
|
||||||
std::initializer_list<Aggregation> aggregations,
|
std::initializer_list<Aggregation> aggregations,
|
||||||
std::chrono::milliseconds windowSize)
|
std::chrono::milliseconds windowSize,
|
||||||
: IntervalStat<T>(name, aggregations, windowSize) {}
|
int64_t maxSamples = std::numeric_limits<int64_t>::max())
|
||||||
|
: Stat<T>(name, aggregations, windowSize, maxSamples) {}
|
||||||
|
|
||||||
uint64_t currentWindowId() const override {
|
uint64_t currentWindowId() const override {
|
||||||
return mockWindowId;
|
return mockWindowId;
|
||||||
@ -192,10 +226,10 @@ struct HandlerGuard {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
TEST(MonitorTest, IntervalStat) {
|
TEST(MonitorTest, Stat) {
|
||||||
HandlerGuard<AggregatingEventHandler> guard;
|
HandlerGuard<AggregatingEventHandler> guard;
|
||||||
|
|
||||||
IntervalStat<int64_t> a{
|
Stat<int64_t> a{
|
||||||
"a",
|
"a",
|
||||||
{Aggregation::COUNT, Aggregation::SUM},
|
{Aggregation::COUNT, Aggregation::SUM},
|
||||||
std::chrono::milliseconds(1),
|
std::chrono::milliseconds(1),
|
||||||
@ -213,10 +247,10 @@ TEST(MonitorTest, IntervalStat) {
|
|||||||
ASSERT_LE(guard.handler->events.size(), 2);
|
ASSERT_LE(guard.handler->events.size(), 2);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(MonitorTest, IntervalStatEvent) {
|
TEST(MonitorTest, StatEvent) {
|
||||||
HandlerGuard<AggregatingEventHandler> guard;
|
HandlerGuard<AggregatingEventHandler> guard;
|
||||||
|
|
||||||
TestIntervalStat<int64_t> a{
|
TestStat<int64_t> a{
|
||||||
"a",
|
"a",
|
||||||
{Aggregation::COUNT, Aggregation::SUM},
|
{Aggregation::COUNT, Aggregation::SUM},
|
||||||
std::chrono::milliseconds(1),
|
std::chrono::milliseconds(1),
|
||||||
@ -245,11 +279,11 @@ TEST(MonitorTest, IntervalStatEvent) {
|
|||||||
ASSERT_EQ(e.data, data);
|
ASSERT_EQ(e.data, data);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(MonitorTest, IntervalStatEventDestruction) {
|
TEST(MonitorTest, StatEventDestruction) {
|
||||||
HandlerGuard<AggregatingEventHandler> guard;
|
HandlerGuard<AggregatingEventHandler> guard;
|
||||||
|
|
||||||
{
|
{
|
||||||
TestIntervalStat<int64_t> a{
|
TestStat<int64_t> a{
|
||||||
"a",
|
"a",
|
||||||
{Aggregation::COUNT, Aggregation::SUM},
|
{Aggregation::COUNT, Aggregation::SUM},
|
||||||
std::chrono::hours(10),
|
std::chrono::hours(10),
|
||||||
@ -269,59 +303,3 @@ TEST(MonitorTest, IntervalStatEventDestruction) {
|
|||||||
};
|
};
|
||||||
ASSERT_EQ(e.data, data);
|
ASSERT_EQ(e.data, data);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(MonitorTest, FixedCountStatEvent) {
|
|
||||||
HandlerGuard<AggregatingEventHandler> guard;
|
|
||||||
|
|
||||||
FixedCountStat<int64_t> a{
|
|
||||||
"a",
|
|
||||||
{Aggregation::COUNT, Aggregation::SUM},
|
|
||||||
3,
|
|
||||||
};
|
|
||||||
ASSERT_EQ(guard.handler->events.size(), 0);
|
|
||||||
|
|
||||||
a.add(1);
|
|
||||||
ASSERT_EQ(a.count(), 1);
|
|
||||||
a.add(2);
|
|
||||||
ASSERT_EQ(a.count(), 2);
|
|
||||||
ASSERT_EQ(guard.handler->events.size(), 0);
|
|
||||||
|
|
||||||
a.add(1);
|
|
||||||
ASSERT_EQ(a.count(), 0);
|
|
||||||
ASSERT_EQ(guard.handler->events.size(), 1);
|
|
||||||
|
|
||||||
Event e = guard.handler->events.at(0);
|
|
||||||
ASSERT_EQ(e.name, "torch.monitor.Stat");
|
|
||||||
ASSERT_NE(e.timestamp, std::chrono::system_clock::time_point{});
|
|
||||||
std::unordered_map<std::string, data_value_t> data{
|
|
||||||
{"a.sum", 4L},
|
|
||||||
{"a.count", 3L},
|
|
||||||
};
|
|
||||||
ASSERT_EQ(e.data, data);
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST(MonitorTest, FixedCountStatEventDestruction) {
|
|
||||||
HandlerGuard<AggregatingEventHandler> guard;
|
|
||||||
|
|
||||||
{
|
|
||||||
FixedCountStat<int64_t> a{
|
|
||||||
"a",
|
|
||||||
{Aggregation::COUNT, Aggregation::SUM},
|
|
||||||
3,
|
|
||||||
};
|
|
||||||
ASSERT_EQ(guard.handler->events.size(), 0);
|
|
||||||
a.add(1);
|
|
||||||
ASSERT_EQ(a.count(), 1);
|
|
||||||
ASSERT_EQ(guard.handler->events.size(), 0);
|
|
||||||
}
|
|
||||||
ASSERT_EQ(guard.handler->events.size(), 1);
|
|
||||||
|
|
||||||
Event e = guard.handler->events.at(0);
|
|
||||||
ASSERT_EQ(e.name, "torch.monitor.Stat");
|
|
||||||
ASSERT_NE(e.timestamp, std::chrono::system_clock::time_point{});
|
|
||||||
std::unordered_map<std::string, data_value_t> data{
|
|
||||||
{"a.sum", 1L},
|
|
||||||
{"a.count", 1L},
|
|
||||||
};
|
|
||||||
ASSERT_EQ(e.data, data);
|
|
||||||
}
|
|
||||||
|
@ -10,8 +10,6 @@ import time
|
|||||||
|
|
||||||
from torch.monitor import (
|
from torch.monitor import (
|
||||||
Aggregation,
|
Aggregation,
|
||||||
FixedCountStat,
|
|
||||||
IntervalStat,
|
|
||||||
Event,
|
Event,
|
||||||
log_event,
|
log_event,
|
||||||
register_event_handler,
|
register_event_handler,
|
||||||
@ -28,12 +26,11 @@ class TestMonitor(TestCase):
|
|||||||
events.append(event)
|
events.append(event)
|
||||||
|
|
||||||
handle = register_event_handler(handler)
|
handle = register_event_handler(handler)
|
||||||
s = IntervalStat(
|
s = Stat(
|
||||||
"asdf",
|
"asdf",
|
||||||
(Aggregation.SUM, Aggregation.COUNT),
|
(Aggregation.SUM, Aggregation.COUNT),
|
||||||
timedelta(milliseconds=1),
|
timedelta(milliseconds=1),
|
||||||
)
|
)
|
||||||
self.assertIsInstance(s, Stat)
|
|
||||||
self.assertEqual(s.name, "asdf")
|
self.assertEqual(s.name, "asdf")
|
||||||
|
|
||||||
s.add(2)
|
s.add(2)
|
||||||
@ -48,12 +45,12 @@ class TestMonitor(TestCase):
|
|||||||
unregister_event_handler(handle)
|
unregister_event_handler(handle)
|
||||||
|
|
||||||
def test_fixed_count_stat(self) -> None:
|
def test_fixed_count_stat(self) -> None:
|
||||||
s = FixedCountStat(
|
s = Stat(
|
||||||
"asdf",
|
"asdf",
|
||||||
(Aggregation.SUM, Aggregation.COUNT),
|
(Aggregation.SUM, Aggregation.COUNT),
|
||||||
|
timedelta(hours=100),
|
||||||
3,
|
3,
|
||||||
)
|
)
|
||||||
self.assertIsInstance(s, Stat)
|
|
||||||
s.add(1)
|
s.add(1)
|
||||||
s.add(2)
|
s.add(2)
|
||||||
name = s.name
|
name = s.name
|
||||||
@ -126,10 +123,11 @@ class TestMonitorTensorboard(TestCase):
|
|||||||
with self.create_summary_writer() as w:
|
with self.create_summary_writer() as w:
|
||||||
handle = register_event_handler(TensorboardEventHandler(w))
|
handle = register_event_handler(TensorboardEventHandler(w))
|
||||||
|
|
||||||
s = FixedCountStat(
|
s = Stat(
|
||||||
"asdf",
|
"asdf",
|
||||||
(Aggregation.SUM, Aggregation.COUNT),
|
(Aggregation.SUM, Aggregation.COUNT),
|
||||||
2,
|
timedelta(hours=1),
|
||||||
|
5,
|
||||||
)
|
)
|
||||||
for i in range(10):
|
for i in range(10):
|
||||||
s.add(i)
|
s.add(i)
|
||||||
@ -150,8 +148,8 @@ class TestMonitorTensorboard(TestCase):
|
|||||||
tag: [e.tensor_proto.float_val[0] for e in events] for tag, events in raw_result.items()
|
tag: [e.tensor_proto.float_val[0] for e in events] for tag, events in raw_result.items()
|
||||||
}
|
}
|
||||||
self.assertEqual(scalars, {
|
self.assertEqual(scalars, {
|
||||||
"asdf.sum": [1, 5, 9, 13, 17],
|
"asdf.sum": [10],
|
||||||
"asdf.count": [2, 2, 2, 2, 2],
|
"asdf.count": [5],
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|
||||||
|
@ -15,22 +15,13 @@ class Aggregation(Enum):
|
|||||||
class Stat:
|
class Stat:
|
||||||
name: str
|
name: str
|
||||||
count: int
|
count: int
|
||||||
|
def __init__(
|
||||||
|
self, name: str, aggregations: List[Aggregation], window_size: int,
|
||||||
|
max_samples: int = -1,
|
||||||
|
) -> None: ...
|
||||||
def add(self, v: float) -> None: ...
|
def add(self, v: float) -> None: ...
|
||||||
def get(self) -> Dict[Aggregation, float]: ...
|
def get(self) -> Dict[Aggregation, float]: ...
|
||||||
|
|
||||||
class IntervalStat(Stat):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
name: str,
|
|
||||||
aggregations: List[Aggregation],
|
|
||||||
window_size: datetime.timedelta,
|
|
||||||
) -> None: ...
|
|
||||||
|
|
||||||
class FixedCountStat(Stat):
|
|
||||||
def __init__(
|
|
||||||
self, name: str, aggregations: List[Aggregation], window_size: int
|
|
||||||
) -> None: ...
|
|
||||||
|
|
||||||
class Event:
|
class Event:
|
||||||
name: str
|
name: str
|
||||||
timestamp: datetime.datetime
|
timestamp: datetime.datetime
|
||||||
|
@ -69,10 +69,19 @@ void TORCH_API unregisterStat(Stat<double>* stat);
|
|||||||
void TORCH_API unregisterStat(Stat<int64_t>* stat);
|
void TORCH_API unregisterStat(Stat<int64_t>* stat);
|
||||||
} // namespace detail
|
} // namespace detail
|
||||||
|
|
||||||
// Stat is a base class for stats. These stats are used to compute summary
|
// Stat is used to compute summary statistics in a performant way over fixed
|
||||||
// statistics in a performant way over repeating intervals. When the window
|
// intervals. Stat logs the statistics as an Event once every `windowSize`
|
||||||
// closes the stats are logged via the event handlers as a `torch.monitor.Stat`
|
// duration. When the window closes the stats are logged via the event handlers
|
||||||
// event.
|
// as a `torch.monitor.Stat` event.
|
||||||
|
//
|
||||||
|
// `windowSize` should be set to something relatively high to avoid a huge
|
||||||
|
// number of events being logged. Ex: 60s. Stat uses millisecond precision.
|
||||||
|
//
|
||||||
|
// If maxSamples is set, the stat will cap the number of samples per window by
|
||||||
|
// discarding `add` calls once `maxSamples` adds have occurred. If it's not set,
|
||||||
|
// all `add` calls during the window will be included.
|
||||||
|
// This is an optional field to make aggregations more directly comparable
|
||||||
|
// across windows when the number of samples might vary.
|
||||||
//
|
//
|
||||||
// Stats support double and int64_t data types depending on what needs to be
|
// Stats support double and int64_t data types depending on what needs to be
|
||||||
// logged and needs to be templatized with one of them.
|
// logged and needs to be templatized with one of them.
|
||||||
@ -91,8 +100,27 @@ class Stat {
|
|||||||
};
|
};
|
||||||
|
|
||||||
public:
|
public:
|
||||||
Stat(std::string name, std::vector<Aggregation> aggregations)
|
Stat(
|
||||||
: name_(std::move(name)), aggregations_(merge(aggregations)) {
|
std::string name,
|
||||||
|
std::initializer_list<Aggregation> aggregations,
|
||||||
|
std::chrono::milliseconds windowSize,
|
||||||
|
int64_t maxSamples = std::numeric_limits<int64_t>::max())
|
||||||
|
: name_(std::move(name)),
|
||||||
|
aggregations_(merge(aggregations)),
|
||||||
|
windowSize_(windowSize),
|
||||||
|
maxSamples_(maxSamples) {
|
||||||
|
detail::registerStat(this);
|
||||||
|
}
|
||||||
|
|
||||||
|
Stat(
|
||||||
|
std::string name,
|
||||||
|
std::vector<Aggregation> aggregations,
|
||||||
|
std::chrono::milliseconds windowSize,
|
||||||
|
int64_t maxSamples = std::numeric_limits<int64_t>::max())
|
||||||
|
: name_(std::move(name)),
|
||||||
|
aggregations_(merge(aggregations)),
|
||||||
|
windowSize_(windowSize),
|
||||||
|
maxSamples_(maxSamples) {
|
||||||
detail::registerStat(this);
|
detail::registerStat(this);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -110,6 +138,10 @@ class Stat {
|
|||||||
std::lock_guard<std::mutex> guard(mu_);
|
std::lock_guard<std::mutex> guard(mu_);
|
||||||
maybeLogLocked();
|
maybeLogLocked();
|
||||||
|
|
||||||
|
if (alreadyLogged()) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
if (aggregations_.test(static_cast<int>(Aggregation::VALUE))) {
|
if (aggregations_.test(static_cast<int>(Aggregation::VALUE))) {
|
||||||
current_.value = v;
|
current_.value = v;
|
||||||
}
|
}
|
||||||
@ -150,7 +182,29 @@ class Stat {
|
|||||||
}
|
}
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
virtual void maybeLogLocked() = 0;
|
virtual uint64_t currentWindowId() const {
|
||||||
|
std::chrono::milliseconds now =
|
||||||
|
std::chrono::duration_cast<std::chrono::milliseconds>(
|
||||||
|
std::chrono::steady_clock::now().time_since_epoch());
|
||||||
|
|
||||||
|
// always returns a currentWindowId of at least 1 to avoid 0 window issues
|
||||||
|
return (now / windowSize_) + 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
bool alreadyLogged() {
|
||||||
|
return lastLoggedWindowId_ == currentWindowId();
|
||||||
|
}
|
||||||
|
|
||||||
|
void maybeLogLocked() {
|
||||||
|
auto windowId = currentWindowId();
|
||||||
|
bool shouldLog = windowId_ != windowId || current_.count >= maxSamples_;
|
||||||
|
if (shouldLog && !alreadyLogged()) {
|
||||||
|
logLocked();
|
||||||
|
lastLoggedWindowId_ = windowId_;
|
||||||
|
windowId_ = windowId;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void logLocked() {
|
void logLocked() {
|
||||||
prev_ = current_;
|
prev_ = current_;
|
||||||
@ -215,72 +269,11 @@ class Stat {
|
|||||||
std::mutex mu_;
|
std::mutex mu_;
|
||||||
Values current_;
|
Values current_;
|
||||||
Values prev_;
|
Values prev_;
|
||||||
};
|
|
||||||
|
|
||||||
// IntervalStat is a Stat that logs the stat once every `windowSize` duration.
|
|
||||||
// This should be set to something relatively high to avoid a huge number of
|
|
||||||
// events being logged. Ex: 60s.
|
|
||||||
template <typename T>
|
|
||||||
class IntervalStat : public Stat<T> {
|
|
||||||
public:
|
|
||||||
IntervalStat(
|
|
||||||
std::string name,
|
|
||||||
std::initializer_list<Aggregation> aggregations,
|
|
||||||
std::chrono::milliseconds windowSize)
|
|
||||||
: Stat<T>(std::move(name), aggregations), windowSize_(windowSize) {}
|
|
||||||
|
|
||||||
IntervalStat(
|
|
||||||
std::string name,
|
|
||||||
std::vector<Aggregation> aggregations,
|
|
||||||
std::chrono::milliseconds windowSize)
|
|
||||||
: Stat<T>(std::move(name), aggregations), windowSize_(windowSize) {}
|
|
||||||
|
|
||||||
protected:
|
|
||||||
virtual uint64_t currentWindowId() const {
|
|
||||||
auto now = std::chrono::steady_clock::now().time_since_epoch();
|
|
||||||
return now / windowSize_;
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
|
||||||
void maybeLogLocked() override {
|
|
||||||
auto windowId = currentWindowId();
|
|
||||||
if (windowId_ != windowId) {
|
|
||||||
Stat<T>::logLocked();
|
|
||||||
windowId_ = windowId;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
uint64_t windowId_{0};
|
uint64_t windowId_{0};
|
||||||
|
uint64_t lastLoggedWindowId_{0};
|
||||||
const std::chrono::milliseconds windowSize_;
|
const std::chrono::milliseconds windowSize_;
|
||||||
};
|
const int64_t maxSamples_;
|
||||||
|
|
||||||
// FixedCountStat is a Stat that logs the stat every `windowSize` number of add
|
|
||||||
// calls. For high performance stats this window size should be fairly large to
|
|
||||||
// ensure that the event logging frequency is in the range of 1s to 60s under
|
|
||||||
// normal usage. Core stats should error on the side of less frequent.
|
|
||||||
template <typename T>
|
|
||||||
class FixedCountStat : public Stat<T> {
|
|
||||||
public:
|
|
||||||
FixedCountStat(
|
|
||||||
std::string name,
|
|
||||||
std::initializer_list<Aggregation> aggregations,
|
|
||||||
int64_t windowSize)
|
|
||||||
: Stat<T>(std::move(name), aggregations), windowSize_(windowSize) {}
|
|
||||||
|
|
||||||
FixedCountStat(
|
|
||||||
std::string name,
|
|
||||||
std::vector<Aggregation> aggregations,
|
|
||||||
int64_t windowSize)
|
|
||||||
: Stat<T>(std::move(name), aggregations), windowSize_(windowSize) {}
|
|
||||||
|
|
||||||
private:
|
|
||||||
void maybeLogLocked() override {
|
|
||||||
if (Stat<T>::current_.count >= windowSize_) {
|
|
||||||
Stat<T>::logLocked();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
const int64_t windowSize_;
|
|
||||||
};
|
};
|
||||||
} // namespace monitor
|
} // namespace monitor
|
||||||
} // namespace torch
|
} // namespace torch
|
||||||
|
@ -133,7 +133,36 @@ void initMonitorBindings(PyObject* module) {
|
|||||||
m,
|
m,
|
||||||
"Stat",
|
"Stat",
|
||||||
R"DOC(
|
R"DOC(
|
||||||
Parent class for all aggregating stat implementations.
|
Stat is used to compute summary statistics in a performant way over
|
||||||
|
fixed intervals. Stat logs the statistics as an Event once every
|
||||||
|
``window_size`` duration. When the window closes the stats are logged
|
||||||
|
via the event handlers as a ``torch.monitor.Stat`` event.
|
||||||
|
|
||||||
|
``window_size`` should be set to something relatively high to avoid a
|
||||||
|
huge number of events being logged. Ex: 60s. Stat uses millisecond
|
||||||
|
precision.
|
||||||
|
|
||||||
|
If ``max_samples`` is set, the stat will cap the number of samples per
|
||||||
|
window by discarding `add` calls once ``max_samples`` adds have
|
||||||
|
occurred. If it's not set, all ``add`` calls during the window will be
|
||||||
|
included. This is an optional field to make aggregations more directly
|
||||||
|
comparable across windows when the number of samples might vary.
|
||||||
|
|
||||||
|
When the Stat is destructed it will log any remaining data even if the
|
||||||
|
window hasn't elapsed.
|
||||||
|
)DOC")
|
||||||
|
.def(
|
||||||
|
py::init<
|
||||||
|
std::string,
|
||||||
|
std::vector<Aggregation>,
|
||||||
|
std::chrono::milliseconds,
|
||||||
|
int64_t>(),
|
||||||
|
py::arg("name"),
|
||||||
|
py::arg("aggregations"),
|
||||||
|
py::arg("window_size"),
|
||||||
|
py::arg("max_samples") = std::numeric_limits<int64_t>::max(),
|
||||||
|
R"DOC(
|
||||||
|
Constructs the ``Stat``.
|
||||||
)DOC")
|
)DOC")
|
||||||
.def(
|
.def(
|
||||||
"add",
|
"add",
|
||||||
@ -165,47 +194,6 @@ void initMonitorBindings(PyObject* module) {
|
|||||||
once the event has been logged.
|
once the event has been logged.
|
||||||
)DOC");
|
)DOC");
|
||||||
|
|
||||||
py::class_<IntervalStat<double>, Stat<double>>(
|
|
||||||
m,
|
|
||||||
"IntervalStat",
|
|
||||||
R"DOC(
|
|
||||||
IntervalStat is a Stat that logs once every ``window_size`` duration. This
|
|
||||||
should be set to something relatively high to avoid a huge number of
|
|
||||||
events being logged. Ex: 60s.
|
|
||||||
The stat will be logged as an event on the next ``add`` call after the
|
|
||||||
window ends.
|
|
||||||
)DOC")
|
|
||||||
.def(
|
|
||||||
py::init<
|
|
||||||
std::string,
|
|
||||||
std::vector<Aggregation>,
|
|
||||||
std::chrono::milliseconds>(),
|
|
||||||
py::arg("name"),
|
|
||||||
py::arg("aggregations"),
|
|
||||||
py::arg("window_size"),
|
|
||||||
R"DOC(
|
|
||||||
Constructs the ``IntervalStat``.
|
|
||||||
)DOC");
|
|
||||||
|
|
||||||
py::class_<FixedCountStat<double>, Stat<double>>(
|
|
||||||
m,
|
|
||||||
"FixedCountStat",
|
|
||||||
R"DOC(
|
|
||||||
FixedCountStat is a Stat that logs every ``window_size`` number of
|
|
||||||
``add`` calls. For high performance stats this window size should be
|
|
||||||
fairly large to ensure that the event logging frequency is in the range
|
|
||||||
of 1s to 60s under normal usage. Core stats should error on the side of
|
|
||||||
logging less frequently.
|
|
||||||
)DOC")
|
|
||||||
.def(
|
|
||||||
py::init<std::string, std::vector<Aggregation>, int64_t>(),
|
|
||||||
py::arg("name"),
|
|
||||||
py::arg("aggregations"),
|
|
||||||
py::arg("window_size"),
|
|
||||||
R"DOC(
|
|
||||||
Constructs the ``FixedCountStat``.
|
|
||||||
)DOC");
|
|
||||||
|
|
||||||
py::class_<Event>(
|
py::class_<Event>(
|
||||||
m,
|
m,
|
||||||
"Event",
|
"Event",
|
||||||
|
Reference in New Issue
Block a user