diff --git a/test/dynamo/test_structured_trace.py b/test/dynamo/test_structured_trace.py index fe5d099e7334..dd7dec83c95a 100644 --- a/test/dynamo/test_structured_trace.py +++ b/test/dynamo/test_structured_trace.py @@ -260,6 +260,7 @@ class StructuredTraceTest(TestCase): {"artifact": {"name": "aotautograd_cache_miss", "encoding": "json"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"dynamo_cpp_guards_str": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"compilation_metrics": "METRICS", "frame_id": 0, "frame_compile_id": 0, "attempt": 0} +{"compilation_metrics_runtime": "METRICS", "frame_id": 0, "frame_compile_id": 0} """, # noqa: B950 ) diff --git a/torch/_dynamo/utils.py b/torch/_dynamo/utils.py index 7ebaa8dcaac8..fd89d63fbc84 100644 --- a/torch/_dynamo/utils.py +++ b/torch/_dynamo/utils.py @@ -596,7 +596,7 @@ def dynamo_timed( dynamo_compile_column_us: Optional[str] = None, dynamo_compile_runtime_column_us: Optional[str] = None, compile_id: Optional[CompileId] = None, - is_forward: Optional[bool] = None, + is_backward: Optional[bool] = None, log_waitcounter: bool = False, ) -> Generator[Any, None, None]: """ @@ -638,7 +638,8 @@ def dynamo_timed( - compile_id: In the typical case, this parameter should not be needed. Use to supply the compile_id for those cases where we want to log a compile_id where it's not naturally available, e.g., for runtime autotuning. - - is_forward: Optionally set an is_forward field for those logging destinations + - is_backward: Specify forward/backward directly when not available in a + CompileContext, e.g., during runtime autotuning. that support it. - log_waitcounter: If set, we'll log a waitcounter of the form "pytorch.dynamo_timed.{key}" """ @@ -664,8 +665,8 @@ def dynamo_timed( event_metadata.update(metadata) if fn_name: event_metadata.update({"fn_name": fn_name}) - if is_forward is not None: - event_metadata.update({"is_backward": not is_forward}) + if is_backward is not None: + event_metadata.update({"is_backward": is_backward}) chromium_log: ChromiumEventLogger = get_chromium_event_logger() start_ns = time.time_ns() @@ -707,7 +708,7 @@ def dynamo_timed( extra={ "compile_id": compile_id, "is_runtime": True, - "is_forward": is_forward, + "is_forward": not is_backward, }, ) cumulative_time_spent_ns[event_name] += time_spent_ns diff --git a/torch/_inductor/async_compile.py b/torch/_inductor/async_compile.py index 293151a907da..69cb3c37fddb 100644 --- a/torch/_inductor/async_compile.py +++ b/torch/_inductor/async_compile.py @@ -41,6 +41,7 @@ from torch._inductor.runtime.compile_tasks import ( _worker_compile_triton, ) from torch._inductor.utils import clear_on_fresh_inductor_cache +from torch._inductor.virtualized import V from torch.hub import _Faketqdm, tqdm from torch.utils._ordered_set import OrderedSet from torch.utils._triton import has_triton_package @@ -300,6 +301,10 @@ class AsyncCompile: ) is_parallel = self.use_process_pool() set_feature_use("parallel_compile_post_warmup", is_parallel) + + compile_id = torch._guards.CompileContext.current_compile_id() + is_backward = getattr(V.graph, "is_backward", False) + if is_parallel: # We want to support changing these env vars after (and while) the # process pool is running, so pass them to the subprocess to reset. @@ -322,6 +327,7 @@ class AsyncCompile: # Now that we've compiled, we should clear the future # so it can't be used again CompiledTritonKernels.remove_future(source_code) + kernel.set_compile_info(compile_id, is_backward) kernel.precompile( warm_cache_only=False, reload_kernel=reload_kernel_in_parent ) @@ -343,6 +349,7 @@ class AsyncCompile: start_ns = time_ns() _set_triton_ptxas_path() kernel = load_kernel() + kernel.set_compile_info(compile_id, is_backward) kernel.precompile(warm_cache_only=False) elapsed_us = (time_ns() - start_ns) // 1000 get_metrics_context().add_top_n( diff --git a/torch/_inductor/runtime/triton_heuristics.py b/torch/_inductor/runtime/triton_heuristics.py index 5bdb21939f17..ced8b95ca2be 100644 --- a/torch/_inductor/runtime/triton_heuristics.py +++ b/torch/_inductor/runtime/triton_heuristics.py @@ -75,6 +75,8 @@ class NoTritonConfigsError(RuntimeError): if TYPE_CHECKING: from collections.abc import Container, Hashable, Sequence + from torch._guards import CompileId + LauncherType = Any @@ -258,6 +260,16 @@ class CachingAutotuner(KernelInterface): self.triton_interpret = os.environ.get("TRITON_INTERPRET", "0") == "1" + # Compile-time info included in runtime logginging + self.compile_id: Optional[CompileId] = None + self.is_backward = False + + def set_compile_info( + self, compile_id: Optional[CompileId], is_backward: bool + ) -> None: + self.compile_id = compile_id + self.is_backward = is_backward + def precompile( self, warm_cache_only=False, @@ -731,8 +743,9 @@ class CachingAutotuner(KernelInterface): "CachingAutotuner.benchmark_all_configs", log_pt2_compile_event=True, metadata={"kernel_name": self.inductor_meta.get("kernel_name")}, - # TODO(masnesral): Enable this when we figure out how to get the CompileId: - # dynamo_compile_runtime_column_us="runtime_triton_autotune_time_us", + dynamo_compile_runtime_column_us="runtime_triton_autotune_time_us", + compile_id=self.compile_id, + is_backward=self.is_backward, ): timings = { launcher: self.bench(launcher, *args, **kwargs)