mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Summary: As GoogleTest `TEST` macro is non-compliant with it as well as `DEFINE_DISPATCH` All changes but the ones to `.clang-tidy` are generated using following script: ``` for i in `find . -type f -iname "*.c*" -or -iname "*.h"|xargs grep cppcoreguidelines-avoid-non-const-global-variables|cut -f1 -d:|sort|uniq`; do sed -i "/\/\/ NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)/d" $i; done ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/62008 Reviewed By: driazati, r-barnes Differential Revision: D29838584 Pulled By: malfet fbshipit-source-id: 1b2f8602c945bd4ce50a9bfdd204755556e31d13
73 lines
1.9 KiB
C++
73 lines
1.9 KiB
C++
#include <torch/csrc/jit/runtime/logging.h>
|
|
|
|
#include <atomic>
|
|
#include <mutex>
|
|
#include <unordered_map>
|
|
|
|
namespace torch {
|
|
namespace jit {
|
|
namespace logging {
|
|
|
|
// TODO: multi-scale histogram for this thing
|
|
|
|
void LockingLogger::addStatValue(const std::string& stat_name, int64_t val) {
|
|
std::unique_lock<std::mutex> lk(m);
|
|
auto& raw_counter = raw_counters[stat_name];
|
|
raw_counter.sum += val;
|
|
raw_counter.count++;
|
|
}
|
|
|
|
int64_t LockingLogger::getCounterValue(const std::string& name) const {
|
|
std::unique_lock<std::mutex> lk(m);
|
|
if (!raw_counters.count(name)) {
|
|
return 0;
|
|
}
|
|
AggregationType type =
|
|
agg_types.count(name) ? agg_types.at(name) : AggregationType::SUM;
|
|
const auto& raw_counter = raw_counters.at(name);
|
|
switch (type) {
|
|
case AggregationType::SUM: {
|
|
return raw_counter.sum;
|
|
} break;
|
|
case AggregationType::AVG: {
|
|
return raw_counter.sum / raw_counter.count;
|
|
} break;
|
|
}
|
|
throw std::runtime_error("Unknown aggregation type!");
|
|
}
|
|
|
|
void LockingLogger::setAggregationType(
|
|
const std::string& stat_name,
|
|
AggregationType type) {
|
|
agg_types[stat_name] = type;
|
|
}
|
|
|
|
std::atomic<LoggerBase*> global_logger{new NoopLogger()};
|
|
|
|
LoggerBase* getLogger() {
|
|
return global_logger.load();
|
|
}
|
|
|
|
LoggerBase* setLogger(LoggerBase* logger) {
|
|
LoggerBase* previous = global_logger.load();
|
|
while (!global_logger.compare_exchange_strong(previous, logger)) {
|
|
previous = global_logger.load();
|
|
}
|
|
return previous;
|
|
}
|
|
|
|
JITTimePoint timePoint() {
|
|
return JITTimePoint{std::chrono::high_resolution_clock::now()};
|
|
}
|
|
|
|
void recordDurationSince(const std::string& name, const JITTimePoint& tp) {
|
|
auto end = std::chrono::high_resolution_clock::now();
|
|
// Measurement in microseconds.
|
|
auto seconds = std::chrono::duration<double>(end - tp.point).count() * 1e9;
|
|
logging::getLogger()->addStatValue(name, seconds);
|
|
}
|
|
|
|
} // namespace logging
|
|
} // namespace jit
|
|
} // namespace torch
|