From 40f46b09c7c0b1338178bb3f92a8f7d1243165a6 Mon Sep 17 00:00:00 2001 From: clr Date: Tue, 26 Aug 2025 15:27:10 -0700 Subject: [PATCH] async_compile: Fix the wait method to actually wait (#161561) This method never triggered. It's used in 2 tests and they pass, so no serious concern. Note that I did introduce and fix a latent bug, which is if we called shutdown_compile_workers, jobs would crash with this change due to ready_future being finished if we called wait. However we only call wait in tests so that bug is fine. The other behaviour, is that if you called shutdown, I believe we may potentially block on your first triton compile after that, until the pool was ready. This should correctly switch to direct mode, until the pool is ready on later warmups. Pull Request resolved: https://github.com/pytorch/pytorch/pull/161561 Approved by: https://github.com/masnesral ghstack dependencies: #161452 --- test/inductor/test_async_compile.py | 12 ++++++++++-- torch/_inductor/async_compile.py | 5 +++-- 2 files changed, 13 insertions(+), 4 deletions(-) diff --git a/test/inductor/test_async_compile.py b/test/inductor/test_async_compile.py index a0e1ffef0dca..5a61ea851eae 100644 --- a/test/inductor/test_async_compile.py +++ b/test/inductor/test_async_compile.py @@ -50,13 +50,21 @@ class TestAsyncCompile(TestCase): with config.patch(worker_start_method="subprocess", compile_threads=8): async_compile = AsyncCompile() AsyncCompile.wait_pool_ready() - # Working around bug in wait_pool_ready() - async_compile._ready_future.result(timeout=120) with self.assertRaises(SubprocException): async_compile.triton( "fake_kernel_name", source_code="This definitely doesn't exist" ).result() + @requires_gpu() + @requires_triton() + def test_wait_pool_ready(self): + shutdown_compile_workers() + + with config.patch(worker_start_method="subprocess", compile_threads=8): + AsyncCompile.wait_pool_ready() + self.assertTrue(AsyncCompile._ready_future.done()) + self.assertTrue(AsyncCompile.use_process_pool()) + @requires_gpu() @requires_triton() @patch("torch._inductor.runtime.coordinate_descent_tuner.CoordescTuner.autotune") diff --git a/torch/_inductor/async_compile.py b/torch/_inductor/async_compile.py index 97be0539a498..1f3e2f1eabf6 100644 --- a/torch/_inductor/async_compile.py +++ b/torch/_inductor/async_compile.py @@ -148,6 +148,7 @@ def shutdown_compile_workers() -> None: """Shut down all outstanding compile-worker pools.""" for pool in _pool_set: pool.shutdown() + AsyncCompile._ready_future = None after_fork() @@ -308,8 +309,8 @@ class AsyncCompile: @classmethod def wait_pool_ready(cls, timeout=120) -> None: - if cls.use_process_pool(): - assert cls._ready_future is not None + cls.use_process_pool() + if cls._ready_future is not None: cls._ready_future.result(timeout=timeout) @classmethod