mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
[inductor] log kernel autotuning result to a csv (#164191)
Example output: https://gist.github.com/shunting314/2d646c6b6cd9a79fff7a35ffee82baed ``` for each model: for each triton kernel: for each triton config: the csv contains a line for the latency and pointer to find the kernel module in the file system ``` Would use this to try to come up with heuristics to pick a single config. Pull Request resolved: https://github.com/pytorch/pytorch/pull/164191 Approved by: https://github.com/jansel, https://github.com/mlazos
This commit is contained in:
committed by
PyTorch MergeBot
parent
1a5d023a5b
commit
ffda8e5ddf
@ -7,7 +7,7 @@ import os
|
|||||||
import re
|
import re
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
from typing import Callable, cast, Optional, TYPE_CHECKING, Union
|
from typing import Callable, Optional, TYPE_CHECKING, Union
|
||||||
|
|
||||||
from torch._inductor import config
|
from torch._inductor import config
|
||||||
from torch._inductor.utils import get_benchmark_name
|
from torch._inductor.utils import get_benchmark_name
|
||||||
@ -16,6 +16,7 @@ from torch.utils._ordered_set import OrderedSet
|
|||||||
|
|
||||||
# Prevent circular import
|
# Prevent circular import
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
from torch._inductor.runtime.triton_compat import Config
|
||||||
from torch._inductor.scheduler import BaseSchedulerNode
|
from torch._inductor.scheduler import BaseSchedulerNode
|
||||||
|
|
||||||
# counter for tracking how many kernels have been generated
|
# counter for tracking how many kernels have been generated
|
||||||
@ -153,8 +154,8 @@ class MetricTable:
|
|||||||
bn = get_benchmark_name()
|
bn = get_benchmark_name()
|
||||||
# assert bn is not None
|
# assert bn is not None
|
||||||
row = [bn] + [row_dict[column_name] for column_name in self.column_names]
|
row = [bn] + [row_dict[column_name] for column_name in self.column_names]
|
||||||
assert all(isinstance(i, str) for i in row)
|
assert all(isinstance(i, (str, float, type(None))) for i in row)
|
||||||
self._write_row(cast(list[str], row))
|
self._write_row(row)
|
||||||
|
|
||||||
def output_filename(self) -> str:
|
def output_filename(self) -> str:
|
||||||
return f"metric_table_{self.table_name}.csv"
|
return f"metric_table_{self.table_name}.csv"
|
||||||
@ -165,7 +166,7 @@ class MetricTable:
|
|||||||
writer = csv.writer(fd, lineterminator="\n")
|
writer = csv.writer(fd, lineterminator="\n")
|
||||||
writer.writerow(["model_name"] + self.column_names)
|
writer.writerow(["model_name"] + self.column_names)
|
||||||
|
|
||||||
def _write_row(self, row: list[str]) -> None:
|
def _write_row(self, row: list[str | float | None]) -> None:
|
||||||
filename = self.output_filename()
|
filename = self.output_filename()
|
||||||
if self.num_rows_added == 0 and not os.path.exists(filename):
|
if self.num_rows_added == 0 and not os.path.exists(filename):
|
||||||
self.write_header()
|
self.write_header()
|
||||||
@ -452,3 +453,27 @@ def is_metric_table_enabled(name: str) -> bool:
|
|||||||
def get_metric_table(name: str) -> MetricTable:
|
def get_metric_table(name: str) -> MetricTable:
|
||||||
assert name in REGISTERED_METRIC_TABLES, f"Metric table {name} is not defined"
|
assert name in REGISTERED_METRIC_TABLES, f"Metric table {name} is not defined"
|
||||||
return REGISTERED_METRIC_TABLES[name]
|
return REGISTERED_METRIC_TABLES[name]
|
||||||
|
|
||||||
|
|
||||||
|
MetricTable.register_table(
|
||||||
|
"kernel_autotune",
|
||||||
|
[
|
||||||
|
"kernel_path",
|
||||||
|
"kernel_name",
|
||||||
|
"triton_config",
|
||||||
|
"latency_ms",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def log_kernel_autotune_result(
|
||||||
|
kernel_path: str, kernel_name: str, config: Config, latency: float
|
||||||
|
) -> None:
|
||||||
|
get_metric_table("kernel_autotune").add_row(
|
||||||
|
lambda: {
|
||||||
|
"kernel_path": kernel_path,
|
||||||
|
"kernel_name": kernel_name,
|
||||||
|
"triton_config": str(config),
|
||||||
|
"latency_ms": latency,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
@ -32,6 +32,7 @@ from typing import (
|
|||||||
import torch
|
import torch
|
||||||
from torch._dynamo.utils import set_feature_use
|
from torch._dynamo.utils import set_feature_use
|
||||||
from torch._environment import is_fbcode
|
from torch._environment import is_fbcode
|
||||||
|
from torch._inductor import metrics
|
||||||
from torch._prims_common import compute_required_storage_length
|
from torch._prims_common import compute_required_storage_length
|
||||||
from torch.utils._ordered_set import OrderedSet
|
from torch.utils._ordered_set import OrderedSet
|
||||||
|
|
||||||
@ -1088,6 +1089,18 @@ class CachingAutotuner(KernelInterface):
|
|||||||
k.shared,
|
k.shared,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if metrics.is_metric_table_enabled("kernel_autotune"):
|
||||||
|
if self.fn.fn is None:
|
||||||
|
self.fn = self._reload_kernel().fn
|
||||||
|
|
||||||
|
kernel_path = self.fn.fn.__code__.co_filename
|
||||||
|
kernel_name = self.fn.__name__
|
||||||
|
|
||||||
|
for k, v in timings.items():
|
||||||
|
metrics.log_kernel_autotune_result(
|
||||||
|
kernel_path, kernel_name, k.config, v
|
||||||
|
)
|
||||||
|
|
||||||
self.reset_to_zero_args(*args, **kwargs)
|
self.reset_to_zero_args(*args, **kwargs)
|
||||||
return timings
|
return timings
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user