mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[Profiler] Add GC Events to Python Stack Tracer (#161209)
Summary: Adds Python Garbage Collection to Kineto Traces and Profiler FunctionEvents. Create custom cpp callback in profiler_python.cpp. Then define a python function with cpp and register that callback for all python garbage collection. We don't worry about thread safety in this case because we are only doing init/teardown for main thread while holding GIL. Currently we are hiding this behind experimental config because python tracing tends to be unstable especially when adding any new feature. If this is found to not add too much overhead we can set this to on by default. NOTE: To enable this you need both with_stack=True and the experimental config on! Test Plan: Ran trace with GC induced and saw it on trace Also added a test Rollback Plan: Differential Revision: D80491146 Pull Request resolved: https://github.com/pytorch/pytorch/pull/161209 Approved by: https://github.com/ngimel
This commit is contained in:
committed by
PyTorch MergeBot
parent
c8bb0e4720
commit
3373b074f5
@ -2336,6 +2336,74 @@ assert KinetoStepTracker.current_step() == initial_step + 2 * niters
|
||||
events = main_with_thread_fn(profile_all_threads)
|
||||
verify_events(events)
|
||||
|
||||
@skipIfTorchDynamo("profiler gets ignored if dynamo activated")
|
||||
@unittest.skipIf(not kineto_available(), "Kineto is required")
|
||||
def test_python_gc_event(self):
|
||||
activities = [ProfilerActivity.CPU]
|
||||
|
||||
def payload():
|
||||
x = torch.randn(10, 10)
|
||||
y = torch.randn(10, 10)
|
||||
with record_function("pre_gc"):
|
||||
torch.mm(x, y)
|
||||
gc.collect()
|
||||
with record_function("post_gc"):
|
||||
torch.mm(x, y)
|
||||
|
||||
def validate_json(prof, gc_collection_on):
|
||||
with TemporaryFileName(mode="w+") as fname:
|
||||
prof.export_chrome_trace(fname)
|
||||
with open(fname) as f:
|
||||
events = json.load(f)["traceEvents"]
|
||||
# Find required events
|
||||
if gc_collection_on:
|
||||
pre_gc = next(
|
||||
(e for e in events if e["name"] == "pre_gc"), None
|
||||
)
|
||||
post_gc = next(
|
||||
(e for e in events if e["name"] == "post_gc"), None
|
||||
)
|
||||
python_gc_events = [
|
||||
e for e in events if e["name"] == "Python GC"
|
||||
]
|
||||
# Assert all required events are present
|
||||
self.assertIsNotNone(pre_gc, "pre_gc event is missing")
|
||||
self.assertIsNotNone(post_gc, "post_gc event is missing")
|
||||
self.assertTrue(
|
||||
len(python_gc_events) > 0, "No Python GC events found"
|
||||
)
|
||||
# Calculate boundaries
|
||||
pre_gc_end = pre_gc["ts"] + pre_gc.get("dur", 0)
|
||||
post_gc_start = post_gc["ts"]
|
||||
# Assert each Python GC event is correctly placed
|
||||
for python_gc in python_gc_events:
|
||||
python_gc_start = python_gc["ts"]
|
||||
python_gc_end = python_gc["ts"] + python_gc.get("dur", 0)
|
||||
self.assertTrue(
|
||||
python_gc_start > pre_gc_end
|
||||
and python_gc_end < post_gc_start,
|
||||
f"Python GC event at {python_gc_start} is not correctly placed.",
|
||||
)
|
||||
else:
|
||||
python_gc_events = [
|
||||
e for e in events if e["name"] == "Python GC"
|
||||
]
|
||||
self.assertTrue(
|
||||
len(python_gc_events) == 0,
|
||||
"Python GC event found when flag off",
|
||||
)
|
||||
|
||||
for gc_flag in [True, False]:
|
||||
with profile(
|
||||
activities=activities,
|
||||
experimental_config=torch._C._profiler._ExperimentalConfig(
|
||||
record_python_gc_info=gc_flag
|
||||
),
|
||||
with_stack=True,
|
||||
) as prof:
|
||||
payload()
|
||||
validate_json(prof, gc_flag)
|
||||
|
||||
|
||||
class SimpleNet(nn.Module):
|
||||
def __init__(self) -> None:
|
||||
|
@ -704,7 +704,7 @@ class PythonTracer final : public python_tracer::PythonTracerBase {
|
||||
PyFrameObject* frame,
|
||||
int what,
|
||||
PyObject* arg);
|
||||
|
||||
void register_gc_callback() override;
|
||||
void stop() override;
|
||||
void restart() override;
|
||||
std::vector<std::shared_ptr<Result>> getEvents(
|
||||
@ -723,6 +723,8 @@ class PythonTracer final : public python_tracer::PythonTracerBase {
|
||||
PyFrameObject* frame,
|
||||
bool is_startup_frame);
|
||||
|
||||
static PyObject* gc_event_callback(PyObject* self, PyObject* args);
|
||||
|
||||
void recordCCall(
|
||||
ThreadLocalResults& tls,
|
||||
PyFrameObject* frame,
|
||||
@ -733,6 +735,7 @@ class PythonTracer final : public python_tracer::PythonTracerBase {
|
||||
|
||||
std::atomic<bool> active_lock_{false};
|
||||
bool active_{false};
|
||||
bool gc_callback_registered_{false};
|
||||
|
||||
torch::profiler::impl::RecordQueue* queue_;
|
||||
PyInterpreterState* interpreter_{nullptr};
|
||||
@ -973,6 +976,27 @@ const std::vector<PyThreadState*> PythonTracer::interpreterThreads() const {
|
||||
return out;
|
||||
}
|
||||
|
||||
// we are only registering on main thread while holding GIL so this should be
|
||||
// safe
|
||||
static PyObject* py_gc_callback = nullptr;
|
||||
// The C function to be called by Python's GC
|
||||
PyObject* PythonTracer::gc_event_callback(PyObject* self, PyObject* args) {
|
||||
const char* phase;
|
||||
PyObject* info;
|
||||
if (!PyArg_ParseTuple(args, "sO", &phase, &info)) {
|
||||
return nullptr;
|
||||
}
|
||||
PythonTracer* instance =
|
||||
reinterpret_cast<PythonTracer*>(PyCapsule_GetPointer(self, nullptr));
|
||||
if (!instance) {
|
||||
PyErr_SetString(PyExc_RuntimeError, "Invalid tracer instance");
|
||||
return nullptr;
|
||||
}
|
||||
instance->queue_->getSubqueue()->emplace_gc_call(
|
||||
phase, c10::getApproximateTime());
|
||||
Py_RETURN_NONE;
|
||||
}
|
||||
|
||||
PythonTracer::PythonTracer(torch::profiler::impl::RecordQueue* queue)
|
||||
: queue_(queue),
|
||||
|
||||
@ -1045,8 +1069,74 @@ PythonTracer::PythonTracer(torch::profiler::impl::RecordQueue* queue)
|
||||
#endif
|
||||
}
|
||||
|
||||
void unregister_gc_callback() {
|
||||
PyGILState_STATE gstate = PyGILState_Ensure();
|
||||
PyObject* gc_module = PyImport_ImportModule("gc");
|
||||
if (!gc_module) {
|
||||
PyErr_Print();
|
||||
PyGILState_Release(gstate);
|
||||
return;
|
||||
}
|
||||
PyObject* callbacks = PyObject_GetAttrString(gc_module, "callbacks");
|
||||
if (!callbacks || !PyList_Check(callbacks)) {
|
||||
PyErr_Print();
|
||||
Py_XDECREF(gc_module);
|
||||
Py_XDECREF(callbacks);
|
||||
PyGILState_Release(gstate);
|
||||
return;
|
||||
}
|
||||
Py_ssize_t idx = PySequence_Index(callbacks, py_gc_callback);
|
||||
if (idx >= 0) {
|
||||
PySequence_DelItem(callbacks, idx);
|
||||
} else {
|
||||
// Not found, maybe already removed
|
||||
}
|
||||
Py_DECREF(callbacks);
|
||||
Py_DECREF(gc_module);
|
||||
Py_XDECREF(py_gc_callback);
|
||||
py_gc_callback = nullptr;
|
||||
PyGILState_Release(gstate);
|
||||
}
|
||||
|
||||
void PythonTracer::register_gc_callback() {
|
||||
PyGILState_STATE gstate = PyGILState_Ensure();
|
||||
PyObject* gc_module = PyImport_ImportModule("gc");
|
||||
if (!gc_module) {
|
||||
PyErr_Print();
|
||||
PyGILState_Release(gstate);
|
||||
return;
|
||||
}
|
||||
PyObject* callbacks = PyObject_GetAttrString(gc_module, "callbacks");
|
||||
if (!callbacks || !PyList_Check(callbacks)) {
|
||||
PyErr_Print();
|
||||
Py_XDECREF(gc_module);
|
||||
Py_XDECREF(callbacks);
|
||||
PyGILState_Release(gstate);
|
||||
return;
|
||||
}
|
||||
static PyMethodDef method_def = {
|
||||
"gc_event_callback",
|
||||
(PyCFunction)gc_event_callback,
|
||||
METH_VARARGS,
|
||||
nullptr};
|
||||
PyObject* capsule = PyCapsule_New(this, nullptr, nullptr);
|
||||
py_gc_callback = PyCFunction_New(&method_def, capsule);
|
||||
Py_DECREF(capsule); // PyCFunction_New increments refcount
|
||||
if (PyList_Append(callbacks, py_gc_callback) < 0) {
|
||||
PyErr_Print();
|
||||
}
|
||||
gc_callback_registered_ = true;
|
||||
Py_DECREF(callbacks);
|
||||
Py_DECREF(gc_module);
|
||||
PyGILState_Release(gstate);
|
||||
}
|
||||
|
||||
void PythonTracer::stop() {
|
||||
gil_and_restore_thread gil;
|
||||
if (gc_callback_registered_) {
|
||||
unregister_gc_callback();
|
||||
gc_callback_registered_ = false;
|
||||
}
|
||||
if (active_) {
|
||||
for (const auto thread_state : interpreterThreads()) {
|
||||
if (thread_state->c_profilefunc == &PythonTracer::pyProfileFn) {
|
||||
|
@ -613,6 +613,7 @@ std::string Result::name() const {
|
||||
ATTRIBUTE(OutOfMemory, std::string("[OutOfMemory]")),
|
||||
ATTRIBUTE(PyCall, toString(e)),
|
||||
ATTRIBUTE(PyCCall, std::string(e.function_name_.str())),
|
||||
ATTRIBUTE(PythonGC, std::string("Python GC")),
|
||||
[](const auto& e) -> std::string { return e.name_; }));
|
||||
}
|
||||
|
||||
@ -631,6 +632,7 @@ libkineto::ActivityType Result::kinetoType() const {
|
||||
ATTRIBUTE(OutOfMemory, libkineto::ActivityType::CPU_INSTANT_EVENT),
|
||||
ATTRIBUTE(PyCall, libkineto::ActivityType::PYTHON_FUNCTION),
|
||||
ATTRIBUTE(PyCCall, libkineto::ActivityType::PYTHON_FUNCTION),
|
||||
ATTRIBUTE(PythonGC, libkineto::ActivityType::PYTHON_FUNCTION),
|
||||
ATTRIBUTE(Kineto, e.activity_type_)));
|
||||
}
|
||||
|
||||
@ -650,6 +652,7 @@ int64_t Result::endTimeNS() const {
|
||||
ATTRIBUTE(Allocation, start_time_ns_),
|
||||
ATTRIBUTE(OutOfMemory, start_time_ns_),
|
||||
ATTRIBUTE(Kineto, start_time_ns_ + e.duration_ns_),
|
||||
ATTRIBUTE(PythonGC, start_time_ns_ + e.duration_ns_),
|
||||
[&](const auto& e) -> int64_t { return e.end_time_ns_; }));
|
||||
|
||||
// In rare cases we're willing to tolerate ops which are missing an end time
|
||||
@ -700,6 +703,9 @@ RecordQueue::RecordQueue(
|
||||
activities_{std::move(activities)} {
|
||||
if (tracePython()) {
|
||||
python_tracer_ = python_tracer::PythonTracerBase::make(this);
|
||||
if (getPythonGcEvents()) {
|
||||
python_tracer_->register_gc_callback();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -707,6 +713,10 @@ bool RecordQueue::tracePython() const {
|
||||
return config_.with_stack && activities_.count(ActivityType::CPU);
|
||||
}
|
||||
|
||||
bool RecordQueue::getPythonGcEvents() const {
|
||||
return config_.experimental_config.record_python_gc_info;
|
||||
}
|
||||
|
||||
ThreadLocalSubqueue* RecordQueue::getSubqueue() {
|
||||
// In the most common case, a thread will want to write to the same sub-queue
|
||||
// that it wrote to last call. The only time that isn't true is if:
|
||||
@ -1488,6 +1498,31 @@ RecordQueue::getRecords(
|
||||
queue.allocations_.clear();
|
||||
materialize(queue.ooms_);
|
||||
|
||||
std::optional<int64_t> pending_start;
|
||||
for (auto& e : queue.pythongc_) {
|
||||
if (e.first.find("start") != std::string::npos) {
|
||||
pending_start = e.second;
|
||||
} else if (e.first.find("stop") != std::string::npos) {
|
||||
if (pending_start.has_value()) {
|
||||
out.emplace_back(Result::create(
|
||||
/*start_time_ns_=*/converter(pending_start.value()),
|
||||
/*start_tid_=*/queue.tid(),
|
||||
/*kineto_info_=*/queue.kineto_info(),
|
||||
/*extra_fields_=*/
|
||||
// NOLINTNEXTLINE
|
||||
ExtraFields<EventType::PythonGC>{
|
||||
e.first,
|
||||
converter(e.second) - converter(pending_start.value())}));
|
||||
pending_start.reset();
|
||||
} else {
|
||||
// Handle the case where "stop" is found without a matching "start"
|
||||
// For example, you might want to log a warning or take other action:
|
||||
LOG(WARNING) << R"("stop" event found without a matching "start": )"
|
||||
<< e.first;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (auto& i : queue.py_calls_) {
|
||||
python_enters.push_back(
|
||||
{i.first, queue.tid(), queue.kineto_info(), converter(i.second)});
|
||||
|
@ -34,7 +34,8 @@ enum class EventType : uint8_t {
|
||||
OutOfMemory,
|
||||
PyCall,
|
||||
PyCCall,
|
||||
Kineto
|
||||
Kineto,
|
||||
PythonGC
|
||||
};
|
||||
|
||||
// ============================================================================
|
||||
@ -191,6 +192,12 @@ struct ExtraFields<EventType::Backend> {
|
||||
jit_modules_t jit_modules_;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct ExtraFields<EventType::PythonGC> {
|
||||
std::string phase;
|
||||
int64_t duration_ns_;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct ExtraFields<EventType::Vulkan> {
|
||||
using raw_event_t = std::pair<c10::approx_time_t, vulkan_id_t>;
|
||||
@ -415,7 +422,8 @@ struct TORCH_API Result : public std::enable_shared_from_this<Result> {
|
||||
ExtraFields<EventType::OutOfMemory>,
|
||||
ExtraFields<EventType::PyCall>,
|
||||
ExtraFields<EventType::PyCCall>,
|
||||
ExtraFields<EventType::Kineto>>
|
||||
ExtraFields<EventType::Kineto>,
|
||||
ExtraFields<EventType::PythonGC>>
|
||||
extra_fields_;
|
||||
|
||||
std::weak_ptr<Result> parent_;
|
||||
@ -549,6 +557,11 @@ class TORCH_API ThreadLocalSubqueue {
|
||||
py_calls_.emplace_back(std::forward<Args>(args)...);
|
||||
}
|
||||
|
||||
template <class... Args>
|
||||
void emplace_gc_call(Args&&... args) {
|
||||
pythongc_.emplace_back(std::forward<Args>(args)...);
|
||||
}
|
||||
|
||||
uint64_t tid() const {
|
||||
return tid_;
|
||||
}
|
||||
@ -639,6 +652,9 @@ class TORCH_API ThreadLocalSubqueue {
|
||||
std::pair<python_tracer::TraceKey, c10::approx_time_t>,
|
||||
BlockSize>
|
||||
py_calls_;
|
||||
// gc with_stack (Python)
|
||||
AppendOnlyList<std::pair<std::string, c10::approx_time_t>, BlockSize>
|
||||
pythongc_;
|
||||
};
|
||||
|
||||
class TORCH_API RecordQueue {
|
||||
@ -646,6 +662,7 @@ class TORCH_API RecordQueue {
|
||||
RecordQueue(ProfilerConfig config, std::set<ActivityType> activities);
|
||||
|
||||
bool tracePython() const;
|
||||
bool getPythonGcEvents() const;
|
||||
ThreadLocalSubqueue* getSubqueue();
|
||||
void stop();
|
||||
void restart();
|
||||
|
@ -21,6 +21,7 @@ ExperimentalConfig::ExperimentalConfig(
|
||||
bool disable_external_correlation,
|
||||
bool profile_all_threads,
|
||||
bool capture_overload_names,
|
||||
bool record_python_gc_info,
|
||||
std::string custom_profiler_config,
|
||||
bool adjust_timestamps)
|
||||
: profiler_metrics{std::move(profiler_metrics)},
|
||||
@ -32,6 +33,7 @@ ExperimentalConfig::ExperimentalConfig(
|
||||
disable_external_correlation{disable_external_correlation},
|
||||
profile_all_threads{profile_all_threads},
|
||||
capture_overload_names{capture_overload_names},
|
||||
record_python_gc_info{record_python_gc_info},
|
||||
custom_profiler_config(std::move(custom_profiler_config)),
|
||||
adjust_timestamps{adjust_timestamps} {}
|
||||
|
||||
|
@ -62,6 +62,7 @@ struct TORCH_API ExperimentalConfig {
|
||||
bool disable_external_correlation = false,
|
||||
bool profile_all_threads = false,
|
||||
bool capture_overload_names = false,
|
||||
bool record_python_gc_info = false,
|
||||
std::string custom_profiler_config = "",
|
||||
bool adjust_timestamps = false);
|
||||
explicit operator bool() const;
|
||||
@ -102,6 +103,12 @@ struct TORCH_API ExperimentalConfig {
|
||||
* function schema and stored in the profile */
|
||||
bool capture_overload_names;
|
||||
|
||||
/*
|
||||
* Controls whether or not python gc info is recorded. This is used to
|
||||
* determine if gc collect is slowing down your profile.
|
||||
*/
|
||||
bool record_python_gc_info;
|
||||
|
||||
/*
|
||||
* A custom_profiler_config option is introduced to allow custom backends
|
||||
* to apply custom configurations as needed.
|
||||
|
@ -11,6 +11,7 @@ struct NoOpPythonTracer : public PythonTracerBase {
|
||||
|
||||
void stop() override {}
|
||||
void restart() override {}
|
||||
void register_gc_callback() override {}
|
||||
std::vector<std::shared_ptr<Result>> getEvents(
|
||||
std::function<c10::time_t(c10::approx_time_t)>,
|
||||
std::vector<CompressedEvent>&,
|
||||
|
@ -48,6 +48,7 @@ struct TORCH_API PythonTracerBase {
|
||||
|
||||
virtual void stop() = 0;
|
||||
virtual void restart() = 0;
|
||||
virtual void register_gc_callback() = 0;
|
||||
virtual std::vector<std::shared_ptr<Result>> getEvents(
|
||||
std::function<c10::time_t(c10::approx_time_t)> time_converter,
|
||||
std::vector<CompressedEvent>& enters,
|
||||
|
@ -341,6 +341,7 @@ void initPythonBindings(PyObject* module) {
|
||||
bool /* disable_external_correlation*/,
|
||||
bool /* profile_all_threads */,
|
||||
bool /* capture_overload_names */,
|
||||
bool /* record_python_gc_info */,
|
||||
std::string /* custom_profiler_config*/
|
||||
>(),
|
||||
"An experimental config for Kineto features. Please note that"
|
||||
@ -360,6 +361,7 @@ void initPythonBindings(PyObject* module) {
|
||||
" disable_external_correlation (bool) : whether to disable external correlation\n"
|
||||
" profile_all_threads (bool) : whether to profile all threads\n"
|
||||
" capture_overload_names (bool) : whether to include ATen overload names in the profile\n"
|
||||
" record_python_gc_info (bool) : adds python gc events to profile\n"
|
||||
" custom_profiler_config (string) : Used to pass some configurations to the custom profiler backend.\n",
|
||||
py::arg("profiler_metrics") = std::vector<std::string>(),
|
||||
py::arg("profiler_measure_per_kernel") = false,
|
||||
@ -370,6 +372,7 @@ void initPythonBindings(PyObject* module) {
|
||||
py::arg("disable_external_correlation") = false,
|
||||
py::arg("profile_all_threads") = false,
|
||||
py::arg("capture_overload_names") = false,
|
||||
py::arg("record_python_gc_info") = false,
|
||||
py::arg("custom_profiler_config") = "")
|
||||
.def(py::pickle(
|
||||
[](const ExperimentalConfig& p) { // __getstate__
|
||||
@ -393,6 +396,7 @@ void initPythonBindings(PyObject* module) {
|
||||
p.disable_external_correlation,
|
||||
p.profile_all_threads,
|
||||
p.capture_overload_names,
|
||||
p.record_python_gc_info,
|
||||
p.custom_profiler_config,
|
||||
p.performance_events);
|
||||
},
|
||||
|
Reference in New Issue
Block a user