[Profiler] Add GC Events to Python Stack Tracer (#161209)

Summary:
Adds Python Garbage Collection to Kineto Traces and Profiler FunctionEvents. Create custom cpp callback in profiler_python.cpp. Then define a python function with cpp and register that callback for all python garbage collection. We don't worry about thread safety in this case because we are only doing init/teardown for main thread while holding GIL.

Currently we are hiding this behind experimental config because python tracing tends to be unstable especially when adding any new feature. If this is found to not add too much overhead we can set this to on by default. NOTE: To enable this you need both with_stack=True and the experimental config on!

Test Plan:
Ran trace with GC induced and saw it on trace

Also added a test

Rollback Plan:

Differential Revision: D80491146

Pull Request resolved: https://github.com/pytorch/pytorch/pull/161209
Approved by: https://github.com/ngimel
This commit is contained in:
Shivam Raikundalia
2025-08-22 22:11:25 +00:00
committed by PyTorch MergeBot
parent c8bb0e4720
commit 3373b074f5
9 changed files with 228 additions and 3 deletions

View File

@ -2336,6 +2336,74 @@ assert KinetoStepTracker.current_step() == initial_step + 2 * niters
events = main_with_thread_fn(profile_all_threads)
verify_events(events)
@skipIfTorchDynamo("profiler gets ignored if dynamo activated")
@unittest.skipIf(not kineto_available(), "Kineto is required")
def test_python_gc_event(self):
activities = [ProfilerActivity.CPU]
def payload():
x = torch.randn(10, 10)
y = torch.randn(10, 10)
with record_function("pre_gc"):
torch.mm(x, y)
gc.collect()
with record_function("post_gc"):
torch.mm(x, y)
def validate_json(prof, gc_collection_on):
with TemporaryFileName(mode="w+") as fname:
prof.export_chrome_trace(fname)
with open(fname) as f:
events = json.load(f)["traceEvents"]
# Find required events
if gc_collection_on:
pre_gc = next(
(e for e in events if e["name"] == "pre_gc"), None
)
post_gc = next(
(e for e in events if e["name"] == "post_gc"), None
)
python_gc_events = [
e for e in events if e["name"] == "Python GC"
]
# Assert all required events are present
self.assertIsNotNone(pre_gc, "pre_gc event is missing")
self.assertIsNotNone(post_gc, "post_gc event is missing")
self.assertTrue(
len(python_gc_events) > 0, "No Python GC events found"
)
# Calculate boundaries
pre_gc_end = pre_gc["ts"] + pre_gc.get("dur", 0)
post_gc_start = post_gc["ts"]
# Assert each Python GC event is correctly placed
for python_gc in python_gc_events:
python_gc_start = python_gc["ts"]
python_gc_end = python_gc["ts"] + python_gc.get("dur", 0)
self.assertTrue(
python_gc_start > pre_gc_end
and python_gc_end < post_gc_start,
f"Python GC event at {python_gc_start} is not correctly placed.",
)
else:
python_gc_events = [
e for e in events if e["name"] == "Python GC"
]
self.assertTrue(
len(python_gc_events) == 0,
"Python GC event found when flag off",
)
for gc_flag in [True, False]:
with profile(
activities=activities,
experimental_config=torch._C._profiler._ExperimentalConfig(
record_python_gc_info=gc_flag
),
with_stack=True,
) as prof:
payload()
validate_json(prof, gc_flag)
class SimpleNet(nn.Module):
def __init__(self) -> None:

View File

@ -704,7 +704,7 @@ class PythonTracer final : public python_tracer::PythonTracerBase {
PyFrameObject* frame,
int what,
PyObject* arg);
void register_gc_callback() override;
void stop() override;
void restart() override;
std::vector<std::shared_ptr<Result>> getEvents(
@ -723,6 +723,8 @@ class PythonTracer final : public python_tracer::PythonTracerBase {
PyFrameObject* frame,
bool is_startup_frame);
static PyObject* gc_event_callback(PyObject* self, PyObject* args);
void recordCCall(
ThreadLocalResults& tls,
PyFrameObject* frame,
@ -733,6 +735,7 @@ class PythonTracer final : public python_tracer::PythonTracerBase {
std::atomic<bool> active_lock_{false};
bool active_{false};
bool gc_callback_registered_{false};
torch::profiler::impl::RecordQueue* queue_;
PyInterpreterState* interpreter_{nullptr};
@ -973,6 +976,27 @@ const std::vector<PyThreadState*> PythonTracer::interpreterThreads() const {
return out;
}
// we are only registering on main thread while holding GIL so this should be
// safe
static PyObject* py_gc_callback = nullptr;
// The C function to be called by Python's GC
PyObject* PythonTracer::gc_event_callback(PyObject* self, PyObject* args) {
const char* phase;
PyObject* info;
if (!PyArg_ParseTuple(args, "sO", &phase, &info)) {
return nullptr;
}
PythonTracer* instance =
reinterpret_cast<PythonTracer*>(PyCapsule_GetPointer(self, nullptr));
if (!instance) {
PyErr_SetString(PyExc_RuntimeError, "Invalid tracer instance");
return nullptr;
}
instance->queue_->getSubqueue()->emplace_gc_call(
phase, c10::getApproximateTime());
Py_RETURN_NONE;
}
PythonTracer::PythonTracer(torch::profiler::impl::RecordQueue* queue)
: queue_(queue),
@ -1045,8 +1069,74 @@ PythonTracer::PythonTracer(torch::profiler::impl::RecordQueue* queue)
#endif
}
void unregister_gc_callback() {
PyGILState_STATE gstate = PyGILState_Ensure();
PyObject* gc_module = PyImport_ImportModule("gc");
if (!gc_module) {
PyErr_Print();
PyGILState_Release(gstate);
return;
}
PyObject* callbacks = PyObject_GetAttrString(gc_module, "callbacks");
if (!callbacks || !PyList_Check(callbacks)) {
PyErr_Print();
Py_XDECREF(gc_module);
Py_XDECREF(callbacks);
PyGILState_Release(gstate);
return;
}
Py_ssize_t idx = PySequence_Index(callbacks, py_gc_callback);
if (idx >= 0) {
PySequence_DelItem(callbacks, idx);
} else {
// Not found, maybe already removed
}
Py_DECREF(callbacks);
Py_DECREF(gc_module);
Py_XDECREF(py_gc_callback);
py_gc_callback = nullptr;
PyGILState_Release(gstate);
}
void PythonTracer::register_gc_callback() {
PyGILState_STATE gstate = PyGILState_Ensure();
PyObject* gc_module = PyImport_ImportModule("gc");
if (!gc_module) {
PyErr_Print();
PyGILState_Release(gstate);
return;
}
PyObject* callbacks = PyObject_GetAttrString(gc_module, "callbacks");
if (!callbacks || !PyList_Check(callbacks)) {
PyErr_Print();
Py_XDECREF(gc_module);
Py_XDECREF(callbacks);
PyGILState_Release(gstate);
return;
}
static PyMethodDef method_def = {
"gc_event_callback",
(PyCFunction)gc_event_callback,
METH_VARARGS,
nullptr};
PyObject* capsule = PyCapsule_New(this, nullptr, nullptr);
py_gc_callback = PyCFunction_New(&method_def, capsule);
Py_DECREF(capsule); // PyCFunction_New increments refcount
if (PyList_Append(callbacks, py_gc_callback) < 0) {
PyErr_Print();
}
gc_callback_registered_ = true;
Py_DECREF(callbacks);
Py_DECREF(gc_module);
PyGILState_Release(gstate);
}
void PythonTracer::stop() {
gil_and_restore_thread gil;
if (gc_callback_registered_) {
unregister_gc_callback();
gc_callback_registered_ = false;
}
if (active_) {
for (const auto thread_state : interpreterThreads()) {
if (thread_state->c_profilefunc == &PythonTracer::pyProfileFn) {

View File

@ -613,6 +613,7 @@ std::string Result::name() const {
ATTRIBUTE(OutOfMemory, std::string("[OutOfMemory]")),
ATTRIBUTE(PyCall, toString(e)),
ATTRIBUTE(PyCCall, std::string(e.function_name_.str())),
ATTRIBUTE(PythonGC, std::string("Python GC")),
[](const auto& e) -> std::string { return e.name_; }));
}
@ -631,6 +632,7 @@ libkineto::ActivityType Result::kinetoType() const {
ATTRIBUTE(OutOfMemory, libkineto::ActivityType::CPU_INSTANT_EVENT),
ATTRIBUTE(PyCall, libkineto::ActivityType::PYTHON_FUNCTION),
ATTRIBUTE(PyCCall, libkineto::ActivityType::PYTHON_FUNCTION),
ATTRIBUTE(PythonGC, libkineto::ActivityType::PYTHON_FUNCTION),
ATTRIBUTE(Kineto, e.activity_type_)));
}
@ -650,6 +652,7 @@ int64_t Result::endTimeNS() const {
ATTRIBUTE(Allocation, start_time_ns_),
ATTRIBUTE(OutOfMemory, start_time_ns_),
ATTRIBUTE(Kineto, start_time_ns_ + e.duration_ns_),
ATTRIBUTE(PythonGC, start_time_ns_ + e.duration_ns_),
[&](const auto& e) -> int64_t { return e.end_time_ns_; }));
// In rare cases we're willing to tolerate ops which are missing an end time
@ -700,6 +703,9 @@ RecordQueue::RecordQueue(
activities_{std::move(activities)} {
if (tracePython()) {
python_tracer_ = python_tracer::PythonTracerBase::make(this);
if (getPythonGcEvents()) {
python_tracer_->register_gc_callback();
}
}
}
@ -707,6 +713,10 @@ bool RecordQueue::tracePython() const {
return config_.with_stack && activities_.count(ActivityType::CPU);
}
bool RecordQueue::getPythonGcEvents() const {
return config_.experimental_config.record_python_gc_info;
}
ThreadLocalSubqueue* RecordQueue::getSubqueue() {
// In the most common case, a thread will want to write to the same sub-queue
// that it wrote to last call. The only time that isn't true is if:
@ -1488,6 +1498,31 @@ RecordQueue::getRecords(
queue.allocations_.clear();
materialize(queue.ooms_);
std::optional<int64_t> pending_start;
for (auto& e : queue.pythongc_) {
if (e.first.find("start") != std::string::npos) {
pending_start = e.second;
} else if (e.first.find("stop") != std::string::npos) {
if (pending_start.has_value()) {
out.emplace_back(Result::create(
/*start_time_ns_=*/converter(pending_start.value()),
/*start_tid_=*/queue.tid(),
/*kineto_info_=*/queue.kineto_info(),
/*extra_fields_=*/
// NOLINTNEXTLINE
ExtraFields<EventType::PythonGC>{
e.first,
converter(e.second) - converter(pending_start.value())}));
pending_start.reset();
} else {
// Handle the case where "stop" is found without a matching "start"
// For example, you might want to log a warning or take other action:
LOG(WARNING) << R"("stop" event found without a matching "start": )"
<< e.first;
}
}
}
for (auto& i : queue.py_calls_) {
python_enters.push_back(
{i.first, queue.tid(), queue.kineto_info(), converter(i.second)});

View File

@ -34,7 +34,8 @@ enum class EventType : uint8_t {
OutOfMemory,
PyCall,
PyCCall,
Kineto
Kineto,
PythonGC
};
// ============================================================================
@ -191,6 +192,12 @@ struct ExtraFields<EventType::Backend> {
jit_modules_t jit_modules_;
};
template <>
struct ExtraFields<EventType::PythonGC> {
std::string phase;
int64_t duration_ns_;
};
template <>
struct ExtraFields<EventType::Vulkan> {
using raw_event_t = std::pair<c10::approx_time_t, vulkan_id_t>;
@ -415,7 +422,8 @@ struct TORCH_API Result : public std::enable_shared_from_this<Result> {
ExtraFields<EventType::OutOfMemory>,
ExtraFields<EventType::PyCall>,
ExtraFields<EventType::PyCCall>,
ExtraFields<EventType::Kineto>>
ExtraFields<EventType::Kineto>,
ExtraFields<EventType::PythonGC>>
extra_fields_;
std::weak_ptr<Result> parent_;
@ -549,6 +557,11 @@ class TORCH_API ThreadLocalSubqueue {
py_calls_.emplace_back(std::forward<Args>(args)...);
}
template <class... Args>
void emplace_gc_call(Args&&... args) {
pythongc_.emplace_back(std::forward<Args>(args)...);
}
uint64_t tid() const {
return tid_;
}
@ -639,6 +652,9 @@ class TORCH_API ThreadLocalSubqueue {
std::pair<python_tracer::TraceKey, c10::approx_time_t>,
BlockSize>
py_calls_;
// gc with_stack (Python)
AppendOnlyList<std::pair<std::string, c10::approx_time_t>, BlockSize>
pythongc_;
};
class TORCH_API RecordQueue {
@ -646,6 +662,7 @@ class TORCH_API RecordQueue {
RecordQueue(ProfilerConfig config, std::set<ActivityType> activities);
bool tracePython() const;
bool getPythonGcEvents() const;
ThreadLocalSubqueue* getSubqueue();
void stop();
void restart();

View File

@ -21,6 +21,7 @@ ExperimentalConfig::ExperimentalConfig(
bool disable_external_correlation,
bool profile_all_threads,
bool capture_overload_names,
bool record_python_gc_info,
std::string custom_profiler_config,
bool adjust_timestamps)
: profiler_metrics{std::move(profiler_metrics)},
@ -32,6 +33,7 @@ ExperimentalConfig::ExperimentalConfig(
disable_external_correlation{disable_external_correlation},
profile_all_threads{profile_all_threads},
capture_overload_names{capture_overload_names},
record_python_gc_info{record_python_gc_info},
custom_profiler_config(std::move(custom_profiler_config)),
adjust_timestamps{adjust_timestamps} {}

View File

@ -62,6 +62,7 @@ struct TORCH_API ExperimentalConfig {
bool disable_external_correlation = false,
bool profile_all_threads = false,
bool capture_overload_names = false,
bool record_python_gc_info = false,
std::string custom_profiler_config = "",
bool adjust_timestamps = false);
explicit operator bool() const;
@ -102,6 +103,12 @@ struct TORCH_API ExperimentalConfig {
* function schema and stored in the profile */
bool capture_overload_names;
/*
* Controls whether or not python gc info is recorded. This is used to
* determine if gc collect is slowing down your profile.
*/
bool record_python_gc_info;
/*
* A custom_profiler_config option is introduced to allow custom backends
* to apply custom configurations as needed.

View File

@ -11,6 +11,7 @@ struct NoOpPythonTracer : public PythonTracerBase {
void stop() override {}
void restart() override {}
void register_gc_callback() override {}
std::vector<std::shared_ptr<Result>> getEvents(
std::function<c10::time_t(c10::approx_time_t)>,
std::vector<CompressedEvent>&,

View File

@ -48,6 +48,7 @@ struct TORCH_API PythonTracerBase {
virtual void stop() = 0;
virtual void restart() = 0;
virtual void register_gc_callback() = 0;
virtual std::vector<std::shared_ptr<Result>> getEvents(
std::function<c10::time_t(c10::approx_time_t)> time_converter,
std::vector<CompressedEvent>& enters,

View File

@ -341,6 +341,7 @@ void initPythonBindings(PyObject* module) {
bool /* disable_external_correlation*/,
bool /* profile_all_threads */,
bool /* capture_overload_names */,
bool /* record_python_gc_info */,
std::string /* custom_profiler_config*/
>(),
"An experimental config for Kineto features. Please note that"
@ -360,6 +361,7 @@ void initPythonBindings(PyObject* module) {
" disable_external_correlation (bool) : whether to disable external correlation\n"
" profile_all_threads (bool) : whether to profile all threads\n"
" capture_overload_names (bool) : whether to include ATen overload names in the profile\n"
" record_python_gc_info (bool) : adds python gc events to profile\n"
" custom_profiler_config (string) : Used to pass some configurations to the custom profiler backend.\n",
py::arg("profiler_metrics") = std::vector<std::string>(),
py::arg("profiler_measure_per_kernel") = false,
@ -370,6 +372,7 @@ void initPythonBindings(PyObject* module) {
py::arg("disable_external_correlation") = false,
py::arg("profile_all_threads") = false,
py::arg("capture_overload_names") = false,
py::arg("record_python_gc_info") = false,
py::arg("custom_profiler_config") = "")
.def(py::pickle(
[](const ExperimentalConfig& p) { // __getstate__
@ -393,6 +396,7 @@ void initPythonBindings(PyObject* module) {
p.disable_external_correlation,
p.profile_all_threads,
p.capture_overload_names,
p.record_python_gc_info,
p.custom_profiler_config,
p.performance_events);
},