Enable strobelight profiling specific compile frame ids using COMPILE_STROBELIGHT_FRAME_FILTER (#147549)

running python test/strobelight/examples/compile_time_profile_example.py
```
strobelight_compile_time_profiler, line 123, 2025-02-20 14:08:08,409, INFO: compile time strobelight profiling enabled
strobelight_compile_time_profiler, line 159, 2025-02-20 14:08:08,409, INFO: Unique sample tag for this run is: 2025-02-20-14:08:081656673devgpu005.nha1.facebook.com
strobelight_compile_time_profiler, line 160, 2025-02-20 14:08:09,124, INFO: URL to access the strobelight profile at the end of the run: https://fburl.com/scuba/pyperf_experimental/on_demand/9felqj0i

strobelight_compile_time_profiler, line 205, 2025-02-20 14:08:12,436, INFO: profiling frame 0/0 is skipped due to frame_id_filter 1/.*
strobelight_compile_time_profiler, line 205, 2025-02-20 14:08:15,553, INFO: profiling frame 0/0 is skipped due to frame_id_filter 1/.*
strobelight_compile_time_profiler, line 205, 2025-02-20 14:08:16,170, INFO: profiling frame 0/0 is skipped due to frame_id_filter 1/.*
strobelight_compile_time_profiler, line 214, 2025-02-20 14:08:16,877, INFO: profiling frame 1/0
strobelight_function_profiler, line 247, 2025-02-20 14:08:19,416, INFO: strobelight run id is: 4015948658689996
strobelight_function_profiler, line 249, 2025-02-20 14:08:21,546, INFO: strobelight profiling running
strobelight_function_profiler, line 289, 2025-02-20 14:08:25,964, INFO: work function took 4.417063233006047 seconds
strobelight_function_profiler, line 230, 2025-02-20 14:08:28,310, INFO: strobelight profiling stopped
strobelight_function_profiler, line 221, 2025-02-20 14:08:44,308, INFO: Total samples: 119
strobelight_function_profiler, line 221, 2025-02-20 14:08:44,308, INFO: GraphProfiler (python stack): https://fburl.com/scuba/pyperf_experimental/on_demand/73h2f7ur
strobelight_function_profiler, line 221, 2025-02-20 14:08:44,308, INFO: Icicle view (python stack): https://fburl.com/scuba/pyperf_experimental/on_demand/zs06fi9e
strobelight_compile_time_profiler, line 167, 2025-02-20 14:08:44,308, INFO: 1 strobelight success runs out of 1 non-recursive compilation events.

```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/147549
Approved by: https://github.com/bobrenjc93
ghstack dependencies: #147547
This commit is contained in:
Laith Sakka
2025-02-20 14:33:18 -08:00
committed by PyTorch MergeBot
parent fc095a885c
commit 77d2780657
3 changed files with 49 additions and 4 deletions

View File

@ -7,6 +7,12 @@ if __name__ == "__main__":
# You can pass TORCH_COMPILE_STROBELIGHT=True instead.
StrobelightCompileTimeProfiler.enable()
# You can use the code below to filter what frames to be profiled.
StrobelightCompileTimeProfiler.frame_id_filter = "1/.*"
# StrobelightCompileTimeProfiler.frame_id_filter='0/.*'
# StrobelightCompileTimeProfiler.frame_id_filter='.*'
# You can set env variable COMPILE_STROBELIGHT_FRAME_FILTER to set the filter also.
def fn(x, y, z):
return x * y + z
@ -18,6 +24,14 @@ if __name__ == "__main__":
# Strobelight will be called only 3 times because dynamo will be disabled after
# 3rd iteration.
# Frame 0/0
for i in range(3):
torch._dynamo.reset()
work(i)
@torch.compile(fullgraph=True)
def func4(x):
return x * x
# Frame 1/0
func4(torch.rand(10))

View File

@ -3,6 +3,7 @@
import json
import logging
import os
import re
import subprocess
from datetime import datetime
from socket import gethostname
@ -84,6 +85,10 @@ class StrobelightCompileTimeProfiler:
ignored_profile_runs: int = 0
inside_profile_compile_time: bool = False
enabled: bool = False
# A regex that can be used to filter out what frames to profile. ex: "1/.*"
frame_id_filter: Optional[str] = os.environ.get("COMPILE_STROBELIGHT_FRAME_FILTER")
# A unique identifier that is used as the run_user_name in the strobelight profile to
# associate all compile time profiles together.
identifier: Optional[str] = None
@ -103,6 +108,12 @@ class StrobelightCompileTimeProfiler:
float(os.environ.get("COMPILE_STROBELIGHT_SAMPLE_RATE", 1e7))
)
@classmethod
def get_frame(cls) -> str:
from torch._guards import CompileContext
return (str)(CompileContext.current_trace_id())
@classmethod
def enable(cls, profiler_class: Any = StrobelightCLIFunctionProfiler) -> None:
if cls.enabled:
@ -164,25 +175,43 @@ class StrobelightCompileTimeProfiler:
def profile_compile_time(
cls, func: Any, phase_name: str, *args: Any, **kwargs: Any
) -> Any:
if not cls.enabled:
def skip() -> Any:
return func(*args, **kwargs)
if not cls.enabled:
return skip()
if cls.profiler is None:
logger.error("profiler is not set")
return
frame_id = cls.get_frame()
if cls.inside_profile_compile_time:
cls.ignored_profile_runs += 1
logger.info(
"profile_compile_time is requested for phase: %s while already in running phase: %s, recursive call ignored",
"profile_compile_time is requested for phase: %s, frame %s, while already in running phase: %s,"
"frame %s, recursive call ignored",
phase_name,
frame_id,
cls.current_phase,
frame_id,
)
return func(*args, **kwargs)
return skip()
if cls.frame_id_filter is not None:
should_run = re.match(cls.frame_id_filter, frame_id) is not None
if not should_run:
logger.info(
"profiling frame %s is skipped due to frame_id_filter %s",
frame_id,
cls.frame_id_filter,
)
return skip()
cls.inside_profile_compile_time = True
cls.current_phase = phase_name
logger.info("profiling frame %s", frame_id)
work_result = cls.profiler.profile(func, *args, **kwargs)
if cls.profiler.profile_result is not None:

View File

@ -91,6 +91,8 @@ def compile_time_strobelight_meta(
if "skip" in kwargs and isinstance(skip := kwargs["skip"], int):
kwargs["skip"] = skip + 1
# This is not needed but we have it here to avoid having profile_compile_time
# in stack traces when profiling is not enabled.
if not StrobelightCompileTimeProfiler.enabled:
return function(*args, **kwargs)