mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[inductor] Quiesce Triton compile worker pool after each dynamo compile (#156187)
For internal usages, keeping the Triton compile worker pool active for the lifetime of the process has caused some challenges, e.g., it slows down and muddies profiling due to the huge number of threads on a box: N threads = 8 ranks * 32 subprocs * M threads started by torch. Also, each subproc can use more than 1GB each. This PR adds the functionality to shutdown worker subprocs after each dynamo compile when using the SubprocPool implementation. The idea is to leave the main sidecar process running, but signal it to tear down its internal ProcessPoolExecutor when compile is finished. Restarting the ProcessPoolExecutor is relatively fast, e.g., 500ms because the ProcessPoolExecutor forks from the sidecar. Changes: * Do not start the ProcessPoolExecutor automatically when compile_fx is imported. Instead, start the sidecar process only. The sidecar process imports torch, so is still slow to start. * Introduce wakeup() and quiesce() calls to the implementation to start and stop the ProcessPoolExecutor. * Add a context manager to automatically quiesce() at the end of dynamo compilation. * Signal a wakeup() in compile_fx only when we have cuda devices. * Add a killswitch so we can turn of quiescing. Testing: For correctness, the stacked change at https://github.com/pytorch/pytorch/pull/156534 enables the feature for OSS so it's exercised in CI. For performance, because of recent compile-time variance (see https://github.com/pytorch/pytorch/issues/152566), it's pretty hard to glean whether there's a regression.... * Training: https://hud.pytorch.org/benchmark/compilers?dashboard=torchinductor&startTime=Tue%2C%2017%20Jun%202025%2021%3A32%3A04%20GMT&stopTime=Tue%2C%2024%20Jun%202025%2021%3A32%3A04%20GMT&granularity=hour&mode=training&dtype=amp&deviceName=cuda%20(h100)&lBranch=gh/masnesral/210/head&lCommit=1b7315031c3bfad66a1a01700167a9ca1a2ae5f1&rBranch=main&rCommit=eab45643f22e58ee12d95d8b0162d51ca0a50801 * Inference: https://hud.pytorch.org/benchmark/compilers?dashboard=torchinductor&startTime=Tue%2C%2017%20Jun%202025%2021%3A32%3A04%20GMT&stopTime=Tue%2C%2024%20Jun%202025%2021%3A32%3A04%20GMT&granularity=hour&mode=inference&dtype=bfloat16&deviceName=cuda%20(h100)&lBranch=gh/masnesral/210/head&lCommit=1b7315031c3bfad66a1a01700167a9ca1a2ae5f1&rBranch=main&rCommit=eab45643f22e58ee12d95d8b0162d51ca0a50801 The wins (mostly for inference) don't make sense, but I'm also skeptical of the losses (mostly for training). I can't repro any of the slowdowns locally. Furthermore, check out the benchmarking results for the stacked diff, which actually enables the quiescing functionality for OSS. That should only slow down compile since there can only be overhead to stop and start the workers. But the results are somehow better: * Training: https://hud.pytorch.org/benchmark/compilers?dashboard=torchinductor&startTime=Tue%2C%2017%20Jun%202025%2021%3A32%3A04%20GMT&stopTime=Tue%2C%2024%20Jun%202025%2021%3A32%3A04%20GMT&granularity=hour&mode=training&dtype=amp&deviceName=cuda%20(h100)&lBranch=gh/masnesral/214/head&lCommit=41943253882a019b8ceafcd2bf4cd6acbe0cbca9&rBranch=main&rCommit=eab45643f22e58ee12d95d8b0162d51ca0a50801 * Inference: https://hud.pytorch.org/benchmark/compilers?dashboard=torchinductor&startTime=Tue%2C%2017%20Jun%202025%2021%3A32%3A04%20GMT&stopTime=Tue%2C%2024%20Jun%202025%2021%3A32%3A04%20GMT&granularity=hour&mode=inference&dtype=bfloat16&deviceName=cuda%20(h100)&lBranch=gh/masnesral/214/head&lCommit=41943253882a019b8ceafcd2bf4cd6acbe0cbca9&rBranch=main&rCommit=eab45643f22e58ee12d95d8b0162d51ca0a50801 Pull Request resolved: https://github.com/pytorch/pytorch/pull/156187 Approved by: https://github.com/aorenste, https://github.com/jansel
This commit is contained in:
committed by
PyTorch MergeBot
parent
178fe7aa98
commit
7a41f20794
@ -29,8 +29,7 @@ class TestAsyncCompile(TestCase):
|
||||
|
||||
with config.patch("worker_start_method", method):
|
||||
shutdown_compile_workers()
|
||||
pool = AsyncCompile.process_pool()
|
||||
pool.ready_future.result(timeout=120)
|
||||
AsyncCompile.wait_pool_ready()
|
||||
|
||||
with fresh_cache():
|
||||
compiled_fn = torch.compile(fn)
|
||||
|
@ -53,6 +53,19 @@ class TestCompileWorker(TestCase):
|
||||
finally:
|
||||
pool.shutdown()
|
||||
|
||||
@skipIfWindows(msg="pass_fds not supported on Windows.")
|
||||
def test_quiesce(self):
|
||||
pool = SubprocPool(2)
|
||||
try:
|
||||
a = pool.submit(operator.add, 100, 1)
|
||||
pool.quiesce()
|
||||
pool.wakeup()
|
||||
b = pool.submit(operator.sub, 100, 1)
|
||||
self.assertEqual(a.result(), 101)
|
||||
self.assertEqual(b.result(), 99)
|
||||
finally:
|
||||
pool.shutdown()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._inductor.test_case import run_tests
|
||||
|
@ -20,6 +20,12 @@ from torch._dynamo.utils import dynamo_timed
|
||||
def inductor(*args, **kwargs):
|
||||
with dynamo_timed("inductor_import", log_pt2_compile_event=True):
|
||||
# do import here to avoid loading inductor into memory when it is not used
|
||||
# The AsyncCompile subproc pool can be slow to start, so warm it up as early
|
||||
# as possible.
|
||||
from torch._inductor.async_compile import maybe_warm_pool
|
||||
|
||||
maybe_warm_pool()
|
||||
|
||||
from torch._inductor.compile_fx import compile_fx
|
||||
|
||||
return compile_fx(*args, **kwargs)
|
||||
|
@ -735,6 +735,7 @@ def _compile(
|
||||
# in the case of normal and exception code paths
|
||||
convert_frame_box: Optional[ConvertFrameBox] = None,
|
||||
) -> ConvertFrameReturn:
|
||||
from torch._inductor.async_compile import async_compile_pool_manager
|
||||
from torch.fx.experimental.validator import (
|
||||
bisect,
|
||||
BisectValidationException,
|
||||
@ -1015,6 +1016,7 @@ def _compile(
|
||||
with (
|
||||
_use_lazy_graph_module(config.use_lazy_graph_module),
|
||||
compile_context(CompileContext(compile_id)),
|
||||
async_compile_pool_manager(),
|
||||
chromium_event_timed(
|
||||
"dynamo", reset_event_log_on_exit=True, log_pt2_compile_event=True
|
||||
),
|
||||
|
@ -2286,6 +2286,8 @@ To fix this, your tensor subclass must implement the dunder method __force_to_sa
|
||||
|
||||
@staticmethod
|
||||
def _backward_impl(ctx, all_args):
|
||||
from torch._inductor.async_compile import async_compile_pool_manager
|
||||
|
||||
# compiled autograd reimplements this function at proxy_call_aot_backward
|
||||
assert not backward_state_indices, (
|
||||
"BackwardState requires CompiledAutograd"
|
||||
@ -2327,6 +2329,7 @@ To fix this, your tensor subclass must implement the dunder method __force_to_sa
|
||||
with (
|
||||
tracing(saved_context),
|
||||
compile_context(saved_compile_context),
|
||||
async_compile_pool_manager(),
|
||||
context(),
|
||||
track_graph_compiling(aot_config, "backward"),
|
||||
metrics_context,
|
||||
|
@ -2,6 +2,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import atexit
|
||||
import contextlib
|
||||
import functools
|
||||
import json
|
||||
import logging
|
||||
@ -216,7 +217,25 @@ class CompiledTritonKernels:
|
||||
del CompiledTritonKernels._cache[key]
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def async_compile_pool_manager():
|
||||
"""
|
||||
Context manager to quiesce the subproc pool at the end of compilation, i.e.,
|
||||
when dynamo is done.
|
||||
"""
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
AsyncCompile.quiesce()
|
||||
|
||||
|
||||
class AsyncCompile:
|
||||
"""
|
||||
Utilities to compile in thread pools or subprocess pools (in the case of Triton).
|
||||
"""
|
||||
|
||||
_ready_future: Optional[Future[Any]] = None
|
||||
|
||||
def __init__(self) -> None:
|
||||
pass
|
||||
|
||||
@ -235,6 +254,7 @@ class AsyncCompile:
|
||||
@functools.lru_cache(1)
|
||||
def process_pool() -> AnyPool:
|
||||
assert get_compile_threads() > 1
|
||||
AsyncCompile._ready_future = None
|
||||
log.info(
|
||||
"Creating '%s' pool with %d workers",
|
||||
config.worker_start_method,
|
||||
@ -262,8 +282,6 @@ class AsyncCompile:
|
||||
# kill the worker thread that sends the shutdown message to the workers...
|
||||
multiprocessing.util.Finalize(None, pool.shutdown, exitpriority=sys.maxsize)
|
||||
|
||||
# Set an attribute we can check to see if the pool is ready.
|
||||
pool.ready_future = pool.submit(AsyncCompile._get_ready) # type: ignore[union-attr]
|
||||
_pool_set.add(pool)
|
||||
return pool
|
||||
|
||||
@ -272,20 +290,61 @@ class AsyncCompile:
|
||||
if get_compile_threads() <= 1:
|
||||
return
|
||||
_compile_start()
|
||||
# Pool is initialized on first access
|
||||
# Pool is created on first access. Note for a SubprocPool, the sidecar process starts,
|
||||
# but its ProcessPoolExecutor does not initialize until a wakeup() call or the first
|
||||
# job is submitted.
|
||||
cls.process_pool()
|
||||
_compile_end()
|
||||
|
||||
@classmethod
|
||||
def wait_pool_ready(cls, timeout=120) -> None:
|
||||
if cls.use_process_pool():
|
||||
assert cls._ready_future is not None
|
||||
cls._ready_future.result(timeout=timeout)
|
||||
|
||||
@classmethod
|
||||
def submit(cls, task: Callable[..., Any]) -> Any:
|
||||
if get_compile_threads() <= 1:
|
||||
return task()
|
||||
return cls.pool().submit(task)
|
||||
|
||||
def use_process_pool(self):
|
||||
return (
|
||||
get_compile_threads() > 1 and self.process_pool().ready_future.done() # type: ignore[union-attr]
|
||||
)
|
||||
@classmethod
|
||||
def use_process_pool(cls):
|
||||
if get_compile_threads() <= 1:
|
||||
return False
|
||||
|
||||
# Create a dummy job to check if the pool is ready. Submit it here instead of at
|
||||
# pool creation so we don't launch the full pool of worker subprocesses until
|
||||
# we're sure they're needed.
|
||||
if not cls._ready_future:
|
||||
cls._ready_future = cls.process_pool().submit(cls._get_ready)
|
||||
return cls._ready_future.done()
|
||||
|
||||
@classmethod
|
||||
def quiesce(cls) -> None:
|
||||
"""
|
||||
If using a SubprocPool, signal the sidecar process to shut down its
|
||||
ProcessPoolExecutor.
|
||||
"""
|
||||
# Don't inadvertently create a process pool if it doesn't already exist:
|
||||
if not cls.process_pool.cache_info().currsize:
|
||||
return
|
||||
if config.quiesce_async_compile_pool:
|
||||
pool = cls.process_pool()
|
||||
if isinstance(pool, SubprocPool):
|
||||
pool.quiesce()
|
||||
|
||||
@classmethod
|
||||
def wakeup(cls) -> None:
|
||||
"""
|
||||
If using a SubprocPool, signal the sidecar process to start up its
|
||||
ProcessPoolExecutor.
|
||||
"""
|
||||
if not cls.use_process_pool():
|
||||
return
|
||||
pool = cls.process_pool()
|
||||
if isinstance(pool, SubprocPool):
|
||||
pool.wakeup()
|
||||
|
||||
def triton(self, kernel_name: str, source_code: str, device_str: str = "cuda"):
|
||||
"""
|
||||
@ -517,18 +576,24 @@ class AsyncCompile:
|
||||
pbar.update(1)
|
||||
|
||||
|
||||
if (
|
||||
os.environ.get("TORCH_TNT_IN_USE", "0") == "1"
|
||||
or os.environ.get("TORCH_WARM_POOL", "1") != "1"
|
||||
# The subprocess pool is only used for the Triton backend
|
||||
or not has_triton_package()
|
||||
# Skip for fbcode. We have internal reports of usages inside multiprocessing
|
||||
# pools that lead a multiplicative number of compile subprocesses.
|
||||
or config.is_fbcode()
|
||||
):
|
||||
pass
|
||||
else:
|
||||
def maybe_warm_pool() -> None:
|
||||
if (
|
||||
os.environ.get("TORCH_TNT_IN_USE", "0") == "1"
|
||||
or os.environ.get("TORCH_WARM_POOL", "1") != "1"
|
||||
# The subprocess pool is only used for the Triton backend
|
||||
or not has_triton_package()
|
||||
# Skip for fbcode. We have internal reports of usages inside multiprocessing
|
||||
# pools that lead a multiplicative number of compile subprocesses.
|
||||
or config.is_fbcode()
|
||||
):
|
||||
return
|
||||
|
||||
AsyncCompile.warm_pool()
|
||||
# TODO: This starts the SubprocPool's internal process pool as early as possible at
|
||||
# the expense of creating a bunch of worker processes that might not be needed. We
|
||||
# could start them lazily if we're willing to lose a small amount of compile time.
|
||||
AsyncCompile.wakeup()
|
||||
|
||||
|
||||
# On exit give the workers a chance to clean themselves up. Without this the
|
||||
# resource_tracker can complain about leaked semaphores coming from the
|
||||
|
@ -22,7 +22,7 @@ from typing import Any, Callable, Optional, TYPE_CHECKING, TypeVar, Union
|
||||
from typing_extensions import Never, override, ParamSpec, Protocol, TypedDict, Unpack
|
||||
from unittest import mock
|
||||
|
||||
import torch._inductor.async_compile # noqa: F401 required to warm up AsyncCompile pools
|
||||
import torch._inductor.async_compile
|
||||
import torch.fx
|
||||
import torch.utils._pytree as pytree
|
||||
from functorch.compile import min_cut_rematerialization_partition
|
||||
@ -2037,6 +2037,12 @@ def compile_fx(
|
||||
NB: This function TAKES OWNERSHIP of the input ``model_`` and can potentially
|
||||
mutate it! Make a copy if you need to preserve the original GraphModule.
|
||||
"""
|
||||
# Wake up the AsyncCompile subproc pool as early as possible (if there's cuda).
|
||||
if any(
|
||||
isinstance(e, torch.Tensor) and e.device.type in ("cuda", "xpu")
|
||||
for e in example_inputs_
|
||||
):
|
||||
torch._inductor.async_compile.AsyncCompile.wakeup()
|
||||
|
||||
# Some arguments trigger a recursive call to compile_fx. Handle these
|
||||
# short circuits first, before anything else
|
||||
|
@ -13,7 +13,7 @@ import traceback
|
||||
import typing
|
||||
from concurrent.futures import Future, ProcessPoolExecutor
|
||||
from concurrent.futures.process import BrokenProcessPool
|
||||
from enum import Enum
|
||||
from enum import Enum, IntEnum
|
||||
from typing import Any, Callable, IO, Optional, TypeVar
|
||||
from typing_extensions import Never, ParamSpec
|
||||
|
||||
@ -36,31 +36,42 @@ _P = ParamSpec("_P")
|
||||
_T = TypeVar("_T")
|
||||
|
||||
|
||||
def _pack_msg(job_id: int, length: int) -> bytes:
|
||||
return struct.pack("nn", job_id, length)
|
||||
class MsgHeader(IntEnum):
|
||||
ERROR = 0
|
||||
SHUTDOWN = 1
|
||||
QUIESCE = 2
|
||||
WAKEUP = 3
|
||||
JOB = 4
|
||||
|
||||
|
||||
def _unpack_msg(data: bytes) -> tuple[int, int]:
|
||||
def _pack_msg(msg_header: MsgHeader, job_id: int, length: int) -> bytes:
|
||||
return struct.pack("nnn", int(msg_header), job_id, length)
|
||||
|
||||
|
||||
def _unpack_msg(data: bytes) -> tuple[MsgHeader, int, int]:
|
||||
if not data:
|
||||
return -1, -1
|
||||
return struct.unpack("nn", data)
|
||||
return MsgHeader.ERROR, -1, -1
|
||||
msg_header, job_id, length = struct.unpack("nnn", data)
|
||||
return MsgHeader(msg_header), job_id, length
|
||||
|
||||
|
||||
msg_bytes = len(_pack_msg(0, 0))
|
||||
msg_bytes = len(_pack_msg(MsgHeader.JOB, 0, 0))
|
||||
|
||||
|
||||
def _send_msg(write_pipe: IO[bytes], job_id: int, job_data: bytes = b"") -> None:
|
||||
length = len(job_data)
|
||||
write_pipe.write(_pack_msg(job_id, length))
|
||||
def _send_msg(
|
||||
write_pipe: IO[bytes], msg_header: MsgHeader, job_id: int = -1, data: bytes = b""
|
||||
) -> None:
|
||||
length = len(data)
|
||||
write_pipe.write(_pack_msg(msg_header, job_id, length))
|
||||
if length > 0:
|
||||
write_pipe.write(job_data)
|
||||
write_pipe.write(data)
|
||||
write_pipe.flush()
|
||||
|
||||
|
||||
def _recv_msg(read_pipe: IO[bytes]) -> tuple[int, bytes]:
|
||||
job_id, length = _unpack_msg(read_pipe.read(msg_bytes))
|
||||
def _recv_msg(read_pipe: IO[bytes]) -> tuple[MsgHeader, int, bytes]:
|
||||
msg_header, job_id, length = _unpack_msg(read_pipe.read(msg_bytes))
|
||||
data = read_pipe.read(length) if length > 0 else b""
|
||||
return job_id, data
|
||||
return msg_header, job_id, data
|
||||
|
||||
|
||||
class _SubprocExceptionInfo:
|
||||
@ -147,9 +158,7 @@ class SubprocPool:
|
||||
"PYTHONPATH": os.environ.get(
|
||||
"TORCH_CUSTOM_PYTHONPATH", os.pathsep.join(sys.path)
|
||||
),
|
||||
# We don't want to re-warm the pool when the subprocess imports
|
||||
# torch._inductor.codecache since the warming process is what
|
||||
# creates the SubprocPool in the first place.
|
||||
# Safeguard against creating a SubprocPool in the subprocess.
|
||||
"TORCH_WARM_POOL": "0",
|
||||
# Some internal usages need a modified LD_LIBRARY_PATH.
|
||||
"LD_LIBRARY_PATH": get_ld_library_path(),
|
||||
@ -182,24 +191,28 @@ class SubprocPool:
|
||||
job_id = next(self.job_id_count)
|
||||
self.pending_futures[job_id] = future = Future()
|
||||
future.set_running_or_notify_cancel()
|
||||
self._send(MsgHeader.JOB, job_id, job_data)
|
||||
return future
|
||||
|
||||
def _send(self, msg_header: MsgHeader, job_id: int = -1, data: bytes = b"") -> None:
|
||||
with self.write_lock:
|
||||
if not self.running:
|
||||
raise RuntimeError("submit() on closed pool")
|
||||
_send_msg(self.write_pipe, job_id, job_data)
|
||||
return future
|
||||
raise RuntimeError("Attempting to use a closed pool")
|
||||
_send_msg(self.write_pipe, msg_header, job_id, data)
|
||||
|
||||
def _read_thread(self) -> None:
|
||||
while True:
|
||||
data = b""
|
||||
job_id = -1
|
||||
try:
|
||||
job_id, data = _recv_msg(self.read_pipe)
|
||||
msg_header, job_id, data = _recv_msg(self.read_pipe)
|
||||
except Exception:
|
||||
# Something went wrong during the read. There's no way we have a
|
||||
# valid job_id.
|
||||
# valid msg.
|
||||
log.exception("failure in subproc_pool._recv_msg")
|
||||
job_id = -1
|
||||
msg_header = MsgHeader.ERROR
|
||||
|
||||
if job_id < 0:
|
||||
if msg_header != MsgHeader.JOB:
|
||||
# read_pipe returned None or got exception
|
||||
if self.running:
|
||||
log.warning("SubprocPool unclean exit")
|
||||
@ -232,13 +245,19 @@ class SubprocPool:
|
||||
self.pending_futures[job_id].set_result(result)
|
||||
del self.pending_futures[job_id]
|
||||
|
||||
def quiesce(self) -> None:
|
||||
self._send(MsgHeader.QUIESCE)
|
||||
|
||||
def wakeup(self) -> None:
|
||||
self._send(MsgHeader.WAKEUP)
|
||||
|
||||
def shutdown(self) -> None:
|
||||
try:
|
||||
with self.write_lock:
|
||||
if not self.running:
|
||||
return
|
||||
self.running = False
|
||||
_send_msg(self.write_pipe, -1)
|
||||
_send_msg(self.write_pipe, MsgHeader.SHUTDOWN)
|
||||
self.write_pipe.close()
|
||||
self.process.wait(300)
|
||||
except OSError as e:
|
||||
@ -268,37 +287,36 @@ class SubprocMain:
|
||||
self.write_pipe = write_pipe
|
||||
self.write_lock = threading.Lock()
|
||||
self.nprocs = nprocs
|
||||
self.pool = self._new_pool(nprocs, True)
|
||||
self.pool: Optional[ProcessPoolExecutor] = None
|
||||
self.running = True
|
||||
|
||||
def _new_pool(self, nprocs: int, warm: bool) -> ProcessPoolExecutor:
|
||||
pool = TrackedProcessPoolExecutor(
|
||||
nprocs,
|
||||
mp_context=multiprocessing.get_context(self.kind.value),
|
||||
initializer=functools.partial(_async_compile_initializer, os.getpid()),
|
||||
)
|
||||
multiprocessing.util.Finalize(None, pool.shutdown, exitpriority=sys.maxsize)
|
||||
if warm:
|
||||
_warm_process_pool(pool, nprocs)
|
||||
return pool
|
||||
|
||||
def main(self) -> None:
|
||||
while True:
|
||||
job_id, data = _recv_msg(self.read_pipe)
|
||||
if job_id < 0:
|
||||
msg_header, job_id, data = _recv_msg(self.read_pipe)
|
||||
if msg_header == MsgHeader.JOB:
|
||||
self.submit(job_id, data)
|
||||
elif msg_header == MsgHeader.WAKEUP:
|
||||
self._start_pool()
|
||||
elif msg_header == MsgHeader.QUIESCE:
|
||||
self._quiesce()
|
||||
else:
|
||||
return self._shutdown()
|
||||
self.submit(job_id, data)
|
||||
|
||||
def _quiesce(self) -> None:
|
||||
if self.pool is not None:
|
||||
self.pool.shutdown(wait=False)
|
||||
self.pool = None
|
||||
|
||||
def _shutdown(self) -> None:
|
||||
with self.write_lock:
|
||||
self.running = False
|
||||
try:
|
||||
_send_msg(self.write_pipe, -1)
|
||||
_send_msg(self.write_pipe, MsgHeader.SHUTDOWN)
|
||||
self.write_pipe.close()
|
||||
except BrokenPipeError:
|
||||
pass # parent process already shutdown
|
||||
self.read_pipe.close()
|
||||
self.pool.shutdown()
|
||||
self._quiesce()
|
||||
|
||||
def submit(self, job_id: int, data: bytes) -> None:
|
||||
while self.running:
|
||||
@ -309,7 +327,7 @@ class SubprocMain:
|
||||
# If any subprocess in the pool crashes, we get a BrokenProcessPool
|
||||
# exception and the whole pool becomes unusable. Handle crashes by
|
||||
# recreating the pool and resubmitting.
|
||||
self.pool = self._new_pool(self.nprocs, False)
|
||||
self.pool = None
|
||||
|
||||
def _submit_inner(self, job_id: int, data: bytes) -> None:
|
||||
def callback(fut: Future[Any]) -> None:
|
||||
@ -323,14 +341,31 @@ class SubprocMain:
|
||||
assert isinstance(result, bytes)
|
||||
with self.write_lock:
|
||||
if self.running:
|
||||
_send_msg(self.write_pipe, job_id, result)
|
||||
_send_msg(self.write_pipe, MsgHeader.JOB, job_id, result)
|
||||
return
|
||||
|
||||
self._start_pool()
|
||||
assert self.pool is not None
|
||||
|
||||
future = self.pool.submit(
|
||||
functools.partial(SubprocMain.do_job, self.pickler, data)
|
||||
)
|
||||
future.add_done_callback(callback)
|
||||
|
||||
def _start_pool(self) -> None:
|
||||
if self.pool is not None:
|
||||
return
|
||||
|
||||
self.pool = TrackedProcessPoolExecutor(
|
||||
self.nprocs,
|
||||
mp_context=multiprocessing.get_context(self.kind.value),
|
||||
initializer=functools.partial(_async_compile_initializer, os.getpid()),
|
||||
)
|
||||
multiprocessing.util.Finalize(
|
||||
None, self.pool.shutdown, exitpriority=sys.maxsize
|
||||
)
|
||||
_warm_process_pool(self.pool, self.nprocs)
|
||||
|
||||
@staticmethod
|
||||
def do_job(pickler: SubprocPickler, data: bytes) -> bytes:
|
||||
# do the pickle/unpickle in the sub-subproc
|
||||
|
@ -811,6 +811,13 @@ def decide_compile_threads() -> int:
|
||||
# TODO: Set directly after internal rollout.
|
||||
compile_threads: Optional[int] = None if is_fbcode() else decide_compile_threads()
|
||||
|
||||
# Whether to quiesce the Triton-compile subprocess pool at the end of each compilation.
|
||||
quiesce_async_compile_pool: bool = Config(
|
||||
justknob="pytorch/inductor:quiesce_async_compile_pool",
|
||||
env_name_force="TORCHINDUCTOR_QUIESCE_ASYNC_COMPILE_POOL",
|
||||
default=False,
|
||||
)
|
||||
|
||||
# Whether or not to enable statically launching CUDA kernels
|
||||
# compiled by triton (instead of using triton's own launcher)
|
||||
use_static_cuda_launcher: bool = static_cuda_launcher_default()
|
||||
|
Reference in New Issue
Block a user