mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Summary: Resubmit #20698 which got messed up. Idea is that when PyTorch is used in a custom build environment (e.g. Facebook), it's useful to track usage of various APIs centrally. This PR introduces a simple very lightweight mechanism to do so - only first invocation of a trigger point would be logged. This is significantly more lightweight than #18235 and thus we can allow to put logging in e.g. TensorImpl. Also adds an initial list of trigger points. Trigger points are added in such a way that no static initialization triggers them, i.e. just linking with libtorch.so will not cause any logging. Further suggestions of what to log are welcomed. Pull Request resolved: https://github.com/pytorch/pytorch/pull/20745 Differential Revision: D15429196 Pulled By: dzhulgakov fbshipit-source-id: a5e41a709a65b7ebccc6b95f93854e583cf20aca
184 lines
5.6 KiB
C++
184 lines
5.6 KiB
C++
#include <torch/csrc/python_headers.h>
|
|
|
|
#include <torch/csrc/jit/export.h>
|
|
#include <torch/csrc/jit/passes/dead_code_elimination.h>
|
|
#include <torch/csrc/jit/passes/lower_tuples.h>
|
|
#include <torch/csrc/jit/pybind.h>
|
|
#include <torch/csrc/jit/python_tracer.h>
|
|
#include <torch/csrc/jit/tracer.h>
|
|
#include <torch/csrc/utils/python_strings.h>
|
|
|
|
#include <c10/util/Exception.h>
|
|
|
|
#include <sstream>
|
|
|
|
using namespace torch::autograd;
|
|
using namespace torch::jit;
|
|
using namespace torch::jit::tracer;
|
|
|
|
namespace torch {
|
|
namespace jit {
|
|
namespace tracer {
|
|
|
|
// Python interpreter retrieval routine adapted from
|
|
// https://stackoverflow.com/a/8706144
|
|
std::string getPythonInterpreterStackTrace() {
|
|
std::stringstream stack_trace;
|
|
AutoGIL gil;
|
|
PyFrameObject* frame = PyEval_GetFrame();
|
|
while (nullptr != frame) {
|
|
int line = PyCode_Addr2Line(frame->f_code, frame->f_lasti);
|
|
std::string filename = THPUtils_unpackString(frame->f_code->co_filename);
|
|
std::string funcname = THPUtils_unpackString(frame->f_code->co_name);
|
|
stack_trace << filename << "(" << line << "): " << funcname << "\n";
|
|
frame = frame->f_back;
|
|
}
|
|
return stack_trace.str();
|
|
}
|
|
|
|
std::shared_ptr<torch::jit::Graph> createGraphByTracing(
|
|
const py::function& func,
|
|
TypedStack trace_inputs,
|
|
const py::function& var_name_lookup_fn,
|
|
bool force_outplace,
|
|
const std::shared_ptr<script::Module>& self) {
|
|
C10_LOG_API_USAGE_ONCE("torch.tracer");
|
|
|
|
auto enter_info = tracer::enter(std::move(trace_inputs), self);
|
|
auto graph = enter_info.first->graph;
|
|
|
|
getTracingState()->lookup_var_name_fn =
|
|
[var_name_lookup_fn](const Variable& var) -> std::string {
|
|
AutoGIL ag;
|
|
return py::cast<std::string>(var_name_lookup_fn(var));
|
|
};
|
|
getTracingState()->force_outplace = force_outplace;
|
|
try {
|
|
size_t num_func_inputs = enter_info.second.size();
|
|
py::tuple py_inputs(num_func_inputs);
|
|
for (size_t i = 0; i < num_func_inputs; ++i) {
|
|
py_inputs[i] = py::cast(enter_info.second[i]);
|
|
}
|
|
auto out = func(*py_inputs);
|
|
if (out.ptr() == Py_None) {
|
|
AT_ERROR(
|
|
"The traced function didn't return any values! Side-effects are not "
|
|
"captured in traces, so it would be a no-op.");
|
|
}
|
|
tracer::exit({toIValue(out)});
|
|
EliminateDeadCode(graph);
|
|
LowerSimpleTuples(graph);
|
|
|
|
return graph;
|
|
} catch (...) {
|
|
tracer::abandon();
|
|
throw;
|
|
}
|
|
}
|
|
|
|
Node* preRecordPythonTrace(
|
|
THPObjectPtr pyobj,
|
|
const std::string& arg_types,
|
|
at::ArrayRef<Variable> inputs,
|
|
pyobj_list scalar_args) {
|
|
THPObjectPtr apply(PyObject_GetAttrString(pyobj.get(), "apply"));
|
|
if (!apply) {
|
|
throw python_error();
|
|
}
|
|
|
|
auto& graph = getTracingState()->graph;
|
|
|
|
Node* n = graph->createPythonOp(
|
|
std::move(apply), arg_types, std::move(scalar_args));
|
|
recordSourceLocation(n);
|
|
|
|
for (const Variable& input : inputs) {
|
|
n->addInput(getValueTrace(input));
|
|
}
|
|
|
|
// NB: Order matters. This must append after inputs but before outputs.
|
|
graph->appendNode(n);
|
|
|
|
return n;
|
|
}
|
|
|
|
void pythonRecordSourceLocation(Node* n) {
|
|
n->setSourceRange(SourceRange(getPythonInterpreterStackTrace()));
|
|
}
|
|
|
|
void pythonWarn(const std::string& reason) {
|
|
AutoGIL gil;
|
|
auto warn_class = py::module::import("torch.jit").attr("TracerWarning");
|
|
PyErr_WarnEx(warn_class.ptr(), reason.c_str(), 1);
|
|
}
|
|
|
|
void initPythonTracerBindings(PyObject* module) {
|
|
setRecordSourceLocation(pythonRecordSourceLocation);
|
|
|
|
auto m = py::handle(module).cast<py::module>();
|
|
py::class_<TracingState, std::shared_ptr<TracingState>>(
|
|
m, "TracingState", py::dynamic_attr())
|
|
// NB: no constructor; you have to get it from C++ code
|
|
.def(
|
|
"__repr__",
|
|
[](const TracingState& s) {
|
|
std::ostringstream ss;
|
|
ss << "<TracingState " << (const void*)&s << ">";
|
|
return ss.str();
|
|
})
|
|
.def(
|
|
"__str__",
|
|
[](const TracingState& s) -> std::string {
|
|
std::ostringstream ss;
|
|
ss << *s.graph;
|
|
return ss.str();
|
|
})
|
|
.def(
|
|
"push_scope",
|
|
[](TracingState& s, const std::string& scope_name) {
|
|
s.graph->push_scope(scope_name);
|
|
})
|
|
.def("pop_scope", [](TracingState& s) { s.graph->pop_scope(); })
|
|
.def(
|
|
"set_graph",
|
|
[](TracingState& s, std::shared_ptr<Graph> g) { s.graph = g; })
|
|
.def("graph", [](TracingState& s) { return s.graph; });
|
|
|
|
m.def("_tracer_warn_use_python", []() { tracer::setWarn(pythonWarn); });
|
|
m.def("_tracer_enter", [](py::args trace_inputs) {
|
|
return tracer::enter(toTypedStack(trace_inputs));
|
|
});
|
|
m.def("_tracer_exit", [](py::tuple var_outputs) {
|
|
tracer::exit(toStack(var_outputs));
|
|
});
|
|
m.def("_tracer_abandon", []() { tracer::abandon(); });
|
|
m.def("_get_tracing_state", []() { return getTracingState(); });
|
|
m.def("_set_tracing_state", [](std::shared_ptr<TracingState> state) {
|
|
return setTracingState(state);
|
|
});
|
|
m.def("_get_value_trace", [](const Variable& var) {
|
|
return getValueTrace(var);
|
|
});
|
|
m.def("_set_value_trace", [](const Variable& var, Value* value) {
|
|
return setValueTrace(var, value);
|
|
});
|
|
m.def("_tracer_set_get_unique_name_fn", [](py::function func) {
|
|
const auto& tracing_state = getTracingState();
|
|
AT_ASSERT(tracing_state);
|
|
tracing_state->lookup_var_name_fn =
|
|
[func](const Variable& var) -> std::string {
|
|
AutoGIL ag;
|
|
return py::cast<std::string>(func(var));
|
|
};
|
|
});
|
|
m.def("_tracer_set_force_outplace", [](bool force_outplace) {
|
|
const auto& tracing_state = getTracingState();
|
|
AT_ASSERT(tracing_state);
|
|
tracing_state->force_outplace = force_outplace;
|
|
});
|
|
}
|
|
|
|
} // namespace tracer
|
|
} // namespace jit
|
|
} // namespace torch
|