#include #include #include #include namespace torch { namespace monitor { const char* aggregationName(Aggregation agg) { switch (agg) { case Aggregation::NONE: return "none"; case Aggregation::VALUE: return "value"; case Aggregation::MEAN: return "mean"; case Aggregation::COUNT: return "count"; case Aggregation::SUM: return "sum"; case Aggregation::MAX: return "max"; case Aggregation::MIN: return "min"; default: throw std::runtime_error( "unknown aggregation: " + std::to_string(static_cast(agg))); } } namespace { struct Stats { std::mutex mu; std::unordered_set*> doubles; std::unordered_set*> int64s; }; Stats& stats() { static Stats stats; return stats; } } // namespace namespace detail { void registerStat(Stat* stat) { std::lock_guard guard(stats().mu); stats().doubles.insert(stat); } void registerStat(Stat* stat) { std::lock_guard guard(stats().mu); stats().int64s.insert(stat); } void unregisterStat(Stat* stat) { std::lock_guard guard(stats().mu); stats().doubles.erase(stat); } void unregisterStat(Stat* stat) { std::lock_guard guard(stats().mu); stats().int64s.erase(stat); } } // namespace detail } // namespace monitor } // namespace torch