[inductor] Quiesce Triton compile worker pool after each dynamo compile (#156187)

For internal usages, keeping the Triton compile worker pool active for the lifetime of the process has caused some challenges, e.g., it slows down and muddies profiling due to the huge number of threads on a box: N threads = 8 ranks * 32 subprocs * M threads started by torch. Also, each subproc can use more than 1GB each. This PR adds the functionality to shutdown worker subprocs after each dynamo compile when using the SubprocPool implementation. The idea is to leave the main sidecar process running, but signal it to tear down its internal ProcessPoolExecutor when compile is finished. Restarting the ProcessPoolExecutor is relatively fast, e.g., 500ms because the ProcessPoolExecutor forks from the sidecar. Changes:
* Do not start the ProcessPoolExecutor automatically when compile_fx is imported. Instead, start the sidecar process only. The sidecar process imports torch, so is still slow to start.
* Introduce wakeup() and quiesce() calls to the implementation to start and stop the ProcessPoolExecutor.
* Add a context manager to automatically quiesce() at the end of dynamo compilation.
* Signal a wakeup() in compile_fx only when we have cuda devices.
* Add a killswitch so we can turn of quiescing.

Testing:
For correctness, the stacked change at https://github.com/pytorch/pytorch/pull/156534 enables the feature for OSS so it's exercised in CI.

For performance, because of recent compile-time variance (see https://github.com/pytorch/pytorch/issues/152566), it's pretty hard to glean whether there's a regression....

* Training: https://hud.pytorch.org/benchmark/compilers?dashboard=torchinductor&startTime=Tue%2C%2017%20Jun%202025%2021%3A32%3A04%20GMT&stopTime=Tue%2C%2024%20Jun%202025%2021%3A32%3A04%20GMT&granularity=hour&mode=training&dtype=amp&deviceName=cuda%20(h100)&lBranch=gh/masnesral/210/head&lCommit=1b7315031c3bfad66a1a01700167a9ca1a2ae5f1&rBranch=main&rCommit=eab45643f22e58ee12d95d8b0162d51ca0a50801
* Inference: https://hud.pytorch.org/benchmark/compilers?dashboard=torchinductor&startTime=Tue%2C%2017%20Jun%202025%2021%3A32%3A04%20GMT&stopTime=Tue%2C%2024%20Jun%202025%2021%3A32%3A04%20GMT&granularity=hour&mode=inference&dtype=bfloat16&deviceName=cuda%20(h100)&lBranch=gh/masnesral/210/head&lCommit=1b7315031c3bfad66a1a01700167a9ca1a2ae5f1&rBranch=main&rCommit=eab45643f22e58ee12d95d8b0162d51ca0a50801

The wins (mostly for inference) don't make sense, but I'm also skeptical of the losses (mostly for training). I can't repro any of the slowdowns locally. Furthermore, check out the benchmarking results for the stacked diff, which actually enables the quiescing functionality for OSS. That should only slow down compile since there can only be overhead to stop and start the workers. But the results are somehow better:

* Training: https://hud.pytorch.org/benchmark/compilers?dashboard=torchinductor&startTime=Tue%2C%2017%20Jun%202025%2021%3A32%3A04%20GMT&stopTime=Tue%2C%2024%20Jun%202025%2021%3A32%3A04%20GMT&granularity=hour&mode=training&dtype=amp&deviceName=cuda%20(h100)&lBranch=gh/masnesral/214/head&lCommit=41943253882a019b8ceafcd2bf4cd6acbe0cbca9&rBranch=main&rCommit=eab45643f22e58ee12d95d8b0162d51ca0a50801
* Inference: https://hud.pytorch.org/benchmark/compilers?dashboard=torchinductor&startTime=Tue%2C%2017%20Jun%202025%2021%3A32%3A04%20GMT&stopTime=Tue%2C%2024%20Jun%202025%2021%3A32%3A04%20GMT&granularity=hour&mode=inference&dtype=bfloat16&deviceName=cuda%20(h100)&lBranch=gh/masnesral/214/head&lCommit=41943253882a019b8ceafcd2bf4cd6acbe0cbca9&rBranch=main&rCommit=eab45643f22e58ee12d95d8b0162d51ca0a50801

Pull Request resolved: https://github.com/pytorch/pytorch/pull/156187
Approved by: https://github.com/aorenste, https://github.com/jansel
This commit is contained in:
Sam Larsen
2025-07-07 17:17:55 -07:00
committed by PyTorch MergeBot
parent 178fe7aa98
commit 7a41f20794
9 changed files with 201 additions and 65 deletions

View File

@ -29,8 +29,7 @@ class TestAsyncCompile(TestCase):
with config.patch("worker_start_method", method):
shutdown_compile_workers()
pool = AsyncCompile.process_pool()
pool.ready_future.result(timeout=120)
AsyncCompile.wait_pool_ready()
with fresh_cache():
compiled_fn = torch.compile(fn)

View File

@ -53,6 +53,19 @@ class TestCompileWorker(TestCase):
finally:
pool.shutdown()
@skipIfWindows(msg="pass_fds not supported on Windows.")
def test_quiesce(self):
pool = SubprocPool(2)
try:
a = pool.submit(operator.add, 100, 1)
pool.quiesce()
pool.wakeup()
b = pool.submit(operator.sub, 100, 1)
self.assertEqual(a.result(), 101)
self.assertEqual(b.result(), 99)
finally:
pool.shutdown()
if __name__ == "__main__":
from torch._inductor.test_case import run_tests

View File

@ -20,6 +20,12 @@ from torch._dynamo.utils import dynamo_timed
def inductor(*args, **kwargs):
with dynamo_timed("inductor_import", log_pt2_compile_event=True):
# do import here to avoid loading inductor into memory when it is not used
# The AsyncCompile subproc pool can be slow to start, so warm it up as early
# as possible.
from torch._inductor.async_compile import maybe_warm_pool
maybe_warm_pool()
from torch._inductor.compile_fx import compile_fx
return compile_fx(*args, **kwargs)

View File

@ -735,6 +735,7 @@ def _compile(
# in the case of normal and exception code paths
convert_frame_box: Optional[ConvertFrameBox] = None,
) -> ConvertFrameReturn:
from torch._inductor.async_compile import async_compile_pool_manager
from torch.fx.experimental.validator import (
bisect,
BisectValidationException,
@ -1015,6 +1016,7 @@ def _compile(
with (
_use_lazy_graph_module(config.use_lazy_graph_module),
compile_context(CompileContext(compile_id)),
async_compile_pool_manager(),
chromium_event_timed(
"dynamo", reset_event_log_on_exit=True, log_pt2_compile_event=True
),

View File

@ -2286,6 +2286,8 @@ To fix this, your tensor subclass must implement the dunder method __force_to_sa
@staticmethod
def _backward_impl(ctx, all_args):
from torch._inductor.async_compile import async_compile_pool_manager
# compiled autograd reimplements this function at proxy_call_aot_backward
assert not backward_state_indices, (
"BackwardState requires CompiledAutograd"
@ -2327,6 +2329,7 @@ To fix this, your tensor subclass must implement the dunder method __force_to_sa
with (
tracing(saved_context),
compile_context(saved_compile_context),
async_compile_pool_manager(),
context(),
track_graph_compiling(aot_config, "backward"),
metrics_context,

View File

@ -2,6 +2,7 @@
from __future__ import annotations
import atexit
import contextlib
import functools
import json
import logging
@ -216,7 +217,25 @@ class CompiledTritonKernels:
del CompiledTritonKernels._cache[key]
@contextlib.contextmanager
def async_compile_pool_manager():
"""
Context manager to quiesce the subproc pool at the end of compilation, i.e.,
when dynamo is done.
"""
try:
yield
finally:
AsyncCompile.quiesce()
class AsyncCompile:
"""
Utilities to compile in thread pools or subprocess pools (in the case of Triton).
"""
_ready_future: Optional[Future[Any]] = None
def __init__(self) -> None:
pass
@ -235,6 +254,7 @@ class AsyncCompile:
@functools.lru_cache(1)
def process_pool() -> AnyPool:
assert get_compile_threads() > 1
AsyncCompile._ready_future = None
log.info(
"Creating '%s' pool with %d workers",
config.worker_start_method,
@ -262,8 +282,6 @@ class AsyncCompile:
# kill the worker thread that sends the shutdown message to the workers...
multiprocessing.util.Finalize(None, pool.shutdown, exitpriority=sys.maxsize)
# Set an attribute we can check to see if the pool is ready.
pool.ready_future = pool.submit(AsyncCompile._get_ready) # type: ignore[union-attr]
_pool_set.add(pool)
return pool
@ -272,20 +290,61 @@ class AsyncCompile:
if get_compile_threads() <= 1:
return
_compile_start()
# Pool is initialized on first access
# Pool is created on first access. Note for a SubprocPool, the sidecar process starts,
# but its ProcessPoolExecutor does not initialize until a wakeup() call or the first
# job is submitted.
cls.process_pool()
_compile_end()
@classmethod
def wait_pool_ready(cls, timeout=120) -> None:
if cls.use_process_pool():
assert cls._ready_future is not None
cls._ready_future.result(timeout=timeout)
@classmethod
def submit(cls, task: Callable[..., Any]) -> Any:
if get_compile_threads() <= 1:
return task()
return cls.pool().submit(task)
def use_process_pool(self):
return (
get_compile_threads() > 1 and self.process_pool().ready_future.done() # type: ignore[union-attr]
)
@classmethod
def use_process_pool(cls):
if get_compile_threads() <= 1:
return False
# Create a dummy job to check if the pool is ready. Submit it here instead of at
# pool creation so we don't launch the full pool of worker subprocesses until
# we're sure they're needed.
if not cls._ready_future:
cls._ready_future = cls.process_pool().submit(cls._get_ready)
return cls._ready_future.done()
@classmethod
def quiesce(cls) -> None:
"""
If using a SubprocPool, signal the sidecar process to shut down its
ProcessPoolExecutor.
"""
# Don't inadvertently create a process pool if it doesn't already exist:
if not cls.process_pool.cache_info().currsize:
return
if config.quiesce_async_compile_pool:
pool = cls.process_pool()
if isinstance(pool, SubprocPool):
pool.quiesce()
@classmethod
def wakeup(cls) -> None:
"""
If using a SubprocPool, signal the sidecar process to start up its
ProcessPoolExecutor.
"""
if not cls.use_process_pool():
return
pool = cls.process_pool()
if isinstance(pool, SubprocPool):
pool.wakeup()
def triton(self, kernel_name: str, source_code: str, device_str: str = "cuda"):
"""
@ -517,18 +576,24 @@ class AsyncCompile:
pbar.update(1)
if (
os.environ.get("TORCH_TNT_IN_USE", "0") == "1"
or os.environ.get("TORCH_WARM_POOL", "1") != "1"
# The subprocess pool is only used for the Triton backend
or not has_triton_package()
# Skip for fbcode. We have internal reports of usages inside multiprocessing
# pools that lead a multiplicative number of compile subprocesses.
or config.is_fbcode()
):
pass
else:
def maybe_warm_pool() -> None:
if (
os.environ.get("TORCH_TNT_IN_USE", "0") == "1"
or os.environ.get("TORCH_WARM_POOL", "1") != "1"
# The subprocess pool is only used for the Triton backend
or not has_triton_package()
# Skip for fbcode. We have internal reports of usages inside multiprocessing
# pools that lead a multiplicative number of compile subprocesses.
or config.is_fbcode()
):
return
AsyncCompile.warm_pool()
# TODO: This starts the SubprocPool's internal process pool as early as possible at
# the expense of creating a bunch of worker processes that might not be needed. We
# could start them lazily if we're willing to lose a small amount of compile time.
AsyncCompile.wakeup()
# On exit give the workers a chance to clean themselves up. Without this the
# resource_tracker can complain about leaked semaphores coming from the

View File

@ -22,7 +22,7 @@ from typing import Any, Callable, Optional, TYPE_CHECKING, TypeVar, Union
from typing_extensions import Never, override, ParamSpec, Protocol, TypedDict, Unpack
from unittest import mock
import torch._inductor.async_compile # noqa: F401 required to warm up AsyncCompile pools
import torch._inductor.async_compile
import torch.fx
import torch.utils._pytree as pytree
from functorch.compile import min_cut_rematerialization_partition
@ -2037,6 +2037,12 @@ def compile_fx(
NB: This function TAKES OWNERSHIP of the input ``model_`` and can potentially
mutate it! Make a copy if you need to preserve the original GraphModule.
"""
# Wake up the AsyncCompile subproc pool as early as possible (if there's cuda).
if any(
isinstance(e, torch.Tensor) and e.device.type in ("cuda", "xpu")
for e in example_inputs_
):
torch._inductor.async_compile.AsyncCompile.wakeup()
# Some arguments trigger a recursive call to compile_fx. Handle these
# short circuits first, before anything else

View File

@ -13,7 +13,7 @@ import traceback
import typing
from concurrent.futures import Future, ProcessPoolExecutor
from concurrent.futures.process import BrokenProcessPool
from enum import Enum
from enum import Enum, IntEnum
from typing import Any, Callable, IO, Optional, TypeVar
from typing_extensions import Never, ParamSpec
@ -36,31 +36,42 @@ _P = ParamSpec("_P")
_T = TypeVar("_T")
def _pack_msg(job_id: int, length: int) -> bytes:
return struct.pack("nn", job_id, length)
class MsgHeader(IntEnum):
ERROR = 0
SHUTDOWN = 1
QUIESCE = 2
WAKEUP = 3
JOB = 4
def _unpack_msg(data: bytes) -> tuple[int, int]:
def _pack_msg(msg_header: MsgHeader, job_id: int, length: int) -> bytes:
return struct.pack("nnn", int(msg_header), job_id, length)
def _unpack_msg(data: bytes) -> tuple[MsgHeader, int, int]:
if not data:
return -1, -1
return struct.unpack("nn", data)
return MsgHeader.ERROR, -1, -1
msg_header, job_id, length = struct.unpack("nnn", data)
return MsgHeader(msg_header), job_id, length
msg_bytes = len(_pack_msg(0, 0))
msg_bytes = len(_pack_msg(MsgHeader.JOB, 0, 0))
def _send_msg(write_pipe: IO[bytes], job_id: int, job_data: bytes = b"") -> None:
length = len(job_data)
write_pipe.write(_pack_msg(job_id, length))
def _send_msg(
write_pipe: IO[bytes], msg_header: MsgHeader, job_id: int = -1, data: bytes = b""
) -> None:
length = len(data)
write_pipe.write(_pack_msg(msg_header, job_id, length))
if length > 0:
write_pipe.write(job_data)
write_pipe.write(data)
write_pipe.flush()
def _recv_msg(read_pipe: IO[bytes]) -> tuple[int, bytes]:
job_id, length = _unpack_msg(read_pipe.read(msg_bytes))
def _recv_msg(read_pipe: IO[bytes]) -> tuple[MsgHeader, int, bytes]:
msg_header, job_id, length = _unpack_msg(read_pipe.read(msg_bytes))
data = read_pipe.read(length) if length > 0 else b""
return job_id, data
return msg_header, job_id, data
class _SubprocExceptionInfo:
@ -147,9 +158,7 @@ class SubprocPool:
"PYTHONPATH": os.environ.get(
"TORCH_CUSTOM_PYTHONPATH", os.pathsep.join(sys.path)
),
# We don't want to re-warm the pool when the subprocess imports
# torch._inductor.codecache since the warming process is what
# creates the SubprocPool in the first place.
# Safeguard against creating a SubprocPool in the subprocess.
"TORCH_WARM_POOL": "0",
# Some internal usages need a modified LD_LIBRARY_PATH.
"LD_LIBRARY_PATH": get_ld_library_path(),
@ -182,24 +191,28 @@ class SubprocPool:
job_id = next(self.job_id_count)
self.pending_futures[job_id] = future = Future()
future.set_running_or_notify_cancel()
self._send(MsgHeader.JOB, job_id, job_data)
return future
def _send(self, msg_header: MsgHeader, job_id: int = -1, data: bytes = b"") -> None:
with self.write_lock:
if not self.running:
raise RuntimeError("submit() on closed pool")
_send_msg(self.write_pipe, job_id, job_data)
return future
raise RuntimeError("Attempting to use a closed pool")
_send_msg(self.write_pipe, msg_header, job_id, data)
def _read_thread(self) -> None:
while True:
data = b""
job_id = -1
try:
job_id, data = _recv_msg(self.read_pipe)
msg_header, job_id, data = _recv_msg(self.read_pipe)
except Exception:
# Something went wrong during the read. There's no way we have a
# valid job_id.
# valid msg.
log.exception("failure in subproc_pool._recv_msg")
job_id = -1
msg_header = MsgHeader.ERROR
if job_id < 0:
if msg_header != MsgHeader.JOB:
# read_pipe returned None or got exception
if self.running:
log.warning("SubprocPool unclean exit")
@ -232,13 +245,19 @@ class SubprocPool:
self.pending_futures[job_id].set_result(result)
del self.pending_futures[job_id]
def quiesce(self) -> None:
self._send(MsgHeader.QUIESCE)
def wakeup(self) -> None:
self._send(MsgHeader.WAKEUP)
def shutdown(self) -> None:
try:
with self.write_lock:
if not self.running:
return
self.running = False
_send_msg(self.write_pipe, -1)
_send_msg(self.write_pipe, MsgHeader.SHUTDOWN)
self.write_pipe.close()
self.process.wait(300)
except OSError as e:
@ -268,37 +287,36 @@ class SubprocMain:
self.write_pipe = write_pipe
self.write_lock = threading.Lock()
self.nprocs = nprocs
self.pool = self._new_pool(nprocs, True)
self.pool: Optional[ProcessPoolExecutor] = None
self.running = True
def _new_pool(self, nprocs: int, warm: bool) -> ProcessPoolExecutor:
pool = TrackedProcessPoolExecutor(
nprocs,
mp_context=multiprocessing.get_context(self.kind.value),
initializer=functools.partial(_async_compile_initializer, os.getpid()),
)
multiprocessing.util.Finalize(None, pool.shutdown, exitpriority=sys.maxsize)
if warm:
_warm_process_pool(pool, nprocs)
return pool
def main(self) -> None:
while True:
job_id, data = _recv_msg(self.read_pipe)
if job_id < 0:
msg_header, job_id, data = _recv_msg(self.read_pipe)
if msg_header == MsgHeader.JOB:
self.submit(job_id, data)
elif msg_header == MsgHeader.WAKEUP:
self._start_pool()
elif msg_header == MsgHeader.QUIESCE:
self._quiesce()
else:
return self._shutdown()
self.submit(job_id, data)
def _quiesce(self) -> None:
if self.pool is not None:
self.pool.shutdown(wait=False)
self.pool = None
def _shutdown(self) -> None:
with self.write_lock:
self.running = False
try:
_send_msg(self.write_pipe, -1)
_send_msg(self.write_pipe, MsgHeader.SHUTDOWN)
self.write_pipe.close()
except BrokenPipeError:
pass # parent process already shutdown
self.read_pipe.close()
self.pool.shutdown()
self._quiesce()
def submit(self, job_id: int, data: bytes) -> None:
while self.running:
@ -309,7 +327,7 @@ class SubprocMain:
# If any subprocess in the pool crashes, we get a BrokenProcessPool
# exception and the whole pool becomes unusable. Handle crashes by
# recreating the pool and resubmitting.
self.pool = self._new_pool(self.nprocs, False)
self.pool = None
def _submit_inner(self, job_id: int, data: bytes) -> None:
def callback(fut: Future[Any]) -> None:
@ -323,14 +341,31 @@ class SubprocMain:
assert isinstance(result, bytes)
with self.write_lock:
if self.running:
_send_msg(self.write_pipe, job_id, result)
_send_msg(self.write_pipe, MsgHeader.JOB, job_id, result)
return
self._start_pool()
assert self.pool is not None
future = self.pool.submit(
functools.partial(SubprocMain.do_job, self.pickler, data)
)
future.add_done_callback(callback)
def _start_pool(self) -> None:
if self.pool is not None:
return
self.pool = TrackedProcessPoolExecutor(
self.nprocs,
mp_context=multiprocessing.get_context(self.kind.value),
initializer=functools.partial(_async_compile_initializer, os.getpid()),
)
multiprocessing.util.Finalize(
None, self.pool.shutdown, exitpriority=sys.maxsize
)
_warm_process_pool(self.pool, self.nprocs)
@staticmethod
def do_job(pickler: SubprocPickler, data: bytes) -> bytes:
# do the pickle/unpickle in the sub-subproc

View File

@ -811,6 +811,13 @@ def decide_compile_threads() -> int:
# TODO: Set directly after internal rollout.
compile_threads: Optional[int] = None if is_fbcode() else decide_compile_threads()
# Whether to quiesce the Triton-compile subprocess pool at the end of each compilation.
quiesce_async_compile_pool: bool = Config(
justknob="pytorch/inductor:quiesce_async_compile_pool",
env_name_force="TORCHINDUCTOR_QUIESCE_ASYNC_COMPILE_POOL",
default=False,
)
# Whether or not to enable statically launching CUDA kernels
# compiled by triton (instead of using triton's own launcher)
use_static_cuda_launcher: bool = static_cuda_launcher_default()