Revert "[Profiler] Fix lost C call events problem in Python 3.12.0-3.12.4 (#155446)"

This reverts commit da94023b0205bf98c3da366f2f86e0a443f4db17.

Reverted https://github.com/pytorch/pytorch/pull/155446 on behalf of https://github.com/ZainRizvi due to Sorry but this is breaking internally. @sraikund16 can you please help validate the fix? (See D78845227 for details). You can follow the instructions here: https://fburl.com/fixing-ghfirst-reverts ([comment](https://github.com/pytorch/pytorch/pull/155446#issuecomment-3115072504))
This commit is contained in:
PyTorch MergeBot
2025-07-24 21:46:00 +00:00
parent e20736bf1d
commit b533f12120
3 changed files with 35 additions and 306 deletions

View File

@ -1,68 +0,0 @@
# Owner(s): ["oncall: profiler"]
import json
import sys
import time
from torch.profiler import profile, ProfilerActivity
from torch.testing._internal.common_utils import (
run_tests,
skipIfPythonVersionMismatch,
TemporaryFileName,
TestCase,
)
class TestPythonTracer(TestCase):
@skipIfPythonVersionMismatch(lambda major, minor, micro: major == 3 and minor == 12)
def test_method_with_c_function(self):
class A:
method_with_c_function = classmethod(repr)
def get_key(x):
A().method_with_c_function()
time.sleep(1.2)
return len(x)
names = ["Alice", "Bob"]
with profile(
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], with_stack=True
) as prof:
sorted(names, key=get_key)
with TemporaryFileName(mode="w+") as fname:
prof.export_chrome_trace(fname)
with open(fname) as f:
events = json.load(f)["traceEvents"]
found = False
for event in events:
if (
event.get("cat", "") == "python_function"
and event.get("name", "") == "<built-in function sorted>"
):
duration = event.get("dur", 0)
if duration >= 2000000:
found = True
break
self.assertTrue(found)
@skipIfPythonVersionMismatch(lambda major, minor, micro: major == 3 and minor == 12)
def test_monitoring_callback(self):
vi = sys.version_info
from sys import monitoring
with profile(
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], with_stack=True
):
name = monitoring.get_tool(2)
if vi.micro < 5:
self.assertEqual(name, "PyTorch Profiler")
else:
self.assertEqual(name, None)
name = monitoring.get_tool(2)
self.assertEqual(name, None)
if __name__ == "__main__":
run_tests()

View File

@ -679,17 +679,6 @@ struct ThreadLocalResults {
// ============================================================================
// == Tracing implementation ==================================================
// ============================================================================
#define IS_PYTHON_3_12 (PY_MAJOR_VERSION == 3 && PY_MINOR_VERSION == 12)
#if IS_PYTHON_3_12
// forward declarations
struct _PyEventHandler;
static PyObject* c_call_callback(
_PyEventHandler* self,
PyObject* const* args,
size_t nargsf,
PyObject* kwnames);
#endif
class PythonTracer final : public python_tracer::PythonTracerBase {
public:
PythonTracer(torch::profiler::impl::RecordQueue* queue);
@ -728,6 +717,8 @@ class PythonTracer final : public python_tracer::PythonTracerBase {
const std::vector<PyThreadState*> interpreterThreads() const;
PyObject* get_callable_from_frame(PyFrameObject* frame);
std::atomic<bool> active_lock_{false};
bool active_{false};
@ -739,211 +730,8 @@ class PythonTracer final : public python_tracer::PythonTracerBase {
std::vector<StartFrame> start_frames_;
std::deque<ThreadLocalResults> thread_local_results_;
ValueCache value_cache_;
#if IS_PYTHON_3_12
friend PyObject* c_call_callback(
_PyEventHandler* self,
PyObject* const* args,
size_t nargsf,
PyObject* kwnames);
#endif
};
#if IS_PYTHON_3_12
#define PROFILER_ID 2
#define PY_MONITORING_EVENT_CALL 4
static bool should_compensate_c_call_events() {
static const bool result = []() {
const char* version = Py_GetVersion();
const char micro = version[5];
return micro == '0' || (micro <= '4' && version[6] == ' ');
}();
return result;
}
struct _PyEventHandler {
PyObject_HEAD
vectorcallfunc vectorcall;
};
static PyTypeObject _PyEventHandler_Type = {
PyVarObject_HEAD_INIT(&PyType_Type, 0)
"torch.profiler.python_tracer_event_handler",
sizeof(_PyEventHandler),
.tp_dealloc = (destructor)PyObject_Free,
.tp_vectorcall_offset = offsetof(_PyEventHandler, vectorcall),
.tp_call = PyVectorcall_Call,
.tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE |
Py_TPFLAGS_HAVE_VECTORCALL | Py_TPFLAGS_DISALLOW_INSTANTIATION,
};
static PyObject* c_call_callback(
_PyEventHandler* self,
PyObject* const* args,
size_t nargsf,
PyObject* kwnames) {
// The logic of this function is based on sys_defile_call_or_return defined
// in https://github.com/python/cpython/blob/v3.12.5/Python/legacy_tracing.c
PyThreadState* tstate = PyThreadState_GET();
if (tstate->c_profilefunc != PythonTracer::pyProfileFn) {
// We don't care this case if tstate->c_profilefunc is not pyProfileFn,
// just return normally.
Py_RETURN_NONE;
}
PyObject* callable = args[2];
if (Py_TYPE(callable) == &PyMethod_Type) {
// The call event of a method with c function is missing on 3.12.0-3.12.4.
// See
// https://github.com/python/cpython/commit/257c413cd16ddabcedde413288d0bb93bf872da7
// Other cases have already be handled by the legacy_tracing, so we only
// need to handle this case.
// The exception branches keep the same behavior as CPython.
PyObject* func = PyMethod_GET_FUNCTION(callable);
if (!func) {
return NULL;
}
if (PyCFunction_Check(func)) {
PyFrameObject* frame = PyEval_GetFrame();
if (!frame) {
PyErr_SetString(
PyExc_SystemError, "Missing frame when calling profile function.");
return NULL;
}
Py_INCREF(frame);
auto& local_results =
*reinterpret_cast<TraceContext*>(tstate->c_profileobj)
->thread_local_results_;
local_results.active_tracer_->recordCCall(local_results, frame, func);
Py_DECREF(frame);
}
}
Py_RETURN_NONE;
}
static void registerMonitoringCallback() {
if (!should_compensate_c_call_events()) {
return;
}
auto sys_module = THPObjectPtr(PyImport_ImportModule("sys"));
if (!sys_module) {
TORCH_WARN("Failed to import sys module.");
PyErr_Clear();
return;
}
auto monitoring =
THPObjectPtr(PyObject_GetAttrString(sys_module, "monitoring"));
if (!monitoring) {
TORCH_WARN("Failed to get monitoring from sys module.");
PyErr_Clear();
return;
}
auto result = THPObjectPtr(PyObject_CallMethod(
monitoring, "use_tool_id", "is", PROFILER_ID, "PyTorch Profiler"));
if (!result) {
TORCH_WARN("Failed to call sys.monitoring.use_tool_id");
PyErr_Clear();
return;
}
auto handler = THPObjectPtr(PyObject_NEW(PyObject, &_PyEventHandler_Type));
if (!handler) {
TORCH_WARN("Failed to create _PyEventHandler object.");
PyErr_Clear();
return;
}
reinterpret_cast<_PyEventHandler*>(handler.get())->vectorcall =
(vectorcallfunc)c_call_callback;
result = THPObjectPtr(PyObject_CallMethod(
monitoring,
"register_callback",
"iiO",
PROFILER_ID,
1 << PY_MONITORING_EVENT_CALL,
handler.get()));
if (!result) {
TORCH_WARN("Failed to call sys.monitoring.register_callback.");
PyErr_Clear();
return;
}
result = THPObjectPtr(PyObject_CallMethod(
monitoring,
"set_events",
"ii",
PROFILER_ID,
1 << PY_MONITORING_EVENT_CALL));
if (!result) {
TORCH_WARN("Failed to call sys.monitoring.set_events.");
PyErr_Clear();
return;
}
}
static void unregisterMonitoringCallback() {
if (!should_compensate_c_call_events()) {
return;
}
auto sys_module = THPObjectPtr(PyImport_ImportModule("sys"));
if (!sys_module) {
TORCH_WARN("Failed to import sys module.");
PyErr_Clear();
return;
}
auto monitoring =
THPObjectPtr(PyObject_GetAttrString(sys_module, "monitoring"));
if (!monitoring) {
TORCH_WARN("Failed to get monitoring from sys module.");
PyErr_Clear();
return;
}
auto tool_name = THPObjectPtr(
PyObject_CallMethod(monitoring, "get_tool", "i", PROFILER_ID));
if (!tool_name) {
TORCH_WARN("Failed to call sys.monitoring.use_tool_id");
PyErr_Clear();
return;
}
if (!THPUtils_checkString(tool_name)) {
return;
}
const char* str = THPUtils_unpackStringView(tool_name).data();
if (strcmp(str, "PyTorch Profiler") != 0) {
return;
}
auto none = THPObjectPtr(Py_None);
Py_INCREF(Py_None);
auto result = THPObjectPtr(PyObject_CallMethod(
monitoring,
"register_callback",
"iiO",
PROFILER_ID,
1 << PY_MONITORING_EVENT_CALL,
none.get()));
if (!result) {
TORCH_WARN("Failed to call sys.monitoring.register_callback.");
PyErr_Clear();
return;
}
result = THPObjectPtr(
PyObject_CallMethod(monitoring, "set_events", "ii", PROFILER_ID, 0));
if (!result) {
TORCH_WARN("Failed to call sys.monitoring.set_events.");
PyErr_Clear();
return;
}
result = THPObjectPtr(
PyObject_CallMethod(monitoring, "free_tool_id", "i", PROFILER_ID));
if (!result) {
TORCH_WARN("Failed to call sys.monitoring.free_tool_id.");
PyErr_Clear();
return;
}
}
#endif
const std::vector<PyThreadState*> PythonTracer::interpreterThreads() const {
pybind11::gil_scoped_acquire gil;
std::vector<PyThreadState*> out;
@ -1009,6 +797,16 @@ PythonTracer::PythonTracer(torch::profiler::impl::RecordQueue* queue)
for (auto it = current_stack.rbegin(); it != current_stack.rend(); it++) {
recordPyCall(thread_local_results_.back(), it->get(), true);
PyFrameObject* frame = it->get();
PyObject* callable = get_callable_from_frame(frame);
if (callable) {
// If the frame has a callable, record it as a C call since
// PyEval_GetFrame only gets the python frame. We need to record this C
// call so that when exiting the profiler we don't have a mismatched C
// call.
recordCCall(thread_local_results_.back(), it->get(), callable, true);
}
auto frame_refcount = Py_REFCNT(it->get());
// We hold one reference in `current_stack`, and the interpreter holds
@ -1021,9 +819,6 @@ PythonTracer::PythonTracer(torch::profiler::impl::RecordQueue* queue)
// cannot be round tripped via `sys.settrace(sys.gettrace())`
PyEval_SetProfile(PythonTracer::pyProfileFn, (PyObject*)ctx);
}
#if IS_PYTHON_3_12
registerMonitoringCallback();
#endif
}
void PythonTracer::stop() {
@ -1036,10 +831,6 @@ void PythonTracer::stop() {
}
}
#if IS_PYTHON_3_12
unregisterMonitoringCallback();
#endif
auto lock_returned = active_lock_.compare_exchange_strong(active_, false);
active_ = false;
SOFT_ASSERT(lock_returned, "Failed to return python tracer lock.");
@ -1063,9 +854,6 @@ void PythonTracer::restart() {
PyEval_SetProfile(PythonTracer::pyProfileFn, (PyObject*)ctx);
}
}
#if IS_PYTHON_3_12
registerMonitoringCallback();
#endif
}
// NOLINTNEXTLINE(bugprone-exception-escape)
@ -1149,6 +937,26 @@ void PythonTracer::recordCCall(
queue_->getSubqueue()->emplace_py_call(key, c10::getApproximateTime());
}
PyObject* PythonTracer::get_callable_from_frame(PyFrameObject* frame) {
if (frame == nullptr) {
return nullptr;
}
// Get the code object associated with the frame
auto code = THPCodeObjectPtr(PyFrame_GetCode(frame));
if (code == nullptr) {
return nullptr;
}
// Get the function name (if needed)
auto name = THPUtils_unpackStringView(code->co_name).data();
// To get the function object, you will need to look in the globals or the
// frame's f_globals
PyObject* func = PyDict_GetItemString(PyFrame_GetGlobals(frame), name);
if (func) {
Py_INCREF(func); // Make sure the returned function has a reference
}
return func; // Returns a PyObject* (the function)
}
// ============================================================================
// == Post processing =========================================================
// ============================================================================
@ -1250,7 +1058,9 @@ class PostProcess {
state.exits_.top().t_ < enter.enter_t_) {
auto& exit = state.exits_.top();
auto& tstack = stacks[exit.python_tid_];
pop(tstack, exit.t_);
if (!tstack.empty()) {
pop(tstack, exit.t_);
}
state.exits_.pop();
}
out.push_back(Result::create(

View File

@ -5764,16 +5764,3 @@ def recover_orig_fp32_precision(fn):
torch.backends.cuda.matmul.fp32_precision = old_cuda_matmul_p
return recover()(fn)
def skipIfPythonVersionMismatch(predicate):
vi = sys.version_info
def dec_fn(fn):
@wraps(fn)
def wrap_fn(self, *args, **kwargs):
if predicate(vi.major, vi.minor, vi.micro):
return fn(self, *args, **kwargs)
else:
raise unittest.SkipTest("Python version mismatch")
return wrap_fn
return dec_fn