diff --git a/test/profiler/test_cpp_thread.cpp b/test/profiler/test_cpp_thread.cpp index ce60d9c816c6..58792313a90b 100644 --- a/test/profiler/test_cpp_thread.cpp +++ b/test/profiler/test_cpp_thread.cpp @@ -47,6 +47,8 @@ void start_threads(int thread_count, int iteration_count, bool attach) { static std::atomic barrier = 0; barrier = 0; + static std::atomic another_barrier = 0; + another_barrier = 0; thread_local bool enabled_in_main_thread = false; std::vector threads; @@ -78,6 +80,14 @@ void start_threads(int thread_count, int iteration_count, bool attach) { } ProfilerEventHandler::Handler->emulateTraining(iteration, id); + + // We need another barrier here to ensure that the main thread doesn't + // stop the profiler while other threads are still using it. This fixes + // https://github.com/pytorch/pytorch/issues/132331 + ++another_barrier; + while (another_barrier % thread_count) { + std::this_thread::yield(); + } } }); }