diff --git a/test/inductor/test_async_compile.py b/test/inductor/test_async_compile.py index a0e1ffef0dca..5a61ea851eae 100644 --- a/test/inductor/test_async_compile.py +++ b/test/inductor/test_async_compile.py @@ -50,13 +50,21 @@ class TestAsyncCompile(TestCase): with config.patch(worker_start_method="subprocess", compile_threads=8): async_compile = AsyncCompile() AsyncCompile.wait_pool_ready() - # Working around bug in wait_pool_ready() - async_compile._ready_future.result(timeout=120) with self.assertRaises(SubprocException): async_compile.triton( "fake_kernel_name", source_code="This definitely doesn't exist" ).result() + @requires_gpu() + @requires_triton() + def test_wait_pool_ready(self): + shutdown_compile_workers() + + with config.patch(worker_start_method="subprocess", compile_threads=8): + AsyncCompile.wait_pool_ready() + self.assertTrue(AsyncCompile._ready_future.done()) + self.assertTrue(AsyncCompile.use_process_pool()) + @requires_gpu() @requires_triton() @patch("torch._inductor.runtime.coordinate_descent_tuner.CoordescTuner.autotune") diff --git a/torch/_inductor/async_compile.py b/torch/_inductor/async_compile.py index 97be0539a498..1f3e2f1eabf6 100644 --- a/torch/_inductor/async_compile.py +++ b/torch/_inductor/async_compile.py @@ -148,6 +148,7 @@ def shutdown_compile_workers() -> None: """Shut down all outstanding compile-worker pools.""" for pool in _pool_set: pool.shutdown() + AsyncCompile._ready_future = None after_fork() @@ -308,8 +309,8 @@ class AsyncCompile: @classmethod def wait_pool_ready(cls, timeout=120) -> None: - if cls.use_process_pool(): - assert cls._ready_future is not None + cls.use_process_pool() + if cls._ready_future is not None: cls._ready_future.result(timeout=timeout) @classmethod