mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Enable strobelight profiling specific compile frame ids using COMPILE_STROBELIGHT_FRAME_FILTER (#147549)
running python test/strobelight/examples/compile_time_profile_example.py ``` strobelight_compile_time_profiler, line 123, 2025-02-20 14:08:08,409, INFO: compile time strobelight profiling enabled strobelight_compile_time_profiler, line 159, 2025-02-20 14:08:08,409, INFO: Unique sample tag for this run is: 2025-02-20-14:08:081656673devgpu005.nha1.facebook.com strobelight_compile_time_profiler, line 160, 2025-02-20 14:08:09,124, INFO: URL to access the strobelight profile at the end of the run: https://fburl.com/scuba/pyperf_experimental/on_demand/9felqj0i strobelight_compile_time_profiler, line 205, 2025-02-20 14:08:12,436, INFO: profiling frame 0/0 is skipped due to frame_id_filter 1/.* strobelight_compile_time_profiler, line 205, 2025-02-20 14:08:15,553, INFO: profiling frame 0/0 is skipped due to frame_id_filter 1/.* strobelight_compile_time_profiler, line 205, 2025-02-20 14:08:16,170, INFO: profiling frame 0/0 is skipped due to frame_id_filter 1/.* strobelight_compile_time_profiler, line 214, 2025-02-20 14:08:16,877, INFO: profiling frame 1/0 strobelight_function_profiler, line 247, 2025-02-20 14:08:19,416, INFO: strobelight run id is: 4015948658689996 strobelight_function_profiler, line 249, 2025-02-20 14:08:21,546, INFO: strobelight profiling running strobelight_function_profiler, line 289, 2025-02-20 14:08:25,964, INFO: work function took 4.417063233006047 seconds strobelight_function_profiler, line 230, 2025-02-20 14:08:28,310, INFO: strobelight profiling stopped strobelight_function_profiler, line 221, 2025-02-20 14:08:44,308, INFO: Total samples: 119 strobelight_function_profiler, line 221, 2025-02-20 14:08:44,308, INFO: GraphProfiler (python stack): https://fburl.com/scuba/pyperf_experimental/on_demand/73h2f7ur strobelight_function_profiler, line 221, 2025-02-20 14:08:44,308, INFO: Icicle view (python stack): https://fburl.com/scuba/pyperf_experimental/on_demand/zs06fi9e strobelight_compile_time_profiler, line 167, 2025-02-20 14:08:44,308, INFO: 1 strobelight success runs out of 1 non-recursive compilation events. ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/147549 Approved by: https://github.com/bobrenjc93 ghstack dependencies: #147547
This commit is contained in:
committed by
PyTorch MergeBot
parent
fc095a885c
commit
77d2780657
@ -7,6 +7,12 @@ if __name__ == "__main__":
|
||||
# You can pass TORCH_COMPILE_STROBELIGHT=True instead.
|
||||
StrobelightCompileTimeProfiler.enable()
|
||||
|
||||
# You can use the code below to filter what frames to be profiled.
|
||||
StrobelightCompileTimeProfiler.frame_id_filter = "1/.*"
|
||||
# StrobelightCompileTimeProfiler.frame_id_filter='0/.*'
|
||||
# StrobelightCompileTimeProfiler.frame_id_filter='.*'
|
||||
# You can set env variable COMPILE_STROBELIGHT_FRAME_FILTER to set the filter also.
|
||||
|
||||
def fn(x, y, z):
|
||||
return x * y + z
|
||||
|
||||
@ -18,6 +24,14 @@ if __name__ == "__main__":
|
||||
|
||||
# Strobelight will be called only 3 times because dynamo will be disabled after
|
||||
# 3rd iteration.
|
||||
# Frame 0/0
|
||||
for i in range(3):
|
||||
torch._dynamo.reset()
|
||||
work(i)
|
||||
|
||||
@torch.compile(fullgraph=True)
|
||||
def func4(x):
|
||||
return x * x
|
||||
|
||||
# Frame 1/0
|
||||
func4(torch.rand(10))
|
||||
|
@ -3,6 +3,7 @@
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import subprocess
|
||||
from datetime import datetime
|
||||
from socket import gethostname
|
||||
@ -84,6 +85,10 @@ class StrobelightCompileTimeProfiler:
|
||||
ignored_profile_runs: int = 0
|
||||
inside_profile_compile_time: bool = False
|
||||
enabled: bool = False
|
||||
|
||||
# A regex that can be used to filter out what frames to profile. ex: "1/.*"
|
||||
frame_id_filter: Optional[str] = os.environ.get("COMPILE_STROBELIGHT_FRAME_FILTER")
|
||||
|
||||
# A unique identifier that is used as the run_user_name in the strobelight profile to
|
||||
# associate all compile time profiles together.
|
||||
identifier: Optional[str] = None
|
||||
@ -103,6 +108,12 @@ class StrobelightCompileTimeProfiler:
|
||||
float(os.environ.get("COMPILE_STROBELIGHT_SAMPLE_RATE", 1e7))
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_frame(cls) -> str:
|
||||
from torch._guards import CompileContext
|
||||
|
||||
return (str)(CompileContext.current_trace_id())
|
||||
|
||||
@classmethod
|
||||
def enable(cls, profiler_class: Any = StrobelightCLIFunctionProfiler) -> None:
|
||||
if cls.enabled:
|
||||
@ -164,25 +175,43 @@ class StrobelightCompileTimeProfiler:
|
||||
def profile_compile_time(
|
||||
cls, func: Any, phase_name: str, *args: Any, **kwargs: Any
|
||||
) -> Any:
|
||||
if not cls.enabled:
|
||||
def skip() -> Any:
|
||||
return func(*args, **kwargs)
|
||||
|
||||
if not cls.enabled:
|
||||
return skip()
|
||||
|
||||
if cls.profiler is None:
|
||||
logger.error("profiler is not set")
|
||||
return
|
||||
|
||||
frame_id = cls.get_frame()
|
||||
|
||||
if cls.inside_profile_compile_time:
|
||||
cls.ignored_profile_runs += 1
|
||||
logger.info(
|
||||
"profile_compile_time is requested for phase: %s while already in running phase: %s, recursive call ignored",
|
||||
"profile_compile_time is requested for phase: %s, frame %s, while already in running phase: %s,"
|
||||
"frame %s, recursive call ignored",
|
||||
phase_name,
|
||||
frame_id,
|
||||
cls.current_phase,
|
||||
frame_id,
|
||||
)
|
||||
return func(*args, **kwargs)
|
||||
return skip()
|
||||
|
||||
if cls.frame_id_filter is not None:
|
||||
should_run = re.match(cls.frame_id_filter, frame_id) is not None
|
||||
if not should_run:
|
||||
logger.info(
|
||||
"profiling frame %s is skipped due to frame_id_filter %s",
|
||||
frame_id,
|
||||
cls.frame_id_filter,
|
||||
)
|
||||
return skip()
|
||||
|
||||
cls.inside_profile_compile_time = True
|
||||
cls.current_phase = phase_name
|
||||
|
||||
logger.info("profiling frame %s", frame_id)
|
||||
work_result = cls.profiler.profile(func, *args, **kwargs)
|
||||
|
||||
if cls.profiler.profile_result is not None:
|
||||
|
@ -91,6 +91,8 @@ def compile_time_strobelight_meta(
|
||||
if "skip" in kwargs and isinstance(skip := kwargs["skip"], int):
|
||||
kwargs["skip"] = skip + 1
|
||||
|
||||
# This is not needed but we have it here to avoid having profile_compile_time
|
||||
# in stack traces when profiling is not enabled.
|
||||
if not StrobelightCompileTimeProfiler.enabled:
|
||||
return function(*args, **kwargs)
|
||||
|
||||
|
Reference in New Issue
Block a user