[logging] Set compile_id in the CachingAutotuner during compilation so we have it for dynamo_timed logging (#148693)

Summary: This is a simpler alternative to https://github.com/pytorch/pytorch/pull/146455, where we can stick the compileId (and forward/backward bool) in the CachingAutotuner so that we have it for logging `benchmark_all_configs`. Recall that the first attempt put the compileId in the inductor_meta and that interfered with caching.

Test Plan:
`python benchmarks/dynamo/torchbench.py --performance --training --amp --backend inductor --device cuda --print-compilation-time --repeat 5 --cold-start-latency --only nanogpt`
* tlparse: https://fburl.com/e71yn6uc
* dynamo_compile: https://fburl.com/scuba/dynamo_compile/sandbox/4ageghhv
* pt2_compile_events: https://fburl.com/scuba/pt2_compile_events/4fgv1itq

Pull Request resolved: https://github.com/pytorch/pytorch/pull/148693
Approved by: https://github.com/eellison
This commit is contained in:
Sam Larsen
2025-03-06 18:36:52 -08:00
committed by PyTorch MergeBot
parent 5b8da17681
commit 73c8068cf8
4 changed files with 29 additions and 7 deletions

View File

@ -260,6 +260,7 @@ class StructuredTraceTest(TestCase):
{"artifact": {"name": "aotautograd_cache_miss", "encoding": "json"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"dynamo_cpp_guards_str": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"compilation_metrics": "METRICS", "frame_id": 0, "frame_compile_id": 0, "attempt": 0}
{"compilation_metrics_runtime": "METRICS", "frame_id": 0, "frame_compile_id": 0}
""", # noqa: B950
)

View File

@ -596,7 +596,7 @@ def dynamo_timed(
dynamo_compile_column_us: Optional[str] = None,
dynamo_compile_runtime_column_us: Optional[str] = None,
compile_id: Optional[CompileId] = None,
is_forward: Optional[bool] = None,
is_backward: Optional[bool] = None,
log_waitcounter: bool = False,
) -> Generator[Any, None, None]:
"""
@ -638,7 +638,8 @@ def dynamo_timed(
- compile_id: In the typical case, this parameter should not be needed. Use to
supply the compile_id for those cases where we want to log a compile_id where
it's not naturally available, e.g., for runtime autotuning.
- is_forward: Optionally set an is_forward field for those logging destinations
- is_backward: Specify forward/backward directly when not available in a
CompileContext, e.g., during runtime autotuning.
that support it.
- log_waitcounter: If set, we'll log a waitcounter of the form "pytorch.dynamo_timed.{key}"
"""
@ -664,8 +665,8 @@ def dynamo_timed(
event_metadata.update(metadata)
if fn_name:
event_metadata.update({"fn_name": fn_name})
if is_forward is not None:
event_metadata.update({"is_backward": not is_forward})
if is_backward is not None:
event_metadata.update({"is_backward": is_backward})
chromium_log: ChromiumEventLogger = get_chromium_event_logger()
start_ns = time.time_ns()
@ -707,7 +708,7 @@ def dynamo_timed(
extra={
"compile_id": compile_id,
"is_runtime": True,
"is_forward": is_forward,
"is_forward": not is_backward,
},
)
cumulative_time_spent_ns[event_name] += time_spent_ns

View File

@ -41,6 +41,7 @@ from torch._inductor.runtime.compile_tasks import (
_worker_compile_triton,
)
from torch._inductor.utils import clear_on_fresh_inductor_cache
from torch._inductor.virtualized import V
from torch.hub import _Faketqdm, tqdm
from torch.utils._ordered_set import OrderedSet
from torch.utils._triton import has_triton_package
@ -300,6 +301,10 @@ class AsyncCompile:
)
is_parallel = self.use_process_pool()
set_feature_use("parallel_compile_post_warmup", is_parallel)
compile_id = torch._guards.CompileContext.current_compile_id()
is_backward = getattr(V.graph, "is_backward", False)
if is_parallel:
# We want to support changing these env vars after (and while) the
# process pool is running, so pass them to the subprocess to reset.
@ -322,6 +327,7 @@ class AsyncCompile:
# Now that we've compiled, we should clear the future
# so it can't be used again
CompiledTritonKernels.remove_future(source_code)
kernel.set_compile_info(compile_id, is_backward)
kernel.precompile(
warm_cache_only=False, reload_kernel=reload_kernel_in_parent
)
@ -343,6 +349,7 @@ class AsyncCompile:
start_ns = time_ns()
_set_triton_ptxas_path()
kernel = load_kernel()
kernel.set_compile_info(compile_id, is_backward)
kernel.precompile(warm_cache_only=False)
elapsed_us = (time_ns() - start_ns) // 1000
get_metrics_context().add_top_n(

View File

@ -75,6 +75,8 @@ class NoTritonConfigsError(RuntimeError):
if TYPE_CHECKING:
from collections.abc import Container, Hashable, Sequence
from torch._guards import CompileId
LauncherType = Any
@ -258,6 +260,16 @@ class CachingAutotuner(KernelInterface):
self.triton_interpret = os.environ.get("TRITON_INTERPRET", "0") == "1"
# Compile-time info included in runtime logginging
self.compile_id: Optional[CompileId] = None
self.is_backward = False
def set_compile_info(
self, compile_id: Optional[CompileId], is_backward: bool
) -> None:
self.compile_id = compile_id
self.is_backward = is_backward
def precompile(
self,
warm_cache_only=False,
@ -731,8 +743,9 @@ class CachingAutotuner(KernelInterface):
"CachingAutotuner.benchmark_all_configs",
log_pt2_compile_event=True,
metadata={"kernel_name": self.inductor_meta.get("kernel_name")},
# TODO(masnesral): Enable this when we figure out how to get the CompileId:
# dynamo_compile_runtime_column_us="runtime_triton_autotune_time_us",
dynamo_compile_runtime_column_us="runtime_triton_autotune_time_us",
compile_id=self.compile_id,
is_backward=self.is_backward,
):
timings = {
launcher: self.bench(launcher, *args, **kwargs)