mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[inductor] log kernel autotuning result to a csv (#164191)
Example output: https://gist.github.com/shunting314/2d646c6b6cd9a79fff7a35ffee82baed ``` for each model: for each triton kernel: for each triton config: the csv contains a line for the latency and pointer to find the kernel module in the file system ``` Would use this to try to come up with heuristics to pick a single config. Pull Request resolved: https://github.com/pytorch/pytorch/pull/164191 Approved by: https://github.com/jansel, https://github.com/mlazos
This commit is contained in:
committed by
PyTorch MergeBot
parent
1a5d023a5b
commit
ffda8e5ddf
@ -7,7 +7,7 @@ import os
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from functools import lru_cache
|
||||
from typing import Callable, cast, Optional, TYPE_CHECKING, Union
|
||||
from typing import Callable, Optional, TYPE_CHECKING, Union
|
||||
|
||||
from torch._inductor import config
|
||||
from torch._inductor.utils import get_benchmark_name
|
||||
@ -16,6 +16,7 @@ from torch.utils._ordered_set import OrderedSet
|
||||
|
||||
# Prevent circular import
|
||||
if TYPE_CHECKING:
|
||||
from torch._inductor.runtime.triton_compat import Config
|
||||
from torch._inductor.scheduler import BaseSchedulerNode
|
||||
|
||||
# counter for tracking how many kernels have been generated
|
||||
@ -153,8 +154,8 @@ class MetricTable:
|
||||
bn = get_benchmark_name()
|
||||
# assert bn is not None
|
||||
row = [bn] + [row_dict[column_name] for column_name in self.column_names]
|
||||
assert all(isinstance(i, str) for i in row)
|
||||
self._write_row(cast(list[str], row))
|
||||
assert all(isinstance(i, (str, float, type(None))) for i in row)
|
||||
self._write_row(row)
|
||||
|
||||
def output_filename(self) -> str:
|
||||
return f"metric_table_{self.table_name}.csv"
|
||||
@ -165,7 +166,7 @@ class MetricTable:
|
||||
writer = csv.writer(fd, lineterminator="\n")
|
||||
writer.writerow(["model_name"] + self.column_names)
|
||||
|
||||
def _write_row(self, row: list[str]) -> None:
|
||||
def _write_row(self, row: list[str | float | None]) -> None:
|
||||
filename = self.output_filename()
|
||||
if self.num_rows_added == 0 and not os.path.exists(filename):
|
||||
self.write_header()
|
||||
@ -452,3 +453,27 @@ def is_metric_table_enabled(name: str) -> bool:
|
||||
def get_metric_table(name: str) -> MetricTable:
|
||||
assert name in REGISTERED_METRIC_TABLES, f"Metric table {name} is not defined"
|
||||
return REGISTERED_METRIC_TABLES[name]
|
||||
|
||||
|
||||
MetricTable.register_table(
|
||||
"kernel_autotune",
|
||||
[
|
||||
"kernel_path",
|
||||
"kernel_name",
|
||||
"triton_config",
|
||||
"latency_ms",
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
def log_kernel_autotune_result(
|
||||
kernel_path: str, kernel_name: str, config: Config, latency: float
|
||||
) -> None:
|
||||
get_metric_table("kernel_autotune").add_row(
|
||||
lambda: {
|
||||
"kernel_path": kernel_path,
|
||||
"kernel_name": kernel_name,
|
||||
"triton_config": str(config),
|
||||
"latency_ms": latency,
|
||||
}
|
||||
)
|
||||
|
@ -32,6 +32,7 @@ from typing import (
|
||||
import torch
|
||||
from torch._dynamo.utils import set_feature_use
|
||||
from torch._environment import is_fbcode
|
||||
from torch._inductor import metrics
|
||||
from torch._prims_common import compute_required_storage_length
|
||||
from torch.utils._ordered_set import OrderedSet
|
||||
|
||||
@ -1088,6 +1089,18 @@ class CachingAutotuner(KernelInterface):
|
||||
k.shared,
|
||||
)
|
||||
|
||||
if metrics.is_metric_table_enabled("kernel_autotune"):
|
||||
if self.fn.fn is None:
|
||||
self.fn = self._reload_kernel().fn
|
||||
|
||||
kernel_path = self.fn.fn.__code__.co_filename
|
||||
kernel_name = self.fn.__name__
|
||||
|
||||
for k, v in timings.items():
|
||||
metrics.log_kernel_autotune_result(
|
||||
kernel_path, kernel_name, k.config, v
|
||||
)
|
||||
|
||||
self.reset_to_zero_args(*args, **kwargs)
|
||||
return timings
|
||||
|
||||
|
Reference in New Issue
Block a user