Files
pytorch/torch/csrc/jit/script/logging.cpp
James Reed 85f36014e2 Experimental logging/counters API (#18235)
Summary:
This defines a generic counters API that users can utilize to provide monitoring functionality in e.g. a production service. We expose both counters for runtime internals as well as a TorchScript API to create user-defined counters. Synopsis of the API:

- `torch/csrc/jit/script/logging.h` specifies the externally-facing API in C++
- `torch/jit/_logging.py` specifies the Python API

We use an interface, `LoggerBase`, to define the interactions between users and a logging backend. Implementing a subclass of `LoggerBase` allows the user to handle these events in a custom way, such as logging into a DB or calling into an infra-specific counters API.

From the frontend perspective, we can create log events in two ways:
1. We provide an `add_stat_value(name, val)` function. This calls into the Logger backend with a key/value pair. For example, we might call `add_stat_value('foo', 1)` to bump an event counter.
2. We provide a `time_point()` function to record a timestamp in nanoseconds. This can be used in conjunction with `add_stat_value` to record runtime wall clock durations.

Examples of frontend usage can be found in `test_jit.py TestLogging`.

We provide a trivial `LockingLogger` implementation as an example and for testing purposes. It is likely not ready for production usage. It demonstrates that a backend implementing the API can do things like specify aggregation types and report these aggregate stats via the `get_counters()` API.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/18235

Differential Revision: D14545060

Pulled By: jamesr66a

fbshipit-source-id: 04099543a1898cfdd411511e46e03d5dce9b4881
2019-03-29 17:14:03 -07:00

74 lines
1.9 KiB
C++

#include "torch/csrc/jit/script/logging.h"
#include <atomic>
#include <mutex>
#include <unordered_map>
namespace torch {
namespace jit {
namespace logging {
// TODO: multi-scale histogram for this thing
void LockingLogger::addStatValue(const std::string& stat_name, int64_t val) {
std::unique_lock<std::mutex> lk(m);
auto& raw_counter = raw_counters[stat_name];
raw_counter.sum += val;
raw_counter.count++;
}
TORCH_API int64_t LockingLogger::getCounterValue(const std::string& name) const {
std::unique_lock<std::mutex> lk(m);
if (!raw_counters.count(name)) {
return 0;
}
AggregationType type = agg_types.count(name) ? agg_types.at(name)
: AggregationType::SUM;
const auto &raw_counter = raw_counters.at(name);
switch (type) {
case AggregationType::SUM: {
return raw_counter.sum;
} break;
case AggregationType::AVG: {
return raw_counter.sum / raw_counter.count;
} break;
}
throw std::runtime_error("Unknown aggregation type!");
}
void LockingLogger::setAggregationType(
const std::string& stat_name,
AggregationType type) {
agg_types[stat_name] = type;
}
std::atomic<LoggerBase*> global_logger{new NoopLogger()};
LoggerBase* getLogger() {
return global_logger.load();
}
LoggerBase *setLogger(LoggerBase* logger) {
LoggerBase *previous = global_logger.load();
while (!global_logger.compare_exchange_strong(previous, logger)) {
previous = global_logger.load();
}
return previous;
}
JITTimePoint timePoint() {
return JITTimePoint{std::chrono::high_resolution_clock::now()};
}
void recordDurationSince(const std::string& name, JITTimePoint tp) {
auto end = std::chrono::high_resolution_clock::now();
// Measurement in microseconds.
auto seconds = std::chrono::duration<double>(end - tp.point).count() * 1e9;
logging::getLogger()->addStatValue(name, seconds);
}
} // namespace logging
} // namespace jit
} // namespace torch