Improve subproc autotuning implementation (#149700)

Summary: The primary change is to update the autotune-in-a-subproc implementation to avoid using multiprocessing spawn. Spawn (re)executes the toplevel script in the subproc, which can be problematic. The approach here is similar to Triton parallel compile: we Popen a subproc on a controlled entry point and communicate over pipes. That change drove a lot of refactoring in the TuningProcess class, so I took the opportunity to simplify some things, rename some methods, etc.

One other notable change is around the timeout / kill approach. After a timeout, we were previously attempting to stop the subproc in three steps (graceful shutdown, sigkill if graceful fails, sigterm if sigkill fails). I'm gonna argue think that's not useful: 1) The graceful shutdown is never going to work unless the subproc happens to have just completed its task and is ready to receive the next command. 2) If we're going to kill the subproc, let's just take the most aggressive approach and move on as quickly as possible to restarting it rather than waiting to see if previous shutdown attempts succeeded. The only downside that I can find find is maybe a little log spew?, e.g., ` ResourceWarning: subprocess 2987680 is still running`

List of changes:
* Use Popen instead of spawn for the autotuning subprocess.
* Introduced a new entry point `__autotune_main__.py`
* Renamed some TuningProcess methods. For example `shutdown` makes more sense than `terminate` because the latter implies a forced kill.
* Simplified the implementation around benchmarking timeout and how we kill the subproc after a timeout.
* Deprecated the unused timeout configs in `_inductor/config.py`
* Moved `get_ld_library_path` helper to a common utils file.
* Added more unit tests for subproc crashes / timeouts / exceptions, etc.

Test plan:
* New unit tests
* Also ran internally with all combinations of: build mode `opt` and `dev-nosan`, and `buck run` vs. executing the `.par` file directly.
* Made sure the functionality to parallelize autotuning across different GPUs is working (it wasn't clear to me this was behaving the way we wanted it to).

Differential Revision: [D71976971](https://our.internmc.facebook.com/intern/diff/D71976971)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/149700
Approved by: https://github.com/aorenste, https://github.com/jansel, https://github.com/eellison
This commit is contained in:
Sam Larsen
2025-03-27 13:57:43 -07:00
committed by PyTorch MergeBot
parent 8b04364914
commit 266bd22b44
6 changed files with 356 additions and 314 deletions

View File

@ -3,6 +3,7 @@ import contextlib
import json
import math
import os
import random
import tempfile
import unittest
from typing import Callable, Optional
@ -15,8 +16,9 @@ from torch._dynamo.testing import rand_strided, reset_rng_state
from torch._dynamo.utils import same
from torch._inductor import config
from torch._inductor.autotune_process import (
BenchmarkRequest,
_TestBenchmarkRequest,
CUDA_VISIBLE_DEVICES,
TuningProcess,
TuningProcessPool,
)
from torch._inductor.graph import GraphLowering
@ -1196,35 +1198,6 @@ class TestMaxAutotuneRemoteCache(TestCase):
self.assertEqual(global_stats.autotune_remote, Stats(2, 3, 2))
class _TestBenchmarkRequest(BenchmarkRequest):
def __init__(
self, value: float, multi_device: bool, parent_visible_devices: Optional[str]
) -> None:
self.value = value
self.multi_device = multi_device
self.parent_visible_devices = parent_visible_devices
def benchmark(
self, *input_tensors: torch.Tensor, output_tensor: Optional[torch.Tensor] = None
) -> float:
# Verify that the visible devices env var is set correctly. If multi-device
# auto-tuning is disabled, the visible devices should be unmanipulated from
# the parent process. If multi-device auto-tuning is enabled, the visible
# devices should be a _single_ valid device number. Note that we can't perform
# this validation directly from the test body because benchmarks execute in a
# separate process. If the check fails, however, the test will detect the
# failure by virtue of not receiving the expected result back.
visible_devices = os.environ.get(CUDA_VISIBLE_DEVICES)
if not self.multi_device:
assert visible_devices == self.parent_visible_devices
else:
assert self.parent_visible_devices is not None
valid_devices = self.parent_visible_devices.split(",")
assert visible_devices in valid_devices
return self.value
class _TestTritonTemplateCaller(TritonTemplateCaller):
def __init__(self, bmreq: _TestBenchmarkRequest):
self.bmreq = bmreq
@ -1234,63 +1207,141 @@ class _TestTritonTemplateCaller(TritonTemplateCaller):
class TestTuningProcess(TestCase):
def check_healthy(self, p: TuningProcess, device: Optional[int] = None):
result = random.random()
bmreq = _TestBenchmarkRequest(result, device=device)
p.put(bmreq.benchmark)
self.assertEqual(p.get(), result)
def test_tuning_subproc_timeout(self):
p = TuningProcess(None)
bmreq = _TestBenchmarkRequest(0, sleep=120)
p.put(bmreq.benchmark)
with self.assertRaises(TimeoutError):
p.get(timeout=1.0)
# Make sure the TuningProcess is still usable after a timeout.
self.check_healthy(p)
p.shutdown()
def test_tuning_subproc_exception(self):
p = TuningProcess(None)
bmreq = _TestBenchmarkRequest(0, exc=RuntimeError("Fail"))
p.put(bmreq.benchmark)
with self.assertRaises(RuntimeError):
p.get()
# Make sure the TuningProcess is still usable after an exception.
self.check_healthy(p)
p.shutdown()
def test_tuning_subproc_crash(self):
p = TuningProcess(None)
bmreq = _TestBenchmarkRequest(0, crash=True)
p.put(bmreq.benchmark)
with self.assertRaises(EOFError):
p.get()
# Make sure the TuningProcess is still usable after a crash.
self.check_healthy(p)
p.shutdown()
def test_tuning_subproc_killed(self):
p = TuningProcess(None)
p.kill()
self.check_healthy(p)
p.shutdown()
def test_visible_devices(self):
device_list = TuningProcessPool.get_device_list()
for device in device_list:
p = TuningProcess(device)
self.check_healthy(p, device=device)
p.shutdown()
class TestTuningProcessPool(TestCase):
# Use only one device/subprocess so we test the process restarts
# and is usable after a crash.
@config.patch({"autotune_multi_device": False})
def test_tuning_pool_crash(self):
# Use only one device/subprocess so we test the process restarts
# and is usable after a "crash".
with config.patch({"autotune_multi_device": False}):
tuning_pool = TuningProcessPool()
tuning_pool.initialize()
tuning_pool = TuningProcessPool()
tuning_pool.initialize()
# First force the tuning process to "crash" by setting a bogus
# string for the expected visible devices.
bmreq = _TestBenchmarkRequest(3.14, False, "invalid")
choice = _TestTritonTemplateCaller(bmreq)
# First force the tuning process to crash.
bmreq = _TestBenchmarkRequest(0, crash=True)
choice = _TestTritonTemplateCaller(bmreq)
timings = tuning_pool.benchmark([choice])
self.assertTrue(choice in timings)
self.assertEqual(timings[choice], float("inf"))
# Then send another request and make sure the sub-process
# has restarted and is operational.
bmreq = _TestBenchmarkRequest(3.14)
choice = _TestTritonTemplateCaller(bmreq)
timings = tuning_pool.benchmark([choice])
self.assertTrue(choice in timings)
self.assertEqual(timings[choice], bmreq.result)
tuning_pool.shutdown()
@config.patch({"autotune_multi_device": False})
def test_tuning_pool_timeout(self):
tuning_pool = TuningProcessPool()
tuning_pool.initialize()
# First force the tuning process to timeout.
bmreq = _TestBenchmarkRequest(0, sleep=120)
choice = _TestTritonTemplateCaller(bmreq)
with config.patch({"max_autotune_subproc_result_timeout_seconds": 1.0}):
timings = tuning_pool.benchmark([choice])
self.assertTrue(choice in timings)
self.assertEqual(timings[choice], float("inf"))
self.assertTrue(choice in timings)
self.assertEqual(timings[choice], float("inf"))
# Then send another request and make sure the sub-process
# has restarted and is operational. 'valid_devices' expected
# to be None because autotune_multi_device is off.
choice.bmreq.parent_visible_devices = os.environ.get(CUDA_VISIBLE_DEVICES)
# Then send another request and make sure the sub-process
# has restarted and is operational.
bmreq = _TestBenchmarkRequest(3.14)
choice = _TestTritonTemplateCaller(bmreq)
timings = tuning_pool.benchmark([choice])
self.assertTrue(choice in timings)
self.assertEqual(timings[choice], bmreq.value)
timings = tuning_pool.benchmark([choice])
self.assertTrue(choice in timings)
self.assertEqual(timings[choice], bmreq.result)
tuning_pool.terminate()
tuning_pool.shutdown()
# XPU have to enable XPU_VISIBLE_DEVICES to control devices visibility.
@skipIfXpu
@config.patch({"autotune_multi_device": True})
def test_tuning_pool_multiple_devices(self):
with config.patch({"autotune_multi_device": True}):
# Adapt the test to the available devices (and whether CUDA_VISIBLE_DEVICES
# is already set in the environment); use a subset of the available devices
# to ensure only the subset are visible to the sub-processes.
if CUDA_VISIBLE_DEVICES in os.environ:
visible_devices = os.environ[CUDA_VISIBLE_DEVICES].split(",")
else:
visible_devices = [str(d) for d in range(torch.cuda.device_count())]
parent_visible_devices = ",".join(visible_devices[-2:])
os.environ[CUDA_VISIBLE_DEVICES] = parent_visible_devices
# Adapt the test to the available devices (and whether CUDA_VISIBLE_DEVICES
# is already set in the environment); use a subset of the available devices
# to ensure only the subset are visible to the sub-processes.
if CUDA_VISIBLE_DEVICES in os.environ:
visible_devices = os.environ[CUDA_VISIBLE_DEVICES].split(",")
else:
visible_devices = [str(d) for d in range(torch.cuda.device_count())]
cuda_visible_devices = ",".join(visible_devices[-2:])
with unittest.mock.patch.dict(
os.environ, {CUDA_VISIBLE_DEVICES: cuda_visible_devices}
):
tuning_pool = TuningProcessPool()
tuning_pool.initialize()
choice1 = _TestTritonTemplateCaller(
_TestBenchmarkRequest(3.14, True, parent_visible_devices),
)
choice2 = _TestTritonTemplateCaller(
_TestBenchmarkRequest(2.718, True, parent_visible_devices),
)
choice1 = _TestTritonTemplateCaller(_TestBenchmarkRequest(3.14))
choice2 = _TestTritonTemplateCaller(_TestBenchmarkRequest(2.718))
timings = tuning_pool.benchmark([choice1, choice2])
self.assertEqual(timings[choice1], choice1.bmreq.value)
self.assertEqual(timings[choice2], choice2.bmreq.value)
timings = tuning_pool.benchmark([choice1, choice2])
self.assertEqual(timings[choice1], choice1.bmreq.result)
self.assertEqual(timings[choice2], choice2.bmreq.result)
tuning_pool.terminate()
tuning_pool.shutdown()
@instantiate_parametrized_tests

View File

@ -0,0 +1,33 @@
import argparse
import logging
import os
from torch._inductor.autotune_process import TuningProcess
from torch._inductor.compile_worker.utils import _async_compile_initializer
log = logging.getLogger(__name__)
def main() -> None:
parser = argparse.ArgumentParser()
parser.add_argument("--parent", type=int)
parser.add_argument("--read-fd", type=int)
parser.add_argument("--write-fd", type=int)
args = parser.parse_args()
read_pipe = os.fdopen(args.read_fd, "rb")
write_pipe = os.fdopen(args.write_fd, "wb")
try:
# Ensures the subprocess exits if the parent crashes:
_async_compile_initializer(args.parent)
TuningProcess.process_main(read_pipe, write_pipe)
except Exception:
log.exception("Uncaught exception in autotune subprocess")
finally:
read_pipe.close()
write_pipe.close()
if __name__ == "__main__":
main()

View File

@ -1,23 +1,26 @@
# mypy: allow-untyped-defs
from __future__ import annotations
import contextlib
import atexit
import ctypes
import dataclasses
import functools
import logging
import os
import pickle
import queue
import selectors
import subprocess
import sys
import time
import warnings
from collections.abc import Iterable, Sequence
from concurrent.futures import ThreadPoolExecutor
from ctypes import byref, c_size_t, c_void_p, CDLL
from typing import Any, Callable, Optional, TYPE_CHECKING, Union
from typing import Any, Callable, IO, Optional, TYPE_CHECKING, Union
import torch
import torch._inductor.async_compile # noqa: F401 required to warm up AsyncCompile pools
from torch import multiprocessing
from torch._dynamo.device_interface import get_interface_for_device
from torch._dynamo.testing import rand_strided
from torch._inductor import ir
@ -28,14 +31,12 @@ from torch._inductor.codecache import (
get_hash,
PyCodeCache,
)
from torch._inductor.utils import get_gpu_type, is_gpu
from torch._inductor.utils import get_gpu_type, get_ld_library_path, is_gpu
from torch._logging import getArtifactLogger
from torch.utils._ordered_set import OrderedSet
if TYPE_CHECKING:
from multiprocessing.process import BaseProcess
from multiprocessing.queues import Queue
from types import ModuleType
from torch._inductor.select_algorithm import TritonTemplateCaller
@ -49,236 +50,183 @@ from .virtualized import V
CUDA_VISIBLE_DEVICES = "CUDA_VISIBLE_DEVICES"
EXIT_HANDLER_REGISTERED = False
autotuning_log = getArtifactLogger(__name__, "autotuning")
log = logging.getLogger(__name__)
# Used to synchronize between parent and child processes
class Ping:
pass
class Pong:
pass
class NonzeroWorkspaceNotSupportedError(Exception):
pass
@contextlib.contextmanager
def set_cuda_visible_device(device: Optional[int]):
"""
Context manager to set the CUDA_VISIBLE_DEVICES environment variable to the
specified single device. If device is None, don't manipulate the environment.
"""
if device is None:
yield
return
current = os.environ.get(CUDA_VISIBLE_DEVICES)
os.environ[CUDA_VISIBLE_DEVICES] = str(device)
try:
yield
finally:
if current is None:
del os.environ[CUDA_VISIBLE_DEVICES]
else:
os.environ[CUDA_VISIBLE_DEVICES] = current
@dataclasses.dataclass
class TuningProcess:
"""
Abstraction for launching a helper process to benchmark kernels. Spawns
the parent process and uses multiprocessing queues to send benchmark
requests and return results.
Class to launch and interact with a benchmarking subprocess.
"""
device: Optional[int] = None
process: Optional[BaseProcess] = None
request_queue: Optional[Queue[Any]] = None
response_queue: Optional[Queue[Any]] = None
@staticmethod
def process_main(
request_queue: Queue[Any],
response_queue: Queue[Any],
) -> None:
def process_main(read_pipe: IO[bytes], write_pipe: IO[bytes]) -> None:
"""
Entry point for the child process.
"""
autotuning_log.debug(
"Entering TuningProcess child. Visible devices = %s",
"Started autotune subprocess %s. Visible devices: %s",
os.getpid(),
os.environ.get(CUDA_VISIBLE_DEVICES),
)
def workloop():
while True:
job = TuningProcess.recv(read_pipe)
if job is None:
# None is a sentinel for the child to shut down
break
try:
result = job()
except Exception as e:
result = e
TuningProcess.send(result, write_pipe)
try:
TuningProcess.workloop(request_queue, response_queue)
except Exception:
autotuning_log.exception("Exception in TuningProcess")
workloop()
except EOFError:
# The parent closed the pipe
pass
@staticmethod
def workloop(request_queue: Queue[Any], response_queue: Queue[Any]) -> None:
"""
Work loop for the benchmarking subprocess.
"""
while True:
obj = request_queue.get()
def send(obj: Any, write_pipe: IO[bytes]) -> None:
pickle.dump(obj, write_pipe)
write_pipe.flush()
if obj is None:
break # None is a sentinel for the child to terminate
elif isinstance(obj, Ping):
response_queue.put(Pong())
elif isinstance(obj, BenchmarkRequest):
response_queue.put(obj.benchmark())
else:
raise RuntimeError(f"Invalid request type {type(obj)}")
@staticmethod
def recv(read_pipe: IO[bytes]) -> Any:
return pickle.load(read_pipe)
def valid(self) -> bool:
def __init__(self, device: Optional[int]):
self.device = device
self.start()
def start(self):
"""
True if the sub-process has been initialized.
Start the benchmarking subprocess.
"""
return (
self.process is not None
and self.request_queue is not None
and self.response_queue is not None
entry = os.path.join(os.path.dirname(__file__), "__autotune_main__.py")
subproc_read_fd, write_fd = os.pipe()
read_fd, subproc_write_fd = os.pipe()
self.write_pipe = os.fdopen(write_fd, "wb")
self.read_pipe = os.fdopen(read_fd, "rb")
self.selector = selectors.DefaultSelector()
self.selector.register(self.read_pipe, selectors.EVENT_READ)
cmd = [
sys.executable,
entry,
f"--parent={os.getpid()}",
f"--read-fd={str(subproc_read_fd)}",
f"--write-fd={str(subproc_write_fd)}",
]
extra_env = {
# We need to set the PYTHONPATH so the subprocess can find torch.
"PYTHONPATH": os.pathsep.join(sys.path),
# We shouldn't be using the Triton async compile subprocess pool,
# but as a precaution set the env var that disables its creation.
"TORCH_WARM_POOL": "0",
# Some internal usages need a modified LD_LIBRARY_PATH.
"LD_LIBRARY_PATH": get_ld_library_path(),
}
if self.device is not None:
extra_env[CUDA_VISIBLE_DEVICES] = str(self.device)
self.process = subprocess.Popen(
cmd,
env={**os.environ, **extra_env},
pass_fds=(subproc_read_fd, subproc_write_fd),
)
os.close(subproc_read_fd)
os.close(subproc_write_fd)
def clear(self) -> None:
self.running = True
def alive(self) -> bool:
"""
Reset to an uninitialized state.
True if the subprocess is still running.
"""
self.process = self.request_queue = self.response_queue = None
return self.running and self.process.poll() is None
def initialize(self) -> None:
"""
Create child process, request/response queues, and do the warm up.
Set the environment to make only the provided GPU device visible
to the process.
"""
if self.valid():
return
# cuda runtime does not work with "fork", use "spawn" to start processes.
ctx = multiprocessing.get_context("spawn")
self.request_queue = ctx.Queue()
self.response_queue = ctx.Queue()
self.process = ctx.Process(
target=self.process_main,
args=(
self.request_queue,
self.response_queue,
),
)
assert self.process is not None
with set_cuda_visible_device(self.device):
self.process.start()
def put(self, obj: Any) -> None:
def put(self, req: Any) -> None:
"""
Push a work item to the child process.
"""
# In case of a prior crash, ensure the subprocess is running
self.initialize()
assert self.request_queue is not None
self.request_queue.put(obj)
if not self.alive():
self.start()
TuningProcess.send(req, self.write_pipe)
def get(
self, result_timeout=120.0, graceful_timeout=3.0, terminate_timeout=1.0
) -> Any:
def get(self, timeout: float = 120.0) -> Any:
"""
Get a response from the child process. Raises queue.Empty on timeout
or if the process dies.
This method is (so far) only used by TuningProcessPool, where torch._inductor.config entries are being used
to populate the timeouts:
Arguments:
@param result_timeout: Timeout in seconds, defaults to 120.0 or to
config.max_autotune_subproc_result_timeout_seconds when called by TuningProcessPool
@param graceful_timeout: Timeout in seconds to allow graceful shutdown (SIGTERM is sent after this time).
Defaults to 3.0 or to config.max_autotune_subproc_graceful_timeout_seconds
@param terminate_timeout: Timeout in seconds after SIGTERM, until we send SIGKILL if the process
remains alive. Defaults to 1.0 or to
config.max_autotune_subproc_terminate_timeout_seconds.
Returns:
A response from the child process (Any type)
Get a response from the child process. Raises TimeoutError on timeout;
raises EOFError if the subprocess crashes.
"""
assert self.process is not None
assert self.response_queue is not None
while True:
try:
remaining_timeout = result_timeout
res = None
while remaining_timeout is not None and remaining_timeout >= 1.0:
remaining_timeout -= 0.5
try:
res = self.response_queue.get(timeout=0.5)
break
except queue.Empty:
if not self.process.is_alive():
raise # is being caught a few lines below
if res is None:
res = self.response_queue.get(timeout=remaining_timeout)
return res
except queue.Empty:
status = self.process.exitcode
if status is None:
self.kill(
graceful_timeout=graceful_timeout,
terminate_timeout=terminate_timeout,
)
else:
# child process crashed
self.clear()
raise
try:
if not self.selector.select(timeout):
raise TimeoutError(f"Timeout in autotune subprocess {self.process.pid}")
result = TuningProcess.recv(self.read_pipe)
except TimeoutError:
self.kill()
raise
except EOFError:
# The subprocess crashed
self.close()
raise
except Exception:
autotuning_log.exception(
"Unexpected exception in autotune subprocess %s", self.process.pid
)
self.kill()
raise
def terminate(self) -> None:
if isinstance(result, Exception):
raise result
return result
def shutdown(self, wait: bool = True) -> None:
"""
Signal the child process to terminate.
Signal the child process to shut down gracefully.
"""
if self.valid():
assert self.process is not None
assert self.request_queue is not None
self.request_queue.put(None)
if self.alive():
TuningProcess.send(None, self.write_pipe)
if wait:
self.wait()
def wait(self) -> None:
"""
Wait for the child process to exit.
"""
if self.process is not None:
self.process.join()
self.clear()
if self.alive():
self.process.wait()
self.close()
def kill(self, graceful_timeout=5.0, terminate_timeout=1.0) -> None:
# Tries to kill the process, using a graceful_timeout in which the process
# is allowed to exit gracefully. If the process is still alive,
# it will be terminated. If that is not sufficient to end it
# within terminate_timeout seconds, it will be killed.
if self.process is not None:
self.terminate()
self.process.join(timeout=graceful_timeout)
if self.process.is_alive():
autotuning_log.warning(
"Sending SIGTERM to process with PID %d",
self.process.pid,
)
self.process.terminate()
self.process.join(timeout=terminate_timeout)
if self.process.is_alive():
autotuning_log.error(
"Sending SIGKILL to process with PID %d",
self.process.pid,
)
self.process.kill() # This should definitely end the process
self.clear()
def close(self) -> None:
"""
Close resources.
"""
self.selector.close()
self.read_pipe.close()
self.write_pipe.close()
self.running = False
def kill(self) -> None:
"""
Send a SIGKILL to the child process.
"""
if self.alive():
autotuning_log.error(
"Sending SIGKILL to autotune subprocess %d",
self.process.pid,
)
self.process.kill()
self.close()
@dataclasses.dataclass
class TuningProcessPool:
"""
Maintains a pool of TuningProcesses to benchmark kernels in parallel
@ -286,8 +234,9 @@ class TuningProcessPool:
set the sub-process environment to make only that device visible.
"""
processes: Optional[queue.Queue[TuningProcess]] = None
executor: Optional[ThreadPoolExecutor] = None
def __init__(self) -> None:
self.processes: Optional[queue.Queue[TuningProcess]] = None
self.executor: Optional[ThreadPoolExecutor] = None
def initialize(self) -> None:
"""
@ -298,35 +247,21 @@ class TuningProcessPool:
return
devices = self.get_device_list()
log.debug("Sub-process autotune device list: %s", devices)
autotuning_log.debug("Sub-process autotune device list: %s", devices)
# Launch the child processes and push a msg to "warm up"
# Launch the child processes.
self.processes = queue.Queue()
for device in devices:
p = TuningProcess(device=device)
p.initialize()
p.put(Ping())
self.processes.put(p)
# Wait for the initialization to finish
for p in self.processes.queue:
assert isinstance(p.get(result_timeout=None), Pong)
# Use a thread pool to manage distributing work to the subprocesses.
# Threads block on an available process, so it makes sense to match
# the number of threads with the number of devices.
self.executor = ThreadPoolExecutor(max_workers=len(devices))
# Register the exit handler for the parent process so it will terminate
# the child processes.
global EXIT_HANDLER_REGISTERED
if not EXIT_HANDLER_REGISTERED:
EXIT_HANDLER_REGISTERED = True
import atexit
atexit.register(self.terminate)
def get_device_list(self) -> Sequence[Optional[int]]:
@staticmethod
def get_device_list() -> Sequence[Optional[int]]:
"""
Gather the list of devices to be used in the pool.
"""
@ -346,9 +281,9 @@ class TuningProcessPool:
return list(range(count))
def terminate(self) -> None:
def shutdown(self) -> None:
"""
Signal all child processes to terminate.
Signal all child processes to exit.
"""
if self.executor is not None:
self.executor.shutdown()
@ -356,7 +291,7 @@ class TuningProcessPool:
if self.processes is not None:
for p in self.processes.queue:
p.terminate()
p.shutdown(wait=False)
for p in self.processes.queue:
p.wait()
self.processes = None
@ -371,19 +306,24 @@ class TuningProcessPool:
assert self.processes is not None
process = self.processes.get()
process.put(choice.bmreq)
process.put(choice.bmreq.benchmark)
try:
return process.get(
config.max_autotune_subproc_result_timeout_seconds,
config.max_autotune_subproc_graceful_timeout_seconds,
config.max_autotune_subproc_terminate_timeout_seconds,
)
except queue.Empty:
except TimeoutError:
warnings.warn(
f"Timed out benchmarking choice '{choice}'. It will be ignored. "
"Please debug the root cause in case the choice can bring perf gains."
)
# Set to INF so this choice will be ignored
return float("inf")
except Exception:
warnings.warn(
f"Failed to benchmark choice '{choice}'. It will be ignored. "
"Please debug the root cause in case the choice can bring perf gains."
)
# set to INF so this choice will be ignored
# Set to INF so this choice will be ignored
return float("inf")
finally:
self.processes.put(process)
@ -406,6 +346,7 @@ class TuningProcessPool:
tuning_pool = TuningProcessPool()
atexit.register(tuning_pool.shutdown)
LayoutOrBuffer = Union[ir.Layout, ir.Buffer]
@ -563,21 +504,38 @@ class BenchmarkRequest:
return out
class TestBenchmarkRequest(BenchmarkRequest):
class _TestBenchmarkRequest(BenchmarkRequest):
"""
Supports unit testing. Defined in this file so that the TuningProcess
sub-process knows how to unpickle these objects.
Supports unit testing. Defined in this file instead of the test file so the
TuningProcess sub-process can unpickle these objects.
"""
def __init__(self, value: Optional[float] = None) -> None:
self.value = value
def __init__(
self,
result: float = 0.0,
device: Optional[int] = None,
sleep: Optional[float] = None,
exc: Optional[Exception] = None,
crash: bool = False,
):
self.result = result
self.device = device
self.sleep = sleep
self.exc = exc
self.crash = crash
def benchmark(
self, *input_tensors: torch.Tensor, output_tensor: Optional[torch.Tensor] = None
) -> float:
if self.value is None:
raise Exception("Failed to run") # noqa: TRY002
return self.value
if self.device is not None:
assert os.environ.get(CUDA_VISIBLE_DEVICES, None) == str(self.device)
if self.sleep:
time.sleep(self.sleep)
if self.exc:
raise self.exc
if self.crash:
sys.exit(1)
return self.result
class GPUDeviceBenchmarkMixin:

View File

@ -20,8 +20,8 @@ from typing_extensions import Never, ParamSpec
# justknobs, e.g., in the Triton compiler. For internal, the import installs
# functionality to destroy singletons before forking and re-enable them after.
import torch._thread_safe_fork # noqa: F401
from torch._inductor import config
from torch._inductor.compile_worker.utils import _async_compile_initializer
from torch._inductor.utils import get_ld_library_path
log = logging.getLogger(__name__)
@ -57,19 +57,6 @@ def _recv_msg(read_pipe: IO[bytes]) -> tuple[int, bytes]:
return job_id, data
def _get_ld_library_path() -> str:
path = os.environ.get("LD_LIBRARY_PATH", "")
if config.is_fbcode():
from libfb.py.parutil import get_runtime_path
runtime_path = get_runtime_path()
if runtime_path:
lib_path = os.path.join(runtime_path, "runtime", "lib")
path = os.pathsep.join([lib_path, path]) if path else lib_path
return path
class _SubprocExceptionInfo:
"""
Carries exception info from subprocesses across the wire. traceback
@ -150,7 +137,7 @@ class SubprocPool:
# creates the SubprocPool in the first place.
"TORCH_WARM_POOL": "0",
# Some internal usages need a modified LD_LIBRARY_PATH.
"LD_LIBRARY_PATH": _get_ld_library_path(),
"LD_LIBRARY_PATH": get_ld_library_path(),
},
pass_fds=(subproc_read_fd, subproc_write_fd),
)

View File

@ -422,12 +422,12 @@ autotune_in_subproc = os.environ.get("TORCHINDUCTOR_AUTOTUNE_IN_SUBPROC") == "1"
# The following three timeouts are applicable if autotune_in_subproc is True:
# Max time that a a valid benchmark result may take during autotuning
# Max time that a valid benchmark result may take during autotuning
max_autotune_subproc_result_timeout_seconds = 60.0
# Additional time we allow subprocesses to terminate gracefully after the timeout until we send a SIGTERM
max_autotune_subproc_graceful_timeout_seconds = 1.0
# Additional time that we grant after a SIGTERM until we do a hard SIGKILL of subprocesses
max_autotune_subproc_terminate_timeout_seconds = 2.0
# DEPRECATED. This setting is ignored.
max_autotune_subproc_graceful_timeout_seconds = 0.0
# DEPRECATED. This setting is ignored.
max_autotune_subproc_terminate_timeout_seconds = 0.0
# If autotuning in subprocess, whether to use multiple devices
autotune_multi_device = os.environ.get("TORCHINDUCTOR_AUTOTUNE_MULTI_DEVICE") == "1"

View File

@ -2811,3 +2811,16 @@ def is_cudagraph_unsafe_op(node: Operation) -> bool:
return True
return False
def get_ld_library_path() -> str:
path = os.environ.get("LD_LIBRARY_PATH", "")
if config.is_fbcode():
from libfb.py.parutil import get_runtime_path
runtime_path = get_runtime_path()
if runtime_path:
lib_path = os.path.join(runtime_path, "runtime", "lib")
path = os.pathsep.join([lib_path, path]) if path else lib_path
return path