[inductor] log kernel autotuning result to a csv (#164191)

Example output: https://gist.github.com/shunting314/2d646c6b6cd9a79fff7a35ffee82baed
```
for each model:
  for each triton kernel:
     for each triton config:
        the csv contains a line for the latency and pointer to find the kernel module in the file system
```

Would use this to try to come up with heuristics to pick a single config.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164191
Approved by: https://github.com/jansel, https://github.com/mlazos
This commit is contained in:
Shunting Zhang
2025-09-29 17:57:40 -07:00
committed by PyTorch MergeBot
parent 1a5d023a5b
commit ffda8e5ddf
2 changed files with 42 additions and 4 deletions

View File

@ -7,7 +7,7 @@ import os
import re
from dataclasses import dataclass
from functools import lru_cache
from typing import Callable, cast, Optional, TYPE_CHECKING, Union
from typing import Callable, Optional, TYPE_CHECKING, Union
from torch._inductor import config
from torch._inductor.utils import get_benchmark_name
@ -16,6 +16,7 @@ from torch.utils._ordered_set import OrderedSet
# Prevent circular import
if TYPE_CHECKING:
from torch._inductor.runtime.triton_compat import Config
from torch._inductor.scheduler import BaseSchedulerNode
# counter for tracking how many kernels have been generated
@ -153,8 +154,8 @@ class MetricTable:
bn = get_benchmark_name()
# assert bn is not None
row = [bn] + [row_dict[column_name] for column_name in self.column_names]
assert all(isinstance(i, str) for i in row)
self._write_row(cast(list[str], row))
assert all(isinstance(i, (str, float, type(None))) for i in row)
self._write_row(row)
def output_filename(self) -> str:
return f"metric_table_{self.table_name}.csv"
@ -165,7 +166,7 @@ class MetricTable:
writer = csv.writer(fd, lineterminator="\n")
writer.writerow(["model_name"] + self.column_names)
def _write_row(self, row: list[str]) -> None:
def _write_row(self, row: list[str | float | None]) -> None:
filename = self.output_filename()
if self.num_rows_added == 0 and not os.path.exists(filename):
self.write_header()
@ -452,3 +453,27 @@ def is_metric_table_enabled(name: str) -> bool:
def get_metric_table(name: str) -> MetricTable:
assert name in REGISTERED_METRIC_TABLES, f"Metric table {name} is not defined"
return REGISTERED_METRIC_TABLES[name]
MetricTable.register_table(
"kernel_autotune",
[
"kernel_path",
"kernel_name",
"triton_config",
"latency_ms",
],
)
def log_kernel_autotune_result(
kernel_path: str, kernel_name: str, config: Config, latency: float
) -> None:
get_metric_table("kernel_autotune").add_row(
lambda: {
"kernel_path": kernel_path,
"kernel_name": kernel_name,
"triton_config": str(config),
"latency_ms": latency,
}
)

View File

@ -32,6 +32,7 @@ from typing import (
import torch
from torch._dynamo.utils import set_feature_use
from torch._environment import is_fbcode
from torch._inductor import metrics
from torch._prims_common import compute_required_storage_length
from torch.utils._ordered_set import OrderedSet
@ -1088,6 +1089,18 @@ class CachingAutotuner(KernelInterface):
k.shared,
)
if metrics.is_metric_table_enabled("kernel_autotune"):
if self.fn.fn is None:
self.fn = self._reload_kernel().fn
kernel_path = self.fn.fn.__code__.co_filename
kernel_name = self.fn.__name__
for k, v in timings.items():
metrics.log_kernel_autotune_result(
kernel_path, kernel_name, k.config, v
)
self.reset_to_zero_args(*args, **kwargs)
return timings