mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Fixes https://github.com/pytorch/pytorch/issues/132331 We need another barrier here to ensure that the main thread doesn't stop the profiler while other threads are still using it (and crash). I can reliably reproduce the issue with `pytest -v test/profiler/test_cpp_thread.py -k test_profile_memory --flake-finder`. ### Testing `pytest -v test/profiler/test_cpp_thread.py --flake-finder` all passes. Pull Request resolved: https://github.com/pytorch/pytorch/pull/136304 Approved by: https://github.com/briancoutinho
120 lines
3.7 KiB
C++
120 lines
3.7 KiB
C++
|
|
#include <torch/csrc/autograd/profiler_kineto.h> // @manual
|
|
#include <torch/torch.h>
|
|
#include <string>
|
|
|
|
using namespace torch::autograd::profiler;
|
|
|
|
void blueprint(const std::string& text) {
|
|
printf("\33[94m%s\33[0m\n", text.c_str());
|
|
}
|
|
|
|
/**
|
|
* We're emulating a C++ training engine calling into Python to allow Python
|
|
* code controlling how profiling should be done.
|
|
*/
|
|
class ProfilerEventHandler
|
|
: public std::enable_shared_from_this<ProfilerEventHandler> {
|
|
public:
|
|
static std::shared_ptr<ProfilerEventHandler> Handler;
|
|
static void Register(const std::shared_ptr<ProfilerEventHandler>& handler) {
|
|
Handler = handler;
|
|
}
|
|
|
|
public:
|
|
virtual ~ProfilerEventHandler() {}
|
|
virtual void onIterationStart(int) {}
|
|
virtual void emulateTraining(int, int) {}
|
|
};
|
|
std::shared_ptr<ProfilerEventHandler> ProfilerEventHandler::Handler;
|
|
|
|
class ProfilerEventHandlerTrampoline : public ProfilerEventHandler {
|
|
public:
|
|
virtual void onIterationStart(int iteration) override {
|
|
PYBIND11_OVERRIDE(void, ProfilerEventHandler, onIterationStart, iteration);
|
|
}
|
|
virtual void emulateTraining(int iteration, int thread_id) override {
|
|
PYBIND11_OVERRIDE(
|
|
void, ProfilerEventHandler, emulateTraining, iteration, thread_id);
|
|
}
|
|
};
|
|
|
|
/**
|
|
* This is the entry point for the C++ training engine.
|
|
*/
|
|
void start_threads(int thread_count, int iteration_count, bool attach) {
|
|
blueprint("start_cpp_threads called");
|
|
|
|
static std::atomic<int> barrier = 0;
|
|
barrier = 0;
|
|
static std::atomic<int> another_barrier = 0;
|
|
another_barrier = 0;
|
|
thread_local bool enabled_in_main_thread = false;
|
|
|
|
std::vector<std::thread> threads;
|
|
for (int id = 0; id < thread_count; id++) {
|
|
blueprint("starting thread " + std::to_string(id));
|
|
threads.emplace_back([thread_count, iteration_count, id, attach]() {
|
|
for (int iteration = 0; iteration < iteration_count; iteration++) {
|
|
if (id == 0) {
|
|
ProfilerEventHandler::Handler->onIterationStart(iteration);
|
|
}
|
|
|
|
// this barrier makes sure all child threads will be turned on
|
|
// with profiling when main thread is enabled
|
|
++barrier;
|
|
while (barrier % thread_count) {
|
|
std::this_thread::yield();
|
|
}
|
|
|
|
if (id > 0 && attach) {
|
|
bool enabled = isProfilerEnabledInMainThread();
|
|
if (enabled != enabled_in_main_thread) {
|
|
if (enabled) {
|
|
enableProfilerInChildThread();
|
|
} else {
|
|
disableProfilerInChildThread();
|
|
}
|
|
enabled_in_main_thread = enabled;
|
|
}
|
|
}
|
|
|
|
ProfilerEventHandler::Handler->emulateTraining(iteration, id);
|
|
|
|
// We need another barrier here to ensure that the main thread doesn't
|
|
// stop the profiler while other threads are still using it. This fixes
|
|
// https://github.com/pytorch/pytorch/issues/132331
|
|
++another_barrier;
|
|
while (another_barrier % thread_count) {
|
|
std::this_thread::yield();
|
|
}
|
|
}
|
|
});
|
|
}
|
|
for (auto& t : threads) {
|
|
t.join();
|
|
}
|
|
}
|
|
|
|
PYBIND11_MODULE(profiler_test_cpp_thread_lib, m) {
|
|
py::class_<
|
|
ProfilerEventHandler,
|
|
ProfilerEventHandlerTrampoline,
|
|
std::shared_ptr<ProfilerEventHandler>>(m, "ProfilerEventHandler")
|
|
.def(py::init<>())
|
|
.def_static("Register", &ProfilerEventHandler::Register)
|
|
.def(
|
|
"onIterationStart",
|
|
&ProfilerEventHandler::onIterationStart,
|
|
py::call_guard<py::gil_scoped_release>())
|
|
.def(
|
|
"emulateTraining",
|
|
&ProfilerEventHandler::emulateTraining,
|
|
py::call_guard<py::gil_scoped_release>());
|
|
|
|
m.def(
|
|
"start_threads",
|
|
&start_threads,
|
|
py::call_guard<py::gil_scoped_release>());
|
|
};
|