Compare commits

...

9 Commits

Author SHA1 Message Date
e428407c84 more stuffs 2025-06-25 20:02:29 -07:00
28b74d81a0 export fixes 2025-06-25 16:27:42 -07:00
11d45cecf3 exporting model 2025-06-25 15:38:59 -07:00
90031acf23 final stuff 2025-06-14 11:39:51 -07:00
e445a5a943 add clear_feedback_savers 2025-06-14 11:37:47 -07:00
19446879c3 TMP shapes 2025-06-14 11:37:47 -07:00
20b3e08a1b enable stuff 2025-06-14 11:37:47 -07:00
cbcf7c26ee remove breakpoint 2025-06-14 11:37:47 -07:00
7b5ea8b998 TMP update autotune configs 2025-06-14 11:37:47 -07:00
8 changed files with 1737 additions and 14 deletions

131
benchmarks/dynamo/shapes.py Normal file
View File

@ -0,0 +1,131 @@
from torch._inductor.select_algorithm import add_feedback_saver, clear_feedback_savers
import torch
from microbenchmarks.operator_inp_utils import OperatorInputsLoader
from torch.utils._ordered_set import OrderedSet
aten = torch.ops.aten
loader = OperatorInputsLoader.get_huggingface_loader()
from triton.testing import do_bench
import csv
from torch._inductor import config
torch.set_grad_enabled(False)
config.fx_graph_cache = False
config.force_disable_caches = True
def zip_dicts(dict1, dict2, d1_default, d2_default):
"""
Zip two dictionaries together, replacing missing keys with default values.
Args:
dict1 (dict): The first dictionary.
dict2 (dict): The second dictionary.
d1_default (Any): the default value for the first dictionary
d2_default (Any): the default value for the second dictionary
Yields:
tuple: A tuple containing the key, the value from dict1 (or d1_default if missing),
and the value from dict2 (or d2_default if missing).
"""
# Find the union of all keys
all_keys = OrderedSet(dict1.keys()) | OrderedSet(dict2.keys())
# Iterate over all keys
for key in all_keys:
# Get the values from both dictionaries, or default if missing
value1 = dict1.get(key)
value2 = dict2.get(key)
yield (
key,
value1 if value1 is not None else d1_default,
value2 if value2 is not None else d2_default,
)
def compare_op():
for op in [aten.mm.default, aten.addmm.default]:
for dtype in [torch.bfloat16, torch.float16]:
with open(f"{op}_{dtype}_benchmark_results.csv", "w", newline='') as file2:
file2.write("M,K,N,BLOCK_M,BLOCK_K,BLOCK_N,NUM_STAGES,NUM_WARPS,GROUP_M,do_bench_time\n")
with open(f"new_old_config_compare_{op}_{dtype}.csv", 'w', newline='') as file:
def feedback_saver(timings, name, input_nodes, choices, profiled_time):
with open(f"{op}_{dtype}_benchmark_results.csv", "a", newline='') as file2:
if name == "addmm":
M, K, N = input_nodes[1].layout.size[0], input_nodes[1].layout.size[1], input_nodes[2].layout.size[1]
elif name == "mm":
M, K, N = input_nodes[0].layout.size[0], input_nodes[0].layout.size[1], input_nodes[1].layout.size[1]
else:
raise Exception(f"Unknown op {name}")
file2.write("--------------------\n")
file2.write(f"{name},{M},{K},{N}\n")
for choice, db_time in timings.items():
if not isinstance(choice, torch._inductor.select_algorithm.TritonTemplateCaller):
continue
BLOCK_M, BLOCK_K, BLOCK_N = tuple(map(int, choice.log_info['tile_shape'].strip('()').split(',')))
line = ",".join(map(str, [M, K, N, BLOCK_M, BLOCK_K, BLOCK_N, choice.log_info['num_stages'], choice.log_info['num_warps'], choice.log_info['GROUP_M'], db_time]))
file2.write(line + "\n")
file2.flush()
add_feedback_saver(feedback_saver)
writer = csv.writer(file)
writer.writerow(['M', 'K', 'N', 'Old_Time', 'New_Time'])
for i, (args, kwargs) in enumerate(loader.get_inputs_for_operator(op, dtype=dtype, device="cuda")):
torch._dynamo.reset()
try:
inp_t = args[1]
weight_t = args[2]
except:
inp_t = args[0]
weight_t = args[1]
if len(inp_t.shape) != 2:
continue
# dont know why we have these
if inp_t.numel() == 0:
continue
print(f"{inp_t.shape[0]}_{inp_t.shape[1]}_{weight_t.shape[1]}")
speeds = []
M, K, N = inp_t.shape[0], inp_t.shape[1], weight_t.shape[1]
for new_configs in [False, True]:
with open(f"{op}_{dtype}_benchmark_results.csv", "a") as file2:
if new_configs:
file2.write("New Configs\n")
else:
file2.write("Old Configs\n")
torch._dynamo.reset()
context = config.patch({
"fx_graph_cache": False,
"force_disable_caches": True,
"new_configs": new_configs,
"max_autotune_gemm_backends": "TRITON",
})
with context:
#in1 = torch.zeros((M, K)).cuda().to(dtype=dtype)
in2 = torch.zeros((K, N)).cuda().to(dtype=dtype)
if op == aten.addmm.default:
in3 = torch.zeros((M, N)).cuda().to(dtype=dtype)
def fn(inp):
return inp @ in2 + in3
else:
def fn(inp):
return inp @ in2
mod = torch.compile(fn, mode="max-autotune-no-cudagraphs", fullgraph=True, dynamic=False)
speeds.append(do_bench(lambda: mod(inp_t)))
writer.writerow([M, K, N, speeds[0], speeds[1]])
file.flush()
clear_feedback_savers()
# compare_op("new_old_config_compare_addmm_float16.csv", aten.addmm.default, dtype=torch.float16)
# compare_op("new_old_config_compare_addmm_bfloat16.csv", aten.addmm.default, dtype=torch.bfloat16)
# compare_op("new_old_config_compare_addmm_float32.csv", aten.addmm.default, dtype=torch.float32)
compare_op()

View File

@ -905,6 +905,8 @@ profile_bandwidth_with_do_bench_using_profiling = (
disable_cpp_codegen = False
new_configs: bool = True
# Freezing will attempt to inline weights as constants in optimization
# and run constant folding and other optimizations on them. After freezing, weights
# can no longer be updated.

View File

@ -0,0 +1,384 @@
"""
Neural network model for predicting triton kernel performance.
This module provides functionality to load and use a pre-trained neural network
for predicting the performance of triton kernels.
"""
import copy
import os
import time
from collections.abc import Sequence
from typing import Any
import numpy as np
import pandas as pd # type: ignore[import-untyped]
from pyre_extensions import assert_is_instance # type: ignore[import-untyped]
import torch
import torch.nn as nn
from torch._inductor.kernel_lut import TritonGEMMConfig
from torch.optim.lr_scheduler import StepLR
# Default model path - can be overridden by environment variable
import os
script_dir = os.path.dirname(__file__)
DEFAULT_MODEL_PATH = os.path.join(os.path.dirname(__file__), "triton_h100_from_arm_108.pkl")
MODEL_PATH = os.environ.get("TRITON_KERNEL_SELECTION_MODEL_PATH", DEFAULT_MODEL_PATH)
import logging
log = logging.getLogger(__name__)
class NeuralNetwork(nn.Module):
"""
Multilayer perceptron with a single output.
It is designed for modeling runtime when there is a constant overhead of
`kernel_overhead` and the non-overhead runtime tends to be easier to model
on a log scale (e.g. doubling a dimension involved in a matrix
multiplication results in runtime roughly doubling.)
"""
def __init__(
self,
n_inputs: int,
hidden_layer_widths: Sequence[int],
kernel_overhead: float = 0.00541,
) -> None:
"""
Args:
n_inputs: Number of inputs
hidden_layer_widths: Hidden layer widths
kernel_overhead: Overhead of the kernel, assumed to be constant. The
default of 0.00541 is the lowest runtime seen in Triton H100 data.
"""
super().__init__()
self.n_inputs = n_inputs
self.kernel_overhead = kernel_overhead
self.log_kernel_overhead: float = torch.log(
torch.tensor(kernel_overhead)
).item()
all_layer_widths = list(hidden_layer_widths) + [1]
all_input_widths = [n_inputs] + list(hidden_layer_widths)
layers: list[nn.Module] = []
for n_in, n_out in zip(all_input_widths, all_layer_widths, strict=True):
layers.append(nn.Linear(n_in, n_out))
layers.append(nn.BatchNorm1d(n_out))
layers.append(nn.ReLU())
self.linear_relu_stack = nn.Sequential(*layers[:-2])
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Predict as log(exp(inputs) + self.kernel_overhead).
Works well for predicting log(runtime) when runtime contains a constant
overhead of `kernel_overhead`. (The log specification means that this
wouldn't be trivially modeled with a bias term.)
Probably could have fit the overhead rather than hard-coding it by
having `self.kernel_overhead` be a tunable parameter or by having exp
and log layers.
"""
log_base_pred = self.linear_relu_stack(x)
log_overhead_tsr = torch.full_like(
input=log_base_pred, fill_value=self.log_kernel_overhead
)
return torch.logsumexp(
torch.stack([log_base_pred, log_overhead_tsr], dim=-1), dim=-1
)
def get_nn_x(
df: pd.DataFrame, mean: torch.Tensor | None = None, std: torch.Tensor | None = None
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Standardize the data and convert it to a tensor."""
x_df = df[
[
"dtype_size",
"dim_m",
"dim_n",
"dim_k",
"total_gb",
"total_gflop",
"flops_per_byte",
"config_block_k",
"config_block_m",
"config_block_n",
"config_num_stages",
"config_num_warps",
]
].copy()
for col in x_df.columns:
x_df[col] = np.log(x_df[col])
x_tens = torch.from_numpy(x_df.astype(float).to_numpy()).to(device="cuda")
if mean is None:
mean = torch.from_numpy(assert_is_instance(x_df.mean(), pd.Series).to_numpy()).to(device="cuda")
if std is None:
std = torch.from_numpy(assert_is_instance(x_df.std(), pd.Series).to_numpy()).to(device="cuda")
x_tens -= mean
x_tens /= std
return x_tens.to(torch.float32), mean, std
def get_total_gb_feature(df: pd.DataFrame) -> pd.Series:
"""
Calculate the total gigabytes feature from the dataframe.
Args:
df: DataFrame containing the necessary columns for calculation
Returns:
Series containing the calculated total gigabytes
"""
# Calculate memory access in bytes
m, n, k = df["dim_m"], df["dim_n"], df["dim_k"]
dtype_size = df["dtype_size"] / 8 # Convert bits to bytes
# A: m×k, B: k×n, C: m×n
return ((m * k + k * n + m * n) * dtype_size) / 1e9 # Convert to GB
def get_total_gflop_feature(df: pd.DataFrame) -> pd.Series:
"""
Calculate the total gigaflops feature from the dataframe.
Args:
df: DataFrame containing the necessary columns for calculation
Returns:
Series containing the calculated total gigaflops
"""
# For matrix multiplication, flops = 2 * m * n * k
m, n, k = df["dim_m"], df["dim_n"], df["dim_k"]
return (2 * m * n * k) / 1e9 # Convert to GFLOP
class ModelWrapper:
"""
Wrapper for the neural network model that handles encoding inputs and decoding outputs.
This class provides methods to prepare inputs for the model and interpret its outputs,
handling the necessary standardization and feature engineering.
"""
def __init__(self) -> None:
"""Initialize the model wrapper with the pre-trained model and standardization parameters."""
start_time = time.time()
self.model = NeuralNetwork(
n_inputs=12, hidden_layer_widths=[2**8 for _ in range(6)]
)
self.model.load_state_dict(torch.load(MODEL_PATH))
self.model.eval()
# export the model.
end_time = time.time()
log.info("NN Kernel Prediction Model loaded.")
log.info("Took: %s seconds", end_time - start_time)
# Mean values for standardizing input features
self.mean_for_standardization = torch.tensor(
[
2.78275084,
8.23996746,
7.27791873,
7.92035942,
-2.39558163,
3.40679233,
5.80237395,
3.95781827,
4.19478321,
4.19098234,
0.9045909,
1.28331208,
]
)
# Standard deviation values for standardizing input features
self.std_for_standardization = torch.tensor(
[
0.08322756,
2.31893439,
1.65605574,
2.15447078,
2.19682881,
2.99600806,
1.24328795,
0.92352521,
0.93849802,
0.93872011,
0.57455891,
0.5837217,
]
)
def vec(
self, m: int, n: int, k: int, dsize: int, config: Any
) -> tuple[int, int, int, int, int, int, int, int, int]:
"""
Convert matrix multiplication parameters and config to a feature vector.
Args:
m: First dimension of matrix multiplication
n: Second dimension of matrix multiplication
k: Third dimension of matrix multiplication
dsize: Data size in bits (e.g., 16 for float16, 32 for float32)
config: Configuration object containing kernel parameters
Returns:
Tuple containing the extracted features
"""
kwargs = config.all_kwargs()
return (
int(m),
int(n),
int(k),
int(dsize),
int(kwargs["BLOCK_M"]),
int(kwargs["BLOCK_N"]),
int(kwargs["BLOCK_K"]),
int(kwargs["num_stages"]),
int(kwargs["num_warps"]),
)
@staticmethod
def vec_params(
m: int, n: int, k: int, dsize: int, params: TritonGEMMConfig
) -> tuple[int, int, int, int, int, int, int, int, int]:
"""
Convert matrix multiplication parameters and config to a feature vector.
Args:
m: First dimension of matrix multiplication
n: Second dimension of matrix multiplication
k: Third dimension of matrix multiplication
dsize: Data size in bits (e.g., 16 for float16, 32 for float32)
config: Configuration object containing kernel parameters
Returns:
Tuple containing the extracted features
"""
return (
int(m),
int(n),
int(k),
int(dsize),
int(params.block_m),
int(params.block_n),
int(params.block_k),
int(params.num_stages),
int(params.num_warps),
)
def encode(
self, m: int, n: int, k: int, dtype: torch.dtype, configs: list[Any]
) -> torch.Tensor:
"""
Encode the matrix multiplication parameters and configs as input tensors for the model.
Args:
m: First dimension of matrix multiplication
n: Second dimension of matrix multiplication
k: Third dimension of matrix multiplication
dtype: Data type of the matrices
configs: List of configuration objects
Returns:
Tensor containing the encoded inputs ready for the model
Raises:
ValueError: If the dtype is not supported
"""
# Determine data size based on dtype
if dtype == torch.bfloat16 or dtype == torch.float16:
dsize = 16
elif dtype == torch.float32:
dsize = 32
else:
raise ValueError(f"Unsupported dtype: {dtype}. Add support for this dtype.")
# Create feature dataframe
df = pd.DataFrame(
columns=[
"dim_m",
"dim_n",
"dim_k",
"dtype_size",
"config_block_m",
"config_block_n",
"config_block_k",
"config_num_stages",
"config_num_warps",
],
data=[self.vec(m, n, k, dsize, config) for config in configs],
)
# Reorder columns to match expected model input
df = df[
[
"dtype_size",
"dim_m",
"dim_n",
"dim_k",
"total_gb",
"total_gflop",
"flops_per_byte",
"config_block_k",
"config_block_m",
"config_block_n",
"config_num_stages",
"config_num_warps",
]
]
# Calculate derived features
df["total_gb"] = get_total_gb_feature(df=df).astype(np.float32)
df["total_gflop"] = get_total_gflop_feature(df=df).astype(np.float32)
df["flops_per_byte"] = df["total_gflop"] / df["total_gb"]
# Standardize the input
inp, _, _ = get_nn_x(
df=df, mean=self.mean_for_standardization, std=self.std_for_standardization
)
return inp
def inference(self, inp_tensor: torch.Tensor) -> torch.Tensor:
"""
Run inference on the model with the given input tensor.
Args:
inp_tensor: Input tensor for the model
Returns:
Output tensor from the model
"""
with torch.no_grad():
return self.model(inp_tensor)
def decode(self, ret_tensor: torch.Tensor) -> torch.Tensor:
"""
Decode the model output tensor.
Args:
ret_tensor: Output tensor from the model
Returns:
Decoded tensor representing runtime predictions
"""
return ret_tensor
# Create a singleton instance of the model wrapper
import functools
@functools.lru_cache
def get_model() -> ModelWrapper:
return ModelWrapper()

View File

@ -1,10 +1,14 @@
# mypy: allow-untyped-defs
import timeit
import functools
import logging
from typing import Any, Optional
import sympy
from torch._inductor.kernel.gemm_modeling import get_nn_x, NeuralNetwork, get_total_gb_feature, get_total_gflop_feature
import numpy as np
import torch
from torch._dynamo.utils import counters
from torch._inductor.autoheuristic.autoheuristic import AutoHeuristicSelectAlgorithm
@ -25,6 +29,7 @@ from ..codegen.rocm.ck_tile_universal_gemm_template import CKTileGemmTemplate
from ..codegen.rocm.ck_universal_gemm_template import CKGemmTemplate
from ..codegen.subgraph import SubgraphTemplate
from ..ir import FlexibleLayout, is_triton
import pandas as pd
from ..lowering import (
add_layout_constraint,
constrain_to_fx_strides,
@ -63,6 +68,7 @@ from .mm_common import (
scale_mm_epilogue,
scaled_mm_options,
)
from ..template_heuristics import CUDAConfigHeuristic
try:
@ -659,7 +665,147 @@ def decomposeK(a, b, k_splits):
result = torch.bmm(a_reshaped, b_reshaped, out_dtype=torch.float32)
reduced_buf = torch.sum(result, 0)
return reduced_buf.to(a.dtype)
def get_model():
fname = '/home/gabeferns/manifold/triton_h100_from_arm_108.pkl'
import sys
sys.path.append('/home/santorella/fbsource/fbcode/scripts/santorella/gemm_modeling')
import time
start_time = time.time()
model = NeuralNetwork(n_inputs=12, hidden_layer_widths=[2**8 for _ in range(6)])
model.load_state_dict(torch.load(fname))
model.to("cuda")
model.eval()
end_time = time.time()
print("model loaded!")
print(f"took: {end_time - start_time} seconds")
return model
class ModelWrapper:
def __init__(self):
self.model = get_model()
self.mean_for_standardization = torch.tensor([
2.78275084,
8.23996746,
7.27791873,
7.92035942,
-2.39558163,
3.40679233,
5.80237395,
3.95781827,
4.19478321,
4.19098234,
0.9045909,
1.28331208,
], device="cuda")
self.std_for_standardization = torch.tensor([
0.08322756,
2.31893439,
1.65605574,
2.15447078,
2.19682881,
2.99600806,
1.24328795,
0.92352521,
0.93849802,
0.93872011,
0.57455891,
0.5837217,
], device="cuda")
def vec(self, m:int, n:int, k:int, dsize:int, config) -> tuple[int, int, int, int, int, int, int, int, int]:
kwargs = config.all_kwargs()
ret = (
m,
n,
k,
dsize,
kwargs["BLOCK_M"],
kwargs["BLOCK_N"],
kwargs["BLOCK_K"],
kwargs["num_stages"],
kwargs["num_warps"]
)
# for v in ret:
# if type(v) not in [int, sympy.Expr]:
# breakpoint()
return (
int(m),
int(n),
int(k),
int(dsize),
int(kwargs["BLOCK_M"]),
int(kwargs["BLOCK_N"]),
int(kwargs["BLOCK_K"]),
int(kwargs["num_stages"]),
int(kwargs["num_warps"])
)
def encode(self, m: int, n: int, k: int, dtype: torch.dtype, configs) -> torch.Tensor:
# encodes the triton autotune config as the vector expected by the model
if dtype == torch.bfloat16 or dtype == torch.float16:
dsize = 16
elif dtype == torch.float32:
dsize = 32
else:
raise Exception("missing dtype in encode, add")
df = pd.DataFrame(
columns = [
"dim_m",
"dim_n",
"dim_k",
"dtype_size",
"config_block_m",
"config_block_n",
"config_block_k",
"config_num_stages",
"config_num_warps",
],
data = [self.vec(m, n, k, dsize, config) for config in configs]
)
df["total_gb"] = get_total_gb_feature(df=df).astype(np.float32)
df["total_gflop"] = get_total_gflop_feature(df=df).astype(np.float32)
df["flops_per_byte"] = df["total_gflop"] / df["total_gb"]
df = df[[
'dtype_size',
'dim_m',
'dim_n',
'dim_k',
'total_gb',
'total_gflop',
'flops_per_byte',
'config_block_k',
'config_block_m',
'config_block_n',
'config_num_stages',
'config_num_warps'
]]
inp, _, _ = get_nn_x(df=df, mean=self.mean_for_standardization, std=self.std_for_standardization)
return inp
def inference(self, inp_tensor: torch.Tensor) -> torch.Tensor:
from torch.export import Dim, export
inp_tensor = inp_tensor.to("cuda")
batch = Dim("batch")
example_args = (inp_tensor,)
with torch.no_grad():
exported_program: torch.export.ExportedProgram = export(
self.model, args=example_args, dynamic_shapes={"x": {0: batch}}
)
print(exported_program)
# torch._inductor.aoti_compile_and_package(
# exported_program,
# package_path="mm_model.pt2",
# )
aoti = torch._inductor.aoti_compile_and_package(exported_program, package_path="aoti_mm_model.pt2")
loaded = torch._inductor.aoti_load_package("aoti_mm_model.pt2")
print('reloaded')
breakpoint()
return self.model(inp_tensor)
def decode(self, ret_tensor: torch.Tensor) -> torch.Tensor:
# returns the runtime
# could just run exp in here
return ret_tensor
wrappedmodel = ModelWrapper()
@register_lowering(aten.mm, type_promotion_kind=None)
def tuned_mm(mat1, mat2, *, layout=None):
@ -667,6 +813,7 @@ def tuned_mm(mat1, mat2, *, layout=None):
Lowering for autotuning aten.mm with different backends (Aten, Triton, CUTLASS, etc.)
"""
m, n, k, layout, mat1, mat2 = mm_args(mat1, mat2, layout=layout)
#m, n, k = 1024, 4096, 1024
device_type = ir.get_device_type(mat1)
name = "mm"
@ -689,6 +836,7 @@ def tuned_mm(mat1, mat2, *, layout=None):
)
# options to tune from
print(f"using aten? {use_aten_gemm_kernels()}")
choices = (
[aten_mm.bind((mat1, mat2), aten_layout)] if use_aten_gemm_kernels() else []
)
@ -698,13 +846,43 @@ def tuned_mm(mat1, mat2, *, layout=None):
persistent_mm_configs = V.choices.get_persistent_mm_configs(device_type)
extra_mm_configs = V.choices.get_extra_mm_configs(device_type)
import time
if is_nonzero and use_triton_template(layout):
for config in mm_configs(
m,
n,
k,
**mm_config_kwargs(device_type, _is_large_block_for_cpu),
import os
if (
os.environ.get("TORCHINDUCTOR_NEW_CONFIGS", "0") == "1"
or torch._inductor.config.new_configs
):
exhaustive_configs = CUDAConfigHeuristic().get_exhaustive_mm_configs()
config_list = list(exhaustive_configs(m, n, k))
start_time = time.time()
t: torch.Tensor = wrappedmodel.encode(m, n, k, mat1.get_dtype(), config_list)
end_time = time.time()
print(f"Encoding exhaustive configs took {end_time - start_time:.4f} seconds")
start_time = time.time()
res = torch.exp(wrappedmodel.inference(t))
end_time = time.time()
total_time = end_time - start_time
print(f"running inference on exhaustive configs took {total_time:.4f} seconds, {total_time / len(config_list):.4f} seconds per config")
timings = list(zip(res.flatten().tolist(), config_list))
timings.sort(key=lambda x: x[0])
def print_timings(arst):
for timing, config in arst:
kw = config.kwargs
print(f"{timing}, Config(M: {kw['BLOCK_M']}, K: {kw['BLOCK_K']}, K: {kw['BLOCK_N']}, num_stages: {config.num_stages}, num_warps: {config.num_warps})")
top20 = timings[:20]
print(f"Top 20 predicted configs on M:{m} K:{k} N:{n}: ")
print_timings(top20)
prelim_configs = [cfg for _, cfg in top20]
else:
print(f"Running original configs on M:{m} K:{k} N:{n}... ")
prelim_configs = mm_configs(
m,
n,
k,
**mm_config_kwargs(device_type, _is_large_block_for_cpu),
)
for config in prelim_configs:
mm_template.maybe_append_choice(
choices,
input_nodes=(mat1, mat2),

View File

@ -0,0 +1,532 @@
import json
import logging
import typing
from collections import OrderedDict
from dataclasses import asdict, dataclass, field, fields
from functools import lru_cache
from typing import Any, get_origin, Optional, TYPE_CHECKING, TypeVar, Union
if TYPE_CHECKING:
from triton import Config as TritonConfig
try:
from typing_extensions import Self
except ImportError:
from typing import Self
import torch
from torch.utils._ordered_set import OrderedSet
# Set up logging for kernel LUT
logger = logging.getLogger(__name__)
T = TypeVar("T", bound="JSONSerializable")
LeafType = Union[
None, bool, int, float, str, OrderedDict[str, Any], torch.dtype, list[Any]
]
JSONType = Union[T, LeafType]
@dataclass(kw_only=True)
class JSONSerializable:
"""
This class implements a system similar to Pydantic Models for validating and serializing dataclasses.
"""
# Incrementing version will invalidate all LUT entries, in the case of major perf update or
# changes to the Ontology.
version: int = 1
_is_leaf: bool = False
@classmethod
def from_dict(cls, inp: OrderedDict[str, Any] | str) -> Self:
"""
Convert a dictionary representation of the object.
"""
try:
ret = OrderedDict()
if isinstance(inp, str):
if cls._is_leaf:
return cls.parse(inp)
else:
raise NotImplementedError(
f"String representation not implemented for base {cls.__name__}"
)
for k, v in inp.items():
v_type = cls.__dataclass_fields__[k].type
if get_origin(v_type) is OrderedDict:
k1_type, v1_type = typing.get_args(v_type)
if isinstance(k1_type, type) and issubclass(
k1_type, JSONSerializable
):
def kp(tmpk: Any) -> Any:
return k1_type.from_dict(tmpk)
k_process = kp
else:
def k_process(tmpk: Any) -> Any:
return tmpk
if isinstance(v1_type, type) and issubclass(
v1_type, JSONSerializable
):
def vp(tmpv: Any) -> Any:
return v1_type.from_dict(tmpv)
v_process = vp
else:
def v_process(tmpv: Any) -> Any:
return tmpv
v_new: Any = OrderedDict(
(k_process(key), v_process(val)) for key, val in v.items()
)
elif get_origin(v_type) is list:
elem_type = typing.get_args(v_type)[0]
if isinstance(elem_type, type) and issubclass(
elem_type, JSONSerializable
):
v_new = [elem_type.from_dict(x) for x in v]
else:
v_new = v
elif isinstance(v_type, type) and issubclass(v_type, JSONSerializable):
v_new = v_type.from_dict(v)
else:
v_new = v
ret[k] = v_new
return cls(**ret) # type: ignore[arg-type]
except Exception as e:
logger.error("Failed to deserialize %s from dict: %s", cls.__name__, e)
raise ValueError(f"Malformed data for {cls.__name__}: {e}") from e
def to_dict(self) -> OrderedDict[str, Any]:
"""
Convert the object to a dictionary representation.
Will be written to and from using json.dumps and json.loads.
"""
# get the fields of the dataclass
field_list = fields(self)
# filter out the _ fields
field_list = [field for field in field_list if not field.name.startswith("_")]
# ensure the fields are sorted for consistent serialization
field_list.sort(key=lambda x: x.name)
ret: OrderedDict[str, Any] = OrderedDict()
for field_obj in field_list:
field_val = getattr(self, field_obj.name)
if isinstance(field_val, JSONSerializable):
if field_val._is_leaf:
ret[field_obj.name] = str(field_val)
else:
ret[field_obj.name] = field_val.to_dict()
elif isinstance(field_val, list):
if len(field_val) == 0:
ret[field_obj.name] = []
elif isinstance(field_val[0], JSONSerializable):
if field_val[0]._is_leaf:
ret[field_obj.name] = [str(x) for x in field_val]
else:
ret[field_obj.name] = [x.to_dict() for x in field_val]
else:
ret[field_obj.name] = field_val
elif isinstance(field_val, OrderedDict):
tmp: OrderedDict[Any, Any] = OrderedDict()
for k, v in field_val.items():
if isinstance(v, JSONSerializable):
if v._is_leaf:
new_v: Any = str(v)
else:
new_v = v.to_dict()
else:
new_v = v
if isinstance(k, JSONSerializable):
if k._is_leaf:
new_k: Any = str(k)
else:
new_k = k.to_dict()
else:
new_k = k
tmp[new_k] = new_v
ret[field_obj.name] = tmp
else:
ret[field_obj.name] = field_val
return ret
def __str__(self) -> str:
"""
Return a string representation of the object.
"""
return json.dumps(self.to_dict())
@classmethod
def parse(cls, string: str) -> Self:
"""
Parse the string representaiton of the object. Only reqiured for leaf nodes.
"""
raise NotImplementedError(
f"String representation not implemented for base {cls.__name__}"
)
@dataclass(kw_only=True)
class TritonGEMMConfig(JSONSerializable):
_is_leaf: bool = True
name: str
grid: int
block_m: int
block_n: int
block_k: int
group_m: int
num_stages: int
num_warps: int
EVEN_K: bool = False
ALLOW_TF32: bool = False
USE_FAST_ACCUM: bool = False
ACC_TYPE: str = "tl.float32"
def __hash__(self) -> int:
return hash(
(
self.name,
self.grid,
self.block_m,
self.block_n,
self.block_k,
self.group_m,
self.num_stages,
self.num_warps,
self.EVEN_K,
self.ALLOW_TF32,
self.USE_FAST_ACCUM,
self.ACC_TYPE,
)
)
@classmethod
def parse(cls, string: str) -> Self:
d = json.loads(string, object_pairs_hook=OrderedDict)
# validate types, yay python :P
if "name" not in d:
raise KeyError("Missing required field: name")
if not isinstance(d["name"], str):
raise TypeError(f"name must be a string, got {type(d['name'])}")
if "grid" not in d:
raise KeyError("Missing required field: grid")
if not isinstance(d["grid"], int):
raise TypeError(f"grid must be an int, got {type(d['grid'])}")
if "block_m" not in d:
raise KeyError("Missing required field: block_m")
if not isinstance(d["block_m"], int):
raise TypeError(f"block_m must be an int, got {type(d['block_m'])}")
if "block_n" not in d:
raise KeyError("Missing required field: block_n")
if not isinstance(d["block_n"], int):
raise TypeError(f"block_n must be an int, got {type(d['block_n'])}")
if "block_k" not in d:
raise KeyError("Missing required field: block_k")
if not isinstance(d["block_k"], int):
raise TypeError(f"block_k must be an int, got {type(d['block_k'])}")
if "group_m" not in d:
raise KeyError("Missing required field: group_m")
if not isinstance(d["group_m"], int):
raise TypeError(f"group_m must be an int, got {type(d['group_m'])}")
if "num_stages" not in d:
raise KeyError("Missing required field: num_stages")
if not isinstance(d["num_stages"], int):
raise TypeError(f"num_stages must be an int, got {type(d['num_stages'])}")
if "num_warps" not in d:
raise KeyError("Missing required field: num_warps")
if not isinstance(d["num_warps"], int):
raise TypeError(f"num_warps must be an int, got {type(d['num_warps'])}")
if "EVEN_K" in d and not isinstance(d["EVEN_K"], bool):
raise TypeError(f"EVEN_K must be a bool, got {type(d['EVEN_K'])}")
if "ALLOW_TF32" in d and not isinstance(d["ALLOW_TF32"], bool):
raise TypeError(f"ALLOW_TF32 must be a bool, got {type(d['ALLOW_TF32'])}")
if "USE_FAST_ACCUM" in d and not isinstance(d["USE_FAST_ACCUM"], bool):
raise TypeError(
f"USE_FAST_ACCUM must be a bool, got {type(d['USE_FAST_ACCUM'])}"
)
if "ACC_TYPE" in d and not isinstance(d["ACC_TYPE"], str):
raise TypeError(f"ACC_TYPE must be a string, got {type(d['ACC_TYPE'])}")
return cls(**d)
@dataclass(kw_only=True)
class MMProblem(JSONSerializable):
_is_leaf: bool = True
B: int
M: int
M_dtype: torch.dtype
N: int
K: int
K_dtype: torch.dtype
out_dtype: torch.dtype
out_size: tuple[int, int, int]
out_stride: tuple[int, int, int]
def __hash__(self) -> int:
return hash(
(
self.B,
self.M,
self.M_dtype,
self.N,
self.K_dtype,
self.K,
self.out_dtype,
self.out_size,
self.out_stride,
)
)
def __str__(self) -> str:
"""
Return a string representation of the object.
"""
d = asdict(self)
d["M_dtype"] = str(d["M_dtype"]).split(".")[-1]
d["K_dtype"] = str(d["K_dtype"]).split(".")[-1]
d["out_dtype"] = str(d["out_dtype"]).split(".")[-1]
d["out_size"] = list(d["out_size"])
d["out_stride"] = list(d["out_stride"])
d = OrderedDict((k, v) for k, v in d.items() if not k.startswith("_"))
return json.dumps(d)
@classmethod
def parse(cls, string: str) -> Self:
d = json.loads(string, object_pairs_hook=OrderedDict)
# validate types, yay python :P
if "B" not in d:
raise KeyError("Missing required field: B")
if not isinstance(d["B"], int):
raise TypeError(f"B must be an int, got {type(d['B'])}")
if "M" not in d:
raise KeyError("Missing required field: M")
if not isinstance(d["M"], int):
raise TypeError(f"M must be an int, got {type(d['M'])}")
if "N" not in d:
raise KeyError("Missing required field: N")
if not isinstance(d["N"], int):
raise TypeError(f"N must be an int, got {type(d['N'])}")
if "K" not in d:
raise KeyError("Missing required field: K")
if not isinstance(d["K"], int):
raise TypeError(f"K must be an int, got {type(d['K'])}")
if "M_dtype" not in d:
raise KeyError("Missing required field: M_dtype")
if not isinstance(d["M_dtype"], str):
raise TypeError(f"M_dtype must be a string, got {type(d['M_dtype'])}")
if "K_dtype" not in d:
raise KeyError("Missing required field: K_dtype")
if not isinstance(d["K_dtype"], str):
raise TypeError(f"K_dtype must be a string, got {type(d['K_dtype'])}")
if "out_dtype" not in d:
raise KeyError("Missing required field: out_dtype")
if not isinstance(d["out_dtype"], str):
raise TypeError(f"out_dtype must be a string, got {type(d['out_dtype'])}")
if "out_size" not in d:
raise KeyError("Missing required field: out_size")
if not isinstance(d["out_size"], list):
raise TypeError(f"out_size must be a list, got {type(d['out_size'])}")
if "out_stride" not in d:
raise KeyError("Missing required field: out_stride")
if not isinstance(d["out_stride"], list):
raise TypeError(f"out_stride must be a list, got {type(d['out_stride'])}")
# Validate torch dtype strings
try:
d["M_dtype"] = getattr(torch, d["M_dtype"])
except AttributeError:
raise ValueError(f"Invalid torch dtype: {d['M_dtype']}") from None
try:
d["K_dtype"] = getattr(torch, d["K_dtype"])
except AttributeError:
raise ValueError(f"Invalid torch dtype: {d['K_dtype']}") from None
try:
d["out_dtype"] = getattr(torch, d["out_dtype"])
except AttributeError:
raise ValueError(f"Invalid torch dtype: {d['out_dtype']}") from None
d["out_size"] = tuple(d["out_size"])
d["out_stride"] = tuple(d["out_stride"])
return cls(**d)
@dataclass(kw_only=True)
class Solution(JSONSerializable):
# like mm or addmm
name: str
# mapping
config: list[TritonGEMMConfig]
@dataclass(kw_only=True)
class Operation(JSONSerializable):
name: str
solution: OrderedDict[MMProblem, Solution]
@dataclass(kw_only=True)
class Hardware(JSONSerializable):
# like gfx942:sramecc+:xnack-
operation: OrderedDict[str, Operation]
@dataclass(kw_only=True)
class Table(JSONSerializable):
hardware: OrderedDict[str, Hardware]
_set_cache: OrderedDict[
tuple[str, str, MMProblem], OrderedSet[TritonGEMMConfig]
] = field(default_factory=OrderedDict)
def serialize(self) -> str:
foo = self.to_dict()
return json.dumps(foo, indent=2)
@classmethod
def deserialize(cls, s: str) -> Optional[Self]:
try:
return cls.from_dict(json.loads(s, object_pairs_hook=OrderedDict))
except (json.JSONDecodeError, TypeError, ValueError) as e:
logger.error("Failed to deserialize table: %s", e)
return None
def lookup(
self, hardware: str, op_name: str, problem: MMProblem
) -> Optional[list[TritonGEMMConfig]]:
"""
Lookup the best TritonGEMMConfig for a given problem.
"""
if hardware not in self.hardware:
return None
tmp = self.hardware[hardware].operation
if op_name not in tmp:
return None
tmp = tmp[op_name].solution
if problem not in tmp:
return None
return tmp[problem].config
def lookup_set(
self, hardware: str, op_name: str, problem: MMProblem
) -> Optional[OrderedSet[TritonGEMMConfig]]:
"""
Easier and faster to check membership in a set, but cache the sets for runtime.
"""
if (hardware, op_name, problem) in self._set_cache:
return self._set_cache[(hardware, op_name, problem)]
problem_list = self.lookup(hardware, op_name, problem)
problem_set = OrderedSet(problem_list) if problem_list is not None else None
if problem_set is None:
return None
self._set_cache[(hardware, op_name, problem)] = problem_set
return problem_set
def filter(
self,
hardware: str,
op_name: str,
problem: MMProblem,
to_filter: list[TritonGEMMConfig],
) -> Optional[list[TritonGEMMConfig]]:
"""
Filter a list of TritonGEMMConfig for a given problem.
"""
problem_set = self.lookup_set(hardware, op_name, problem)
if problem_set is None:
return None
ret = [x for x in to_filter if x in problem_set]
if len(ret) == 0:
return None
return ret
def convert_triton_configs_to_gemm_configs(
triton_configs: list["TritonConfig"], name_prefix: str = "triton_config"
) -> list[TritonGEMMConfig]:
"""
Convert a list of triton.runtime.autotuner.Config objects to TritonGEMMConfig objects.
Args:
triton_configs: List of triton.runtime.autotuner.Config objects
name_prefix: Prefix for generated config names (default: "triton_config")
Returns:
List of TritonGEMMConfig objects
"""
gemm_configs = []
for i, config in enumerate(triton_configs):
# Extract kwargs which contain the block sizes
kwargs = getattr(config, "kwargs", {})
# Handle case where kwargs is None
if kwargs is None:
kwargs = {}
# Extract required parameters from kwargs
block_m = kwargs.get("BLOCK_M", 64) # Default fallback values
block_n = kwargs.get("BLOCK_N", 64)
block_k = kwargs.get("BLOCK_K", 32)
group_m = kwargs.get("GROUP_M", 8)
# Extract other parameters directly from config object
num_stages = getattr(config, "num_stages", 2)
num_warps = getattr(config, "num_warps", 4)
# Extract optional parameters with defaults
even_k = kwargs.get("EVEN_K", False)
allow_tf32 = kwargs.get("ALLOW_TF32", False)
use_fast_accum = kwargs.get("USE_FAST_ACCUM", False)
acc_type = kwargs.get("ACC_TYPE", "tl.float32")
# Generate a unique name for this config
config_name = f"{name_prefix}_{i}"
# Create TritonGEMMConfig object
gemm_config = TritonGEMMConfig(
name=config_name,
grid=1, # Default grid value, can be adjusted based on requirements
block_m=block_m,
block_n=block_n,
block_k=block_k,
group_m=group_m,
num_stages=num_stages,
num_warps=num_warps,
EVEN_K=even_k,
ALLOW_TF32=allow_tf32,
USE_FAST_ACCUM=use_fast_accum,
ACC_TYPE=acc_type,
)
gemm_configs.append(gemm_config)
return gemm_configs
@lru_cache
def get_table(path: str) -> Optional[Table]:
"""Load a table from a file path."""
try:
with open(path) as f:
return Table.deserialize(f.read())
except OSError as e:
logger.error("Failed to read table from %s: %s", path, e)
return None
def get_table_safe(path: str) -> Optional[Table]:
"""Safely load a table from a file path without caching."""
try:
with open(path) as f:
return Table.deserialize(f.read())
except OSError as e:
logger.error("Failed to read table from %s: %s", path, e)
return None

View File

@ -0,0 +1,383 @@
"""
Neural network model for predicting triton kernel performance.
This module provides functionality to load and use a pre-trained neural network
for predicting the performance of triton kernels.
"""
import copy
import os
import time
from collections.abc import Sequence
from typing import Any
import numpy as np
import pandas as pd # type: ignore[import-untyped]
from pyre_extensions import assert_is_instance # type: ignore[import-untyped]
import torch
import torch.nn as nn
from torch._inductor.kernel_lut import TritonGEMMConfig
from torch.optim.lr_scheduler import StepLR
# Default model path - can be overridden by environment variable
import os
script_dir = os.path.dirname(__file__)
DEFAULT_MODEL_PATH = os.path.join(os.path.dirname(__file__), "mm_model.pt2")
MODEL_PATH = os.environ.get("TRITON_KERNEL_SELECTION_MODEL_PATH", DEFAULT_MODEL_PATH)
import logging
log = logging.getLogger(__name__)
class NeuralNetwork(nn.Module):
"""
Multilayer perceptron with a single output.
It is designed for modeling runtime when there is a constant overhead of
`kernel_overhead` and the non-overhead runtime tends to be easier to model
on a log scale (e.g. doubling a dimension involved in a matrix
multiplication results in runtime roughly doubling.)
"""
def __init__(
self,
n_inputs: int,
hidden_layer_widths: Sequence[int],
kernel_overhead: float = 0.00541,
) -> None:
"""
Args:
n_inputs: Number of inputs
hidden_layer_widths: Hidden layer widths
kernel_overhead: Overhead of the kernel, assumed to be constant. The
default of 0.00541 is the lowest runtime seen in Triton H100 data.
"""
super().__init__()
self.n_inputs = n_inputs
self.kernel_overhead = kernel_overhead
self.log_kernel_overhead: float = torch.log(
torch.tensor(kernel_overhead)
).item()
all_layer_widths = list(hidden_layer_widths) + [1]
all_input_widths = [n_inputs] + list(hidden_layer_widths)
layers: list[nn.Module] = []
for n_in, n_out in zip(all_input_widths, all_layer_widths, strict=True):
layers.append(nn.Linear(n_in, n_out))
layers.append(nn.BatchNorm1d(n_out))
layers.append(nn.ReLU())
self.linear_relu_stack = nn.Sequential(*layers[:-2])
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Predict as log(exp(inputs) + self.kernel_overhead).
Works well for predicting log(runtime) when runtime contains a constant
overhead of `kernel_overhead`. (The log specification means that this
wouldn't be trivially modeled with a bias term.)
Probably could have fit the overhead rather than hard-coding it by
having `self.kernel_overhead` be a tunable parameter or by having exp
and log layers.
"""
log_base_pred = self.linear_relu_stack(x)
log_overhead_tsr = torch.full_like(
input=log_base_pred, fill_value=self.log_kernel_overhead
)
return torch.logsumexp(
torch.stack([log_base_pred, log_overhead_tsr], dim=-1), dim=-1
)
def get_nn_x(
df: pd.DataFrame, mean: torch.Tensor | None = None, std: torch.Tensor | None = None
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Standardize the data and convert it to a tensor."""
x_df = df[
[
"dtype_size",
"dim_m",
"dim_n",
"dim_k",
"total_gb",
"total_gflop",
"flops_per_byte",
"config_block_k",
"config_block_m",
"config_block_n",
"config_num_stages",
"config_num_warps",
]
].copy()
for col in x_df.columns:
x_df[col] = np.log(x_df[col])
x_tens = torch.from_numpy(x_df.astype(float).to_numpy())
if mean is None:
mean = torch.from_numpy(assert_is_instance(x_df.mean(), pd.Series).to_numpy())
if std is None:
std = torch.from_numpy(assert_is_instance(x_df.std(), pd.Series).to_numpy())
x_tens -= mean
x_tens /= std
return x_tens.to(torch.float32), mean, std
def get_total_gb_feature(df: pd.DataFrame) -> pd.Series:
"""
Calculate the total gigabytes feature from the dataframe.
Args:
df: DataFrame containing the necessary columns for calculation
Returns:
Series containing the calculated total gigabytes
"""
# Calculate memory access in bytes
m, n, k = df["dim_m"], df["dim_n"], df["dim_k"]
dtype_size = df["dtype_size"] / 8 # Convert bits to bytes
# A: m×k, B: k×n, C: m×n
return ((m * k + k * n + m * n) * dtype_size) / 1e9 # Convert to GB
def get_total_gflop_feature(df: pd.DataFrame) -> pd.Series:
"""
Calculate the total gigaflops feature from the dataframe.
Args:
df: DataFrame containing the necessary columns for calculation
Returns:
Series containing the calculated total gigaflops
"""
# For matrix multiplication, flops = 2 * m * n * k
m, n, k = df["dim_m"], df["dim_n"], df["dim_k"]
return (2 * m * n * k) / 1e9 # Convert to GFLOP
class ModelWrapper:
"""
Wrapper for the neural network model that handles encoding inputs and decoding outputs.
This class provides methods to prepare inputs for the model and interpret its outputs,
handling the necessary standardization and feature engineering.
"""
def __init__(self) -> None:
"""Initialize the model wrapper with the pre-trained model and standardization parameters."""
start_time = time.time()
self.model = NeuralNetwork(
n_inputs=12, hidden_layer_widths=[2**8 for _ in range(6)]
)
self.model = torch.export.load(MODEL_PATH)
end_time = time.time()
log.info("NN Kernel Prediction Model loaded.")
log.info("Took: %s seconds", end_time - start_time)
# Mean values for standardizing input features
self.mean_for_standardization = torch.tensor(
[
2.78275084,
8.23996746,
7.27791873,
7.92035942,
-2.39558163,
3.40679233,
5.80237395,
3.95781827,
4.19478321,
4.19098234,
0.9045909,
1.28331208,
]
)
# Standard deviation values for standardizing input features
self.std_for_standardization = torch.tensor(
[
0.08322756,
2.31893439,
1.65605574,
2.15447078,
2.19682881,
2.99600806,
1.24328795,
0.92352521,
0.93849802,
0.93872011,
0.57455891,
0.5837217,
]
)
def vec(
self, m: int, n: int, k: int, dsize: int, config: Any
) -> tuple[int, int, int, int, int, int, int, int, int]:
"""
Convert matrix multiplication parameters and config to a feature vector.
Args:
m: First dimension of matrix multiplication
n: Second dimension of matrix multiplication
k: Third dimension of matrix multiplication
dsize: Data size in bits (e.g., 16 for float16, 32 for float32)
config: Configuration object containing kernel parameters
Returns:
Tuple containing the extracted features
"""
kwargs = config.all_kwargs()
return (
int(m),
int(n),
int(k),
int(dsize),
int(kwargs["BLOCK_M"]),
int(kwargs["BLOCK_N"]),
int(kwargs["BLOCK_K"]),
int(kwargs["num_stages"]),
int(kwargs["num_warps"]),
)
@staticmethod
def vec_params(
m: int, n: int, k: int, dsize: int, params: TritonGEMMConfig
) -> tuple[int, int, int, int, int, int, int, int, int]:
"""
Convert matrix multiplication parameters and config to a feature vector.
Args:
m: First dimension of matrix multiplication
n: Second dimension of matrix multiplication
k: Third dimension of matrix multiplication
dsize: Data size in bits (e.g., 16 for float16, 32 for float32)
config: Configuration object containing kernel parameters
Returns:
Tuple containing the extracted features
"""
return (
int(m),
int(n),
int(k),
int(dsize),
int(params.block_m),
int(params.block_n),
int(params.block_k),
int(params.num_stages),
int(params.num_warps),
)
def encode(
self, m: int, n: int, k: int, dtype: torch.dtype, configs: list[Any]
) -> torch.Tensor:
"""
Encode the matrix multiplication parameters and configs as input tensors for the model.
Args:
m: First dimension of matrix multiplication
n: Second dimension of matrix multiplication
k: Third dimension of matrix multiplication
dtype: Data type of the matrices
configs: List of configuration objects
Returns:
Tensor containing the encoded inputs ready for the model
Raises:
ValueError: If the dtype is not supported
"""
# Determine data size based on dtype
if dtype == torch.bfloat16 or dtype == torch.float16:
dsize = 16
elif dtype == torch.float32:
dsize = 32
else:
raise ValueError(f"Unsupported dtype: {dtype}. Add support for this dtype.")
# Create feature dataframe
df = pd.DataFrame(
columns=[
"dim_m",
"dim_n",
"dim_k",
"dtype_size",
"config_block_m",
"config_block_n",
"config_block_k",
"config_num_stages",
"config_num_warps",
],
data=[self.vec(m, n, k, dsize, config) for config in configs],
)
# Calculate derived features
df["total_gb"] = get_total_gb_feature(df=df).astype(np.float32)
df["total_gflop"] = get_total_gflop_feature(df=df).astype(np.float32)
df["flops_per_byte"] = df["total_gflop"] / df["total_gb"]
# Reorder columns to match expected model input
df = df[
[
"dtype_size",
"dim_m",
"dim_n",
"dim_k",
"total_gb",
"total_gflop",
"flops_per_byte",
"config_block_k",
"config_block_m",
"config_block_n",
"config_num_stages",
"config_num_warps",
]
]
# Standardize the input
inp, _, _ = get_nn_x(
df=df, mean=self.mean_for_standardization, std=self.std_for_standardization
)
return inp
def inference(self, inp_tensor: torch.Tensor) -> torch.Tensor:
"""
Run inference on the model with the given input tensor.
Args:
inp_tensor: Input tensor for the model
Returns:
Output tensor from the model
"""
with torch.no_grad():
breakpoint()
return self.model.forward(inp_tensor)
def decode(self, ret_tensor: torch.Tensor) -> torch.Tensor:
"""
Decode the model output tensor.
Args:
ret_tensor: Output tensor from the model
Returns:
Decoded tensor representing runtime predictions
"""
return ret_tensor
# Create a singleton instance of the model wrapper
import functools
@functools.lru_cache
def get_model() -> ModelWrapper:
return ModelWrapper()

View File

@ -3044,6 +3044,9 @@ class AlgorithmSelectorCache(PersistentCache):
def add_feedback_saver(self, fn: FeedbackFunction):
self.feedback_saver_fns.append(fn)
def clear_feedback_savers(self):
self.feedback_saver_fns = []
_ALGORITHM_SELECTOR_CACHE: Optional[AlgorithmSelectorCache] = None
@ -3071,6 +3074,11 @@ def add_feedback_saver(
if _ALGORITHM_SELECTOR_CACHE is None:
_ALGORITHM_SELECTOR_CACHE = AlgorithmSelectorCache()
_ALGORITHM_SELECTOR_CACHE.add_feedback_saver(fn)
def clear_feedback_savers():
global _ALGORITHM_SELECTOR_CACHE
if _ALGORITHM_SELECTOR_CACHE is None:
_ALGORITHM_SELECTOR_CACHE = AlgorithmSelectorCache()
_ALGORITHM_SELECTOR_CACHE.clear_feedback_savers()
def realize_inputs(*args):

View File

@ -19,7 +19,7 @@ if TYPE_CHECKING:
from triton import Config as TritonConfig
@dataclasses.dataclass
@dataclasses.dataclass(frozen=True)
class BaseConfig:
"""
Base Gemm configuration used for most backends (CPU, CUDA)
@ -32,7 +32,7 @@ class BaseConfig:
num_warps: int
@dataclasses.dataclass
@dataclasses.dataclass(frozen=True)
class GemmConfig(BaseConfig):
"""
Gemm configuration used for most backends (CPU, CUDA)
@ -44,7 +44,7 @@ class GemmConfig(BaseConfig):
ConvConfig = BaseConfig
@dataclasses.dataclass
@dataclasses.dataclass(frozen=True)
class ROCmGemmConfig(GemmConfig):
"""
ROCm subclass for GEMMs, with AMD backend specific tuneable kernargs
@ -55,7 +55,7 @@ class ROCmGemmConfig(GemmConfig):
kpack: int = 2
@dataclasses.dataclass
@dataclasses.dataclass(frozen=True)
class ROCmConvConfig(ConvConfig):
"""
ROCm subclass for Conv, with AMD backend specific tuneable kernargs
@ -85,16 +85,92 @@ class BaseHeuristicSingleton(type):
return cls._instances[cls]
import os
class BaseConfigHeuristic(metaclass=BaseHeuristicSingleton):
"""
Base class for mm_configs, device specific triton kernels config inherit from here
"""
def __init__(self) -> None:
# List of dictionaries to store the kernel configs. Configs that evaluate to true
# will be utilised on the target platform. The configs are as follows:
# (BLOCK_M, BLOCK_N, BLOCK_K, num_stages, num_warps)
self.mm_configs: list[BaseConfig] = [
@property
def mm_configs(self) -> list[BaseConfig]:
# if (
# os.environ.get("TORCHINDUCTOR_NEW_CONFIGS", "0") == "1"
# or config.new_configs
# ):
# return [
# GemmConfig(128, 16, 128, 5, 8), # new
# GemmConfig(64, 16, 128, 4, 4), #new
# GemmConfig(64, 128, 128, 4, 8),
# GemmConfig(128, 32, 32, 4, 4),
# GemmConfig(64, 128, 64, 4, 4), # new
# GemmConfig(64, 128, 128, 3, 4), # new
# GemmConfig(128, 16, 128, 5, 8), # new
# # GemmConfig(64, 128, 256, 4, 8),
# GemmConfig(64, 64, 128, 5, 4),
# GemmConfig(128, 32, 16, 5, 2),
# GemmConfig(32, 128, 128, 4, 4),
# GemmConfig(64, 64, 32, 5, 4),
# GemmConfig(64, 128, 128, 3, 8),
# GemmConfig(32, 64, 64, 3, 4),
# # GemmConfig(64, 128, 256, 3, 8),
# GemmConfig(64, 16, 32, 4, 2),
# GemmConfig(128, 64, 64, 5, 4),
# # GemmConfig(32, 128, 256, 4, 8),
# GemmConfig(128, 32, 32, 5, 4),
# GemmConfig(256, 32, 16, 3, 2),
# GemmConfig(16, 128, 64, 1, 4),
# GemmConfig(256, 64, 16, 1, 4),
# GemmConfig(64, 128, 64, 4, 4),
# ]
# # return [
# # GemmConfig(64, 128, 128, 4, 8),
# # GemmConfig(128, 32, 32, 4, 4),
# # GemmConfig(64, 128, 256, 4, 8),
# # GemmConfig(64, 64, 128, 5, 4),
# # GemmConfig(128, 32, 16, 5, 2),
# # #GemmConfig(32, 256, 128, 4, 8),
# # GemmConfig(32, 128, 128, 4, 4),
# # GemmConfig(64, 64, 32, 5, 4),
# # GemmConfig(64, 128, 128, 3, 8),
# # GemmConfig(32, 64, 64, 3, 4),
# # GemmConfig(64, 128, 256, 3, 8),
# # GemmConfig(64, 16, 32, 4, 2),
# # GemmConfig(128, 64, 64, 5, 4),
# # GemmConfig(32, 128, 256, 4, 8),
# # GemmConfig(128, 32, 32, 5, 4),
# # # GemmConfig(64, 128, 128, 5, 4),
# # GemmConfig(256, 32, 16, 3, 2),
# # GemmConfig(16, 128, 64, 1, 4),
# # GemmConfig(256, 64, 16, 1, 4),
# # GemmConfig(64, 128, 64, 4, 4),
# # ]
# # return [
# # # GemmConfig(16, 16, 128, 5, 1),
# # # GemmConfig(16, 16, 256, 4, 1),
# # GemmConfig(64, 16, 128, 4, 4),
# # GemmConfig(64, 16, 256, 4, 4),
# # GemmConfig(64, 32, 128, 4, 4),
# # GemmConfig(64, 32, 128, 5, 8),
# # # GemmConfig(63, 32, 256, 1, 8),
# # GemmConfig(64, 64, 128, 4, 4),
# # GemmConfig(64, 128, 64, 4, 4),
# # GemmConfig(64, 128, 128, 3, 4),
# # GemmConfig(128, 16, 128, 5, 8),
# # GemmConfig(128, 128, 32, 5, 8),
# # GemmConfig(128, 128, 64, 3, 4),
# # GemmConfig(128, 128, 64, 3, 8),
# # GemmConfig(128, 128, 64, 4, 4),
# # GemmConfig(128, 128, 64, 4, 8),
# # GemmConfig(128, 256, 32, 5, 8),
# # GemmConfig(128, 256, 64, 3, 8),
# # GemmConfig(128, 256, 64, 4, 8),
# # #GemmConfig(128, 256, 64, 5, 8),
# # GemmConfig(256, 128, 32, 5, 8)
# # ]
# else:
return [
GemmConfig(32, 32, 16, 1, 2),
GemmConfig(32, 32, 128, 2, 4),
GemmConfig(32, 64, 32, 5, 8),
@ -115,6 +191,35 @@ class BaseConfigHeuristic(metaclass=BaseHeuristicSingleton):
GemmConfig(128, 128, 64, 3, 4),
GemmConfig(128, 128, 64, 5, 8),
]
# TODO more results on compile time overhead
# TODO compare to exhaustive results
# TODO ask about cudagraph benchmarking
def __init__(self) -> None:
# List of dictionaries to store the kernel configs. Configs that evaluate to true
# will be utilised on the target platform. The configs are as follows:
# (BLOCK_M, BLOCK_N, BLOCK_K, num_stages, num_warps)
# self.mm_configs: list[BaseConfig] = [
# GemmConfig(32, 32, 16, 1, 2),
# GemmConfig(32, 32, 128, 2, 4),
# GemmConfig(32, 64, 32, 5, 8),
# GemmConfig(64, 32, 32, 5, 8),
# GemmConfig(64, 32, 128, 5, 4),
# GemmConfig(64, 64, 16, 2, 4),
# GemmConfig(64, 64, 32, 2, 4),
# GemmConfig(64, 64, 64, 3, 8),
# GemmConfig(64, 64, 128, 5, 4),
# GemmConfig(64, 128, 32, 3, 4),
# GemmConfig(64, 128, 32, 4, 8),
# GemmConfig(64, 128, 64, 3, 4),
# GemmConfig(64, 128, 128, 4, 4),
# GemmConfig(128, 64, 32, 3, 4),
# GemmConfig(128, 64, 32, 4, 8),
# GemmConfig(128, 128, 32, 2, 8),
# GemmConfig(128, 128, 32, 3, 4),
# GemmConfig(128, 128, 64, 3, 4),
# GemmConfig(128, 128, 64, 5, 8),
# ]
# Exhaustive search for mm configs
self.exhaustive_configs: list[BaseConfig] = [