mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Pull Request resolved: https://github.com/pytorch/pytorch/pull/156312 Approved by: https://github.com/albanD
322 lines
12 KiB
Python
322 lines
12 KiB
Python
# mypy: disallow-untyped-defs
|
|
|
|
import functools
|
|
import logging
|
|
import os
|
|
import re
|
|
import subprocess
|
|
import time
|
|
from collections.abc import Sequence
|
|
from threading import Lock
|
|
from timeit import default_timer as timer
|
|
from typing import Any, Callable, Optional, TypeVar
|
|
from typing_extensions import ParamSpec
|
|
|
|
|
|
logger = logging.getLogger("strobelight_function_profiler")
|
|
|
|
console_handler = logging.StreamHandler()
|
|
formatter = logging.Formatter(
|
|
"%(name)s, line %(lineno)d, %(asctime)s, %(levelname)s: %(message)s"
|
|
)
|
|
console_handler.setFormatter(formatter)
|
|
|
|
logger.addHandler(console_handler)
|
|
logger.setLevel(logging.INFO)
|
|
logger.propagate = False
|
|
|
|
_P = ParamSpec("_P")
|
|
_R = TypeVar("_R")
|
|
|
|
|
|
class StrobelightCLIProfilerError(Exception):
|
|
"""
|
|
Raised when an error happens during strobelight profiling
|
|
"""
|
|
|
|
|
|
def _pid_namespace_link(pid: Optional[int] = None) -> str:
|
|
"""Returns the link to the process's namespace, example: pid:[4026531836]"""
|
|
PID_NAMESPACE_PATH = "/proc/{}/ns/pid"
|
|
pid = pid or os.getpid()
|
|
return os.readlink(PID_NAMESPACE_PATH.format(pid))
|
|
|
|
|
|
def _pid_namespace(pid: Optional[int] = None) -> int:
|
|
"""Returns the process's namespace id"""
|
|
pid = pid or os.getpid()
|
|
link = _pid_namespace_link(pid)
|
|
return int(link[link.find("[") + 1 : -1])
|
|
|
|
|
|
def _command_to_string(command: Sequence[str]) -> str:
|
|
return " ".join(command)
|
|
|
|
|
|
class StrobelightCLIFunctionProfiler:
|
|
"""
|
|
Note: this is a Meta only tool.
|
|
|
|
StrobelightCLIFunctionProfiler can be used to profile a python function and
|
|
generate a strobelight link with the results. It works on meta servers but
|
|
does not requires an fbcode target.
|
|
When stop_at_error is false(default), error during profiling does not prevent
|
|
the work function from running.
|
|
|
|
Check function_profiler_example.py for an example.
|
|
"""
|
|
|
|
# This lock is used to make sure only one thread is running the profiler at any point.
|
|
_lock = Lock()
|
|
|
|
def __init__(
|
|
self,
|
|
*,
|
|
stop_at_error: bool = False,
|
|
max_profile_duration_sec: int = 60 * 10,
|
|
sample_each: float = 1e7, # sample each sample_each cycles.
|
|
run_user_name: str = "pytorch-strobelight-ondemand",
|
|
timeout_wait_for_running_sec: int = 60,
|
|
timeout_wait_for_finished_sec: int = 60,
|
|
recorded_env_variables: Optional[list[str]] = None,
|
|
sample_tags: Optional[list[str]] = None,
|
|
stack_max_len: int = 127,
|
|
async_stack_max_len: int = 127,
|
|
):
|
|
self.stop_at_error = stop_at_error
|
|
self.max_profile_duration_sec = max_profile_duration_sec
|
|
self.sample_each = sample_each
|
|
self.run_user_name = run_user_name
|
|
self.timeout_wait_for_running_sec = timeout_wait_for_running_sec
|
|
self.timeout_wait_for_finished_sec = timeout_wait_for_finished_sec
|
|
# Results of the most recent run.
|
|
# Tracks the strobelight run id of the most recent run
|
|
self.current_run_id: Optional[int] = None
|
|
self.profile_result: Optional[list[str]] = None
|
|
self.sample_tags = sample_tags
|
|
|
|
def _run_async(self) -> None:
|
|
processId = os.getpid()
|
|
namespace = _pid_namespace(processId)
|
|
command = [
|
|
"strobeclient",
|
|
"run",
|
|
"--profiler",
|
|
"pyperf",
|
|
"--event",
|
|
"cycles",
|
|
"--async",
|
|
"--sample-interval",
|
|
f"{int(self.sample_each)}",
|
|
"--duration-ms",
|
|
f"{int(self.max_profile_duration_sec * 1000)}",
|
|
"--pid",
|
|
f"{namespace}:{processId}",
|
|
]
|
|
|
|
if self.sample_tags:
|
|
command.append("--sample-tags")
|
|
command.append(",".join(self.sample_tags))
|
|
|
|
logger.debug("running command: %s", _command_to_string(command))
|
|
result = subprocess.run(command, capture_output=True)
|
|
output = result.stderr.decode("utf-8")
|
|
logger.debug("output:\n{%s}", output)
|
|
|
|
if result.returncode != 0:
|
|
raise StrobelightCLIProfilerError(
|
|
f"failed to start strobelight profiling, error in run_async:{output}"
|
|
)
|
|
|
|
if match := re.search(r"INFO Run Id: (-?\d+)", output):
|
|
self.current_run_id = int(match.group(1))
|
|
return
|
|
|
|
raise StrobelightCLIProfilerError(
|
|
f"failed to start strobelight profiling, unexpected result {output}"
|
|
)
|
|
|
|
def _wait_for_running(self, counter: int = 0) -> None:
|
|
if counter > 20:
|
|
raise StrobelightCLIProfilerError(
|
|
"wait_for_running called more than 20 times"
|
|
)
|
|
|
|
command = ["strobeclient", "getRunStatus", "--run-id", f"{self.current_run_id}"]
|
|
logger.debug("running command: %s", _command_to_string(command))
|
|
result = subprocess.run(command, capture_output=True)
|
|
output = result.stderr.decode("utf-8")
|
|
logger.debug("output:\n{%s}", output)
|
|
|
|
if result.returncode != 0:
|
|
raise StrobelightCLIProfilerError(
|
|
f"failed to start strobelight profiling, error in wait_for_running:{output}"
|
|
)
|
|
|
|
if match := re.search("Profile run status: (.*)", output):
|
|
current_status = match.group(1)
|
|
if current_status == "RUNNING":
|
|
return
|
|
elif current_status == "PREPARING":
|
|
time.sleep(10)
|
|
self._wait_for_running(counter + 1)
|
|
return
|
|
else:
|
|
raise StrobelightCLIProfilerError(f"unexpected {current_status} phase")
|
|
|
|
raise StrobelightCLIProfilerError(f"unexpected output\n: {output} ")
|
|
|
|
def _stop_run(self) -> None:
|
|
command = ["strobeclient", "stopRun", "--run-id", str(self.current_run_id)]
|
|
logger.debug("running command: %s", _command_to_string(command))
|
|
result = subprocess.run(command, capture_output=True)
|
|
output = result.stderr.decode("utf-8")
|
|
logger.debug("output:\n{%s}", output)
|
|
|
|
if result.returncode != 0:
|
|
raise StrobelightCLIProfilerError(
|
|
f"failed to stop strobelight profiling, return code is not 0 :{output}"
|
|
)
|
|
|
|
if match := re.search("INFO ::1:(.*)", output):
|
|
current_status = match.group(1)
|
|
if current_status.__contains__("Success!"):
|
|
return
|
|
else:
|
|
raise StrobelightCLIProfilerError(
|
|
f"failed to stop strobelight profiling, got {current_status} result"
|
|
)
|
|
|
|
raise StrobelightCLIProfilerError(f"unexpected output\n: {output} ")
|
|
|
|
def _get_results(self) -> None:
|
|
command = ["strobeclient", "getRunStatus", "--run-id", str(self.current_run_id)]
|
|
logger.debug("running command: %s", _command_to_string(command))
|
|
result = subprocess.run(command, capture_output=True)
|
|
output = result.stderr.decode("utf-8")
|
|
logger.debug("output:\n{%s}", output)
|
|
|
|
if result.returncode != 0:
|
|
raise StrobelightCLIProfilerError(
|
|
f"failed to extract profiling results, return code is not 0 : {output}"
|
|
)
|
|
|
|
if match := re.search("INFO ::1:(.*)", output):
|
|
current_status = match.group(1)
|
|
if current_status.__contains__("Profile run status: PROCESSING"):
|
|
time.sleep(10)
|
|
self._get_results()
|
|
return
|
|
elif not current_status.__contains__("Profile run finished with SUCCESS"):
|
|
raise StrobelightCLIProfilerError(
|
|
f"failed to extract profiling results, unexpected response {output}"
|
|
)
|
|
|
|
self.profile_result = []
|
|
for item in re.findall(
|
|
r"(Total samples(.*)|GraphProfiler(.*)|Icicle view \(python stack\)(.*))",
|
|
output,
|
|
):
|
|
self.profile_result += item[0]
|
|
logger.info(item[0])
|
|
|
|
def _stop_strobelight_no_throw(
|
|
self,
|
|
collect_results: bool,
|
|
) -> None:
|
|
try:
|
|
# call stop run
|
|
self._stop_run()
|
|
logger.info("strobelight profiling stopped")
|
|
|
|
logger.debug("collection stopped")
|
|
|
|
if not collect_results:
|
|
return
|
|
|
|
self._get_results()
|
|
except Exception:
|
|
logger.warning("error during stop_strobelight", exc_info=True)
|
|
|
|
# Return true if strobelight started and is running. Never throw.
|
|
def _start_strobelight(self) -> bool:
|
|
strobelight_started = False
|
|
try:
|
|
self._run_async()
|
|
strobelight_started = True
|
|
logger.info("strobelight run id is: %s", self.current_run_id)
|
|
self._wait_for_running()
|
|
logger.info("strobelight profiling running")
|
|
return True
|
|
|
|
except Exception:
|
|
logger.warning("error during start_strobelight:", exc_info=True)
|
|
if strobelight_started:
|
|
self._stop_strobelight_no_throw(collect_results=False)
|
|
return False
|
|
|
|
def profile(
|
|
self, work_function: Callable[_P, _R], *args: _P.args, **kwargs: _P.kwargs
|
|
) -> Optional[_R]:
|
|
self.current_run_id = None
|
|
self.profile_result = None
|
|
|
|
if locked := StrobelightCLIFunctionProfiler._lock.acquire(False):
|
|
if not locked:
|
|
if self.stop_at_error:
|
|
raise StrobelightCLIProfilerError("concurrent runs not supported")
|
|
|
|
logger.warning("concurrent runs not supported")
|
|
return work_function(*args, **kwargs)
|
|
|
|
started = self._start_strobelight()
|
|
if not started:
|
|
if self.stop_at_error:
|
|
StrobelightCLIFunctionProfiler._lock.release()
|
|
raise StrobelightCLIProfilerError(
|
|
"failed to start strobelight profiling"
|
|
)
|
|
result = work_function(*args, **kwargs)
|
|
StrobelightCLIFunctionProfiler._lock.release()
|
|
return result
|
|
|
|
try:
|
|
logger.debug("collection started")
|
|
start = timer()
|
|
result = work_function(*args, **kwargs)
|
|
end = timer()
|
|
total_time = end - start # Time in seconds, e.g. 5.38091952400282
|
|
logger.info("work function took %s seconds", total_time)
|
|
self._stop_strobelight_no_throw(collect_results=True)
|
|
StrobelightCLIFunctionProfiler._lock.release()
|
|
return result
|
|
except Exception as error:
|
|
logger.warning("work function throw exception", exc_info=True)
|
|
self._stop_strobelight_no_throw(collect_results=False)
|
|
StrobelightCLIFunctionProfiler._lock.release()
|
|
raise error
|
|
return None
|
|
|
|
|
|
# A function decorator that wraps profile, if no profiler is provided one with
|
|
# default args is created. A function can be annotated as:
|
|
# @strobelight()
|
|
# @strobelight(profiler = StrobelightFunctionProfiler(stop_at_error=True,..))
|
|
# @strobelight(stop_at_error=True,...)
|
|
def strobelight(
|
|
profiler: Optional[StrobelightCLIFunctionProfiler] = None, **kwargs: Any
|
|
) -> Callable[[Callable[_P, _R]], Callable[_P, Optional[_R]]]:
|
|
if not profiler:
|
|
profiler = StrobelightCLIFunctionProfiler(**kwargs)
|
|
|
|
def strobelight_inner(
|
|
work_function: Callable[_P, _R],
|
|
) -> Callable[_P, Optional[_R]]:
|
|
@functools.wraps(work_function)
|
|
def wrapper_function(*args: _P.args, **kwargs: _P.kwargs) -> Optional[_R]:
|
|
return profiler.profile(work_function, *args, **kwargs)
|
|
|
|
return wrapper_function
|
|
|
|
return strobelight_inner
|