mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
running python test/strobelight/examples/compile_time_profile_example.py ``` strobelight_compile_time_profiler, line 123, 2025-02-20 14:08:08,409, INFO: compile time strobelight profiling enabled strobelight_compile_time_profiler, line 159, 2025-02-20 14:08:08,409, INFO: Unique sample tag for this run is: 2025-02-20-14:08:081656673devgpu005.nha1.facebook.com strobelight_compile_time_profiler, line 160, 2025-02-20 14:08:09,124, INFO: URL to access the strobelight profile at the end of the run: https://fburl.com/scuba/pyperf_experimental/on_demand/9felqj0i strobelight_compile_time_profiler, line 205, 2025-02-20 14:08:12,436, INFO: profiling frame 0/0 is skipped due to frame_id_filter 1/.* strobelight_compile_time_profiler, line 205, 2025-02-20 14:08:15,553, INFO: profiling frame 0/0 is skipped due to frame_id_filter 1/.* strobelight_compile_time_profiler, line 205, 2025-02-20 14:08:16,170, INFO: profiling frame 0/0 is skipped due to frame_id_filter 1/.* strobelight_compile_time_profiler, line 214, 2025-02-20 14:08:16,877, INFO: profiling frame 1/0 strobelight_function_profiler, line 247, 2025-02-20 14:08:19,416, INFO: strobelight run id is: 4015948658689996 strobelight_function_profiler, line 249, 2025-02-20 14:08:21,546, INFO: strobelight profiling running strobelight_function_profiler, line 289, 2025-02-20 14:08:25,964, INFO: work function took 4.417063233006047 seconds strobelight_function_profiler, line 230, 2025-02-20 14:08:28,310, INFO: strobelight profiling stopped strobelight_function_profiler, line 221, 2025-02-20 14:08:44,308, INFO: Total samples: 119 strobelight_function_profiler, line 221, 2025-02-20 14:08:44,308, INFO: GraphProfiler (python stack): https://fburl.com/scuba/pyperf_experimental/on_demand/73h2f7ur strobelight_function_profiler, line 221, 2025-02-20 14:08:44,308, INFO: Icicle view (python stack): https://fburl.com/scuba/pyperf_experimental/on_demand/zs06fi9e strobelight_compile_time_profiler, line 167, 2025-02-20 14:08:44,308, INFO: 1 strobelight success runs out of 1 non-recursive compilation events. ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/147549 Approved by: https://github.com/bobrenjc93 ghstack dependencies: #147547
225 lines
7.3 KiB
Python
225 lines
7.3 KiB
Python
# mypy: disallow-untyped-defs
|
|
|
|
import json
|
|
import logging
|
|
import os
|
|
import re
|
|
import subprocess
|
|
from datetime import datetime
|
|
from socket import gethostname
|
|
from typing import Any, Optional
|
|
|
|
from torch._strobelight.cli_function_profiler import StrobelightCLIFunctionProfiler
|
|
|
|
|
|
logger = logging.getLogger("strobelight_compile_time_profiler")
|
|
|
|
console_handler = logging.StreamHandler()
|
|
formatter = logging.Formatter(
|
|
"%(name)s, line %(lineno)d, %(asctime)s, %(levelname)s: %(message)s"
|
|
)
|
|
console_handler.setFormatter(formatter)
|
|
|
|
logger.addHandler(console_handler)
|
|
logger.setLevel(logging.INFO)
|
|
logger.propagate = False
|
|
|
|
|
|
def get_fburl(url: str) -> str:
|
|
short_url = url
|
|
# Attempt to shorten the URL
|
|
try:
|
|
result = subprocess.run(
|
|
["fburl", url], capture_output=True, stdin=subprocess.DEVNULL
|
|
)
|
|
if result.returncode == 0:
|
|
short_url = result.stdout.decode("utf-8")
|
|
except Exception as e:
|
|
logger.warning("URL shortening failed: %s, using long URL", repr(e))
|
|
return short_url
|
|
|
|
|
|
def get_strobelight_url(identifier: str) -> str:
|
|
scuba_json = {
|
|
"aggregateList": [],
|
|
"aggregation_field": "async_stack_complete",
|
|
"b_constraints": [[]],
|
|
"c_constraints": [[]],
|
|
"cols": ["namespace_id", "namespace_process_id"],
|
|
"compare": "none",
|
|
"constraints": [
|
|
[{"column": "sample_tags", "op": "all", "value": [f'["{identifier}"]']}]
|
|
],
|
|
"derivedCols": [],
|
|
"end": "now",
|
|
"enumCols": [],
|
|
"filterMode": "DEFAULT",
|
|
"hideEmptyColumns": "false",
|
|
"ignoreGroupByInComparison": "false",
|
|
"is_timeseries": "false",
|
|
"mappedCols": [],
|
|
"metric": "count",
|
|
"modifiers": [],
|
|
"order": "weight",
|
|
"order_desc": "true",
|
|
"param_dimensions": [
|
|
{"dim": "py_async_stack", "op": "edge", "param": "0", "anchor": "0"}
|
|
],
|
|
"purposes": [],
|
|
"return_remainder": "false",
|
|
"samplingRatio": "1",
|
|
"should_pivot": "false",
|
|
"start": "-30 days",
|
|
"timezone": "America/Los_Angeles",
|
|
"top": 10000,
|
|
}
|
|
scuba_url_prefix = "https://www.internalfb.com/intern/scuba/query/?dataset=pyperf_experimental/on_demand&drillstate="
|
|
scuba_url_suff = "&view=GraphProfilerView&&normalized=1726332703&pool=uber"
|
|
long_url = scuba_url_prefix + json.dumps(scuba_json) + scuba_url_suff
|
|
return get_fburl(long_url)
|
|
|
|
|
|
class StrobelightCompileTimeProfiler:
|
|
success_profile_count: int = 0
|
|
failed_profile_count: int = 0
|
|
ignored_profile_runs: int = 0
|
|
inside_profile_compile_time: bool = False
|
|
enabled: bool = False
|
|
|
|
# A regex that can be used to filter out what frames to profile. ex: "1/.*"
|
|
frame_id_filter: Optional[str] = os.environ.get("COMPILE_STROBELIGHT_FRAME_FILTER")
|
|
|
|
# A unique identifier that is used as the run_user_name in the strobelight profile to
|
|
# associate all compile time profiles together.
|
|
identifier: Optional[str] = None
|
|
|
|
current_phase: Optional[str] = None
|
|
|
|
profiler: Optional[Any] = None
|
|
|
|
max_stack_length: int = int(
|
|
os.environ.get("COMPILE_STROBELIGHT_MAX_STACK_LENGTH", 500)
|
|
)
|
|
max_profile_time: int = int(
|
|
os.environ.get("COMPILE_STROBELIGHT_MAX_PROFILE_TIME", 60 * 30)
|
|
)
|
|
# Collect sample each x cycles.
|
|
sample_each: int = int(
|
|
float(os.environ.get("COMPILE_STROBELIGHT_SAMPLE_RATE", 1e7))
|
|
)
|
|
|
|
@classmethod
|
|
def get_frame(cls) -> str:
|
|
from torch._guards import CompileContext
|
|
|
|
return (str)(CompileContext.current_trace_id())
|
|
|
|
@classmethod
|
|
def enable(cls, profiler_class: Any = StrobelightCLIFunctionProfiler) -> None:
|
|
if cls.enabled:
|
|
logger.info("compile time strobelight profiling already enabled")
|
|
return
|
|
|
|
logger.info("compile time strobelight profiling enabled")
|
|
|
|
if profiler_class is StrobelightCLIFunctionProfiler:
|
|
import shutil
|
|
|
|
if not shutil.which("strobeclient"):
|
|
logger.info(
|
|
"strobeclient not found, cant enable compile time strobelight profiling, seems"
|
|
"like you are not on a FB machine."
|
|
)
|
|
return
|
|
|
|
cls.enabled = True
|
|
cls._cls_init()
|
|
# profiler_class should have public API similar to that of StrobelightCLIFunctionProfiler.
|
|
# we have pass different functionProfilerClass for meta-internal fbcode targets.
|
|
# NB: the actual implementation in Meta is at
|
|
# fbcode/caffe2/fb/strobelight/function_profiler.py
|
|
cls.profiler = profiler_class(
|
|
sample_each=cls.sample_each,
|
|
max_profile_duration_sec=cls.max_profile_time,
|
|
stack_max_len=cls.max_stack_length,
|
|
async_stack_max_len=cls.max_stack_length,
|
|
run_user_name="pt2-profiler/"
|
|
+ os.environ.get("USER", os.environ.get("USERNAME", "")),
|
|
sample_tags={cls.identifier},
|
|
)
|
|
|
|
@classmethod
|
|
def _cls_init(cls) -> None:
|
|
cls.identifier = "{date}{pid}{hostname}".format(
|
|
date=datetime.now().strftime("%Y-%m-%d-%H:%M:%S"),
|
|
pid=os.getpid(),
|
|
hostname=gethostname(),
|
|
)
|
|
|
|
logger.info("Unique sample tag for this run is: %s", cls.identifier)
|
|
logger.info(
|
|
"URL to access the strobelight profile at the end of the run: %s",
|
|
get_strobelight_url(cls.identifier),
|
|
)
|
|
|
|
@classmethod
|
|
def _log_stats(cls) -> None:
|
|
logger.info(
|
|
"%s strobelight success runs out of %s non-recursive compilation events.",
|
|
cls.success_profile_count,
|
|
cls.success_profile_count + cls.failed_profile_count,
|
|
)
|
|
|
|
# TODO use threadlevel meta data to tags to record phases.
|
|
@classmethod
|
|
def profile_compile_time(
|
|
cls, func: Any, phase_name: str, *args: Any, **kwargs: Any
|
|
) -> Any:
|
|
def skip() -> Any:
|
|
return func(*args, **kwargs)
|
|
|
|
if not cls.enabled:
|
|
return skip()
|
|
|
|
if cls.profiler is None:
|
|
logger.error("profiler is not set")
|
|
return
|
|
|
|
frame_id = cls.get_frame()
|
|
|
|
if cls.inside_profile_compile_time:
|
|
cls.ignored_profile_runs += 1
|
|
logger.info(
|
|
"profile_compile_time is requested for phase: %s, frame %s, while already in running phase: %s,"
|
|
"frame %s, recursive call ignored",
|
|
phase_name,
|
|
frame_id,
|
|
cls.current_phase,
|
|
frame_id,
|
|
)
|
|
return skip()
|
|
|
|
if cls.frame_id_filter is not None:
|
|
should_run = re.match(cls.frame_id_filter, frame_id) is not None
|
|
if not should_run:
|
|
logger.info(
|
|
"profiling frame %s is skipped due to frame_id_filter %s",
|
|
frame_id,
|
|
cls.frame_id_filter,
|
|
)
|
|
return skip()
|
|
|
|
cls.inside_profile_compile_time = True
|
|
cls.current_phase = phase_name
|
|
logger.info("profiling frame %s", frame_id)
|
|
work_result = cls.profiler.profile(func, *args, **kwargs)
|
|
|
|
if cls.profiler.profile_result is not None:
|
|
cls.success_profile_count += 1
|
|
else:
|
|
cls.failed_profile_count += 1
|
|
|
|
cls._log_stats()
|
|
cls.inside_profile_compile_time = False
|
|
return work_result
|