async_compile: Fix the wait method to actually wait (#161561)

This method never triggered. It's used in 2 tests and they pass, so no serious
concern.

Note that I did introduce and fix a latent bug, which is if we called
shutdown_compile_workers, jobs would crash with this change due to ready_future
being finished if we called wait.

However we only call wait in tests so that bug is fine.

The other behaviour, is that if you called shutdown, I believe we may
potentially block on your first triton compile after that, until the pool was
ready. This should correctly switch to direct mode, until the pool is ready on
later warmups.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/161561
Approved by: https://github.com/masnesral
ghstack dependencies: #161452
This commit is contained in:
clr
2025-08-26 15:27:10 -07:00
committed by PyTorch MergeBot
parent 0d6597138c
commit 40f46b09c7
2 changed files with 13 additions and 4 deletions

View File

@ -50,13 +50,21 @@ class TestAsyncCompile(TestCase):
with config.patch(worker_start_method="subprocess", compile_threads=8):
async_compile = AsyncCompile()
AsyncCompile.wait_pool_ready()
# Working around bug in wait_pool_ready()
async_compile._ready_future.result(timeout=120)
with self.assertRaises(SubprocException):
async_compile.triton(
"fake_kernel_name", source_code="This definitely doesn't exist"
).result()
@requires_gpu()
@requires_triton()
def test_wait_pool_ready(self):
shutdown_compile_workers()
with config.patch(worker_start_method="subprocess", compile_threads=8):
AsyncCompile.wait_pool_ready()
self.assertTrue(AsyncCompile._ready_future.done())
self.assertTrue(AsyncCompile.use_process_pool())
@requires_gpu()
@requires_triton()
@patch("torch._inductor.runtime.coordinate_descent_tuner.CoordescTuner.autotune")

View File

@ -148,6 +148,7 @@ def shutdown_compile_workers() -> None:
"""Shut down all outstanding compile-worker pools."""
for pool in _pool_set:
pool.shutdown()
AsyncCompile._ready_future = None
after_fork()
@ -308,8 +309,8 @@ class AsyncCompile:
@classmethod
def wait_pool_ready(cls, timeout=120) -> None:
if cls.use_process_pool():
assert cls._ready_future is not None
cls.use_process_pool()
if cls._ready_future is not None:
cls._ready_future.result(timeout=timeout)
@classmethod