mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[logging] Set compile_id in the CachingAutotuner during compilation so we have it for dynamo_timed logging (#148693)
Summary: This is a simpler alternative to https://github.com/pytorch/pytorch/pull/146455, where we can stick the compileId (and forward/backward bool) in the CachingAutotuner so that we have it for logging `benchmark_all_configs`. Recall that the first attempt put the compileId in the inductor_meta and that interfered with caching. Test Plan: `python benchmarks/dynamo/torchbench.py --performance --training --amp --backend inductor --device cuda --print-compilation-time --repeat 5 --cold-start-latency --only nanogpt` * tlparse: https://fburl.com/e71yn6uc * dynamo_compile: https://fburl.com/scuba/dynamo_compile/sandbox/4ageghhv * pt2_compile_events: https://fburl.com/scuba/pt2_compile_events/4fgv1itq Pull Request resolved: https://github.com/pytorch/pytorch/pull/148693 Approved by: https://github.com/eellison
This commit is contained in:
committed by
PyTorch MergeBot
parent
5b8da17681
commit
73c8068cf8
@ -260,6 +260,7 @@ class StructuredTraceTest(TestCase):
|
||||
{"artifact": {"name": "aotautograd_cache_miss", "encoding": "json"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
|
||||
{"dynamo_cpp_guards_str": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
|
||||
{"compilation_metrics": "METRICS", "frame_id": 0, "frame_compile_id": 0, "attempt": 0}
|
||||
{"compilation_metrics_runtime": "METRICS", "frame_id": 0, "frame_compile_id": 0}
|
||||
""", # noqa: B950
|
||||
)
|
||||
|
||||
|
@ -596,7 +596,7 @@ def dynamo_timed(
|
||||
dynamo_compile_column_us: Optional[str] = None,
|
||||
dynamo_compile_runtime_column_us: Optional[str] = None,
|
||||
compile_id: Optional[CompileId] = None,
|
||||
is_forward: Optional[bool] = None,
|
||||
is_backward: Optional[bool] = None,
|
||||
log_waitcounter: bool = False,
|
||||
) -> Generator[Any, None, None]:
|
||||
"""
|
||||
@ -638,7 +638,8 @@ def dynamo_timed(
|
||||
- compile_id: In the typical case, this parameter should not be needed. Use to
|
||||
supply the compile_id for those cases where we want to log a compile_id where
|
||||
it's not naturally available, e.g., for runtime autotuning.
|
||||
- is_forward: Optionally set an is_forward field for those logging destinations
|
||||
- is_backward: Specify forward/backward directly when not available in a
|
||||
CompileContext, e.g., during runtime autotuning.
|
||||
that support it.
|
||||
- log_waitcounter: If set, we'll log a waitcounter of the form "pytorch.dynamo_timed.{key}"
|
||||
"""
|
||||
@ -664,8 +665,8 @@ def dynamo_timed(
|
||||
event_metadata.update(metadata)
|
||||
if fn_name:
|
||||
event_metadata.update({"fn_name": fn_name})
|
||||
if is_forward is not None:
|
||||
event_metadata.update({"is_backward": not is_forward})
|
||||
if is_backward is not None:
|
||||
event_metadata.update({"is_backward": is_backward})
|
||||
|
||||
chromium_log: ChromiumEventLogger = get_chromium_event_logger()
|
||||
start_ns = time.time_ns()
|
||||
@ -707,7 +708,7 @@ def dynamo_timed(
|
||||
extra={
|
||||
"compile_id": compile_id,
|
||||
"is_runtime": True,
|
||||
"is_forward": is_forward,
|
||||
"is_forward": not is_backward,
|
||||
},
|
||||
)
|
||||
cumulative_time_spent_ns[event_name] += time_spent_ns
|
||||
|
@ -41,6 +41,7 @@ from torch._inductor.runtime.compile_tasks import (
|
||||
_worker_compile_triton,
|
||||
)
|
||||
from torch._inductor.utils import clear_on_fresh_inductor_cache
|
||||
from torch._inductor.virtualized import V
|
||||
from torch.hub import _Faketqdm, tqdm
|
||||
from torch.utils._ordered_set import OrderedSet
|
||||
from torch.utils._triton import has_triton_package
|
||||
@ -300,6 +301,10 @@ class AsyncCompile:
|
||||
)
|
||||
is_parallel = self.use_process_pool()
|
||||
set_feature_use("parallel_compile_post_warmup", is_parallel)
|
||||
|
||||
compile_id = torch._guards.CompileContext.current_compile_id()
|
||||
is_backward = getattr(V.graph, "is_backward", False)
|
||||
|
||||
if is_parallel:
|
||||
# We want to support changing these env vars after (and while) the
|
||||
# process pool is running, so pass them to the subprocess to reset.
|
||||
@ -322,6 +327,7 @@ class AsyncCompile:
|
||||
# Now that we've compiled, we should clear the future
|
||||
# so it can't be used again
|
||||
CompiledTritonKernels.remove_future(source_code)
|
||||
kernel.set_compile_info(compile_id, is_backward)
|
||||
kernel.precompile(
|
||||
warm_cache_only=False, reload_kernel=reload_kernel_in_parent
|
||||
)
|
||||
@ -343,6 +349,7 @@ class AsyncCompile:
|
||||
start_ns = time_ns()
|
||||
_set_triton_ptxas_path()
|
||||
kernel = load_kernel()
|
||||
kernel.set_compile_info(compile_id, is_backward)
|
||||
kernel.precompile(warm_cache_only=False)
|
||||
elapsed_us = (time_ns() - start_ns) // 1000
|
||||
get_metrics_context().add_top_n(
|
||||
|
@ -75,6 +75,8 @@ class NoTritonConfigsError(RuntimeError):
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Container, Hashable, Sequence
|
||||
|
||||
from torch._guards import CompileId
|
||||
|
||||
LauncherType = Any
|
||||
|
||||
|
||||
@ -258,6 +260,16 @@ class CachingAutotuner(KernelInterface):
|
||||
|
||||
self.triton_interpret = os.environ.get("TRITON_INTERPRET", "0") == "1"
|
||||
|
||||
# Compile-time info included in runtime logginging
|
||||
self.compile_id: Optional[CompileId] = None
|
||||
self.is_backward = False
|
||||
|
||||
def set_compile_info(
|
||||
self, compile_id: Optional[CompileId], is_backward: bool
|
||||
) -> None:
|
||||
self.compile_id = compile_id
|
||||
self.is_backward = is_backward
|
||||
|
||||
def precompile(
|
||||
self,
|
||||
warm_cache_only=False,
|
||||
@ -731,8 +743,9 @@ class CachingAutotuner(KernelInterface):
|
||||
"CachingAutotuner.benchmark_all_configs",
|
||||
log_pt2_compile_event=True,
|
||||
metadata={"kernel_name": self.inductor_meta.get("kernel_name")},
|
||||
# TODO(masnesral): Enable this when we figure out how to get the CompileId:
|
||||
# dynamo_compile_runtime_column_us="runtime_triton_autotune_time_us",
|
||||
dynamo_compile_runtime_column_us="runtime_triton_autotune_time_us",
|
||||
compile_id=self.compile_id,
|
||||
is_backward=self.is_backward,
|
||||
):
|
||||
timings = {
|
||||
launcher: self.bench(launcher, *args, **kwargs)
|
||||
|
Reference in New Issue
Block a user