mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
This is follow-up of #165214 to continue applying ruff UP035 rule to the code base. Pull Request resolved: https://github.com/pytorch/pytorch/pull/165515 Approved by: https://github.com/Lucaskabela
245 lines
7.7 KiB
Python
245 lines
7.7 KiB
Python
import os
|
|
from collections import defaultdict
|
|
from collections.abc import Callable
|
|
from dataclasses import dataclass
|
|
from typing import Any, Optional
|
|
|
|
import matplotlib.pyplot as plt
|
|
|
|
import torch
|
|
from torch._inductor.runtime.benchmarking import benchmarker
|
|
|
|
|
|
def benchmark_kernel_in_milliseconds(func: Callable, *args, **kwargs) -> float:
|
|
# warmup
|
|
for _ in range(5):
|
|
func(*args, **kwargs)
|
|
with torch.compiler.set_stance("fail_on_recompile"):
|
|
return benchmarker.benchmark_gpu(lambda: func(*args, **kwargs))
|
|
|
|
|
|
@dataclass
|
|
class Performance:
|
|
# Benchmark setting usually the shape of the input tensor
|
|
setting: str
|
|
|
|
# Latency in milliseconds
|
|
latency: float
|
|
|
|
# Number of memory access in bytes
|
|
memory_bytes: float
|
|
|
|
# Memory bandwidth in GB/s
|
|
memory_bandwidth: float = 0.0
|
|
|
|
# Compute intensity in FLOPs/byte
|
|
compute_intensity: float = 0.0
|
|
|
|
def __post_init__(self):
|
|
self.memory_bandwidth = self.memory_bytes / (self.latency / 1000) / 1e9
|
|
|
|
def __str__(self):
|
|
return f"setting: {self.setting}, latency: {self.latency} ms, memory bandwidth: {self.memory_bandwidth} GB/s"
|
|
|
|
|
|
class BenchmarkKernel:
|
|
def __init__(self, compile_mode: str = "max-autotune-no-cudagraphs"):
|
|
self.name = self.__class__.__name__
|
|
self.available_backends: list[str] = []
|
|
self.compile_mode: str = compile_mode
|
|
|
|
# mapping from backend to list of performance results
|
|
self.profiling_results: defaultdict[str, list[Performance]] = defaultdict(list)
|
|
|
|
def get_memory_bytes(self, args, kwargs) -> int:
|
|
# Get the necessary memory access in bytes for the kernelßß
|
|
raise NotImplementedError
|
|
|
|
def get_shapes(self) -> tuple[tuple[int, ...], ...]:
|
|
# Get a list of input shapes to benchmark the kernel
|
|
raise NotImplementedError
|
|
|
|
def eager(self, args, kwargs) -> Any:
|
|
raise NotImplementedError
|
|
|
|
def compiled(self, args, kwargs) -> Any:
|
|
raise NotImplementedError
|
|
|
|
def helion(self, args, kwargs) -> Any:
|
|
raise NotImplementedError
|
|
|
|
def quack(self, args, kwargs) -> Any:
|
|
raise NotImplementedError
|
|
|
|
def liger(self, args, kwargs) -> Any:
|
|
raise NotImplementedError
|
|
|
|
def triton(self, args, kwargs) -> Any:
|
|
raise NotImplementedError
|
|
|
|
def benchmark(self):
|
|
raise NotImplementedError
|
|
|
|
def clone_inputs(self, args, kwargs) -> Any:
|
|
args_ref = [
|
|
arg.clone().detach().requires_grad_(arg.requires_grad) for arg in args
|
|
]
|
|
|
|
kwargs_ref = (
|
|
{
|
|
k: (
|
|
v.clone().detach().requires_grad_(v.requires_grad)
|
|
if isinstance(v, torch.Tensor)
|
|
else v
|
|
)
|
|
for k, v in kwargs.items()
|
|
}
|
|
if kwargs
|
|
else kwargs
|
|
)
|
|
|
|
return args_ref, kwargs_ref
|
|
|
|
def check_accuracy(self, args, kwargs) -> None:
|
|
res = {}
|
|
for backend in self.available_backends:
|
|
args_ref, kwargs_ref = self.clone_inputs(args, kwargs)
|
|
res[backend] = getattr(self, backend)(args_ref, kwargs_ref)()
|
|
gold = res["eager"]
|
|
for backend in self.available_backends:
|
|
if backend == "eager":
|
|
continue
|
|
try:
|
|
torch.testing.assert_close(res[backend], gold)
|
|
for t, gold_t in zip(res[backend], gold):
|
|
if t.requires_grad:
|
|
torch.testing.assert_close(t.grad, gold_t.grad)
|
|
print(
|
|
f"Accuracy check \033[92m✓ succeed\033[0m for {backend} backend on {self.name} kernel"
|
|
)
|
|
except Exception as e:
|
|
print(
|
|
f"Accuracy check \033[91m✗ failed\033[0m for {backend} backend on {self.name} kernel. Error {e}"
|
|
)
|
|
|
|
def benchmark_single_shape(
|
|
self, args, kwargs=None, should_check_accuracy=True, setting: str = ""
|
|
):
|
|
for backend in self.available_backends:
|
|
args_ref, kwargs_ref = self.clone_inputs(args, kwargs)
|
|
try:
|
|
avg_time = benchmark_kernel_in_milliseconds(
|
|
getattr(self, backend)(args_ref, kwargs_ref)
|
|
)
|
|
except Exception as e:
|
|
print(
|
|
f"Failed to run {backend} backend on {self.name} kernel for {setting} due to {e}"
|
|
)
|
|
self.available_backends.remove(backend) # noqa: B909
|
|
continue
|
|
mem_bytes = self.get_memory_bytes(args_ref, kwargs_ref)
|
|
perf = Performance(setting, avg_time, mem_bytes)
|
|
print(f"{self.name} kernel on {backend} backend. {perf}")
|
|
self.profiling_results[backend].append(perf)
|
|
|
|
if should_check_accuracy:
|
|
self.check_accuracy(args, kwargs)
|
|
|
|
def visualize(self) -> None:
|
|
visualize_comparison(
|
|
self.profiling_results,
|
|
title=f"{self.name}",
|
|
output_path=f"{self.name}_bench",
|
|
)
|
|
return
|
|
|
|
|
|
def get_backend_colors() -> dict[str, str]:
|
|
"""Get consistent color scheme for different backends."""
|
|
return {
|
|
"eager": "#1f77b4", # blue
|
|
"compiled": "#ff7f0e", # orange
|
|
"quack": "#2ca02c", # green
|
|
"liger": "#d62728", # red
|
|
"helion": "#9467bd", # purple
|
|
"triton": "#8c564b", # brown
|
|
"cutlass": "#e377c2", # pink
|
|
"flash_attn": "#7f7f7f", # gray
|
|
"default": "#000000", # black
|
|
}
|
|
|
|
|
|
def visualize_comparison(
|
|
profiling_results: dict[str, list[Performance]],
|
|
title: Optional[str] = None,
|
|
output_path: Optional[str] = None,
|
|
) -> None:
|
|
"""
|
|
Create a single memory_bandwidth comparison plot from profiling results.
|
|
|
|
Args:
|
|
profiling_results: Dict mapping backend names to lists of Performance objects
|
|
output_path: Path to save the plot (optional)
|
|
"""
|
|
# Get backend colors
|
|
backend_colors = get_backend_colors()
|
|
|
|
# Extract settings from eager backend which runs all settings
|
|
all_settings = []
|
|
for perf in profiling_results["eager"]:
|
|
all_settings.append(perf.setting)
|
|
|
|
# Create single plot
|
|
fig, ax = plt.subplots(1, 1, figsize=(12, 8))
|
|
|
|
for backend in profiling_results:
|
|
backend_perfs = profiling_results[backend]
|
|
perf_dict = {perf.setting: perf for perf in backend_perfs}
|
|
|
|
x_vals = []
|
|
y_vals = []
|
|
for i, setting in enumerate(all_settings):
|
|
if setting in perf_dict:
|
|
x_vals.append(i)
|
|
y_vals.append(perf_dict[setting].memory_bandwidth)
|
|
|
|
if x_vals: # Only plot if we have data
|
|
color = backend_colors.get(backend, backend_colors["default"])
|
|
ax.plot(
|
|
x_vals,
|
|
y_vals,
|
|
"o-",
|
|
label=backend,
|
|
color=color,
|
|
linewidth=2,
|
|
markersize=8,
|
|
alpha=0.8,
|
|
)
|
|
|
|
# Configure the plot
|
|
ax.set_title(title or "Memory Bandwidth Comparison", fontsize=16)
|
|
ax.set_xlabel("Shape", fontsize=12)
|
|
ax.set_ylabel("memory bandwidth (GB/s)", fontsize=12)
|
|
ax.set_xticks(range(len(all_settings)))
|
|
ax.set_xticklabels(
|
|
[
|
|
s.replace("shape: ", "").replace("[", "").replace("]", "")
|
|
for s in all_settings
|
|
],
|
|
rotation=45,
|
|
ha="right",
|
|
)
|
|
ax.legend(fontsize=10)
|
|
ax.grid(True, alpha=0.3)
|
|
|
|
plt.tight_layout()
|
|
|
|
# Save the plot if output path is provided
|
|
if output_path:
|
|
# Save as PNG
|
|
os.makedirs("pics", exist_ok=True)
|
|
full_path = os.path.join("pics", output_path + ".png")
|
|
plt.savefig(full_path, dpi=300, bbox_inches="tight", facecolor="white")
|
|
|
|
plt.close()
|