mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Use our thread-safe getenv wrappers. Pull Request resolved: https://github.com/pytorch/pytorch/pull/140200 Approved by: https://github.com/kwen2501, https://github.com/eqy
190 lines
5.0 KiB
C++
190 lines
5.0 KiB
C++
#include <cstdlib>
|
|
#include <iomanip>
|
|
#include <iostream>
|
|
#include <sstream>
|
|
#include <string>
|
|
#include <unordered_map>
|
|
#include <vector>
|
|
|
|
#include <ATen/core/function.h>
|
|
#include <c10/util/Exception.h>
|
|
#include <c10/util/StringUtil.h>
|
|
#include <c10/util/env.h>
|
|
#include <torch/csrc/jit/api/function_impl.h>
|
|
#include <torch/csrc/jit/frontend/error_report.h>
|
|
#include <torch/csrc/jit/ir/ir.h>
|
|
#include <torch/csrc/jit/jit_log.h>
|
|
#include <torch/csrc/jit/serialization/python_print.h>
|
|
|
|
namespace torch::jit {
|
|
|
|
class JitLoggingConfig {
|
|
public:
|
|
static JitLoggingConfig& getInstance() {
|
|
static JitLoggingConfig instance;
|
|
return instance;
|
|
}
|
|
JitLoggingConfig(JitLoggingConfig const&) = delete;
|
|
void operator=(JitLoggingConfig const&) = delete;
|
|
|
|
private:
|
|
std::string logging_levels;
|
|
std::unordered_map<std::string, size_t> files_to_levels;
|
|
std::ostream* out;
|
|
|
|
JitLoggingConfig() : out(&std::cerr) {
|
|
const auto jit_log_level = c10::utils::get_env("PYTORCH_JIT_LOG_LEVEL");
|
|
if (jit_log_level.has_value()) {
|
|
logging_levels = jit_log_level.value();
|
|
}
|
|
|
|
parse();
|
|
}
|
|
void parse();
|
|
|
|
public:
|
|
std::string getLoggingLevels() const {
|
|
return this->logging_levels;
|
|
}
|
|
void setLoggingLevels(std::string levels) {
|
|
this->logging_levels = std::move(levels);
|
|
parse();
|
|
}
|
|
|
|
const std::unordered_map<std::string, size_t>& getFilesToLevels() const {
|
|
return this->files_to_levels;
|
|
}
|
|
|
|
void setOutputStream(std::ostream& out_stream) {
|
|
this->out = &out_stream;
|
|
}
|
|
|
|
std::ostream& getOutputStream() {
|
|
return *(this->out);
|
|
}
|
|
};
|
|
|
|
std::string get_jit_logging_levels() {
|
|
return JitLoggingConfig::getInstance().getLoggingLevels();
|
|
}
|
|
|
|
void set_jit_logging_levels(std::string level) {
|
|
JitLoggingConfig::getInstance().setLoggingLevels(std::move(level));
|
|
}
|
|
|
|
void set_jit_logging_output_stream(std::ostream& stream) {
|
|
JitLoggingConfig::getInstance().setOutputStream(stream);
|
|
}
|
|
|
|
std::ostream& get_jit_logging_output_stream() {
|
|
return JitLoggingConfig::getInstance().getOutputStream();
|
|
}
|
|
|
|
// gets a string representation of a node header
|
|
// (e.g. outputs, a node kind and outputs)
|
|
std::string getHeader(const Node* node) {
|
|
std::stringstream ss;
|
|
node->print(ss, 0, {}, false, false, false, false);
|
|
return ss.str();
|
|
}
|
|
|
|
void JitLoggingConfig::parse() {
|
|
std::stringstream in_ss;
|
|
in_ss << "function:" << this->logging_levels;
|
|
|
|
files_to_levels.clear();
|
|
std::string line;
|
|
while (std::getline(in_ss, line, ':')) {
|
|
if (line.empty()) {
|
|
continue;
|
|
}
|
|
|
|
auto index_at = line.find_last_of('>');
|
|
auto begin_index = index_at == std::string::npos ? 0 : index_at + 1;
|
|
size_t logging_level = index_at == std::string::npos ? 0 : index_at + 1;
|
|
auto end_index = line.find_last_of('.') == std::string::npos
|
|
? line.size()
|
|
: line.find_last_of('.');
|
|
auto filename = line.substr(begin_index, end_index - begin_index);
|
|
files_to_levels.insert({filename, logging_level});
|
|
}
|
|
}
|
|
|
|
bool is_enabled(const char* cfname, JitLoggingLevels level) {
|
|
const auto& files_to_levels =
|
|
JitLoggingConfig::getInstance().getFilesToLevels();
|
|
std::string fname{cfname};
|
|
fname = c10::detail::StripBasename(fname);
|
|
const auto end_index = fname.find_last_of('.') == std::string::npos
|
|
? fname.size()
|
|
: fname.find_last_of('.');
|
|
const auto fname_no_ext = fname.substr(0, end_index);
|
|
|
|
const auto it = files_to_levels.find(fname_no_ext);
|
|
if (it == files_to_levels.end()) {
|
|
return false;
|
|
}
|
|
|
|
return level <= static_cast<JitLoggingLevels>(it->second);
|
|
}
|
|
|
|
// Unfortunately, in `GraphExecutor` where `log_function` is invoked
|
|
// we won't have access to an original function, so we have to construct
|
|
// a dummy function to give to PythonPrint
|
|
std::string log_function(const std::shared_ptr<torch::jit::Graph>& graph) {
|
|
torch::jit::GraphFunction func("source_dump", graph, nullptr);
|
|
std::vector<at::IValue> constants;
|
|
PrintDepsTable deps;
|
|
PythonPrint pp(constants, deps);
|
|
pp.printFunction(func);
|
|
return pp.str();
|
|
}
|
|
|
|
std::string jit_log_prefix(
|
|
const std::string& prefix,
|
|
const std::string& in_str) {
|
|
std::stringstream in_ss(in_str);
|
|
std::stringstream out_ss;
|
|
std::string line;
|
|
while (std::getline(in_ss, line)) {
|
|
out_ss << prefix << line << '\n';
|
|
}
|
|
|
|
return out_ss.str();
|
|
}
|
|
|
|
std::string jit_log_prefix(
|
|
JitLoggingLevels level,
|
|
const char* fn,
|
|
int l,
|
|
const std::string& in_str) {
|
|
std::stringstream prefix_ss;
|
|
prefix_ss << "[";
|
|
prefix_ss << level << " ";
|
|
prefix_ss << c10::detail::StripBasename(std::string(fn)) << ":";
|
|
prefix_ss << std::setfill('0') << std::setw(3) << l;
|
|
prefix_ss << "] ";
|
|
|
|
return jit_log_prefix(prefix_ss.str(), in_str);
|
|
}
|
|
|
|
std::ostream& operator<<(std::ostream& out, JitLoggingLevels level) {
|
|
switch (level) {
|
|
case JitLoggingLevels::GRAPH_DUMP:
|
|
out << "DUMP";
|
|
break;
|
|
case JitLoggingLevels::GRAPH_UPDATE:
|
|
out << "UPDATE";
|
|
break;
|
|
case JitLoggingLevels::GRAPH_DEBUG:
|
|
out << "DEBUG";
|
|
break;
|
|
default:
|
|
TORCH_INTERNAL_ASSERT(false, "Invalid level");
|
|
}
|
|
|
|
return out;
|
|
}
|
|
|
|
} // namespace torch::jit
|