diff --git a/torch/_inductor/metrics.py b/torch/_inductor/metrics.py index 116550be70e7..5e2b5a9f90d9 100644 --- a/torch/_inductor/metrics.py +++ b/torch/_inductor/metrics.py @@ -7,7 +7,7 @@ import os import re from dataclasses import dataclass from functools import lru_cache -from typing import Callable, cast, Optional, TYPE_CHECKING, Union +from typing import Callable, Optional, TYPE_CHECKING, Union from torch._inductor import config from torch._inductor.utils import get_benchmark_name @@ -16,6 +16,7 @@ from torch.utils._ordered_set import OrderedSet # Prevent circular import if TYPE_CHECKING: + from torch._inductor.runtime.triton_compat import Config from torch._inductor.scheduler import BaseSchedulerNode # counter for tracking how many kernels have been generated @@ -153,8 +154,8 @@ class MetricTable: bn = get_benchmark_name() # assert bn is not None row = [bn] + [row_dict[column_name] for column_name in self.column_names] - assert all(isinstance(i, str) for i in row) - self._write_row(cast(list[str], row)) + assert all(isinstance(i, (str, float, type(None))) for i in row) + self._write_row(row) def output_filename(self) -> str: return f"metric_table_{self.table_name}.csv" @@ -165,7 +166,7 @@ class MetricTable: writer = csv.writer(fd, lineterminator="\n") writer.writerow(["model_name"] + self.column_names) - def _write_row(self, row: list[str]) -> None: + def _write_row(self, row: list[str | float | None]) -> None: filename = self.output_filename() if self.num_rows_added == 0 and not os.path.exists(filename): self.write_header() @@ -452,3 +453,27 @@ def is_metric_table_enabled(name: str) -> bool: def get_metric_table(name: str) -> MetricTable: assert name in REGISTERED_METRIC_TABLES, f"Metric table {name} is not defined" return REGISTERED_METRIC_TABLES[name] + + +MetricTable.register_table( + "kernel_autotune", + [ + "kernel_path", + "kernel_name", + "triton_config", + "latency_ms", + ], +) + + +def log_kernel_autotune_result( + kernel_path: str, kernel_name: str, config: Config, latency: float +) -> None: + get_metric_table("kernel_autotune").add_row( + lambda: { + "kernel_path": kernel_path, + "kernel_name": kernel_name, + "triton_config": str(config), + "latency_ms": latency, + } + ) diff --git a/torch/_inductor/runtime/triton_heuristics.py b/torch/_inductor/runtime/triton_heuristics.py index 2f90424082a0..fa266d8764e5 100644 --- a/torch/_inductor/runtime/triton_heuristics.py +++ b/torch/_inductor/runtime/triton_heuristics.py @@ -32,6 +32,7 @@ from typing import ( import torch from torch._dynamo.utils import set_feature_use from torch._environment import is_fbcode +from torch._inductor import metrics from torch._prims_common import compute_required_storage_length from torch.utils._ordered_set import OrderedSet @@ -1088,6 +1089,18 @@ class CachingAutotuner(KernelInterface): k.shared, ) + if metrics.is_metric_table_enabled("kernel_autotune"): + if self.fn.fn is None: + self.fn = self._reload_kernel().fn + + kernel_path = self.fn.fn.__code__.co_filename + kernel_name = self.fn.__name__ + + for k, v in timings.items(): + metrics.log_kernel_autotune_result( + kernel_path, kernel_name, k.config, v + ) + self.reset_to_zero_args(*args, **kwargs) return timings