mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Summary: This defines a generic counters API that users can utilize to provide monitoring functionality in e.g. a production service. We expose both counters for runtime internals as well as a TorchScript API to create user-defined counters. Synopsis of the API: - `torch/csrc/jit/script/logging.h` specifies the externally-facing API in C++ - `torch/jit/_logging.py` specifies the Python API We use an interface, `LoggerBase`, to define the interactions between users and a logging backend. Implementing a subclass of `LoggerBase` allows the user to handle these events in a custom way, such as logging into a DB or calling into an infra-specific counters API. From the frontend perspective, we can create log events in two ways: 1. We provide an `add_stat_value(name, val)` function. This calls into the Logger backend with a key/value pair. For example, we might call `add_stat_value('foo', 1)` to bump an event counter. 2. We provide a `time_point()` function to record a timestamp in nanoseconds. This can be used in conjunction with `add_stat_value` to record runtime wall clock durations. Examples of frontend usage can be found in `test_jit.py TestLogging`. We provide a trivial `LockingLogger` implementation as an example and for testing purposes. It is likely not ready for production usage. It demonstrates that a backend implementing the API can do things like specify aggregation types and report these aggregate stats via the `get_counters()` API. Pull Request resolved: https://github.com/pytorch/pytorch/pull/18235 Differential Revision: D14545060 Pulled By: jamesr66a fbshipit-source-id: 04099543a1898cfdd411511e46e03d5dce9b4881
74 lines
1.9 KiB
C++
74 lines
1.9 KiB
C++
#include "torch/csrc/jit/script/logging.h"
|
|
|
|
#include <atomic>
|
|
#include <mutex>
|
|
#include <unordered_map>
|
|
|
|
namespace torch {
|
|
namespace jit {
|
|
namespace logging {
|
|
|
|
// TODO: multi-scale histogram for this thing
|
|
|
|
void LockingLogger::addStatValue(const std::string& stat_name, int64_t val) {
|
|
std::unique_lock<std::mutex> lk(m);
|
|
auto& raw_counter = raw_counters[stat_name];
|
|
raw_counter.sum += val;
|
|
raw_counter.count++;
|
|
}
|
|
|
|
TORCH_API int64_t LockingLogger::getCounterValue(const std::string& name) const {
|
|
std::unique_lock<std::mutex> lk(m);
|
|
if (!raw_counters.count(name)) {
|
|
return 0;
|
|
}
|
|
AggregationType type = agg_types.count(name) ? agg_types.at(name)
|
|
: AggregationType::SUM;
|
|
const auto &raw_counter = raw_counters.at(name);
|
|
switch (type) {
|
|
case AggregationType::SUM: {
|
|
return raw_counter.sum;
|
|
} break;
|
|
case AggregationType::AVG: {
|
|
return raw_counter.sum / raw_counter.count;
|
|
} break;
|
|
}
|
|
throw std::runtime_error("Unknown aggregation type!");
|
|
}
|
|
|
|
void LockingLogger::setAggregationType(
|
|
const std::string& stat_name,
|
|
AggregationType type) {
|
|
agg_types[stat_name] = type;
|
|
}
|
|
|
|
|
|
std::atomic<LoggerBase*> global_logger{new NoopLogger()};
|
|
|
|
LoggerBase* getLogger() {
|
|
return global_logger.load();
|
|
}
|
|
|
|
LoggerBase *setLogger(LoggerBase* logger) {
|
|
LoggerBase *previous = global_logger.load();
|
|
while (!global_logger.compare_exchange_strong(previous, logger)) {
|
|
previous = global_logger.load();
|
|
}
|
|
return previous;
|
|
}
|
|
|
|
JITTimePoint timePoint() {
|
|
return JITTimePoint{std::chrono::high_resolution_clock::now()};
|
|
}
|
|
|
|
void recordDurationSince(const std::string& name, JITTimePoint tp) {
|
|
auto end = std::chrono::high_resolution_clock::now();
|
|
// Measurement in microseconds.
|
|
auto seconds = std::chrono::duration<double>(end - tp.point).count() * 1e9;
|
|
logging::getLogger()->addStatValue(name, seconds);
|
|
}
|
|
|
|
} // namespace logging
|
|
} // namespace jit
|
|
} // namespace torch
|