mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Improve subproc autotuning implementation (#149700)
Summary: The primary change is to update the autotune-in-a-subproc implementation to avoid using multiprocessing spawn. Spawn (re)executes the toplevel script in the subproc, which can be problematic. The approach here is similar to Triton parallel compile: we Popen a subproc on a controlled entry point and communicate over pipes. That change drove a lot of refactoring in the TuningProcess class, so I took the opportunity to simplify some things, rename some methods, etc. One other notable change is around the timeout / kill approach. After a timeout, we were previously attempting to stop the subproc in three steps (graceful shutdown, sigkill if graceful fails, sigterm if sigkill fails). I'm gonna argue think that's not useful: 1) The graceful shutdown is never going to work unless the subproc happens to have just completed its task and is ready to receive the next command. 2) If we're going to kill the subproc, let's just take the most aggressive approach and move on as quickly as possible to restarting it rather than waiting to see if previous shutdown attempts succeeded. The only downside that I can find find is maybe a little log spew?, e.g., ` ResourceWarning: subprocess 2987680 is still running` List of changes: * Use Popen instead of spawn for the autotuning subprocess. * Introduced a new entry point `__autotune_main__.py` * Renamed some TuningProcess methods. For example `shutdown` makes more sense than `terminate` because the latter implies a forced kill. * Simplified the implementation around benchmarking timeout and how we kill the subproc after a timeout. * Deprecated the unused timeout configs in `_inductor/config.py` * Moved `get_ld_library_path` helper to a common utils file. * Added more unit tests for subproc crashes / timeouts / exceptions, etc. Test plan: * New unit tests * Also ran internally with all combinations of: build mode `opt` and `dev-nosan`, and `buck run` vs. executing the `.par` file directly. * Made sure the functionality to parallelize autotuning across different GPUs is working (it wasn't clear to me this was behaving the way we wanted it to). Differential Revision: [D71976971](https://our.internmc.facebook.com/intern/diff/D71976971) Pull Request resolved: https://github.com/pytorch/pytorch/pull/149700 Approved by: https://github.com/aorenste, https://github.com/jansel, https://github.com/eellison
This commit is contained in:
committed by
PyTorch MergeBot
parent
8b04364914
commit
266bd22b44
@ -3,6 +3,7 @@ import contextlib
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
import random
|
||||
import tempfile
|
||||
import unittest
|
||||
from typing import Callable, Optional
|
||||
@ -15,8 +16,9 @@ from torch._dynamo.testing import rand_strided, reset_rng_state
|
||||
from torch._dynamo.utils import same
|
||||
from torch._inductor import config
|
||||
from torch._inductor.autotune_process import (
|
||||
BenchmarkRequest,
|
||||
_TestBenchmarkRequest,
|
||||
CUDA_VISIBLE_DEVICES,
|
||||
TuningProcess,
|
||||
TuningProcessPool,
|
||||
)
|
||||
from torch._inductor.graph import GraphLowering
|
||||
@ -1196,35 +1198,6 @@ class TestMaxAutotuneRemoteCache(TestCase):
|
||||
self.assertEqual(global_stats.autotune_remote, Stats(2, 3, 2))
|
||||
|
||||
|
||||
class _TestBenchmarkRequest(BenchmarkRequest):
|
||||
def __init__(
|
||||
self, value: float, multi_device: bool, parent_visible_devices: Optional[str]
|
||||
) -> None:
|
||||
self.value = value
|
||||
self.multi_device = multi_device
|
||||
self.parent_visible_devices = parent_visible_devices
|
||||
|
||||
def benchmark(
|
||||
self, *input_tensors: torch.Tensor, output_tensor: Optional[torch.Tensor] = None
|
||||
) -> float:
|
||||
# Verify that the visible devices env var is set correctly. If multi-device
|
||||
# auto-tuning is disabled, the visible devices should be unmanipulated from
|
||||
# the parent process. If multi-device auto-tuning is enabled, the visible
|
||||
# devices should be a _single_ valid device number. Note that we can't perform
|
||||
# this validation directly from the test body because benchmarks execute in a
|
||||
# separate process. If the check fails, however, the test will detect the
|
||||
# failure by virtue of not receiving the expected result back.
|
||||
visible_devices = os.environ.get(CUDA_VISIBLE_DEVICES)
|
||||
if not self.multi_device:
|
||||
assert visible_devices == self.parent_visible_devices
|
||||
else:
|
||||
assert self.parent_visible_devices is not None
|
||||
valid_devices = self.parent_visible_devices.split(",")
|
||||
assert visible_devices in valid_devices
|
||||
|
||||
return self.value
|
||||
|
||||
|
||||
class _TestTritonTemplateCaller(TritonTemplateCaller):
|
||||
def __init__(self, bmreq: _TestBenchmarkRequest):
|
||||
self.bmreq = bmreq
|
||||
@ -1234,63 +1207,141 @@ class _TestTritonTemplateCaller(TritonTemplateCaller):
|
||||
|
||||
|
||||
class TestTuningProcess(TestCase):
|
||||
def check_healthy(self, p: TuningProcess, device: Optional[int] = None):
|
||||
result = random.random()
|
||||
bmreq = _TestBenchmarkRequest(result, device=device)
|
||||
p.put(bmreq.benchmark)
|
||||
self.assertEqual(p.get(), result)
|
||||
|
||||
def test_tuning_subproc_timeout(self):
|
||||
p = TuningProcess(None)
|
||||
|
||||
bmreq = _TestBenchmarkRequest(0, sleep=120)
|
||||
p.put(bmreq.benchmark)
|
||||
with self.assertRaises(TimeoutError):
|
||||
p.get(timeout=1.0)
|
||||
|
||||
# Make sure the TuningProcess is still usable after a timeout.
|
||||
self.check_healthy(p)
|
||||
p.shutdown()
|
||||
|
||||
def test_tuning_subproc_exception(self):
|
||||
p = TuningProcess(None)
|
||||
|
||||
bmreq = _TestBenchmarkRequest(0, exc=RuntimeError("Fail"))
|
||||
p.put(bmreq.benchmark)
|
||||
with self.assertRaises(RuntimeError):
|
||||
p.get()
|
||||
|
||||
# Make sure the TuningProcess is still usable after an exception.
|
||||
self.check_healthy(p)
|
||||
p.shutdown()
|
||||
|
||||
def test_tuning_subproc_crash(self):
|
||||
p = TuningProcess(None)
|
||||
|
||||
bmreq = _TestBenchmarkRequest(0, crash=True)
|
||||
p.put(bmreq.benchmark)
|
||||
with self.assertRaises(EOFError):
|
||||
p.get()
|
||||
|
||||
# Make sure the TuningProcess is still usable after a crash.
|
||||
self.check_healthy(p)
|
||||
p.shutdown()
|
||||
|
||||
def test_tuning_subproc_killed(self):
|
||||
p = TuningProcess(None)
|
||||
p.kill()
|
||||
self.check_healthy(p)
|
||||
p.shutdown()
|
||||
|
||||
def test_visible_devices(self):
|
||||
device_list = TuningProcessPool.get_device_list()
|
||||
for device in device_list:
|
||||
p = TuningProcess(device)
|
||||
self.check_healthy(p, device=device)
|
||||
p.shutdown()
|
||||
|
||||
|
||||
class TestTuningProcessPool(TestCase):
|
||||
# Use only one device/subprocess so we test the process restarts
|
||||
# and is usable after a crash.
|
||||
@config.patch({"autotune_multi_device": False})
|
||||
def test_tuning_pool_crash(self):
|
||||
# Use only one device/subprocess so we test the process restarts
|
||||
# and is usable after a "crash".
|
||||
with config.patch({"autotune_multi_device": False}):
|
||||
tuning_pool = TuningProcessPool()
|
||||
tuning_pool.initialize()
|
||||
tuning_pool = TuningProcessPool()
|
||||
tuning_pool.initialize()
|
||||
|
||||
# First force the tuning process to "crash" by setting a bogus
|
||||
# string for the expected visible devices.
|
||||
bmreq = _TestBenchmarkRequest(3.14, False, "invalid")
|
||||
choice = _TestTritonTemplateCaller(bmreq)
|
||||
# First force the tuning process to crash.
|
||||
bmreq = _TestBenchmarkRequest(0, crash=True)
|
||||
choice = _TestTritonTemplateCaller(bmreq)
|
||||
|
||||
timings = tuning_pool.benchmark([choice])
|
||||
self.assertTrue(choice in timings)
|
||||
self.assertEqual(timings[choice], float("inf"))
|
||||
|
||||
# Then send another request and make sure the sub-process
|
||||
# has restarted and is operational.
|
||||
bmreq = _TestBenchmarkRequest(3.14)
|
||||
choice = _TestTritonTemplateCaller(bmreq)
|
||||
|
||||
timings = tuning_pool.benchmark([choice])
|
||||
self.assertTrue(choice in timings)
|
||||
self.assertEqual(timings[choice], bmreq.result)
|
||||
|
||||
tuning_pool.shutdown()
|
||||
|
||||
@config.patch({"autotune_multi_device": False})
|
||||
def test_tuning_pool_timeout(self):
|
||||
tuning_pool = TuningProcessPool()
|
||||
tuning_pool.initialize()
|
||||
|
||||
# First force the tuning process to timeout.
|
||||
bmreq = _TestBenchmarkRequest(0, sleep=120)
|
||||
choice = _TestTritonTemplateCaller(bmreq)
|
||||
|
||||
with config.patch({"max_autotune_subproc_result_timeout_seconds": 1.0}):
|
||||
timings = tuning_pool.benchmark([choice])
|
||||
self.assertTrue(choice in timings)
|
||||
self.assertEqual(timings[choice], float("inf"))
|
||||
self.assertTrue(choice in timings)
|
||||
self.assertEqual(timings[choice], float("inf"))
|
||||
|
||||
# Then send another request and make sure the sub-process
|
||||
# has restarted and is operational. 'valid_devices' expected
|
||||
# to be None because autotune_multi_device is off.
|
||||
choice.bmreq.parent_visible_devices = os.environ.get(CUDA_VISIBLE_DEVICES)
|
||||
# Then send another request and make sure the sub-process
|
||||
# has restarted and is operational.
|
||||
bmreq = _TestBenchmarkRequest(3.14)
|
||||
choice = _TestTritonTemplateCaller(bmreq)
|
||||
|
||||
timings = tuning_pool.benchmark([choice])
|
||||
self.assertTrue(choice in timings)
|
||||
self.assertEqual(timings[choice], bmreq.value)
|
||||
timings = tuning_pool.benchmark([choice])
|
||||
self.assertTrue(choice in timings)
|
||||
self.assertEqual(timings[choice], bmreq.result)
|
||||
|
||||
tuning_pool.terminate()
|
||||
tuning_pool.shutdown()
|
||||
|
||||
# XPU have to enable XPU_VISIBLE_DEVICES to control devices visibility.
|
||||
@skipIfXpu
|
||||
@config.patch({"autotune_multi_device": True})
|
||||
def test_tuning_pool_multiple_devices(self):
|
||||
with config.patch({"autotune_multi_device": True}):
|
||||
# Adapt the test to the available devices (and whether CUDA_VISIBLE_DEVICES
|
||||
# is already set in the environment); use a subset of the available devices
|
||||
# to ensure only the subset are visible to the sub-processes.
|
||||
if CUDA_VISIBLE_DEVICES in os.environ:
|
||||
visible_devices = os.environ[CUDA_VISIBLE_DEVICES].split(",")
|
||||
else:
|
||||
visible_devices = [str(d) for d in range(torch.cuda.device_count())]
|
||||
|
||||
parent_visible_devices = ",".join(visible_devices[-2:])
|
||||
os.environ[CUDA_VISIBLE_DEVICES] = parent_visible_devices
|
||||
# Adapt the test to the available devices (and whether CUDA_VISIBLE_DEVICES
|
||||
# is already set in the environment); use a subset of the available devices
|
||||
# to ensure only the subset are visible to the sub-processes.
|
||||
if CUDA_VISIBLE_DEVICES in os.environ:
|
||||
visible_devices = os.environ[CUDA_VISIBLE_DEVICES].split(",")
|
||||
else:
|
||||
visible_devices = [str(d) for d in range(torch.cuda.device_count())]
|
||||
|
||||
cuda_visible_devices = ",".join(visible_devices[-2:])
|
||||
with unittest.mock.patch.dict(
|
||||
os.environ, {CUDA_VISIBLE_DEVICES: cuda_visible_devices}
|
||||
):
|
||||
tuning_pool = TuningProcessPool()
|
||||
tuning_pool.initialize()
|
||||
|
||||
choice1 = _TestTritonTemplateCaller(
|
||||
_TestBenchmarkRequest(3.14, True, parent_visible_devices),
|
||||
)
|
||||
choice2 = _TestTritonTemplateCaller(
|
||||
_TestBenchmarkRequest(2.718, True, parent_visible_devices),
|
||||
)
|
||||
choice1 = _TestTritonTemplateCaller(_TestBenchmarkRequest(3.14))
|
||||
choice2 = _TestTritonTemplateCaller(_TestBenchmarkRequest(2.718))
|
||||
|
||||
timings = tuning_pool.benchmark([choice1, choice2])
|
||||
self.assertEqual(timings[choice1], choice1.bmreq.value)
|
||||
self.assertEqual(timings[choice2], choice2.bmreq.value)
|
||||
timings = tuning_pool.benchmark([choice1, choice2])
|
||||
self.assertEqual(timings[choice1], choice1.bmreq.result)
|
||||
self.assertEqual(timings[choice2], choice2.bmreq.result)
|
||||
|
||||
tuning_pool.terminate()
|
||||
tuning_pool.shutdown()
|
||||
|
||||
|
||||
@instantiate_parametrized_tests
|
||||
|
33
torch/_inductor/__autotune_main__.py
Normal file
33
torch/_inductor/__autotune_main__.py
Normal file
@ -0,0 +1,33 @@
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
|
||||
from torch._inductor.autotune_process import TuningProcess
|
||||
from torch._inductor.compile_worker.utils import _async_compile_initializer
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--parent", type=int)
|
||||
parser.add_argument("--read-fd", type=int)
|
||||
parser.add_argument("--write-fd", type=int)
|
||||
args = parser.parse_args()
|
||||
read_pipe = os.fdopen(args.read_fd, "rb")
|
||||
write_pipe = os.fdopen(args.write_fd, "wb")
|
||||
|
||||
try:
|
||||
# Ensures the subprocess exits if the parent crashes:
|
||||
_async_compile_initializer(args.parent)
|
||||
TuningProcess.process_main(read_pipe, write_pipe)
|
||||
except Exception:
|
||||
log.exception("Uncaught exception in autotune subprocess")
|
||||
finally:
|
||||
read_pipe.close()
|
||||
write_pipe.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -1,23 +1,26 @@
|
||||
# mypy: allow-untyped-defs
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import atexit
|
||||
import ctypes
|
||||
import dataclasses
|
||||
import functools
|
||||
import logging
|
||||
import os
|
||||
import pickle
|
||||
import queue
|
||||
import selectors
|
||||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
import warnings
|
||||
from collections.abc import Iterable, Sequence
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from ctypes import byref, c_size_t, c_void_p, CDLL
|
||||
from typing import Any, Callable, Optional, TYPE_CHECKING, Union
|
||||
from typing import Any, Callable, IO, Optional, TYPE_CHECKING, Union
|
||||
|
||||
import torch
|
||||
import torch._inductor.async_compile # noqa: F401 required to warm up AsyncCompile pools
|
||||
from torch import multiprocessing
|
||||
from torch._dynamo.device_interface import get_interface_for_device
|
||||
from torch._dynamo.testing import rand_strided
|
||||
from torch._inductor import ir
|
||||
@ -28,14 +31,12 @@ from torch._inductor.codecache import (
|
||||
get_hash,
|
||||
PyCodeCache,
|
||||
)
|
||||
from torch._inductor.utils import get_gpu_type, is_gpu
|
||||
from torch._inductor.utils import get_gpu_type, get_ld_library_path, is_gpu
|
||||
from torch._logging import getArtifactLogger
|
||||
from torch.utils._ordered_set import OrderedSet
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from multiprocessing.process import BaseProcess
|
||||
from multiprocessing.queues import Queue
|
||||
from types import ModuleType
|
||||
|
||||
from torch._inductor.select_algorithm import TritonTemplateCaller
|
||||
@ -49,236 +50,183 @@ from .virtualized import V
|
||||
|
||||
|
||||
CUDA_VISIBLE_DEVICES = "CUDA_VISIBLE_DEVICES"
|
||||
EXIT_HANDLER_REGISTERED = False
|
||||
|
||||
autotuning_log = getArtifactLogger(__name__, "autotuning")
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Used to synchronize between parent and child processes
|
||||
class Ping:
|
||||
pass
|
||||
|
||||
|
||||
class Pong:
|
||||
pass
|
||||
|
||||
|
||||
class NonzeroWorkspaceNotSupportedError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def set_cuda_visible_device(device: Optional[int]):
|
||||
"""
|
||||
Context manager to set the CUDA_VISIBLE_DEVICES environment variable to the
|
||||
specified single device. If device is None, don't manipulate the environment.
|
||||
"""
|
||||
if device is None:
|
||||
yield
|
||||
return
|
||||
|
||||
current = os.environ.get(CUDA_VISIBLE_DEVICES)
|
||||
os.environ[CUDA_VISIBLE_DEVICES] = str(device)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
if current is None:
|
||||
del os.environ[CUDA_VISIBLE_DEVICES]
|
||||
else:
|
||||
os.environ[CUDA_VISIBLE_DEVICES] = current
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class TuningProcess:
|
||||
"""
|
||||
Abstraction for launching a helper process to benchmark kernels. Spawns
|
||||
the parent process and uses multiprocessing queues to send benchmark
|
||||
requests and return results.
|
||||
Class to launch and interact with a benchmarking subprocess.
|
||||
"""
|
||||
|
||||
device: Optional[int] = None
|
||||
process: Optional[BaseProcess] = None
|
||||
request_queue: Optional[Queue[Any]] = None
|
||||
response_queue: Optional[Queue[Any]] = None
|
||||
|
||||
@staticmethod
|
||||
def process_main(
|
||||
request_queue: Queue[Any],
|
||||
response_queue: Queue[Any],
|
||||
) -> None:
|
||||
def process_main(read_pipe: IO[bytes], write_pipe: IO[bytes]) -> None:
|
||||
"""
|
||||
Entry point for the child process.
|
||||
"""
|
||||
autotuning_log.debug(
|
||||
"Entering TuningProcess child. Visible devices = %s",
|
||||
"Started autotune subprocess %s. Visible devices: %s",
|
||||
os.getpid(),
|
||||
os.environ.get(CUDA_VISIBLE_DEVICES),
|
||||
)
|
||||
|
||||
def workloop():
|
||||
while True:
|
||||
job = TuningProcess.recv(read_pipe)
|
||||
if job is None:
|
||||
# None is a sentinel for the child to shut down
|
||||
break
|
||||
try:
|
||||
result = job()
|
||||
except Exception as e:
|
||||
result = e
|
||||
TuningProcess.send(result, write_pipe)
|
||||
|
||||
try:
|
||||
TuningProcess.workloop(request_queue, response_queue)
|
||||
except Exception:
|
||||
autotuning_log.exception("Exception in TuningProcess")
|
||||
workloop()
|
||||
except EOFError:
|
||||
# The parent closed the pipe
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def workloop(request_queue: Queue[Any], response_queue: Queue[Any]) -> None:
|
||||
"""
|
||||
Work loop for the benchmarking subprocess.
|
||||
"""
|
||||
while True:
|
||||
obj = request_queue.get()
|
||||
def send(obj: Any, write_pipe: IO[bytes]) -> None:
|
||||
pickle.dump(obj, write_pipe)
|
||||
write_pipe.flush()
|
||||
|
||||
if obj is None:
|
||||
break # None is a sentinel for the child to terminate
|
||||
elif isinstance(obj, Ping):
|
||||
response_queue.put(Pong())
|
||||
elif isinstance(obj, BenchmarkRequest):
|
||||
response_queue.put(obj.benchmark())
|
||||
else:
|
||||
raise RuntimeError(f"Invalid request type {type(obj)}")
|
||||
@staticmethod
|
||||
def recv(read_pipe: IO[bytes]) -> Any:
|
||||
return pickle.load(read_pipe)
|
||||
|
||||
def valid(self) -> bool:
|
||||
def __init__(self, device: Optional[int]):
|
||||
self.device = device
|
||||
self.start()
|
||||
|
||||
def start(self):
|
||||
"""
|
||||
True if the sub-process has been initialized.
|
||||
Start the benchmarking subprocess.
|
||||
"""
|
||||
return (
|
||||
self.process is not None
|
||||
and self.request_queue is not None
|
||||
and self.response_queue is not None
|
||||
entry = os.path.join(os.path.dirname(__file__), "__autotune_main__.py")
|
||||
|
||||
subproc_read_fd, write_fd = os.pipe()
|
||||
read_fd, subproc_write_fd = os.pipe()
|
||||
self.write_pipe = os.fdopen(write_fd, "wb")
|
||||
self.read_pipe = os.fdopen(read_fd, "rb")
|
||||
|
||||
self.selector = selectors.DefaultSelector()
|
||||
self.selector.register(self.read_pipe, selectors.EVENT_READ)
|
||||
|
||||
cmd = [
|
||||
sys.executable,
|
||||
entry,
|
||||
f"--parent={os.getpid()}",
|
||||
f"--read-fd={str(subproc_read_fd)}",
|
||||
f"--write-fd={str(subproc_write_fd)}",
|
||||
]
|
||||
extra_env = {
|
||||
# We need to set the PYTHONPATH so the subprocess can find torch.
|
||||
"PYTHONPATH": os.pathsep.join(sys.path),
|
||||
# We shouldn't be using the Triton async compile subprocess pool,
|
||||
# but as a precaution set the env var that disables its creation.
|
||||
"TORCH_WARM_POOL": "0",
|
||||
# Some internal usages need a modified LD_LIBRARY_PATH.
|
||||
"LD_LIBRARY_PATH": get_ld_library_path(),
|
||||
}
|
||||
if self.device is not None:
|
||||
extra_env[CUDA_VISIBLE_DEVICES] = str(self.device)
|
||||
self.process = subprocess.Popen(
|
||||
cmd,
|
||||
env={**os.environ, **extra_env},
|
||||
pass_fds=(subproc_read_fd, subproc_write_fd),
|
||||
)
|
||||
os.close(subproc_read_fd)
|
||||
os.close(subproc_write_fd)
|
||||
|
||||
def clear(self) -> None:
|
||||
self.running = True
|
||||
|
||||
def alive(self) -> bool:
|
||||
"""
|
||||
Reset to an uninitialized state.
|
||||
True if the subprocess is still running.
|
||||
"""
|
||||
self.process = self.request_queue = self.response_queue = None
|
||||
return self.running and self.process.poll() is None
|
||||
|
||||
def initialize(self) -> None:
|
||||
"""
|
||||
Create child process, request/response queues, and do the warm up.
|
||||
Set the environment to make only the provided GPU device visible
|
||||
to the process.
|
||||
"""
|
||||
if self.valid():
|
||||
return
|
||||
|
||||
# cuda runtime does not work with "fork", use "spawn" to start processes.
|
||||
ctx = multiprocessing.get_context("spawn")
|
||||
self.request_queue = ctx.Queue()
|
||||
self.response_queue = ctx.Queue()
|
||||
|
||||
self.process = ctx.Process(
|
||||
target=self.process_main,
|
||||
args=(
|
||||
self.request_queue,
|
||||
self.response_queue,
|
||||
),
|
||||
)
|
||||
assert self.process is not None
|
||||
with set_cuda_visible_device(self.device):
|
||||
self.process.start()
|
||||
|
||||
def put(self, obj: Any) -> None:
|
||||
def put(self, req: Any) -> None:
|
||||
"""
|
||||
Push a work item to the child process.
|
||||
"""
|
||||
# In case of a prior crash, ensure the subprocess is running
|
||||
self.initialize()
|
||||
assert self.request_queue is not None
|
||||
self.request_queue.put(obj)
|
||||
if not self.alive():
|
||||
self.start()
|
||||
TuningProcess.send(req, self.write_pipe)
|
||||
|
||||
def get(
|
||||
self, result_timeout=120.0, graceful_timeout=3.0, terminate_timeout=1.0
|
||||
) -> Any:
|
||||
def get(self, timeout: float = 120.0) -> Any:
|
||||
"""
|
||||
Get a response from the child process. Raises queue.Empty on timeout
|
||||
or if the process dies.
|
||||
|
||||
This method is (so far) only used by TuningProcessPool, where torch._inductor.config entries are being used
|
||||
to populate the timeouts:
|
||||
|
||||
Arguments:
|
||||
|
||||
@param result_timeout: Timeout in seconds, defaults to 120.0 or to
|
||||
config.max_autotune_subproc_result_timeout_seconds when called by TuningProcessPool
|
||||
@param graceful_timeout: Timeout in seconds to allow graceful shutdown (SIGTERM is sent after this time).
|
||||
Defaults to 3.0 or to config.max_autotune_subproc_graceful_timeout_seconds
|
||||
@param terminate_timeout: Timeout in seconds after SIGTERM, until we send SIGKILL if the process
|
||||
remains alive. Defaults to 1.0 or to
|
||||
config.max_autotune_subproc_terminate_timeout_seconds.
|
||||
Returns:
|
||||
A response from the child process (Any type)
|
||||
Get a response from the child process. Raises TimeoutError on timeout;
|
||||
raises EOFError if the subprocess crashes.
|
||||
"""
|
||||
assert self.process is not None
|
||||
assert self.response_queue is not None
|
||||
while True:
|
||||
try:
|
||||
remaining_timeout = result_timeout
|
||||
res = None
|
||||
while remaining_timeout is not None and remaining_timeout >= 1.0:
|
||||
remaining_timeout -= 0.5
|
||||
try:
|
||||
res = self.response_queue.get(timeout=0.5)
|
||||
break
|
||||
except queue.Empty:
|
||||
if not self.process.is_alive():
|
||||
raise # is being caught a few lines below
|
||||
if res is None:
|
||||
res = self.response_queue.get(timeout=remaining_timeout)
|
||||
return res
|
||||
except queue.Empty:
|
||||
status = self.process.exitcode
|
||||
if status is None:
|
||||
self.kill(
|
||||
graceful_timeout=graceful_timeout,
|
||||
terminate_timeout=terminate_timeout,
|
||||
)
|
||||
else:
|
||||
# child process crashed
|
||||
self.clear()
|
||||
raise
|
||||
try:
|
||||
if not self.selector.select(timeout):
|
||||
raise TimeoutError(f"Timeout in autotune subprocess {self.process.pid}")
|
||||
result = TuningProcess.recv(self.read_pipe)
|
||||
except TimeoutError:
|
||||
self.kill()
|
||||
raise
|
||||
except EOFError:
|
||||
# The subprocess crashed
|
||||
self.close()
|
||||
raise
|
||||
except Exception:
|
||||
autotuning_log.exception(
|
||||
"Unexpected exception in autotune subprocess %s", self.process.pid
|
||||
)
|
||||
self.kill()
|
||||
raise
|
||||
|
||||
def terminate(self) -> None:
|
||||
if isinstance(result, Exception):
|
||||
raise result
|
||||
return result
|
||||
|
||||
def shutdown(self, wait: bool = True) -> None:
|
||||
"""
|
||||
Signal the child process to terminate.
|
||||
Signal the child process to shut down gracefully.
|
||||
"""
|
||||
if self.valid():
|
||||
assert self.process is not None
|
||||
assert self.request_queue is not None
|
||||
self.request_queue.put(None)
|
||||
if self.alive():
|
||||
TuningProcess.send(None, self.write_pipe)
|
||||
if wait:
|
||||
self.wait()
|
||||
|
||||
def wait(self) -> None:
|
||||
"""
|
||||
Wait for the child process to exit.
|
||||
"""
|
||||
if self.process is not None:
|
||||
self.process.join()
|
||||
self.clear()
|
||||
if self.alive():
|
||||
self.process.wait()
|
||||
self.close()
|
||||
|
||||
def kill(self, graceful_timeout=5.0, terminate_timeout=1.0) -> None:
|
||||
# Tries to kill the process, using a graceful_timeout in which the process
|
||||
# is allowed to exit gracefully. If the process is still alive,
|
||||
# it will be terminated. If that is not sufficient to end it
|
||||
# within terminate_timeout seconds, it will be killed.
|
||||
if self.process is not None:
|
||||
self.terminate()
|
||||
self.process.join(timeout=graceful_timeout)
|
||||
if self.process.is_alive():
|
||||
autotuning_log.warning(
|
||||
"Sending SIGTERM to process with PID %d",
|
||||
self.process.pid,
|
||||
)
|
||||
self.process.terminate()
|
||||
self.process.join(timeout=terminate_timeout)
|
||||
if self.process.is_alive():
|
||||
autotuning_log.error(
|
||||
"Sending SIGKILL to process with PID %d",
|
||||
self.process.pid,
|
||||
)
|
||||
self.process.kill() # This should definitely end the process
|
||||
self.clear()
|
||||
def close(self) -> None:
|
||||
"""
|
||||
Close resources.
|
||||
"""
|
||||
self.selector.close()
|
||||
self.read_pipe.close()
|
||||
self.write_pipe.close()
|
||||
self.running = False
|
||||
|
||||
def kill(self) -> None:
|
||||
"""
|
||||
Send a SIGKILL to the child process.
|
||||
"""
|
||||
if self.alive():
|
||||
autotuning_log.error(
|
||||
"Sending SIGKILL to autotune subprocess %d",
|
||||
self.process.pid,
|
||||
)
|
||||
self.process.kill()
|
||||
self.close()
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class TuningProcessPool:
|
||||
"""
|
||||
Maintains a pool of TuningProcesses to benchmark kernels in parallel
|
||||
@ -286,8 +234,9 @@ class TuningProcessPool:
|
||||
set the sub-process environment to make only that device visible.
|
||||
"""
|
||||
|
||||
processes: Optional[queue.Queue[TuningProcess]] = None
|
||||
executor: Optional[ThreadPoolExecutor] = None
|
||||
def __init__(self) -> None:
|
||||
self.processes: Optional[queue.Queue[TuningProcess]] = None
|
||||
self.executor: Optional[ThreadPoolExecutor] = None
|
||||
|
||||
def initialize(self) -> None:
|
||||
"""
|
||||
@ -298,35 +247,21 @@ class TuningProcessPool:
|
||||
return
|
||||
|
||||
devices = self.get_device_list()
|
||||
log.debug("Sub-process autotune device list: %s", devices)
|
||||
autotuning_log.debug("Sub-process autotune device list: %s", devices)
|
||||
|
||||
# Launch the child processes and push a msg to "warm up"
|
||||
# Launch the child processes.
|
||||
self.processes = queue.Queue()
|
||||
for device in devices:
|
||||
p = TuningProcess(device=device)
|
||||
p.initialize()
|
||||
p.put(Ping())
|
||||
self.processes.put(p)
|
||||
|
||||
# Wait for the initialization to finish
|
||||
for p in self.processes.queue:
|
||||
assert isinstance(p.get(result_timeout=None), Pong)
|
||||
|
||||
# Use a thread pool to manage distributing work to the subprocesses.
|
||||
# Threads block on an available process, so it makes sense to match
|
||||
# the number of threads with the number of devices.
|
||||
self.executor = ThreadPoolExecutor(max_workers=len(devices))
|
||||
|
||||
# Register the exit handler for the parent process so it will terminate
|
||||
# the child processes.
|
||||
global EXIT_HANDLER_REGISTERED
|
||||
if not EXIT_HANDLER_REGISTERED:
|
||||
EXIT_HANDLER_REGISTERED = True
|
||||
import atexit
|
||||
|
||||
atexit.register(self.terminate)
|
||||
|
||||
def get_device_list(self) -> Sequence[Optional[int]]:
|
||||
@staticmethod
|
||||
def get_device_list() -> Sequence[Optional[int]]:
|
||||
"""
|
||||
Gather the list of devices to be used in the pool.
|
||||
"""
|
||||
@ -346,9 +281,9 @@ class TuningProcessPool:
|
||||
|
||||
return list(range(count))
|
||||
|
||||
def terminate(self) -> None:
|
||||
def shutdown(self) -> None:
|
||||
"""
|
||||
Signal all child processes to terminate.
|
||||
Signal all child processes to exit.
|
||||
"""
|
||||
if self.executor is not None:
|
||||
self.executor.shutdown()
|
||||
@ -356,7 +291,7 @@ class TuningProcessPool:
|
||||
|
||||
if self.processes is not None:
|
||||
for p in self.processes.queue:
|
||||
p.terminate()
|
||||
p.shutdown(wait=False)
|
||||
for p in self.processes.queue:
|
||||
p.wait()
|
||||
self.processes = None
|
||||
@ -371,19 +306,24 @@ class TuningProcessPool:
|
||||
assert self.processes is not None
|
||||
|
||||
process = self.processes.get()
|
||||
process.put(choice.bmreq)
|
||||
process.put(choice.bmreq.benchmark)
|
||||
try:
|
||||
return process.get(
|
||||
config.max_autotune_subproc_result_timeout_seconds,
|
||||
config.max_autotune_subproc_graceful_timeout_seconds,
|
||||
config.max_autotune_subproc_terminate_timeout_seconds,
|
||||
)
|
||||
except queue.Empty:
|
||||
except TimeoutError:
|
||||
warnings.warn(
|
||||
f"Timed out benchmarking choice '{choice}'. It will be ignored. "
|
||||
"Please debug the root cause in case the choice can bring perf gains."
|
||||
)
|
||||
# Set to INF so this choice will be ignored
|
||||
return float("inf")
|
||||
except Exception:
|
||||
warnings.warn(
|
||||
f"Failed to benchmark choice '{choice}'. It will be ignored. "
|
||||
"Please debug the root cause in case the choice can bring perf gains."
|
||||
)
|
||||
# set to INF so this choice will be ignored
|
||||
# Set to INF so this choice will be ignored
|
||||
return float("inf")
|
||||
finally:
|
||||
self.processes.put(process)
|
||||
@ -406,6 +346,7 @@ class TuningProcessPool:
|
||||
|
||||
|
||||
tuning_pool = TuningProcessPool()
|
||||
atexit.register(tuning_pool.shutdown)
|
||||
|
||||
|
||||
LayoutOrBuffer = Union[ir.Layout, ir.Buffer]
|
||||
@ -563,21 +504,38 @@ class BenchmarkRequest:
|
||||
return out
|
||||
|
||||
|
||||
class TestBenchmarkRequest(BenchmarkRequest):
|
||||
class _TestBenchmarkRequest(BenchmarkRequest):
|
||||
"""
|
||||
Supports unit testing. Defined in this file so that the TuningProcess
|
||||
sub-process knows how to unpickle these objects.
|
||||
Supports unit testing. Defined in this file instead of the test file so the
|
||||
TuningProcess sub-process can unpickle these objects.
|
||||
"""
|
||||
|
||||
def __init__(self, value: Optional[float] = None) -> None:
|
||||
self.value = value
|
||||
def __init__(
|
||||
self,
|
||||
result: float = 0.0,
|
||||
device: Optional[int] = None,
|
||||
sleep: Optional[float] = None,
|
||||
exc: Optional[Exception] = None,
|
||||
crash: bool = False,
|
||||
):
|
||||
self.result = result
|
||||
self.device = device
|
||||
self.sleep = sleep
|
||||
self.exc = exc
|
||||
self.crash = crash
|
||||
|
||||
def benchmark(
|
||||
self, *input_tensors: torch.Tensor, output_tensor: Optional[torch.Tensor] = None
|
||||
) -> float:
|
||||
if self.value is None:
|
||||
raise Exception("Failed to run") # noqa: TRY002
|
||||
return self.value
|
||||
if self.device is not None:
|
||||
assert os.environ.get(CUDA_VISIBLE_DEVICES, None) == str(self.device)
|
||||
if self.sleep:
|
||||
time.sleep(self.sleep)
|
||||
if self.exc:
|
||||
raise self.exc
|
||||
if self.crash:
|
||||
sys.exit(1)
|
||||
return self.result
|
||||
|
||||
|
||||
class GPUDeviceBenchmarkMixin:
|
||||
|
@ -20,8 +20,8 @@ from typing_extensions import Never, ParamSpec
|
||||
# justknobs, e.g., in the Triton compiler. For internal, the import installs
|
||||
# functionality to destroy singletons before forking and re-enable them after.
|
||||
import torch._thread_safe_fork # noqa: F401
|
||||
from torch._inductor import config
|
||||
from torch._inductor.compile_worker.utils import _async_compile_initializer
|
||||
from torch._inductor.utils import get_ld_library_path
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
@ -57,19 +57,6 @@ def _recv_msg(read_pipe: IO[bytes]) -> tuple[int, bytes]:
|
||||
return job_id, data
|
||||
|
||||
|
||||
def _get_ld_library_path() -> str:
|
||||
path = os.environ.get("LD_LIBRARY_PATH", "")
|
||||
if config.is_fbcode():
|
||||
from libfb.py.parutil import get_runtime_path
|
||||
|
||||
runtime_path = get_runtime_path()
|
||||
if runtime_path:
|
||||
lib_path = os.path.join(runtime_path, "runtime", "lib")
|
||||
path = os.pathsep.join([lib_path, path]) if path else lib_path
|
||||
|
||||
return path
|
||||
|
||||
|
||||
class _SubprocExceptionInfo:
|
||||
"""
|
||||
Carries exception info from subprocesses across the wire. traceback
|
||||
@ -150,7 +137,7 @@ class SubprocPool:
|
||||
# creates the SubprocPool in the first place.
|
||||
"TORCH_WARM_POOL": "0",
|
||||
# Some internal usages need a modified LD_LIBRARY_PATH.
|
||||
"LD_LIBRARY_PATH": _get_ld_library_path(),
|
||||
"LD_LIBRARY_PATH": get_ld_library_path(),
|
||||
},
|
||||
pass_fds=(subproc_read_fd, subproc_write_fd),
|
||||
)
|
||||
|
@ -422,12 +422,12 @@ autotune_in_subproc = os.environ.get("TORCHINDUCTOR_AUTOTUNE_IN_SUBPROC") == "1"
|
||||
|
||||
# The following three timeouts are applicable if autotune_in_subproc is True:
|
||||
|
||||
# Max time that a a valid benchmark result may take during autotuning
|
||||
# Max time that a valid benchmark result may take during autotuning
|
||||
max_autotune_subproc_result_timeout_seconds = 60.0
|
||||
# Additional time we allow subprocesses to terminate gracefully after the timeout until we send a SIGTERM
|
||||
max_autotune_subproc_graceful_timeout_seconds = 1.0
|
||||
# Additional time that we grant after a SIGTERM until we do a hard SIGKILL of subprocesses
|
||||
max_autotune_subproc_terminate_timeout_seconds = 2.0
|
||||
# DEPRECATED. This setting is ignored.
|
||||
max_autotune_subproc_graceful_timeout_seconds = 0.0
|
||||
# DEPRECATED. This setting is ignored.
|
||||
max_autotune_subproc_terminate_timeout_seconds = 0.0
|
||||
|
||||
# If autotuning in subprocess, whether to use multiple devices
|
||||
autotune_multi_device = os.environ.get("TORCHINDUCTOR_AUTOTUNE_MULTI_DEVICE") == "1"
|
||||
|
@ -2811,3 +2811,16 @@ def is_cudagraph_unsafe_op(node: Operation) -> bool:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def get_ld_library_path() -> str:
|
||||
path = os.environ.get("LD_LIBRARY_PATH", "")
|
||||
if config.is_fbcode():
|
||||
from libfb.py.parutil import get_runtime_path
|
||||
|
||||
runtime_path = get_runtime_path()
|
||||
if runtime_path:
|
||||
lib_path = os.path.join(runtime_path, "runtime", "lib")
|
||||
path = os.pathsep.join([lib_path, path]) if path else lib_path
|
||||
|
||||
return path
|
||||
|
Reference in New Issue
Block a user