Compare commits

...

25 Commits

Author SHA1 Message Date
a6c524d35c Update
[ghstack-poisoned]
2025-11-04 10:05:49 -08:00
ca72eacb1e Update (base update)
[ghstack-poisoned]
2025-11-04 10:05:49 -08:00
df73ff1ebc Update
[ghstack-poisoned]
2025-11-03 15:25:05 -08:00
7c77a377b5 Update (base update)
[ghstack-poisoned]
2025-11-03 15:25:05 -08:00
12a7b7968f Update
[ghstack-poisoned]
2025-11-03 13:23:29 -08:00
27186dfec2 Update (base update)
[ghstack-poisoned]
2025-11-03 13:23:29 -08:00
f39f0c8fc5 Update
[ghstack-poisoned]
2025-11-03 13:14:26 -08:00
152d6ad71f Update (base update)
[ghstack-poisoned]
2025-11-03 13:14:26 -08:00
dcced84559 Update
[ghstack-poisoned]
2025-11-03 13:04:02 -08:00
a9c03eb296 Update (base update)
[ghstack-poisoned]
2025-11-03 13:04:02 -08:00
1fb0237a0e Update
[ghstack-poisoned]
2025-11-03 13:00:18 -08:00
868bb4dc61 Update (base update)
[ghstack-poisoned]
2025-11-03 13:00:18 -08:00
34516c3cef Update
[ghstack-poisoned]
2025-10-31 15:21:21 -07:00
41c4fd6dfc Update (base update)
[ghstack-poisoned]
2025-10-31 13:32:39 -07:00
1e4be63fa5 Update
[ghstack-poisoned]
2025-10-31 13:32:39 -07:00
b0e9fa5d5f Update (base update)
[ghstack-poisoned]
2025-10-31 13:24:03 -07:00
24117fb269 Update
[ghstack-poisoned]
2025-10-31 13:24:03 -07:00
6f52759003 Update (base update)
[ghstack-poisoned]
2025-10-31 13:20:48 -07:00
dc6a588c7b Update
[ghstack-poisoned]
2025-10-31 13:20:48 -07:00
77b2bbbb06 Update (base update)
[ghstack-poisoned]
2025-10-31 13:00:39 -07:00
c9efe6cb9c Update
[ghstack-poisoned]
2025-10-31 13:00:39 -07:00
6e21eea645 Update (base update)
[ghstack-poisoned]
2025-10-31 12:32:21 -07:00
7a8918bae9 Update
[ghstack-poisoned]
2025-10-31 12:32:21 -07:00
ef183266ae Update (base update)
[ghstack-poisoned]
2025-10-29 16:00:23 -07:00
d82f7024e7 Update
[ghstack-poisoned]
2025-10-29 16:00:23 -07:00
7 changed files with 38 additions and 40 deletions

View File

@ -4,6 +4,7 @@ import os
import tempfile
from threading import Event
import torch._inductor.config as config
from torch._inductor.compile_worker.subproc_pool import (
raise_testexc,
SubprocException,
@ -16,9 +17,12 @@ from torch.testing._internal.inductor_utils import HAS_CPU
class TestCompileWorker(TestCase):
def make_pool(self, size):
return SubprocPool(size)
@skipIfWindows(msg="pass_fds not supported on Windows.")
def test_basic_jobs(self):
pool = SubprocPool(2)
pool = self.make_pool(2)
try:
a = pool.submit(operator.add, 100, 1)
b = pool.submit(operator.sub, 100, 1)
@ -29,7 +33,7 @@ class TestCompileWorker(TestCase):
@skipIfWindows(msg="pass_fds not supported on Windows.")
def test_exception(self):
pool = SubprocPool(2)
pool = self.make_pool(2)
try:
a = pool.submit(raise_testexc)
with self.assertRaisesRegex(
@ -42,7 +46,7 @@ class TestCompileWorker(TestCase):
@skipIfWindows(msg="pass_fds not supported on Windows.")
def test_crash(self):
pool = SubprocPool(2)
pool = self.make_pool(2)
try:
with self.assertRaises(Exception):
a = pool.submit(os._exit, 1)
@ -58,7 +62,7 @@ class TestCompileWorker(TestCase):
@skipIfWindows(msg="pass_fds not supported on Windows.")
def test_quiesce(self):
pool = SubprocPool(2)
pool = self.make_pool(2)
try:
a = pool.submit(operator.add, 100, 1)
pool.quiesce()
@ -75,7 +79,7 @@ class TestCompileWorker(TestCase):
os.environ["ROLE_RANK"] = "0"
with tempfile.NamedTemporaryFile(delete=True) as temp_log:
os.environ["TORCHINDUCTOR_WORKER_LOGPATH"] = temp_log.name
pool = SubprocPool(2)
pool = self.make_pool(2)
try:
pool.submit(operator.add, 100, 1)
self.assertEqual(os.path.exists(temp_log.name), True)
@ -83,6 +87,12 @@ class TestCompileWorker(TestCase):
pool.shutdown()
@config.patch("quiesce_async_compile_time", 0.1)
class TestCompileWorkerWithTimer(TestCompileWorker):
def make_pool(self, size):
return SubprocPool(size, quiesce=True)
class TestTimer(TestCase):
def test_basics(self):
done = Event()

View File

@ -1282,7 +1282,6 @@ def _compile(
# in the case of normal and exception code paths
convert_frame_box: Optional[ConvertFrameBox] = None,
) -> ConvertFrameReturn:
from torch._inductor.async_compile import async_compile_pool_manager
from torch.fx.experimental.validator import (
BisectValidationException,
ValidationException,
@ -1476,7 +1475,6 @@ def _compile(
with (
_use_lazy_graph_module(config.use_lazy_graph_module),
compile_context(CompileContext(compile_id)),
async_compile_pool_manager(),
chromium_event_timed(
"dynamo", reset_event_log_on_exit=True, log_pt2_compile_event=True
),

View File

@ -2365,8 +2365,6 @@ To fix this, your tensor subclass must implement the dunder method __force_to_sa
@staticmethod
def _backward_impl(ctx, all_args):
from torch._inductor.async_compile import async_compile_pool_manager
# compiled autograd reimplements this function at proxy_call_aot_backward
assert not backward_state_indices, (
"BackwardState requires CompiledAutograd"
@ -2446,7 +2444,6 @@ To fix this, your tensor subclass must implement the dunder method __force_to_sa
with (
tracing(saved_context),
compile_context(saved_compile_context),
async_compile_pool_manager(),
context(),
track_graph_compiling(aot_config, "backward"),
metrics_context,

View File

@ -2,7 +2,6 @@
from __future__ import annotations
import atexit
import contextlib
import functools
import json
import logging
@ -228,18 +227,6 @@ class CompiledTritonKernels:
del CompiledTritonKernels._cache[key]
@contextlib.contextmanager
def async_compile_pool_manager():
"""
Context manager to quiesce the subproc pool at the end of compilation, i.e.,
when dynamo is done.
"""
try:
yield
finally:
AsyncCompile.quiesce()
class AsyncCompile:
"""
Utilities to compile in thread pools or subprocess pools (in the case of Triton).
@ -275,7 +262,9 @@ class AsyncCompile:
pool: AnyPool
if config.worker_start_method == "subprocess":
# Wrapper around ProcessPoolExecutor forks in a new process we control
pool = SubprocPool(get_compile_threads())
pool = SubprocPool(
get_compile_threads(), quiesce=config.quiesce_async_compile_pool
)
else:
if config.worker_start_method == "spawn":
# Avoid creating pools in the spawned subprocs themselves:
@ -331,20 +320,6 @@ class AsyncCompile:
cls._ready_future = cls.process_pool().submit(cls._get_ready)
return cls._ready_future.done()
@classmethod
def quiesce(cls) -> None:
"""
If using a SubprocPool, signal the sidecar process to shut down its
ProcessPoolExecutor.
"""
# Don't inadvertently create a process pool if it doesn't already exist:
if not cls.process_pool.cache_info().currsize:
return
if config.quiesce_async_compile_pool:
pool = cls.process_pool()
if isinstance(pool, SubprocPool):
pool.quiesce()
@classmethod
def wakeup(cls) -> None:
"""

View File

@ -23,6 +23,7 @@ from typing_extensions import Never, ParamSpec
import torch._thread_safe_fork # noqa: F401
from torch._inductor import config
from torch._inductor.codecache import torch_key
from torch._inductor.compile_worker.timer import Timer
from torch._inductor.compile_worker.tracked_process_pool import (
TrackedProcessPoolExecutor,
)
@ -131,6 +132,7 @@ class SubprocPool:
nprocs: int,
pickler: Optional[SubprocPickler] = None,
kind: SubprocKind = SubprocKind.FORK,
quiesce: bool = False,
) -> None:
entry = os.path.join(os.path.dirname(__file__), "__main__.py")
self.pickler = pickler or SubprocPickler()
@ -215,6 +217,13 @@ class SubprocPool:
"pytorch.wait_counter.subproc_pool.first_job"
).guard()
if quiesce:
self.timer: Optional[Timer] = Timer(
config.quiesce_async_compile_time, self.quiesce
)
else:
self.timer = None
# Start thread last to ensure all member variables are initialized
# before any access.
self.read_thread.start()
@ -287,6 +296,8 @@ class SubprocPool:
with self.futures_lock:
if not self.running:
return
if self.timer:
self.timer.record_call()
if isinstance(result, _SubprocExceptionInfo):
# An exception occurred in the submitted job
self.pending_futures[job_id].set_exception(
@ -321,6 +332,8 @@ class SubprocPool:
with self.write_lock:
if not self.running:
return
if self.timer:
self.timer.quit()
self.running = False
self.running_waitcounter.__exit__()
_send_msg(self.write_pipe, MsgHeader.SHUTDOWN)

View File

@ -17,7 +17,7 @@ class Timer:
self.background_thread: Optional[Thread] = None
self.last_called: Optional[float] = None
self.duration = duration
self.sleep_time = 60
self.sleep_time = duration / 2
self.call = call
self.exit = False

View File

@ -957,7 +957,12 @@ compile_threads: Optional[int] = None if is_fbcode() else decide_compile_threads
quiesce_async_compile_pool: bool = Config(
justknob="pytorch/inductor:quiesce_async_compile_pool",
env_name_force="TORCHINDUCTOR_QUIESCE_ASYNC_COMPILE_POOL",
default=False,
default=True,
)
# Time in seconds to wait before quiescing
quiesce_async_compile_time: int = Config(
default=60,
)
# Whether or not to enable statically launching CUDA kernels