Experimental logging/counters API (#18235)

Summary:
This defines a generic counters API that users can utilize to provide monitoring functionality in e.g. a production service. We expose both counters for runtime internals as well as a TorchScript API to create user-defined counters. Synopsis of the API:

- `torch/csrc/jit/script/logging.h` specifies the externally-facing API in C++
- `torch/jit/_logging.py` specifies the Python API

We use an interface, `LoggerBase`, to define the interactions between users and a logging backend. Implementing a subclass of `LoggerBase` allows the user to handle these events in a custom way, such as logging into a DB or calling into an infra-specific counters API.

From the frontend perspective, we can create log events in two ways:
1. We provide an `add_stat_value(name, val)` function. This calls into the Logger backend with a key/value pair. For example, we might call `add_stat_value('foo', 1)` to bump an event counter.
2. We provide a `time_point()` function to record a timestamp in nanoseconds. This can be used in conjunction with `add_stat_value` to record runtime wall clock durations.

Examples of frontend usage can be found in `test_jit.py TestLogging`.

We provide a trivial `LockingLogger` implementation as an example and for testing purposes. It is likely not ready for production usage. It demonstrates that a backend implementing the API can do things like specify aggregation types and report these aggregate stats via the `get_counters()` API.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/18235

Differential Revision: D14545060

Pulled By: jamesr66a

fbshipit-source-id: 04099543a1898cfdd411511e46e03d5dce9b4881
This commit is contained in:
James Reed
2019-03-29 17:06:08 -07:00
committed by Facebook Github Bot
parent e2fd1d966f
commit 85f36014e2
13 changed files with 365 additions and 7 deletions

1
.gitignore vendored
View File

@ -203,7 +203,6 @@ docs/dev
*.sst
*.ldb
LOCK
LOG*
CURRENT
MANIFEST-*

View File

@ -86,6 +86,8 @@ namespace c10 {
_(prim, CreateObject) \
_(prim, SetAttr) \
_(prim, GetAttr) \
_(prim, AddStatValue) \
_(prim, TimePoint) \
_(aten, append) \
_(aten, item) \
_(aten, format) \

View File

@ -1,6 +1,7 @@
from __future__ import division
import torch
import torch.jit
import torch.jit._logging
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.parallel as dp
@ -13874,6 +13875,106 @@ class TestClassType(JitTestCase):
self.assertEqual(y, f2.y)
class TestLogging(JitTestCase):
def test_bump_numeric_counter(self):
class ModuleThatLogs(torch.jit.ScriptModule):
@torch.jit.script_method
def forward(self, x):
for i in range(x.size(0)):
x += 1.0
torch.jit._logging.add_stat_value('foo', 1)
if bool(x.sum() > 0.0):
torch.jit._logging.add_stat_value('positive', 1)
else:
torch.jit._logging.add_stat_value('negative', 1)
return x
logger = torch.jit._logging.LockingLogger()
old_logger = torch.jit._logging.set_logger(logger)
try:
mtl = ModuleThatLogs()
for i in range(5):
mtl(torch.rand(3, 4, 5))
self.assertEqual(logger.get_counter_val('foo'), 15)
self.assertEqual(logger.get_counter_val('positive'), 5)
finally:
torch.jit._logging.set_logger(old_logger)
def test_trace_numeric_counter(self):
def foo(x):
torch.jit._logging.add_stat_value('foo', 1)
return x + 1.0
traced = torch.jit.trace(foo, torch.rand(3, 4))
logger = torch.jit._logging.LockingLogger()
old_logger = torch.jit._logging.set_logger(logger)
try:
traced(torch.rand(3, 4))
self.assertEqual(logger.get_counter_val('foo'), 1)
finally:
torch.jit._logging.set_logger(old_logger)
def test_time_measurement_counter(self):
class ModuleThatTimes(torch.jit.ScriptModule):
def forward(self, x):
tp_start = torch.jit._logging.time_point()
for i in range(30):
x += 1.0
tp_end = torch.jit._logging.time_point()
torch.jit._logging.add_stat_value('mytimer', tp_end - tp_start)
return x
mtm = ModuleThatTimes()
logger = torch.jit._logging.LockingLogger()
old_logger = torch.jit._logging.set_logger(logger)
try:
mtm(torch.rand(3, 4))
self.assertGreater(logger.get_counter_val('mytimer'), 0)
finally:
torch.jit._logging.set_logger(old_logger)
def test_time_measurement_counter_script(self):
class ModuleThatTimes(torch.jit.ScriptModule):
@torch.jit.script_method
def forward(self, x):
tp_start = torch.jit._logging.time_point()
for i in range(30):
x += 1.0
tp_end = torch.jit._logging.time_point()
torch.jit._logging.add_stat_value('mytimer', tp_end - tp_start)
return x
mtm = ModuleThatTimes()
logger = torch.jit._logging.LockingLogger()
old_logger = torch.jit._logging.set_logger(logger)
try:
mtm(torch.rand(3, 4))
self.assertGreater(logger.get_counter_val('mytimer'), 0)
finally:
torch.jit._logging.set_logger(old_logger)
def test_counter_aggregation(self):
def foo(x):
for i in range(3):
torch.jit._logging.add_stat_value('foo', 1)
return x + 1.0
traced = torch.jit.trace(foo, torch.rand(3, 4))
logger = torch.jit._logging.LockingLogger()
logger.set_aggregation_type('foo', torch.jit._logging.AggregationType.AVG)
old_logger = torch.jit._logging.set_logger(logger)
try:
traced(torch.rand(3, 4))
self.assertEqual(logger.get_counter_val('foo'), 1)
finally:
torch.jit._logging.set_logger(old_logger)
for test in autograd_method_tests():
add_autograd_test(*test)

View File

@ -95,6 +95,7 @@ libtorch_sources = [
"torch/csrc/jit/scope.cpp",
"torch/csrc/jit/script/compiler.cpp",
"torch/csrc/jit/script/edit_distance.cpp",
"torch/csrc/jit/script/logging.cpp",
"torch/csrc/jit/script/final_returns.cpp",
"torch/csrc/jit/script/schema_type_parser.cpp",
"torch/csrc/jit/script/script_type_parser.cpp",

View File

@ -185,6 +185,7 @@ set(TORCH_SRCS
${TORCH_SRC_DIR}/csrc/jit/script/builtin_functions.cpp
${TORCH_SRC_DIR}/csrc/jit/script/edit_distance.cpp
${TORCH_SRC_DIR}/csrc/jit/script/lexer.cpp
${TORCH_SRC_DIR}/csrc/jit/script/logging.cpp
${TORCH_SRC_DIR}/csrc/jit/script/module.cpp
${TORCH_SRC_DIR}/csrc/jit/tracer.cpp
${TORCH_SRC_DIR}/csrc/jit/hooks_for_testing.cpp

View File

@ -32,6 +32,7 @@
#include <torch/csrc/autograd/edge.h>
#include <torch/csrc/autograd/function.h>
#include <torch/csrc/jit/script/compiler.h>
#include <torch/csrc/jit/script/logging.h>
#include <cstdint>
#include <iterator>
@ -362,7 +363,10 @@ struct GraphExecutorImpl {
optimize(optimize),
num_inputs(this->graph->inputs().size()),
num_flat_inputs(countFlatInputs(graph)),
num_outputs(this->graph->outputs().size()) {}
num_outputs(this->graph->outputs().size()) {
logging::getLogger()->addStatValue(
logging::runtime_counters::GRAPH_EXECUTORS_CONSTRUCTED, 1.0);
}
// entry point where execution begins
void run(Stack& stack) {
@ -373,6 +377,9 @@ struct GraphExecutorImpl {
" inputs, but got only ",
stack.size());
logging::getLogger()->addStatValue(
logging::runtime_counters::GRAPH_EXECUTOR_INVOCATIONS, 1.0);
if (tracer::isTracing()) {
return runTraced(stack);
}
@ -441,10 +448,15 @@ struct GraphExecutorImpl {
{
std::lock_guard<std::mutex> lock(compile_mutex);
auto it = plan_cache.find(spec);
if (it != plan_cache.end())
if (it != plan_cache.end()) {
logging::getLogger()->addStatValue(
logging::runtime_counters::EXECUTION_PLAN_CACHE_HIT, 1.0);
return it->second;
}
auto plan = compileSpec(spec);
auto r = plan_cache.emplace(std::move(spec), std::move(plan));
logging::getLogger()->addStatValue(
logging::runtime_counters::EXECUTION_PLAN_CACHE_MISS, 1.0);
return r.first->second;
}
}

View File

@ -1,19 +1,20 @@
#include <torch/csrc/jit/interpreter.h>
#include <ATen/core/ivalue.h>
#include <c10/core/thread_pool.h>
#include <c10/util/Exception.h>
#include <torch/csrc/autograd/edge.h>
#include <torch/csrc/autograd/function.h>
#include <torch/csrc/autograd/generated/variable_factories.h>
#include <torch/csrc/autograd/grad_mode.h>
#include <torch/csrc/autograd/profiler.h>
#include <torch/csrc/autograd/variable.h>
#include <c10/util/Exception.h>
#include <torch/csrc/jit/constants.h>
#include <torch/csrc/jit/graph_executor.h>
#include <torch/csrc/jit/ir.h>
#include <ATen/core/ivalue.h>
#include <torch/csrc/jit/operator.h>
#include <torch/csrc/jit/script/jit_exception.h>
#include <c10/core/thread_pool.h>
#include <torch/csrc/jit/script/logging.h>
#include <exception>
#include <iostream>

View File

@ -845,6 +845,8 @@ bool Node::hasSideEffects() const {
case prim::RaiseException:
case prim::SetAttr:
case aten::warn:
case prim::AddStatValue:
case prim::TimePoint:
return true;
}
return false;

View File

@ -8,6 +8,7 @@
#include <torch/csrc/jit/fuser/interface.h>
#include <torch/csrc/jit/graph_executor.h>
#include <torch/csrc/jit/ir.h>
#include <torch/csrc/jit/script/logging.h>
#include <torch/csrc/jit/operator.h>
#include <torch/csrc/jit/script/jit_exception.h>
@ -887,7 +888,46 @@ RegisterOperators reg(
userObj->setSlot(slot, std::move(v));
return 0;
};
})});
})
});
RegisterOperators logging_operators({
Operator("prim::AddStatValue(str key, int val) -> ()", [](Stack& stack) {
auto val = pop(stack).toInt();
auto key = pop(stack).toString();
auto schema = parseSchema("prim::AddStatValue(str key, int val) -> ()");
// TODO: remove this custom tracing code once the custom op bugfix lands
if (jit::tracer::isTracing()) {
const auto& graph = tracer::getTracingState()->graph;
Node* node = graph->create(prim::AddStatValue, /*num_outputs=*/0);
tracer::recordSourceLocation(node);
node->addInput(insertConstant(*graph, key));
tracer::addInputs(node, "val", val);
graph->insertNode(node);
}
torch::jit::logging::getLogger()->addStatValue(*key, val);
return 0;
}),
Operator("prim::TimePoint() -> int", [](Stack& stack) {
auto schema = parseSchema("prim::TimePoint() -> int");
Node* node = nullptr;
// TODO: remove this custom tracing code once the custom op bugfix lands
if (jit::tracer::isTracing()) {
const auto& graph = tracer::getTracingState()->graph;
Node* node = graph->create(prim::TimePoint, /*num_outputs=*/0);
tracer::recordSourceLocation(node);
graph->insertNode(node);
}
auto output = autograd::profiler::getTime();
push(stack, output);
if (jit::tracer::isTracing()) {
jit::tracer::addOutput(node, output);
}
return 0;
})
});
// define implementations for primitive number ops
#define DEFINE_GENERIC_OP(aten_op, int_op, float_op, int_result, float_result) \

View File

@ -16,7 +16,9 @@
#include <torch/csrc/jit/passes/python_print.h>
#include <torch/csrc/jit/pybind_utils.h>
#include <torch/csrc/jit/python_tracer.h>
#include <torch/csrc/jit/script/logging.h>
#include <torch/csrc/jit/script/parser.h>
#include <torch/csrc/jit/tracer.h>
#include <torch/csrc/api/include/torch/ordered_dict.h>
@ -27,6 +29,7 @@
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <pybind11/stl_bind.h>
#include <chrono>
#include <cstddef>
#include <memory>
#include <sstream>
@ -1101,6 +1104,29 @@ void initJitScriptBindings(PyObject* module) {
.def("run", [](testing::FileCheck& f, const Graph& g) {
return f.run(g);
});
m.def("_logging_set_logger", [](logging::LoggerBase* logger) {
return logging::setLogger(logger);
}, py::return_value_policy::reference);
py::class_<logging::LoggerBase, std::shared_ptr<logging::LoggerBase>>(
m, "LoggerBase");
py::enum_<logging::LockingLogger::AggregationType>(m, "AggregationType")
.value("SUM", logging::LockingLogger::AggregationType::SUM)
.value("AVG", logging::LockingLogger::AggregationType::AVG)
.export_values();
py::class_<
logging::LockingLogger,
logging::LoggerBase,
std::shared_ptr<logging::LockingLogger>>(m, "LockingLogger")
.def(py::init<>())
.def("set_aggregation_type", &logging::LockingLogger::setAggregationType)
.def("get_counter_val", &logging::LockingLogger::getCounterValue);
py::class_<
logging::NoopLogger,
logging::LoggerBase,
std::shared_ptr<logging::NoopLogger>>(m, "NoopLogger")
.def(py::init<>());
}
} // namespace script
} // namespace jit

View File

@ -0,0 +1,73 @@
#include "torch/csrc/jit/script/logging.h"
#include <atomic>
#include <mutex>
#include <unordered_map>
namespace torch {
namespace jit {
namespace logging {
// TODO: multi-scale histogram for this thing
void LockingLogger::addStatValue(const std::string& stat_name, int64_t val) {
std::unique_lock<std::mutex> lk(m);
auto& raw_counter = raw_counters[stat_name];
raw_counter.sum += val;
raw_counter.count++;
}
TORCH_API int64_t LockingLogger::getCounterValue(const std::string& name) const {
std::unique_lock<std::mutex> lk(m);
if (!raw_counters.count(name)) {
return 0;
}
AggregationType type = agg_types.count(name) ? agg_types.at(name)
: AggregationType::SUM;
const auto &raw_counter = raw_counters.at(name);
switch (type) {
case AggregationType::SUM: {
return raw_counter.sum;
} break;
case AggregationType::AVG: {
return raw_counter.sum / raw_counter.count;
} break;
}
throw std::runtime_error("Unknown aggregation type!");
}
void LockingLogger::setAggregationType(
const std::string& stat_name,
AggregationType type) {
agg_types[stat_name] = type;
}
std::atomic<LoggerBase*> global_logger{new NoopLogger()};
LoggerBase* getLogger() {
return global_logger.load();
}
LoggerBase *setLogger(LoggerBase* logger) {
LoggerBase *previous = global_logger.load();
while (!global_logger.compare_exchange_strong(previous, logger)) {
previous = global_logger.load();
}
return previous;
}
JITTimePoint timePoint() {
return JITTimePoint{std::chrono::high_resolution_clock::now()};
}
void recordDurationSince(const std::string& name, JITTimePoint tp) {
auto end = std::chrono::high_resolution_clock::now();
// Measurement in microseconds.
auto seconds = std::chrono::duration<double>(end - tp.point).count() * 1e9;
logging::getLogger()->addStatValue(name, seconds);
}
} // namespace logging
} // namespace jit
} // namespace torch

View File

@ -0,0 +1,90 @@
#pragma once
#include <memory>
#include <mutex>
#include <string>
#include <unordered_map>
#include <vector>
#include <torch/csrc/WindowsTorchApiMacro.h>
namespace torch {
namespace jit {
namespace logging {
class LoggerBase {
public:
TORCH_API virtual void addStatValue(
const std::string& stat_name,
int64_t val) = 0;
virtual ~LoggerBase() {}
};
TORCH_API LoggerBase* getLogger();
TORCH_API LoggerBase* setLogger(LoggerBase* logger);
// No-op logger. This is the default and is meant to incur almost no runtime
// overhead.
class NoopLogger : public LoggerBase {
public:
void addStatValue(const std::string& stat_name, int64_t val) override {}
~NoopLogger() {}
};
// Trivial locking logger. Pass in an instance of this to setLogger() to use it.
// This keeps track of the sum of all statistics.
//
// NOTE: this is not written in a scalable way and should probably only be used
// in the single-threaded case or for testing.
class LockingLogger : public LoggerBase {
public:
TORCH_API void addStatValue(const std::string& stat_name, int64_t val) override;
TORCH_API virtual int64_t getCounterValue(const std::string& name) const;
enum class AggregationType { SUM, AVG };
TORCH_API void setAggregationType(
const std::string& stat_name,
AggregationType type);
~LockingLogger() {}
private:
mutable std::mutex m;
struct RawCounter {
RawCounter() : sum(0), count(0) {}
int64_t sum;
size_t count;
};
std::unordered_map<std::string, RawCounter> raw_counters;
std::unordered_map<std::string, AggregationType> agg_types;
};
// Make this struct so the timer internals are opaque to the user.
struct JITTimePoint {
std::chrono::time_point<std::chrono::high_resolution_clock> point;
};
TORCH_API JITTimePoint timePoint();
TORCH_API void recordDurationSince(const std::string& name, JITTimePoint tp);
namespace runtime_counters {
constexpr const char* GRAPH_EXECUTORS_CONSTRUCTED =
"pytorch_runtime.graph_executors_constructed";
constexpr const char* GRAPH_EXECUTOR_INVOCATIONS =
"pytorch_runtime.graph_executor_invocations";
constexpr const char* EXECUTION_PLAN_CACHE_HIT =
"pytorch_runtime.execution_plan_cache_hit";
constexpr const char* EXECUTION_PLAN_CACHE_MISS =
"pytorch_runtime.execution_plan_cache_miss";
inline std::vector<const char*> allRuntimeCounters() {
return {GRAPH_EXECUTORS_CONSTRUCTED,
GRAPH_EXECUTOR_INVOCATIONS,
EXECUTION_PLAN_CACHE_HIT,
EXECUTION_PLAN_CACHE_MISS};
}
} // namespace runtime_counters
} // namespace logging
} // namespace jit
} // namespace torch

10
torch/jit/_logging.py Normal file
View File

@ -0,0 +1,10 @@
import torch
add_stat_value = torch.ops.prim.AddStatValue
set_logger = torch._C._logging_set_logger
LockingLogger = torch._C.LockingLogger
AggregationType = torch._C.AggregationType
NoopLogger = torch._C.NoopLogger
time_point = torch.ops.prim.TimePoint