mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Log autotune choices and benchmark result to scuba/chrome trace (#159496)
Summary: Report the kernel choices and benchmark data to better understand how kernels are selected and the performance gap between the best kernel (likely a CUDA kernel) and Triton kernels. **Example** Event: mm_template_autotuning Column: autotune_choices ```json { "num_choices": 52, "num_triton_choices": 19, "best_kernel": "cutlass_f6c25cf2", "best_kernel_desc": "cutlass3x_sm90_tensorop_gemm_f16_f16_f32_void_f16_128x256x64_2x1x1_0_tnn_align8_stream_k_warpspecialized_cooperative_epi_tma swizzle=8", "best_time": 0.6283040046691895, "best_triton_pos": 26, "best_triton_time": 0.6832960247993469, "best_triton_kernel": "triton_mm_17", "best_triton_kernel_desc": "ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=128, BLOCK_N=128, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=3, num_warps=4, num_consumer_groups=0, num_buffers_warp_spec=0" } ``` Test Plan: ``` TORCHINDUCTOR_MAX_AUTOTUNE_REPORT_CHOICES_STATS =1 buck2 run //scripts/wychi:test_autotune_mm 2>&1 > /tmp/mylog.txt ``` Rollback Plan: Differential Revision: D79235037 Pull Request resolved: https://github.com/pytorch/pytorch/pull/159496 Approved by: https://github.com/masnesral
This commit is contained in:
committed by
PyTorch MergeBot
parent
fd6a6658c3
commit
b599d91738
@ -27,7 +27,13 @@ import torch
|
||||
import torch._inductor.async_compile # noqa: F401 required to warm up AsyncCompile pools
|
||||
from torch._dynamo.device_interface import get_interface_for_device
|
||||
from torch._dynamo.testing import rand_strided
|
||||
from torch._dynamo.utils import counters, dynamo_timed, identity, preserve_rng_state
|
||||
from torch._dynamo.utils import (
|
||||
counters,
|
||||
dynamo_timed,
|
||||
get_chromium_event_logger,
|
||||
identity,
|
||||
preserve_rng_state,
|
||||
)
|
||||
from torch._inductor.utils import clear_on_fresh_cache
|
||||
from torch.utils._filelock import FileLock
|
||||
from torch.utils._ordered_set import OrderedSet
|
||||
@ -2339,7 +2345,12 @@ class AlgorithmSelectorCache(PersistentCache):
|
||||
dynamo_compile_column_us="compile_time_autotune_time_us",
|
||||
metadata=_autotune_metadata(input_nodes),
|
||||
):
|
||||
return benchmark(choices, hint_override=hint_override)
|
||||
benchmark_results = benchmark(choices, hint_override=hint_override)
|
||||
if config.max_autotune_report_choices_stats:
|
||||
_log_autotune_choices_stats(
|
||||
f"{name}_template_autotuning", benchmark_results
|
||||
)
|
||||
return benchmark_results
|
||||
|
||||
if config.autotune_in_subproc:
|
||||
# Initialize the suprocess pool so it will warmup early.
|
||||
@ -3396,5 +3407,50 @@ def _autotune_metadata(input_nodes):
|
||||
}
|
||||
|
||||
|
||||
def _log_autotune_choices_stats(
|
||||
event_name: str, timings: dict[ChoiceCaller, float]
|
||||
) -> None:
|
||||
"""Helper function to extract autotune metadata from benchmark results."""
|
||||
if not timings:
|
||||
return None
|
||||
|
||||
metadata: dict[str, Union[int, float, str]] = {
|
||||
"num_choices": len(timings),
|
||||
"num_triton_choices": len(
|
||||
[c for c in timings if isinstance(c, TritonTemplateCaller)]
|
||||
),
|
||||
}
|
||||
|
||||
sorted_choices = sorted(timings, key=timings.__getitem__)
|
||||
best_choice = sorted_choices[0]
|
||||
metadata["best_kernel"] = best_choice.name
|
||||
if best_choice.description:
|
||||
metadata["best_kernel_desc"] = best_choice.description
|
||||
metadata["best_time"] = timings[best_choice]
|
||||
|
||||
best_triton_pos = next(
|
||||
(
|
||||
i
|
||||
for i, choice in enumerate(sorted_choices)
|
||||
if isinstance(choice, TritonTemplateCaller)
|
||||
),
|
||||
None,
|
||||
)
|
||||
if best_triton_pos is not None:
|
||||
metadata["best_triton_pos"] = best_triton_pos
|
||||
best_triton_kernel = sorted_choices[best_triton_pos]
|
||||
if best_triton_pos != 0:
|
||||
metadata["best_triton_time"] = timings[best_triton_kernel]
|
||||
metadata["best_triton_kernel"] = best_triton_kernel.name
|
||||
if best_triton_kernel.description:
|
||||
metadata["best_triton_kernel_desc"] = best_triton_kernel.description
|
||||
|
||||
payload = json.dumps(metadata)
|
||||
get_chromium_event_logger().add_event_data(
|
||||
event_name, autotune_choices_stats=payload
|
||||
)
|
||||
sys.stderr.write(f"Autotune Choices Stats:\n{payload}\n")
|
||||
|
||||
|
||||
# ensure lowering is imported so that `extern_kernels.*` is populated
|
||||
from . import lowering # noqa: F401
|
||||
|
Reference in New Issue
Block a user