mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-22 22:25:10 +08:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/72009 This simplifies the Stats interface by merging IntervalStat and FixedCountStat into a single Stat w/ a specific window size duration and an optional max samples per window. This allows for the original intention of having comparably sized windows (for statistical purposes) while also having a consistent output bandwidth. Test Plan: ``` buck test //caffe2/test:monitor //caffe2/test/cpp/monitor:monitor ``` Reviewed By: kiukchung Differential Revision: D33822956 fbshipit-source-id: a74782492421be613a1a8b14341b6fb2e8eeb8b4 (cherry picked from commit 293b94e0b4646521ffe047e5222c4bba7e688464)
306 lines
6.6 KiB
C++
306 lines
6.6 KiB
C++
#include <gtest/gtest.h>
|
|
|
|
#include <thread>
|
|
|
|
#include <torch/csrc/monitor/counters.h>
|
|
#include <torch/csrc/monitor/events.h>
|
|
|
|
using namespace torch::monitor;
|
|
|
|
TEST(MonitorTest, CounterDouble) {
|
|
Stat<double> a{
|
|
"a",
|
|
{Aggregation::MEAN, Aggregation::COUNT},
|
|
std::chrono::milliseconds(100000),
|
|
2,
|
|
};
|
|
a.add(5.0);
|
|
ASSERT_EQ(a.count(), 1);
|
|
a.add(6.0);
|
|
ASSERT_EQ(a.count(), 0);
|
|
|
|
auto stats = a.get();
|
|
std::unordered_map<Aggregation, double, AggregationHash> want = {
|
|
{Aggregation::MEAN, 5.5},
|
|
{Aggregation::COUNT, 2.0},
|
|
};
|
|
ASSERT_EQ(stats, want);
|
|
}
|
|
|
|
TEST(MonitorTest, CounterInt64Sum) {
|
|
Stat<int64_t> a{
|
|
"a",
|
|
{Aggregation::SUM},
|
|
std::chrono::milliseconds(100000),
|
|
2,
|
|
};
|
|
a.add(5);
|
|
a.add(6);
|
|
auto stats = a.get();
|
|
std::unordered_map<Aggregation, int64_t, AggregationHash> want = {
|
|
{Aggregation::SUM, 11},
|
|
};
|
|
ASSERT_EQ(stats, want);
|
|
}
|
|
|
|
TEST(MonitorTest, CounterInt64Value) {
|
|
Stat<int64_t> a{
|
|
"a",
|
|
{Aggregation::VALUE},
|
|
std::chrono::milliseconds(100000),
|
|
2,
|
|
};
|
|
a.add(5);
|
|
a.add(6);
|
|
auto stats = a.get();
|
|
std::unordered_map<Aggregation, int64_t, AggregationHash> want = {
|
|
{Aggregation::VALUE, 6},
|
|
};
|
|
ASSERT_EQ(stats, want);
|
|
}
|
|
|
|
TEST(MonitorTest, CounterInt64Mean) {
|
|
Stat<int64_t> a{
|
|
"a",
|
|
{Aggregation::MEAN},
|
|
std::chrono::milliseconds(100000),
|
|
2,
|
|
};
|
|
{
|
|
// zero samples case
|
|
auto stats = a.get();
|
|
std::unordered_map<Aggregation, int64_t, AggregationHash> want = {
|
|
{Aggregation::MEAN, 0},
|
|
};
|
|
ASSERT_EQ(stats, want);
|
|
}
|
|
|
|
a.add(0);
|
|
a.add(10);
|
|
|
|
{
|
|
auto stats = a.get();
|
|
std::unordered_map<Aggregation, int64_t, AggregationHash> want = {
|
|
{Aggregation::MEAN, 5},
|
|
};
|
|
ASSERT_EQ(stats, want);
|
|
}
|
|
}
|
|
|
|
TEST(MonitorTest, CounterInt64Count) {
|
|
Stat<int64_t> a{
|
|
"a",
|
|
{Aggregation::COUNT},
|
|
std::chrono::milliseconds(100000),
|
|
2,
|
|
};
|
|
ASSERT_EQ(a.count(), 0);
|
|
a.add(0);
|
|
ASSERT_EQ(a.count(), 1);
|
|
a.add(10);
|
|
ASSERT_EQ(a.count(), 0);
|
|
|
|
auto stats = a.get();
|
|
std::unordered_map<Aggregation, int64_t, AggregationHash> want = {
|
|
{Aggregation::COUNT, 2},
|
|
};
|
|
ASSERT_EQ(stats, want);
|
|
}
|
|
|
|
TEST(MonitorTest, CounterInt64MinMax) {
|
|
Stat<int64_t> a{
|
|
"a",
|
|
{Aggregation::MIN, Aggregation::MAX},
|
|
std::chrono::milliseconds(100000),
|
|
6,
|
|
};
|
|
{
|
|
auto stats = a.get();
|
|
std::unordered_map<Aggregation, int64_t, AggregationHash> want = {
|
|
{Aggregation::MAX, 0},
|
|
{Aggregation::MIN, 0},
|
|
};
|
|
ASSERT_EQ(stats, want);
|
|
}
|
|
|
|
a.add(0);
|
|
a.add(5);
|
|
a.add(-5);
|
|
a.add(-6);
|
|
a.add(9);
|
|
a.add(2);
|
|
{
|
|
auto stats = a.get();
|
|
std::unordered_map<Aggregation, int64_t, AggregationHash> want = {
|
|
{Aggregation::MAX, 9},
|
|
{Aggregation::MIN, -6},
|
|
};
|
|
ASSERT_EQ(stats, want);
|
|
}
|
|
}
|
|
|
|
TEST(MonitorTest, CounterInt64WindowSize) {
|
|
Stat<int64_t> a{
|
|
"a",
|
|
{Aggregation::COUNT, Aggregation::SUM},
|
|
std::chrono::milliseconds(100000),
|
|
/*windowSize=*/3,
|
|
};
|
|
a.add(1);
|
|
a.add(2);
|
|
ASSERT_EQ(a.count(), 2);
|
|
a.add(3);
|
|
ASSERT_EQ(a.count(), 0);
|
|
|
|
// after logging max for window, should be zero
|
|
a.add(4);
|
|
ASSERT_EQ(a.count(), 0);
|
|
|
|
auto stats = a.get();
|
|
std::unordered_map<Aggregation, int64_t, AggregationHash> want = {
|
|
{Aggregation::COUNT, 3},
|
|
{Aggregation::SUM, 6},
|
|
};
|
|
ASSERT_EQ(stats, want);
|
|
}
|
|
|
|
TEST(MonitorTest, CounterInt64WindowSizeHuge) {
|
|
Stat<int64_t> a{
|
|
"a",
|
|
{Aggregation::COUNT, Aggregation::SUM},
|
|
std::chrono::hours(24 * 365 * 10), // 10 years
|
|
/*windowSize=*/3,
|
|
};
|
|
a.add(1);
|
|
a.add(2);
|
|
ASSERT_EQ(a.count(), 2);
|
|
a.add(3);
|
|
ASSERT_EQ(a.count(), 0);
|
|
|
|
// after logging max for window, should be zero
|
|
a.add(4);
|
|
ASSERT_EQ(a.count(), 0);
|
|
|
|
auto stats = a.get();
|
|
std::unordered_map<Aggregation, int64_t, AggregationHash> want = {
|
|
{Aggregation::COUNT, 3},
|
|
{Aggregation::SUM, 6},
|
|
};
|
|
ASSERT_EQ(stats, want);
|
|
}
|
|
|
|
template <typename T>
|
|
struct TestStat : public Stat<T> {
|
|
uint64_t mockWindowId{1};
|
|
|
|
TestStat(
|
|
std::string name,
|
|
std::initializer_list<Aggregation> aggregations,
|
|
std::chrono::milliseconds windowSize,
|
|
int64_t maxSamples = std::numeric_limits<int64_t>::max())
|
|
: Stat<T>(name, aggregations, windowSize, maxSamples) {}
|
|
|
|
uint64_t currentWindowId() const override {
|
|
return mockWindowId;
|
|
}
|
|
};
|
|
|
|
struct AggregatingEventHandler : public EventHandler {
|
|
std::vector<Event> events;
|
|
|
|
void handle(const Event& e) override {
|
|
events.emplace_back(e);
|
|
}
|
|
};
|
|
|
|
template <typename T>
|
|
struct HandlerGuard {
|
|
std::shared_ptr<T> handler;
|
|
|
|
HandlerGuard() : handler(std::make_shared<T>()) {
|
|
registerEventHandler(handler);
|
|
}
|
|
|
|
~HandlerGuard() {
|
|
unregisterEventHandler(handler);
|
|
}
|
|
};
|
|
|
|
TEST(MonitorTest, Stat) {
|
|
HandlerGuard<AggregatingEventHandler> guard;
|
|
|
|
Stat<int64_t> a{
|
|
"a",
|
|
{Aggregation::COUNT, Aggregation::SUM},
|
|
std::chrono::milliseconds(1),
|
|
};
|
|
ASSERT_EQ(guard.handler->events.size(), 0);
|
|
|
|
a.add(1);
|
|
ASSERT_LE(a.count(), 1);
|
|
|
|
std::this_thread::sleep_for(std::chrono::milliseconds(2));
|
|
a.add(2);
|
|
ASSERT_LE(a.count(), 1);
|
|
|
|
ASSERT_GE(guard.handler->events.size(), 1);
|
|
ASSERT_LE(guard.handler->events.size(), 2);
|
|
}
|
|
|
|
TEST(MonitorTest, StatEvent) {
|
|
HandlerGuard<AggregatingEventHandler> guard;
|
|
|
|
TestStat<int64_t> a{
|
|
"a",
|
|
{Aggregation::COUNT, Aggregation::SUM},
|
|
std::chrono::milliseconds(1),
|
|
};
|
|
ASSERT_EQ(guard.handler->events.size(), 0);
|
|
|
|
a.add(1);
|
|
ASSERT_EQ(a.count(), 1);
|
|
a.add(2);
|
|
ASSERT_EQ(a.count(), 2);
|
|
ASSERT_EQ(guard.handler->events.size(), 0);
|
|
|
|
a.mockWindowId = 100;
|
|
|
|
a.add(3);
|
|
ASSERT_LE(a.count(), 1);
|
|
|
|
ASSERT_EQ(guard.handler->events.size(), 1);
|
|
Event e = guard.handler->events.at(0);
|
|
ASSERT_EQ(e.name, "torch.monitor.Stat");
|
|
ASSERT_NE(e.timestamp, std::chrono::system_clock::time_point{});
|
|
std::unordered_map<std::string, data_value_t> data{
|
|
{"a.sum", 3L},
|
|
{"a.count", 2L},
|
|
};
|
|
ASSERT_EQ(e.data, data);
|
|
}
|
|
|
|
TEST(MonitorTest, StatEventDestruction) {
|
|
HandlerGuard<AggregatingEventHandler> guard;
|
|
|
|
{
|
|
TestStat<int64_t> a{
|
|
"a",
|
|
{Aggregation::COUNT, Aggregation::SUM},
|
|
std::chrono::hours(10),
|
|
};
|
|
a.add(1);
|
|
ASSERT_EQ(a.count(), 1);
|
|
ASSERT_EQ(guard.handler->events.size(), 0);
|
|
}
|
|
ASSERT_EQ(guard.handler->events.size(), 1);
|
|
|
|
Event e = guard.handler->events.at(0);
|
|
ASSERT_EQ(e.name, "torch.monitor.Stat");
|
|
ASSERT_NE(e.timestamp, std::chrono::system_clock::time_point{});
|
|
std::unordered_map<std::string, data_value_t> data{
|
|
{"a.sum", 1L},
|
|
{"a.count", 1L},
|
|
};
|
|
ASSERT_EQ(e.data, data);
|
|
}
|