#pragma once #include #include #include #include #include #include #include namespace torch { namespace monitor { constexpr int NUM_AGGREGATIONS = 7; // Aggregation is the list of possible aggregations for Stats. // These use bitwise flags so they can be efficiently stored. enum class C10_API_ENUM Aggregation { // NONE means no aggregations are set. NONE = 0, // VALUE exports the most recently set value. VALUE = 1, // MEAN computes the mean of the set values within the window. Zero if no // values. MEAN = 2, // COUNT tracks the number of times a value is set within the window. COUNT = 3, // SUM computes the sum of the values set within the window. SUM = 4, // MIN computes the minimum of the values set within the window. Zero if no // values. MAX = 5, // MAX computes the maximum of the values set within the window. Zero if no // values. MIN = 6, }; struct TORCH_API AggregationHash { template std::size_t operator()(T t) const { return static_cast(t); } }; // aggregationName returns the human readable name corresponding to the // aggregation. TORCH_API const char* aggregationName(Aggregation agg); template class Stat; namespace { template inline std::bitset merge(T& list) { std::bitset a; for (Aggregation b : list) { a.set(static_cast(b)); } return a; } } // namespace namespace detail { void TORCH_API registerStat(Stat* stat); void TORCH_API registerStat(Stat* stat); void TORCH_API unregisterStat(Stat* stat); void TORCH_API unregisterStat(Stat* stat); } // namespace detail // Stat is used to compute summary statistics in a performant way over fixed // intervals. Stat logs the statistics as an Event once every `windowSize` // duration. When the window closes the stats are logged via the event handlers // as a `torch.monitor.Stat` event. // // `windowSize` should be set to something relatively high to avoid a huge // number of events being logged. Ex: 60s. Stat uses millisecond precision. // // If maxSamples is set, the stat will cap the number of samples per window by // discarding `add` calls once `maxSamples` adds have occurred. If it's not set, // all `add` calls during the window will be included. // This is an optional field to make aggregations more directly comparable // across windows when the number of samples might vary. // // Stats support double and int64_t data types depending on what needs to be // logged and needs to be templatized with one of them. // // When the Stat is destructed it will log any remaining data even if the window // hasn't elapsed. template class Stat { private: struct Values { T value{0}; T sum{0}; T min{0}; T max{0}; int64_t count{0}; }; public: Stat( std::string name, std::initializer_list aggregations, std::chrono::milliseconds windowSize, int64_t maxSamples = std::numeric_limits::max()) : name_(std::move(name)), aggregations_(merge(aggregations)), windowSize_(windowSize), maxSamples_(maxSamples) { detail::registerStat(this); } Stat( std::string name, std::vector aggregations, std::chrono::milliseconds windowSize, int64_t maxSamples = std::numeric_limits::max()) : name_(std::move(name)), aggregations_(merge(aggregations)), windowSize_(windowSize), maxSamples_(maxSamples) { detail::registerStat(this); } virtual ~Stat() { { // on destruction log if there's unlogged data std::lock_guard guard(mu_); logLocked(); } detail::unregisterStat(this); } // add adds the value v to the current window. void add(T v) { std::lock_guard guard(mu_); maybeLogLocked(); if (alreadyLogged()) { return; } if (aggregations_.test(static_cast(Aggregation::VALUE))) { current_.value = v; } if (aggregations_.test(static_cast(Aggregation::MEAN)) || aggregations_.test(static_cast(Aggregation::SUM))) { current_.sum += v; } if (aggregations_.test(static_cast(Aggregation::MAX))) { if (current_.max < v || current_.count == 0) { current_.max = v; } } if (aggregations_.test(static_cast(Aggregation::MIN))) { if (current_.min > v || current_.count == 0) { current_.min = v; } } current_.count += 1; maybeLogLocked(); } const std::string& name() const noexcept { return name_; } // count returns the number of items in the current open window. int64_t count() noexcept { std::lock_guard guard(mu_); return current_.count; } std::unordered_map get() noexcept { std::lock_guard guard(mu_); return getLocked(); } protected: virtual uint64_t currentWindowId() const { std::chrono::milliseconds now = std::chrono::duration_cast( std::chrono::steady_clock::now().time_since_epoch()); // always returns a currentWindowId of at least 1 to avoid 0 window issues return (now / windowSize_) + 1; } private: bool alreadyLogged() { return lastLoggedWindowId_ == currentWindowId(); } void maybeLogLocked() { auto windowId = currentWindowId(); bool shouldLog = windowId_ != windowId || current_.count >= maxSamples_; if (shouldLog && !alreadyLogged()) { logLocked(); lastLoggedWindowId_ = windowId_; windowId_ = windowId; } } void logLocked() { prev_ = current_; current_ = Values(); // don't log event if there's no data if (prev_.count == 0) { return; } Event e; e.name = "torch.monitor.Stat"; e.timestamp = std::chrono::system_clock::now(); auto stats = getLocked(); e.data.reserve(stats.size()); for (auto& kv : stats) { std::stringstream key; key << name_; key << "."; key << aggregationName(kv.first); e.data[key.str()] = kv.second; } logEvent(e); } std::unordered_map getLocked() const noexcept { std::unordered_map out; out.reserve(aggregations_.count()); if (aggregations_.test(static_cast(Aggregation::VALUE))) { out.emplace(Aggregation::VALUE, prev_.value); } if (aggregations_.test(static_cast(Aggregation::MEAN))) { if (prev_.count == 0) { out.emplace(Aggregation::MEAN, 0); } else { out.emplace(Aggregation::MEAN, prev_.sum / prev_.count); } } if (aggregations_.test(static_cast(Aggregation::COUNT))) { out.emplace(Aggregation::COUNT, prev_.count); } if (aggregations_.test(static_cast(Aggregation::SUM))) { out.emplace(Aggregation::SUM, prev_.sum); } if (aggregations_.test(static_cast(Aggregation::MAX))) { out.emplace(Aggregation::MAX, prev_.max); } if (aggregations_.test(static_cast(Aggregation::MIN))) { out.emplace(Aggregation::MIN, prev_.min); } return out; } const std::string name_; const std::bitset aggregations_; std::mutex mu_; Values current_; Values prev_; uint64_t windowId_{0}; uint64_t lastLoggedWindowId_{0}; const std::chrono::milliseconds windowSize_; const int64_t maxSamples_; }; } // namespace monitor } // namespace torch