mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Revert "[Profiler] Fix lost C call events problem in Python 3.12.0-3.12.4 (#155446)"
This reverts commit da94023b0205bf98c3da366f2f86e0a443f4db17. Reverted https://github.com/pytorch/pytorch/pull/155446 on behalf of https://github.com/ZainRizvi due to Sorry but this is breaking internally. @sraikund16 can you please help validate the fix? (See D78845227 for details). You can follow the instructions here: https://fburl.com/fixing-ghfirst-reverts ([comment](https://github.com/pytorch/pytorch/pull/155446#issuecomment-3115072504))
This commit is contained in:
@ -1,68 +0,0 @@
|
||||
# Owner(s): ["oncall: profiler"]
|
||||
|
||||
import json
|
||||
import sys
|
||||
import time
|
||||
|
||||
from torch.profiler import profile, ProfilerActivity
|
||||
from torch.testing._internal.common_utils import (
|
||||
run_tests,
|
||||
skipIfPythonVersionMismatch,
|
||||
TemporaryFileName,
|
||||
TestCase,
|
||||
)
|
||||
|
||||
|
||||
class TestPythonTracer(TestCase):
|
||||
@skipIfPythonVersionMismatch(lambda major, minor, micro: major == 3 and minor == 12)
|
||||
def test_method_with_c_function(self):
|
||||
class A:
|
||||
method_with_c_function = classmethod(repr)
|
||||
|
||||
def get_key(x):
|
||||
A().method_with_c_function()
|
||||
time.sleep(1.2)
|
||||
return len(x)
|
||||
|
||||
names = ["Alice", "Bob"]
|
||||
|
||||
with profile(
|
||||
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], with_stack=True
|
||||
) as prof:
|
||||
sorted(names, key=get_key)
|
||||
|
||||
with TemporaryFileName(mode="w+") as fname:
|
||||
prof.export_chrome_trace(fname)
|
||||
with open(fname) as f:
|
||||
events = json.load(f)["traceEvents"]
|
||||
found = False
|
||||
for event in events:
|
||||
if (
|
||||
event.get("cat", "") == "python_function"
|
||||
and event.get("name", "") == "<built-in function sorted>"
|
||||
):
|
||||
duration = event.get("dur", 0)
|
||||
if duration >= 2000000:
|
||||
found = True
|
||||
break
|
||||
self.assertTrue(found)
|
||||
|
||||
@skipIfPythonVersionMismatch(lambda major, minor, micro: major == 3 and minor == 12)
|
||||
def test_monitoring_callback(self):
|
||||
vi = sys.version_info
|
||||
from sys import monitoring
|
||||
|
||||
with profile(
|
||||
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], with_stack=True
|
||||
):
|
||||
name = monitoring.get_tool(2)
|
||||
if vi.micro < 5:
|
||||
self.assertEqual(name, "PyTorch Profiler")
|
||||
else:
|
||||
self.assertEqual(name, None)
|
||||
name = monitoring.get_tool(2)
|
||||
self.assertEqual(name, None)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
@ -679,17 +679,6 @@ struct ThreadLocalResults {
|
||||
// ============================================================================
|
||||
// == Tracing implementation ==================================================
|
||||
// ============================================================================
|
||||
#define IS_PYTHON_3_12 (PY_MAJOR_VERSION == 3 && PY_MINOR_VERSION == 12)
|
||||
#if IS_PYTHON_3_12
|
||||
// forward declarations
|
||||
struct _PyEventHandler;
|
||||
static PyObject* c_call_callback(
|
||||
_PyEventHandler* self,
|
||||
PyObject* const* args,
|
||||
size_t nargsf,
|
||||
PyObject* kwnames);
|
||||
#endif
|
||||
|
||||
class PythonTracer final : public python_tracer::PythonTracerBase {
|
||||
public:
|
||||
PythonTracer(torch::profiler::impl::RecordQueue* queue);
|
||||
@ -728,6 +717,8 @@ class PythonTracer final : public python_tracer::PythonTracerBase {
|
||||
|
||||
const std::vector<PyThreadState*> interpreterThreads() const;
|
||||
|
||||
PyObject* get_callable_from_frame(PyFrameObject* frame);
|
||||
|
||||
std::atomic<bool> active_lock_{false};
|
||||
bool active_{false};
|
||||
|
||||
@ -739,211 +730,8 @@ class PythonTracer final : public python_tracer::PythonTracerBase {
|
||||
std::vector<StartFrame> start_frames_;
|
||||
std::deque<ThreadLocalResults> thread_local_results_;
|
||||
ValueCache value_cache_;
|
||||
|
||||
#if IS_PYTHON_3_12
|
||||
friend PyObject* c_call_callback(
|
||||
_PyEventHandler* self,
|
||||
PyObject* const* args,
|
||||
size_t nargsf,
|
||||
PyObject* kwnames);
|
||||
#endif
|
||||
};
|
||||
|
||||
#if IS_PYTHON_3_12
|
||||
#define PROFILER_ID 2
|
||||
#define PY_MONITORING_EVENT_CALL 4
|
||||
|
||||
static bool should_compensate_c_call_events() {
|
||||
static const bool result = []() {
|
||||
const char* version = Py_GetVersion();
|
||||
const char micro = version[5];
|
||||
return micro == '0' || (micro <= '4' && version[6] == ' ');
|
||||
}();
|
||||
return result;
|
||||
}
|
||||
|
||||
struct _PyEventHandler {
|
||||
PyObject_HEAD
|
||||
vectorcallfunc vectorcall;
|
||||
};
|
||||
|
||||
static PyTypeObject _PyEventHandler_Type = {
|
||||
PyVarObject_HEAD_INIT(&PyType_Type, 0)
|
||||
"torch.profiler.python_tracer_event_handler",
|
||||
sizeof(_PyEventHandler),
|
||||
.tp_dealloc = (destructor)PyObject_Free,
|
||||
.tp_vectorcall_offset = offsetof(_PyEventHandler, vectorcall),
|
||||
.tp_call = PyVectorcall_Call,
|
||||
.tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE |
|
||||
Py_TPFLAGS_HAVE_VECTORCALL | Py_TPFLAGS_DISALLOW_INSTANTIATION,
|
||||
};
|
||||
|
||||
static PyObject* c_call_callback(
|
||||
_PyEventHandler* self,
|
||||
PyObject* const* args,
|
||||
size_t nargsf,
|
||||
PyObject* kwnames) {
|
||||
// The logic of this function is based on sys_defile_call_or_return defined
|
||||
// in https://github.com/python/cpython/blob/v3.12.5/Python/legacy_tracing.c
|
||||
|
||||
PyThreadState* tstate = PyThreadState_GET();
|
||||
if (tstate->c_profilefunc != PythonTracer::pyProfileFn) {
|
||||
// We don't care this case if tstate->c_profilefunc is not pyProfileFn,
|
||||
// just return normally.
|
||||
Py_RETURN_NONE;
|
||||
}
|
||||
|
||||
PyObject* callable = args[2];
|
||||
if (Py_TYPE(callable) == &PyMethod_Type) {
|
||||
// The call event of a method with c function is missing on 3.12.0-3.12.4.
|
||||
// See
|
||||
// https://github.com/python/cpython/commit/257c413cd16ddabcedde413288d0bb93bf872da7
|
||||
// Other cases have already be handled by the legacy_tracing, so we only
|
||||
// need to handle this case.
|
||||
// The exception branches keep the same behavior as CPython.
|
||||
PyObject* func = PyMethod_GET_FUNCTION(callable);
|
||||
if (!func) {
|
||||
return NULL;
|
||||
}
|
||||
if (PyCFunction_Check(func)) {
|
||||
PyFrameObject* frame = PyEval_GetFrame();
|
||||
if (!frame) {
|
||||
PyErr_SetString(
|
||||
PyExc_SystemError, "Missing frame when calling profile function.");
|
||||
return NULL;
|
||||
}
|
||||
Py_INCREF(frame);
|
||||
auto& local_results =
|
||||
*reinterpret_cast<TraceContext*>(tstate->c_profileobj)
|
||||
->thread_local_results_;
|
||||
local_results.active_tracer_->recordCCall(local_results, frame, func);
|
||||
Py_DECREF(frame);
|
||||
}
|
||||
}
|
||||
Py_RETURN_NONE;
|
||||
}
|
||||
|
||||
static void registerMonitoringCallback() {
|
||||
if (!should_compensate_c_call_events()) {
|
||||
return;
|
||||
}
|
||||
|
||||
auto sys_module = THPObjectPtr(PyImport_ImportModule("sys"));
|
||||
if (!sys_module) {
|
||||
TORCH_WARN("Failed to import sys module.");
|
||||
PyErr_Clear();
|
||||
return;
|
||||
}
|
||||
auto monitoring =
|
||||
THPObjectPtr(PyObject_GetAttrString(sys_module, "monitoring"));
|
||||
if (!monitoring) {
|
||||
TORCH_WARN("Failed to get monitoring from sys module.");
|
||||
PyErr_Clear();
|
||||
return;
|
||||
}
|
||||
auto result = THPObjectPtr(PyObject_CallMethod(
|
||||
monitoring, "use_tool_id", "is", PROFILER_ID, "PyTorch Profiler"));
|
||||
if (!result) {
|
||||
TORCH_WARN("Failed to call sys.monitoring.use_tool_id");
|
||||
PyErr_Clear();
|
||||
return;
|
||||
}
|
||||
auto handler = THPObjectPtr(PyObject_NEW(PyObject, &_PyEventHandler_Type));
|
||||
if (!handler) {
|
||||
TORCH_WARN("Failed to create _PyEventHandler object.");
|
||||
PyErr_Clear();
|
||||
return;
|
||||
}
|
||||
reinterpret_cast<_PyEventHandler*>(handler.get())->vectorcall =
|
||||
(vectorcallfunc)c_call_callback;
|
||||
result = THPObjectPtr(PyObject_CallMethod(
|
||||
monitoring,
|
||||
"register_callback",
|
||||
"iiO",
|
||||
PROFILER_ID,
|
||||
1 << PY_MONITORING_EVENT_CALL,
|
||||
handler.get()));
|
||||
if (!result) {
|
||||
TORCH_WARN("Failed to call sys.monitoring.register_callback.");
|
||||
PyErr_Clear();
|
||||
return;
|
||||
}
|
||||
result = THPObjectPtr(PyObject_CallMethod(
|
||||
monitoring,
|
||||
"set_events",
|
||||
"ii",
|
||||
PROFILER_ID,
|
||||
1 << PY_MONITORING_EVENT_CALL));
|
||||
if (!result) {
|
||||
TORCH_WARN("Failed to call sys.monitoring.set_events.");
|
||||
PyErr_Clear();
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
static void unregisterMonitoringCallback() {
|
||||
if (!should_compensate_c_call_events()) {
|
||||
return;
|
||||
}
|
||||
|
||||
auto sys_module = THPObjectPtr(PyImport_ImportModule("sys"));
|
||||
if (!sys_module) {
|
||||
TORCH_WARN("Failed to import sys module.");
|
||||
PyErr_Clear();
|
||||
return;
|
||||
}
|
||||
auto monitoring =
|
||||
THPObjectPtr(PyObject_GetAttrString(sys_module, "monitoring"));
|
||||
if (!monitoring) {
|
||||
TORCH_WARN("Failed to get monitoring from sys module.");
|
||||
PyErr_Clear();
|
||||
return;
|
||||
}
|
||||
auto tool_name = THPObjectPtr(
|
||||
PyObject_CallMethod(monitoring, "get_tool", "i", PROFILER_ID));
|
||||
if (!tool_name) {
|
||||
TORCH_WARN("Failed to call sys.monitoring.use_tool_id");
|
||||
PyErr_Clear();
|
||||
return;
|
||||
}
|
||||
if (!THPUtils_checkString(tool_name)) {
|
||||
return;
|
||||
}
|
||||
const char* str = THPUtils_unpackStringView(tool_name).data();
|
||||
if (strcmp(str, "PyTorch Profiler") != 0) {
|
||||
return;
|
||||
}
|
||||
auto none = THPObjectPtr(Py_None);
|
||||
Py_INCREF(Py_None);
|
||||
auto result = THPObjectPtr(PyObject_CallMethod(
|
||||
monitoring,
|
||||
"register_callback",
|
||||
"iiO",
|
||||
PROFILER_ID,
|
||||
1 << PY_MONITORING_EVENT_CALL,
|
||||
none.get()));
|
||||
if (!result) {
|
||||
TORCH_WARN("Failed to call sys.monitoring.register_callback.");
|
||||
PyErr_Clear();
|
||||
return;
|
||||
}
|
||||
result = THPObjectPtr(
|
||||
PyObject_CallMethod(monitoring, "set_events", "ii", PROFILER_ID, 0));
|
||||
if (!result) {
|
||||
TORCH_WARN("Failed to call sys.monitoring.set_events.");
|
||||
PyErr_Clear();
|
||||
return;
|
||||
}
|
||||
result = THPObjectPtr(
|
||||
PyObject_CallMethod(monitoring, "free_tool_id", "i", PROFILER_ID));
|
||||
if (!result) {
|
||||
TORCH_WARN("Failed to call sys.monitoring.free_tool_id.");
|
||||
PyErr_Clear();
|
||||
return;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
const std::vector<PyThreadState*> PythonTracer::interpreterThreads() const {
|
||||
pybind11::gil_scoped_acquire gil;
|
||||
std::vector<PyThreadState*> out;
|
||||
@ -1009,6 +797,16 @@ PythonTracer::PythonTracer(torch::profiler::impl::RecordQueue* queue)
|
||||
|
||||
for (auto it = current_stack.rbegin(); it != current_stack.rend(); it++) {
|
||||
recordPyCall(thread_local_results_.back(), it->get(), true);
|
||||
PyFrameObject* frame = it->get();
|
||||
PyObject* callable = get_callable_from_frame(frame);
|
||||
if (callable) {
|
||||
// If the frame has a callable, record it as a C call since
|
||||
// PyEval_GetFrame only gets the python frame. We need to record this C
|
||||
// call so that when exiting the profiler we don't have a mismatched C
|
||||
// call.
|
||||
recordCCall(thread_local_results_.back(), it->get(), callable, true);
|
||||
}
|
||||
|
||||
auto frame_refcount = Py_REFCNT(it->get());
|
||||
|
||||
// We hold one reference in `current_stack`, and the interpreter holds
|
||||
@ -1021,9 +819,6 @@ PythonTracer::PythonTracer(torch::profiler::impl::RecordQueue* queue)
|
||||
// cannot be round tripped via `sys.settrace(sys.gettrace())`
|
||||
PyEval_SetProfile(PythonTracer::pyProfileFn, (PyObject*)ctx);
|
||||
}
|
||||
#if IS_PYTHON_3_12
|
||||
registerMonitoringCallback();
|
||||
#endif
|
||||
}
|
||||
|
||||
void PythonTracer::stop() {
|
||||
@ -1036,10 +831,6 @@ void PythonTracer::stop() {
|
||||
}
|
||||
}
|
||||
|
||||
#if IS_PYTHON_3_12
|
||||
unregisterMonitoringCallback();
|
||||
#endif
|
||||
|
||||
auto lock_returned = active_lock_.compare_exchange_strong(active_, false);
|
||||
active_ = false;
|
||||
SOFT_ASSERT(lock_returned, "Failed to return python tracer lock.");
|
||||
@ -1063,9 +854,6 @@ void PythonTracer::restart() {
|
||||
PyEval_SetProfile(PythonTracer::pyProfileFn, (PyObject*)ctx);
|
||||
}
|
||||
}
|
||||
#if IS_PYTHON_3_12
|
||||
registerMonitoringCallback();
|
||||
#endif
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE(bugprone-exception-escape)
|
||||
@ -1149,6 +937,26 @@ void PythonTracer::recordCCall(
|
||||
queue_->getSubqueue()->emplace_py_call(key, c10::getApproximateTime());
|
||||
}
|
||||
|
||||
PyObject* PythonTracer::get_callable_from_frame(PyFrameObject* frame) {
|
||||
if (frame == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
// Get the code object associated with the frame
|
||||
auto code = THPCodeObjectPtr(PyFrame_GetCode(frame));
|
||||
if (code == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
// Get the function name (if needed)
|
||||
auto name = THPUtils_unpackStringView(code->co_name).data();
|
||||
// To get the function object, you will need to look in the globals or the
|
||||
// frame's f_globals
|
||||
PyObject* func = PyDict_GetItemString(PyFrame_GetGlobals(frame), name);
|
||||
if (func) {
|
||||
Py_INCREF(func); // Make sure the returned function has a reference
|
||||
}
|
||||
return func; // Returns a PyObject* (the function)
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// == Post processing =========================================================
|
||||
// ============================================================================
|
||||
@ -1250,7 +1058,9 @@ class PostProcess {
|
||||
state.exits_.top().t_ < enter.enter_t_) {
|
||||
auto& exit = state.exits_.top();
|
||||
auto& tstack = stacks[exit.python_tid_];
|
||||
pop(tstack, exit.t_);
|
||||
if (!tstack.empty()) {
|
||||
pop(tstack, exit.t_);
|
||||
}
|
||||
state.exits_.pop();
|
||||
}
|
||||
out.push_back(Result::create(
|
||||
|
@ -5764,16 +5764,3 @@ def recover_orig_fp32_precision(fn):
|
||||
torch.backends.cuda.matmul.fp32_precision = old_cuda_matmul_p
|
||||
|
||||
return recover()(fn)
|
||||
|
||||
def skipIfPythonVersionMismatch(predicate):
|
||||
vi = sys.version_info
|
||||
|
||||
def dec_fn(fn):
|
||||
@wraps(fn)
|
||||
def wrap_fn(self, *args, **kwargs):
|
||||
if predicate(vi.major, vi.minor, vi.micro):
|
||||
return fn(self, *args, **kwargs)
|
||||
else:
|
||||
raise unittest.SkipTest("Python version mismatch")
|
||||
return wrap_fn
|
||||
return dec_fn
|
||||
|
Reference in New Issue
Block a user