diff --git a/test/strobelight/examples/compile_time_profile_example.py b/test/strobelight/examples/compile_time_profile_example.py index 7e46ea62647a..d442ef1d5043 100644 --- a/test/strobelight/examples/compile_time_profile_example.py +++ b/test/strobelight/examples/compile_time_profile_example.py @@ -7,6 +7,12 @@ if __name__ == "__main__": # You can pass TORCH_COMPILE_STROBELIGHT=True instead. StrobelightCompileTimeProfiler.enable() + # You can use the code below to filter what frames to be profiled. + StrobelightCompileTimeProfiler.frame_id_filter = "1/.*" + # StrobelightCompileTimeProfiler.frame_id_filter='0/.*' + # StrobelightCompileTimeProfiler.frame_id_filter='.*' + # You can set env variable COMPILE_STROBELIGHT_FRAME_FILTER to set the filter also. + def fn(x, y, z): return x * y + z @@ -18,6 +24,14 @@ if __name__ == "__main__": # Strobelight will be called only 3 times because dynamo will be disabled after # 3rd iteration. + # Frame 0/0 for i in range(3): torch._dynamo.reset() work(i) + + @torch.compile(fullgraph=True) + def func4(x): + return x * x + + # Frame 1/0 + func4(torch.rand(10)) diff --git a/torch/_strobelight/compile_time_profiler.py b/torch/_strobelight/compile_time_profiler.py index 81ebef2df6b1..2677b75cbbe0 100644 --- a/torch/_strobelight/compile_time_profiler.py +++ b/torch/_strobelight/compile_time_profiler.py @@ -3,6 +3,7 @@ import json import logging import os +import re import subprocess from datetime import datetime from socket import gethostname @@ -84,6 +85,10 @@ class StrobelightCompileTimeProfiler: ignored_profile_runs: int = 0 inside_profile_compile_time: bool = False enabled: bool = False + + # A regex that can be used to filter out what frames to profile. ex: "1/.*" + frame_id_filter: Optional[str] = os.environ.get("COMPILE_STROBELIGHT_FRAME_FILTER") + # A unique identifier that is used as the run_user_name in the strobelight profile to # associate all compile time profiles together. identifier: Optional[str] = None @@ -103,6 +108,12 @@ class StrobelightCompileTimeProfiler: float(os.environ.get("COMPILE_STROBELIGHT_SAMPLE_RATE", 1e7)) ) + @classmethod + def get_frame(cls) -> str: + from torch._guards import CompileContext + + return (str)(CompileContext.current_trace_id()) + @classmethod def enable(cls, profiler_class: Any = StrobelightCLIFunctionProfiler) -> None: if cls.enabled: @@ -164,25 +175,43 @@ class StrobelightCompileTimeProfiler: def profile_compile_time( cls, func: Any, phase_name: str, *args: Any, **kwargs: Any ) -> Any: - if not cls.enabled: + def skip() -> Any: return func(*args, **kwargs) + if not cls.enabled: + return skip() + if cls.profiler is None: logger.error("profiler is not set") return + frame_id = cls.get_frame() + if cls.inside_profile_compile_time: cls.ignored_profile_runs += 1 logger.info( - "profile_compile_time is requested for phase: %s while already in running phase: %s, recursive call ignored", + "profile_compile_time is requested for phase: %s, frame %s, while already in running phase: %s," + "frame %s, recursive call ignored", phase_name, + frame_id, cls.current_phase, + frame_id, ) - return func(*args, **kwargs) + return skip() + + if cls.frame_id_filter is not None: + should_run = re.match(cls.frame_id_filter, frame_id) is not None + if not should_run: + logger.info( + "profiling frame %s is skipped due to frame_id_filter %s", + frame_id, + cls.frame_id_filter, + ) + return skip() cls.inside_profile_compile_time = True cls.current_phase = phase_name - + logger.info("profiling frame %s", frame_id) work_result = cls.profiler.profile(func, *args, **kwargs) if cls.profiler.profile_result is not None: diff --git a/torch/_utils_internal.py b/torch/_utils_internal.py index 14b6df2ce94e..c6788e44bbad 100644 --- a/torch/_utils_internal.py +++ b/torch/_utils_internal.py @@ -91,6 +91,8 @@ def compile_time_strobelight_meta( if "skip" in kwargs and isinstance(skip := kwargs["skip"], int): kwargs["skip"] = skip + 1 + # This is not needed but we have it here to avoid having profile_compile_time + # in stack traces when profiling is not enabled. if not StrobelightCompileTimeProfiler.enabled: return function(*args, **kwargs)