mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
async_compile: Fix the wait method to actually wait (#161561)
This method never triggered. It's used in 2 tests and they pass, so no serious concern. Note that I did introduce and fix a latent bug, which is if we called shutdown_compile_workers, jobs would crash with this change due to ready_future being finished if we called wait. However we only call wait in tests so that bug is fine. The other behaviour, is that if you called shutdown, I believe we may potentially block on your first triton compile after that, until the pool was ready. This should correctly switch to direct mode, until the pool is ready on later warmups. Pull Request resolved: https://github.com/pytorch/pytorch/pull/161561 Approved by: https://github.com/masnesral ghstack dependencies: #161452
This commit is contained in:
@ -50,13 +50,21 @@ class TestAsyncCompile(TestCase):
|
||||
with config.patch(worker_start_method="subprocess", compile_threads=8):
|
||||
async_compile = AsyncCompile()
|
||||
AsyncCompile.wait_pool_ready()
|
||||
# Working around bug in wait_pool_ready()
|
||||
async_compile._ready_future.result(timeout=120)
|
||||
with self.assertRaises(SubprocException):
|
||||
async_compile.triton(
|
||||
"fake_kernel_name", source_code="This definitely doesn't exist"
|
||||
).result()
|
||||
|
||||
@requires_gpu()
|
||||
@requires_triton()
|
||||
def test_wait_pool_ready(self):
|
||||
shutdown_compile_workers()
|
||||
|
||||
with config.patch(worker_start_method="subprocess", compile_threads=8):
|
||||
AsyncCompile.wait_pool_ready()
|
||||
self.assertTrue(AsyncCompile._ready_future.done())
|
||||
self.assertTrue(AsyncCompile.use_process_pool())
|
||||
|
||||
@requires_gpu()
|
||||
@requires_triton()
|
||||
@patch("torch._inductor.runtime.coordinate_descent_tuner.CoordescTuner.autotune")
|
||||
|
@ -148,6 +148,7 @@ def shutdown_compile_workers() -> None:
|
||||
"""Shut down all outstanding compile-worker pools."""
|
||||
for pool in _pool_set:
|
||||
pool.shutdown()
|
||||
AsyncCompile._ready_future = None
|
||||
after_fork()
|
||||
|
||||
|
||||
@ -308,8 +309,8 @@ class AsyncCompile:
|
||||
|
||||
@classmethod
|
||||
def wait_pool_ready(cls, timeout=120) -> None:
|
||||
if cls.use_process_pool():
|
||||
assert cls._ready_future is not None
|
||||
cls.use_process_pool()
|
||||
if cls._ready_future is not None:
|
||||
cls._ready_future.result(timeout=timeout)
|
||||
|
||||
@classmethod
|
||||
|
Reference in New Issue
Block a user