mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-02 14:34:54 +08:00
Compare commits
9 Commits
ciflow/tru
...
exclamafor
| Author | SHA1 | Date | |
|---|---|---|---|
| e428407c84 | |||
| 28b74d81a0 | |||
| 11d45cecf3 | |||
| 90031acf23 | |||
| e445a5a943 | |||
| 19446879c3 | |||
| 20b3e08a1b | |||
| cbcf7c26ee | |||
| 7b5ea8b998 |
131
benchmarks/dynamo/shapes.py
Normal file
131
benchmarks/dynamo/shapes.py
Normal file
@ -0,0 +1,131 @@
|
||||
from torch._inductor.select_algorithm import add_feedback_saver, clear_feedback_savers
|
||||
import torch
|
||||
from microbenchmarks.operator_inp_utils import OperatorInputsLoader
|
||||
from torch.utils._ordered_set import OrderedSet
|
||||
aten = torch.ops.aten
|
||||
loader = OperatorInputsLoader.get_huggingface_loader()
|
||||
from triton.testing import do_bench
|
||||
import csv
|
||||
|
||||
from torch._inductor import config
|
||||
torch.set_grad_enabled(False)
|
||||
config.fx_graph_cache = False
|
||||
config.force_disable_caches = True
|
||||
|
||||
def zip_dicts(dict1, dict2, d1_default, d2_default):
|
||||
"""
|
||||
Zip two dictionaries together, replacing missing keys with default values.
|
||||
|
||||
Args:
|
||||
dict1 (dict): The first dictionary.
|
||||
dict2 (dict): The second dictionary.
|
||||
d1_default (Any): the default value for the first dictionary
|
||||
d2_default (Any): the default value for the second dictionary
|
||||
|
||||
Yields:
|
||||
tuple: A tuple containing the key, the value from dict1 (or d1_default if missing),
|
||||
and the value from dict2 (or d2_default if missing).
|
||||
"""
|
||||
# Find the union of all keys
|
||||
all_keys = OrderedSet(dict1.keys()) | OrderedSet(dict2.keys())
|
||||
|
||||
# Iterate over all keys
|
||||
for key in all_keys:
|
||||
# Get the values from both dictionaries, or default if missing
|
||||
value1 = dict1.get(key)
|
||||
value2 = dict2.get(key)
|
||||
|
||||
yield (
|
||||
key,
|
||||
value1 if value1 is not None else d1_default,
|
||||
value2 if value2 is not None else d2_default,
|
||||
)
|
||||
|
||||
|
||||
def compare_op():
|
||||
for op in [aten.mm.default, aten.addmm.default]:
|
||||
for dtype in [torch.bfloat16, torch.float16]:
|
||||
with open(f"{op}_{dtype}_benchmark_results.csv", "w", newline='') as file2:
|
||||
file2.write("M,K,N,BLOCK_M,BLOCK_K,BLOCK_N,NUM_STAGES,NUM_WARPS,GROUP_M,do_bench_time\n")
|
||||
|
||||
with open(f"new_old_config_compare_{op}_{dtype}.csv", 'w', newline='') as file:
|
||||
def feedback_saver(timings, name, input_nodes, choices, profiled_time):
|
||||
with open(f"{op}_{dtype}_benchmark_results.csv", "a", newline='') as file2:
|
||||
if name == "addmm":
|
||||
M, K, N = input_nodes[1].layout.size[0], input_nodes[1].layout.size[1], input_nodes[2].layout.size[1]
|
||||
elif name == "mm":
|
||||
M, K, N = input_nodes[0].layout.size[0], input_nodes[0].layout.size[1], input_nodes[1].layout.size[1]
|
||||
else:
|
||||
raise Exception(f"Unknown op {name}")
|
||||
|
||||
file2.write("--------------------\n")
|
||||
file2.write(f"{name},{M},{K},{N}\n")
|
||||
for choice, db_time in timings.items():
|
||||
if not isinstance(choice, torch._inductor.select_algorithm.TritonTemplateCaller):
|
||||
continue
|
||||
BLOCK_M, BLOCK_K, BLOCK_N = tuple(map(int, choice.log_info['tile_shape'].strip('()').split(',')))
|
||||
line = ",".join(map(str, [M, K, N, BLOCK_M, BLOCK_K, BLOCK_N, choice.log_info['num_stages'], choice.log_info['num_warps'], choice.log_info['GROUP_M'], db_time]))
|
||||
file2.write(line + "\n")
|
||||
file2.flush()
|
||||
add_feedback_saver(feedback_saver)
|
||||
writer = csv.writer(file)
|
||||
writer.writerow(['M', 'K', 'N', 'Old_Time', 'New_Time'])
|
||||
for i, (args, kwargs) in enumerate(loader.get_inputs_for_operator(op, dtype=dtype, device="cuda")):
|
||||
torch._dynamo.reset()
|
||||
|
||||
try:
|
||||
inp_t = args[1]
|
||||
weight_t = args[2]
|
||||
except:
|
||||
inp_t = args[0]
|
||||
weight_t = args[1]
|
||||
|
||||
|
||||
if len(inp_t.shape) != 2:
|
||||
continue
|
||||
|
||||
# dont know why we have these
|
||||
if inp_t.numel() == 0:
|
||||
continue
|
||||
|
||||
print(f"{inp_t.shape[0]}_{inp_t.shape[1]}_{weight_t.shape[1]}")
|
||||
speeds = []
|
||||
M, K, N = inp_t.shape[0], inp_t.shape[1], weight_t.shape[1]
|
||||
|
||||
for new_configs in [False, True]:
|
||||
with open(f"{op}_{dtype}_benchmark_results.csv", "a") as file2:
|
||||
if new_configs:
|
||||
file2.write("New Configs\n")
|
||||
else:
|
||||
file2.write("Old Configs\n")
|
||||
torch._dynamo.reset()
|
||||
|
||||
context = config.patch({
|
||||
"fx_graph_cache": False,
|
||||
"force_disable_caches": True,
|
||||
"new_configs": new_configs,
|
||||
"max_autotune_gemm_backends": "TRITON",
|
||||
})
|
||||
|
||||
with context:
|
||||
#in1 = torch.zeros((M, K)).cuda().to(dtype=dtype)
|
||||
in2 = torch.zeros((K, N)).cuda().to(dtype=dtype)
|
||||
if op == aten.addmm.default:
|
||||
in3 = torch.zeros((M, N)).cuda().to(dtype=dtype)
|
||||
def fn(inp):
|
||||
return inp @ in2 + in3
|
||||
else:
|
||||
def fn(inp):
|
||||
return inp @ in2
|
||||
|
||||
mod = torch.compile(fn, mode="max-autotune-no-cudagraphs", fullgraph=True, dynamic=False)
|
||||
speeds.append(do_bench(lambda: mod(inp_t)))
|
||||
|
||||
writer.writerow([M, K, N, speeds[0], speeds[1]])
|
||||
file.flush()
|
||||
clear_feedback_savers()
|
||||
|
||||
# compare_op("new_old_config_compare_addmm_float16.csv", aten.addmm.default, dtype=torch.float16)
|
||||
# compare_op("new_old_config_compare_addmm_bfloat16.csv", aten.addmm.default, dtype=torch.bfloat16)
|
||||
# compare_op("new_old_config_compare_addmm_float32.csv", aten.addmm.default, dtype=torch.float32)
|
||||
compare_op()
|
||||
@ -905,6 +905,8 @@ profile_bandwidth_with_do_bench_using_profiling = (
|
||||
disable_cpp_codegen = False
|
||||
|
||||
|
||||
new_configs: bool = True
|
||||
|
||||
# Freezing will attempt to inline weights as constants in optimization
|
||||
# and run constant folding and other optimizations on them. After freezing, weights
|
||||
# can no longer be updated.
|
||||
|
||||
384
torch/_inductor/kernel/gemm_modeling.py
Normal file
384
torch/_inductor/kernel/gemm_modeling.py
Normal file
@ -0,0 +1,384 @@
|
||||
"""
|
||||
Neural network model for predicting triton kernel performance.
|
||||
|
||||
This module provides functionality to load and use a pre-trained neural network
|
||||
for predicting the performance of triton kernels.
|
||||
"""
|
||||
|
||||
import copy
|
||||
import os
|
||||
import time
|
||||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd # type: ignore[import-untyped]
|
||||
from pyre_extensions import assert_is_instance # type: ignore[import-untyped]
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch._inductor.kernel_lut import TritonGEMMConfig
|
||||
from torch.optim.lr_scheduler import StepLR
|
||||
|
||||
|
||||
# Default model path - can be overridden by environment variable
|
||||
import os
|
||||
script_dir = os.path.dirname(__file__)
|
||||
DEFAULT_MODEL_PATH = os.path.join(os.path.dirname(__file__), "triton_h100_from_arm_108.pkl")
|
||||
MODEL_PATH = os.environ.get("TRITON_KERNEL_SELECTION_MODEL_PATH", DEFAULT_MODEL_PATH)
|
||||
import logging
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class NeuralNetwork(nn.Module):
|
||||
"""
|
||||
Multilayer perceptron with a single output.
|
||||
|
||||
It is designed for modeling runtime when there is a constant overhead of
|
||||
`kernel_overhead` and the non-overhead runtime tends to be easier to model
|
||||
on a log scale (e.g. doubling a dimension involved in a matrix
|
||||
multiplication results in runtime roughly doubling.)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
n_inputs: int,
|
||||
hidden_layer_widths: Sequence[int],
|
||||
kernel_overhead: float = 0.00541,
|
||||
) -> None:
|
||||
"""
|
||||
Args:
|
||||
n_inputs: Number of inputs
|
||||
hidden_layer_widths: Hidden layer widths
|
||||
kernel_overhead: Overhead of the kernel, assumed to be constant. The
|
||||
default of 0.00541 is the lowest runtime seen in Triton H100 data.
|
||||
"""
|
||||
super().__init__()
|
||||
self.n_inputs = n_inputs
|
||||
self.kernel_overhead = kernel_overhead
|
||||
self.log_kernel_overhead: float = torch.log(
|
||||
torch.tensor(kernel_overhead)
|
||||
).item()
|
||||
all_layer_widths = list(hidden_layer_widths) + [1]
|
||||
all_input_widths = [n_inputs] + list(hidden_layer_widths)
|
||||
layers: list[nn.Module] = []
|
||||
for n_in, n_out in zip(all_input_widths, all_layer_widths, strict=True):
|
||||
layers.append(nn.Linear(n_in, n_out))
|
||||
layers.append(nn.BatchNorm1d(n_out))
|
||||
layers.append(nn.ReLU())
|
||||
|
||||
self.linear_relu_stack = nn.Sequential(*layers[:-2])
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Predict as log(exp(inputs) + self.kernel_overhead).
|
||||
|
||||
Works well for predicting log(runtime) when runtime contains a constant
|
||||
overhead of `kernel_overhead`. (The log specification means that this
|
||||
wouldn't be trivially modeled with a bias term.)
|
||||
|
||||
Probably could have fit the overhead rather than hard-coding it by
|
||||
having `self.kernel_overhead` be a tunable parameter or by having exp
|
||||
and log layers.
|
||||
"""
|
||||
log_base_pred = self.linear_relu_stack(x)
|
||||
log_overhead_tsr = torch.full_like(
|
||||
input=log_base_pred, fill_value=self.log_kernel_overhead
|
||||
)
|
||||
return torch.logsumexp(
|
||||
torch.stack([log_base_pred, log_overhead_tsr], dim=-1), dim=-1
|
||||
)
|
||||
|
||||
|
||||
def get_nn_x(
|
||||
df: pd.DataFrame, mean: torch.Tensor | None = None, std: torch.Tensor | None = None
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""Standardize the data and convert it to a tensor."""
|
||||
x_df = df[
|
||||
[
|
||||
"dtype_size",
|
||||
"dim_m",
|
||||
"dim_n",
|
||||
"dim_k",
|
||||
"total_gb",
|
||||
"total_gflop",
|
||||
"flops_per_byte",
|
||||
"config_block_k",
|
||||
"config_block_m",
|
||||
"config_block_n",
|
||||
"config_num_stages",
|
||||
"config_num_warps",
|
||||
]
|
||||
].copy()
|
||||
for col in x_df.columns:
|
||||
x_df[col] = np.log(x_df[col])
|
||||
|
||||
x_tens = torch.from_numpy(x_df.astype(float).to_numpy()).to(device="cuda")
|
||||
if mean is None:
|
||||
mean = torch.from_numpy(assert_is_instance(x_df.mean(), pd.Series).to_numpy()).to(device="cuda")
|
||||
if std is None:
|
||||
std = torch.from_numpy(assert_is_instance(x_df.std(), pd.Series).to_numpy()).to(device="cuda")
|
||||
x_tens -= mean
|
||||
x_tens /= std
|
||||
return x_tens.to(torch.float32), mean, std
|
||||
|
||||
|
||||
def get_total_gb_feature(df: pd.DataFrame) -> pd.Series:
|
||||
"""
|
||||
Calculate the total gigabytes feature from the dataframe.
|
||||
|
||||
Args:
|
||||
df: DataFrame containing the necessary columns for calculation
|
||||
|
||||
Returns:
|
||||
Series containing the calculated total gigabytes
|
||||
"""
|
||||
# Calculate memory access in bytes
|
||||
m, n, k = df["dim_m"], df["dim_n"], df["dim_k"]
|
||||
dtype_size = df["dtype_size"] / 8 # Convert bits to bytes
|
||||
|
||||
# A: m×k, B: k×n, C: m×n
|
||||
return ((m * k + k * n + m * n) * dtype_size) / 1e9 # Convert to GB
|
||||
|
||||
|
||||
def get_total_gflop_feature(df: pd.DataFrame) -> pd.Series:
|
||||
"""
|
||||
Calculate the total gigaflops feature from the dataframe.
|
||||
|
||||
Args:
|
||||
df: DataFrame containing the necessary columns for calculation
|
||||
|
||||
Returns:
|
||||
Series containing the calculated total gigaflops
|
||||
"""
|
||||
# For matrix multiplication, flops = 2 * m * n * k
|
||||
m, n, k = df["dim_m"], df["dim_n"], df["dim_k"]
|
||||
return (2 * m * n * k) / 1e9 # Convert to GFLOP
|
||||
|
||||
|
||||
class ModelWrapper:
|
||||
"""
|
||||
Wrapper for the neural network model that handles encoding inputs and decoding outputs.
|
||||
|
||||
This class provides methods to prepare inputs for the model and interpret its outputs,
|
||||
handling the necessary standardization and feature engineering.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize the model wrapper with the pre-trained model and standardization parameters."""
|
||||
start_time = time.time()
|
||||
self.model = NeuralNetwork(
|
||||
n_inputs=12, hidden_layer_widths=[2**8 for _ in range(6)]
|
||||
)
|
||||
self.model.load_state_dict(torch.load(MODEL_PATH))
|
||||
self.model.eval()
|
||||
# export the model.
|
||||
|
||||
end_time = time.time()
|
||||
|
||||
log.info("NN Kernel Prediction Model loaded.")
|
||||
log.info("Took: %s seconds", end_time - start_time)
|
||||
|
||||
# Mean values for standardizing input features
|
||||
self.mean_for_standardization = torch.tensor(
|
||||
[
|
||||
2.78275084,
|
||||
8.23996746,
|
||||
7.27791873,
|
||||
7.92035942,
|
||||
-2.39558163,
|
||||
3.40679233,
|
||||
5.80237395,
|
||||
3.95781827,
|
||||
4.19478321,
|
||||
4.19098234,
|
||||
0.9045909,
|
||||
1.28331208,
|
||||
]
|
||||
)
|
||||
|
||||
# Standard deviation values for standardizing input features
|
||||
self.std_for_standardization = torch.tensor(
|
||||
[
|
||||
0.08322756,
|
||||
2.31893439,
|
||||
1.65605574,
|
||||
2.15447078,
|
||||
2.19682881,
|
||||
2.99600806,
|
||||
1.24328795,
|
||||
0.92352521,
|
||||
0.93849802,
|
||||
0.93872011,
|
||||
0.57455891,
|
||||
0.5837217,
|
||||
]
|
||||
)
|
||||
|
||||
def vec(
|
||||
self, m: int, n: int, k: int, dsize: int, config: Any
|
||||
) -> tuple[int, int, int, int, int, int, int, int, int]:
|
||||
"""
|
||||
Convert matrix multiplication parameters and config to a feature vector.
|
||||
|
||||
Args:
|
||||
m: First dimension of matrix multiplication
|
||||
n: Second dimension of matrix multiplication
|
||||
k: Third dimension of matrix multiplication
|
||||
dsize: Data size in bits (e.g., 16 for float16, 32 for float32)
|
||||
config: Configuration object containing kernel parameters
|
||||
|
||||
Returns:
|
||||
Tuple containing the extracted features
|
||||
"""
|
||||
kwargs = config.all_kwargs()
|
||||
|
||||
return (
|
||||
int(m),
|
||||
int(n),
|
||||
int(k),
|
||||
int(dsize),
|
||||
int(kwargs["BLOCK_M"]),
|
||||
int(kwargs["BLOCK_N"]),
|
||||
int(kwargs["BLOCK_K"]),
|
||||
int(kwargs["num_stages"]),
|
||||
int(kwargs["num_warps"]),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def vec_params(
|
||||
m: int, n: int, k: int, dsize: int, params: TritonGEMMConfig
|
||||
) -> tuple[int, int, int, int, int, int, int, int, int]:
|
||||
"""
|
||||
Convert matrix multiplication parameters and config to a feature vector.
|
||||
|
||||
Args:
|
||||
m: First dimension of matrix multiplication
|
||||
n: Second dimension of matrix multiplication
|
||||
k: Third dimension of matrix multiplication
|
||||
dsize: Data size in bits (e.g., 16 for float16, 32 for float32)
|
||||
config: Configuration object containing kernel parameters
|
||||
|
||||
Returns:
|
||||
Tuple containing the extracted features
|
||||
"""
|
||||
|
||||
return (
|
||||
int(m),
|
||||
int(n),
|
||||
int(k),
|
||||
int(dsize),
|
||||
int(params.block_m),
|
||||
int(params.block_n),
|
||||
int(params.block_k),
|
||||
int(params.num_stages),
|
||||
int(params.num_warps),
|
||||
)
|
||||
|
||||
def encode(
|
||||
self, m: int, n: int, k: int, dtype: torch.dtype, configs: list[Any]
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Encode the matrix multiplication parameters and configs as input tensors for the model.
|
||||
|
||||
Args:
|
||||
m: First dimension of matrix multiplication
|
||||
n: Second dimension of matrix multiplication
|
||||
k: Third dimension of matrix multiplication
|
||||
dtype: Data type of the matrices
|
||||
configs: List of configuration objects
|
||||
|
||||
Returns:
|
||||
Tensor containing the encoded inputs ready for the model
|
||||
|
||||
Raises:
|
||||
ValueError: If the dtype is not supported
|
||||
"""
|
||||
# Determine data size based on dtype
|
||||
if dtype == torch.bfloat16 or dtype == torch.float16:
|
||||
dsize = 16
|
||||
elif dtype == torch.float32:
|
||||
dsize = 32
|
||||
else:
|
||||
raise ValueError(f"Unsupported dtype: {dtype}. Add support for this dtype.")
|
||||
|
||||
# Create feature dataframe
|
||||
df = pd.DataFrame(
|
||||
columns=[
|
||||
"dim_m",
|
||||
"dim_n",
|
||||
"dim_k",
|
||||
"dtype_size",
|
||||
"config_block_m",
|
||||
"config_block_n",
|
||||
"config_block_k",
|
||||
"config_num_stages",
|
||||
"config_num_warps",
|
||||
],
|
||||
data=[self.vec(m, n, k, dsize, config) for config in configs],
|
||||
)
|
||||
# Reorder columns to match expected model input
|
||||
df = df[
|
||||
[
|
||||
"dtype_size",
|
||||
"dim_m",
|
||||
"dim_n",
|
||||
"dim_k",
|
||||
"total_gb",
|
||||
"total_gflop",
|
||||
"flops_per_byte",
|
||||
"config_block_k",
|
||||
"config_block_m",
|
||||
"config_block_n",
|
||||
"config_num_stages",
|
||||
"config_num_warps",
|
||||
]
|
||||
]
|
||||
|
||||
# Calculate derived features
|
||||
df["total_gb"] = get_total_gb_feature(df=df).astype(np.float32)
|
||||
df["total_gflop"] = get_total_gflop_feature(df=df).astype(np.float32)
|
||||
df["flops_per_byte"] = df["total_gflop"] / df["total_gb"]
|
||||
|
||||
# Standardize the input
|
||||
inp, _, _ = get_nn_x(
|
||||
df=df, mean=self.mean_for_standardization, std=self.std_for_standardization
|
||||
)
|
||||
|
||||
return inp
|
||||
|
||||
def inference(self, inp_tensor: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Run inference on the model with the given input tensor.
|
||||
|
||||
Args:
|
||||
inp_tensor: Input tensor for the model
|
||||
|
||||
Returns:
|
||||
Output tensor from the model
|
||||
"""
|
||||
with torch.no_grad():
|
||||
return self.model(inp_tensor)
|
||||
|
||||
def decode(self, ret_tensor: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Decode the model output tensor.
|
||||
|
||||
Args:
|
||||
ret_tensor: Output tensor from the model
|
||||
|
||||
Returns:
|
||||
Decoded tensor representing runtime predictions
|
||||
"""
|
||||
return ret_tensor
|
||||
|
||||
|
||||
# Create a singleton instance of the model wrapper
|
||||
import functools
|
||||
|
||||
|
||||
@functools.lru_cache
|
||||
def get_model() -> ModelWrapper:
|
||||
return ModelWrapper()
|
||||
@ -1,10 +1,14 @@
|
||||
# mypy: allow-untyped-defs
|
||||
|
||||
import timeit
|
||||
import functools
|
||||
import logging
|
||||
from typing import Any, Optional
|
||||
|
||||
import sympy
|
||||
|
||||
from torch._inductor.kernel.gemm_modeling import get_nn_x, NeuralNetwork, get_total_gb_feature, get_total_gflop_feature
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch._dynamo.utils import counters
|
||||
from torch._inductor.autoheuristic.autoheuristic import AutoHeuristicSelectAlgorithm
|
||||
@ -25,6 +29,7 @@ from ..codegen.rocm.ck_tile_universal_gemm_template import CKTileGemmTemplate
|
||||
from ..codegen.rocm.ck_universal_gemm_template import CKGemmTemplate
|
||||
from ..codegen.subgraph import SubgraphTemplate
|
||||
from ..ir import FlexibleLayout, is_triton
|
||||
import pandas as pd
|
||||
from ..lowering import (
|
||||
add_layout_constraint,
|
||||
constrain_to_fx_strides,
|
||||
@ -63,6 +68,7 @@ from .mm_common import (
|
||||
scale_mm_epilogue,
|
||||
scaled_mm_options,
|
||||
)
|
||||
from ..template_heuristics import CUDAConfigHeuristic
|
||||
|
||||
|
||||
try:
|
||||
@ -659,7 +665,147 @@ def decomposeK(a, b, k_splits):
|
||||
result = torch.bmm(a_reshaped, b_reshaped, out_dtype=torch.float32)
|
||||
reduced_buf = torch.sum(result, 0)
|
||||
return reduced_buf.to(a.dtype)
|
||||
def get_model():
|
||||
fname = '/home/gabeferns/manifold/triton_h100_from_arm_108.pkl'
|
||||
import sys
|
||||
sys.path.append('/home/santorella/fbsource/fbcode/scripts/santorella/gemm_modeling')
|
||||
import time
|
||||
start_time = time.time()
|
||||
model = NeuralNetwork(n_inputs=12, hidden_layer_widths=[2**8 for _ in range(6)])
|
||||
model.load_state_dict(torch.load(fname))
|
||||
model.to("cuda")
|
||||
model.eval()
|
||||
end_time = time.time()
|
||||
print("model loaded!")
|
||||
print(f"took: {end_time - start_time} seconds")
|
||||
return model
|
||||
|
||||
class ModelWrapper:
|
||||
def __init__(self):
|
||||
self.model = get_model()
|
||||
self.mean_for_standardization = torch.tensor([
|
||||
2.78275084,
|
||||
8.23996746,
|
||||
7.27791873,
|
||||
7.92035942,
|
||||
-2.39558163,
|
||||
3.40679233,
|
||||
5.80237395,
|
||||
3.95781827,
|
||||
4.19478321,
|
||||
4.19098234,
|
||||
0.9045909,
|
||||
1.28331208,
|
||||
], device="cuda")
|
||||
|
||||
self.std_for_standardization = torch.tensor([
|
||||
0.08322756,
|
||||
2.31893439,
|
||||
1.65605574,
|
||||
2.15447078,
|
||||
2.19682881,
|
||||
2.99600806,
|
||||
1.24328795,
|
||||
0.92352521,
|
||||
0.93849802,
|
||||
0.93872011,
|
||||
0.57455891,
|
||||
0.5837217,
|
||||
], device="cuda")
|
||||
|
||||
def vec(self, m:int, n:int, k:int, dsize:int, config) -> tuple[int, int, int, int, int, int, int, int, int]:
|
||||
kwargs = config.all_kwargs()
|
||||
ret = (
|
||||
m,
|
||||
n,
|
||||
k,
|
||||
dsize,
|
||||
kwargs["BLOCK_M"],
|
||||
kwargs["BLOCK_N"],
|
||||
kwargs["BLOCK_K"],
|
||||
kwargs["num_stages"],
|
||||
kwargs["num_warps"]
|
||||
)
|
||||
# for v in ret:
|
||||
# if type(v) not in [int, sympy.Expr]:
|
||||
# breakpoint()
|
||||
return (
|
||||
int(m),
|
||||
int(n),
|
||||
int(k),
|
||||
int(dsize),
|
||||
int(kwargs["BLOCK_M"]),
|
||||
int(kwargs["BLOCK_N"]),
|
||||
int(kwargs["BLOCK_K"]),
|
||||
int(kwargs["num_stages"]),
|
||||
int(kwargs["num_warps"])
|
||||
)
|
||||
def encode(self, m: int, n: int, k: int, dtype: torch.dtype, configs) -> torch.Tensor:
|
||||
# encodes the triton autotune config as the vector expected by the model
|
||||
if dtype == torch.bfloat16 or dtype == torch.float16:
|
||||
dsize = 16
|
||||
elif dtype == torch.float32:
|
||||
dsize = 32
|
||||
else:
|
||||
raise Exception("missing dtype in encode, add")
|
||||
df = pd.DataFrame(
|
||||
columns = [
|
||||
"dim_m",
|
||||
"dim_n",
|
||||
"dim_k",
|
||||
"dtype_size",
|
||||
"config_block_m",
|
||||
"config_block_n",
|
||||
"config_block_k",
|
||||
"config_num_stages",
|
||||
"config_num_warps",
|
||||
],
|
||||
data = [self.vec(m, n, k, dsize, config) for config in configs]
|
||||
)
|
||||
df["total_gb"] = get_total_gb_feature(df=df).astype(np.float32)
|
||||
df["total_gflop"] = get_total_gflop_feature(df=df).astype(np.float32)
|
||||
df["flops_per_byte"] = df["total_gflop"] / df["total_gb"]
|
||||
df = df[[
|
||||
'dtype_size',
|
||||
'dim_m',
|
||||
'dim_n',
|
||||
'dim_k',
|
||||
'total_gb',
|
||||
'total_gflop',
|
||||
'flops_per_byte',
|
||||
'config_block_k',
|
||||
'config_block_m',
|
||||
'config_block_n',
|
||||
'config_num_stages',
|
||||
'config_num_warps'
|
||||
]]
|
||||
inp, _, _ = get_nn_x(df=df, mean=self.mean_for_standardization, std=self.std_for_standardization)
|
||||
return inp
|
||||
def inference(self, inp_tensor: torch.Tensor) -> torch.Tensor:
|
||||
from torch.export import Dim, export
|
||||
inp_tensor = inp_tensor.to("cuda")
|
||||
batch = Dim("batch")
|
||||
example_args = (inp_tensor,)
|
||||
with torch.no_grad():
|
||||
exported_program: torch.export.ExportedProgram = export(
|
||||
self.model, args=example_args, dynamic_shapes={"x": {0: batch}}
|
||||
)
|
||||
print(exported_program)
|
||||
# torch._inductor.aoti_compile_and_package(
|
||||
# exported_program,
|
||||
# package_path="mm_model.pt2",
|
||||
# )
|
||||
aoti = torch._inductor.aoti_compile_and_package(exported_program, package_path="aoti_mm_model.pt2")
|
||||
loaded = torch._inductor.aoti_load_package("aoti_mm_model.pt2")
|
||||
print('reloaded')
|
||||
breakpoint()
|
||||
return self.model(inp_tensor)
|
||||
|
||||
def decode(self, ret_tensor: torch.Tensor) -> torch.Tensor:
|
||||
# returns the runtime
|
||||
# could just run exp in here
|
||||
return ret_tensor
|
||||
wrappedmodel = ModelWrapper()
|
||||
|
||||
@register_lowering(aten.mm, type_promotion_kind=None)
|
||||
def tuned_mm(mat1, mat2, *, layout=None):
|
||||
@ -667,6 +813,7 @@ def tuned_mm(mat1, mat2, *, layout=None):
|
||||
Lowering for autotuning aten.mm with different backends (Aten, Triton, CUTLASS, etc.)
|
||||
"""
|
||||
m, n, k, layout, mat1, mat2 = mm_args(mat1, mat2, layout=layout)
|
||||
#m, n, k = 1024, 4096, 1024
|
||||
device_type = ir.get_device_type(mat1)
|
||||
name = "mm"
|
||||
|
||||
@ -689,6 +836,7 @@ def tuned_mm(mat1, mat2, *, layout=None):
|
||||
)
|
||||
|
||||
# options to tune from
|
||||
print(f"using aten? {use_aten_gemm_kernels()}")
|
||||
choices = (
|
||||
[aten_mm.bind((mat1, mat2), aten_layout)] if use_aten_gemm_kernels() else []
|
||||
)
|
||||
@ -698,13 +846,43 @@ def tuned_mm(mat1, mat2, *, layout=None):
|
||||
persistent_mm_configs = V.choices.get_persistent_mm_configs(device_type)
|
||||
extra_mm_configs = V.choices.get_extra_mm_configs(device_type)
|
||||
|
||||
import time
|
||||
if is_nonzero and use_triton_template(layout):
|
||||
for config in mm_configs(
|
||||
m,
|
||||
n,
|
||||
k,
|
||||
**mm_config_kwargs(device_type, _is_large_block_for_cpu),
|
||||
import os
|
||||
if (
|
||||
os.environ.get("TORCHINDUCTOR_NEW_CONFIGS", "0") == "1"
|
||||
or torch._inductor.config.new_configs
|
||||
):
|
||||
exhaustive_configs = CUDAConfigHeuristic().get_exhaustive_mm_configs()
|
||||
config_list = list(exhaustive_configs(m, n, k))
|
||||
start_time = time.time()
|
||||
t: torch.Tensor = wrappedmodel.encode(m, n, k, mat1.get_dtype(), config_list)
|
||||
end_time = time.time()
|
||||
print(f"Encoding exhaustive configs took {end_time - start_time:.4f} seconds")
|
||||
start_time = time.time()
|
||||
res = torch.exp(wrappedmodel.inference(t))
|
||||
end_time = time.time()
|
||||
total_time = end_time - start_time
|
||||
print(f"running inference on exhaustive configs took {total_time:.4f} seconds, {total_time / len(config_list):.4f} seconds per config")
|
||||
timings = list(zip(res.flatten().tolist(), config_list))
|
||||
timings.sort(key=lambda x: x[0])
|
||||
def print_timings(arst):
|
||||
for timing, config in arst:
|
||||
kw = config.kwargs
|
||||
print(f"{timing}, Config(M: {kw['BLOCK_M']}, K: {kw['BLOCK_K']}, K: {kw['BLOCK_N']}, num_stages: {config.num_stages}, num_warps: {config.num_warps})")
|
||||
top20 = timings[:20]
|
||||
print(f"Top 20 predicted configs on M:{m} K:{k} N:{n}: ")
|
||||
print_timings(top20)
|
||||
prelim_configs = [cfg for _, cfg in top20]
|
||||
else:
|
||||
print(f"Running original configs on M:{m} K:{k} N:{n}... ")
|
||||
prelim_configs = mm_configs(
|
||||
m,
|
||||
n,
|
||||
k,
|
||||
**mm_config_kwargs(device_type, _is_large_block_for_cpu),
|
||||
)
|
||||
for config in prelim_configs:
|
||||
mm_template.maybe_append_choice(
|
||||
choices,
|
||||
input_nodes=(mat1, mat2),
|
||||
|
||||
532
torch/_inductor/kernel_lut.py
Normal file
532
torch/_inductor/kernel_lut.py
Normal file
@ -0,0 +1,532 @@
|
||||
import json
|
||||
import logging
|
||||
import typing
|
||||
from collections import OrderedDict
|
||||
from dataclasses import asdict, dataclass, field, fields
|
||||
from functools import lru_cache
|
||||
from typing import Any, get_origin, Optional, TYPE_CHECKING, TypeVar, Union
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from triton import Config as TritonConfig
|
||||
|
||||
|
||||
try:
|
||||
from typing_extensions import Self
|
||||
except ImportError:
|
||||
from typing import Self
|
||||
|
||||
import torch
|
||||
from torch.utils._ordered_set import OrderedSet
|
||||
|
||||
|
||||
# Set up logging for kernel LUT
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
T = TypeVar("T", bound="JSONSerializable")
|
||||
LeafType = Union[
|
||||
None, bool, int, float, str, OrderedDict[str, Any], torch.dtype, list[Any]
|
||||
]
|
||||
JSONType = Union[T, LeafType]
|
||||
|
||||
|
||||
@dataclass(kw_only=True)
|
||||
class JSONSerializable:
|
||||
"""
|
||||
This class implements a system similar to Pydantic Models for validating and serializing dataclasses.
|
||||
"""
|
||||
|
||||
# Incrementing version will invalidate all LUT entries, in the case of major perf update or
|
||||
# changes to the Ontology.
|
||||
version: int = 1
|
||||
_is_leaf: bool = False
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, inp: OrderedDict[str, Any] | str) -> Self:
|
||||
"""
|
||||
Convert a dictionary representation of the object.
|
||||
"""
|
||||
try:
|
||||
ret = OrderedDict()
|
||||
if isinstance(inp, str):
|
||||
if cls._is_leaf:
|
||||
return cls.parse(inp)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"String representation not implemented for base {cls.__name__}"
|
||||
)
|
||||
for k, v in inp.items():
|
||||
v_type = cls.__dataclass_fields__[k].type
|
||||
if get_origin(v_type) is OrderedDict:
|
||||
k1_type, v1_type = typing.get_args(v_type)
|
||||
if isinstance(k1_type, type) and issubclass(
|
||||
k1_type, JSONSerializable
|
||||
):
|
||||
|
||||
def kp(tmpk: Any) -> Any:
|
||||
return k1_type.from_dict(tmpk)
|
||||
|
||||
k_process = kp
|
||||
else:
|
||||
|
||||
def k_process(tmpk: Any) -> Any:
|
||||
return tmpk
|
||||
|
||||
if isinstance(v1_type, type) and issubclass(
|
||||
v1_type, JSONSerializable
|
||||
):
|
||||
|
||||
def vp(tmpv: Any) -> Any:
|
||||
return v1_type.from_dict(tmpv)
|
||||
|
||||
v_process = vp
|
||||
else:
|
||||
|
||||
def v_process(tmpv: Any) -> Any:
|
||||
return tmpv
|
||||
|
||||
v_new: Any = OrderedDict(
|
||||
(k_process(key), v_process(val)) for key, val in v.items()
|
||||
)
|
||||
|
||||
elif get_origin(v_type) is list:
|
||||
elem_type = typing.get_args(v_type)[0]
|
||||
if isinstance(elem_type, type) and issubclass(
|
||||
elem_type, JSONSerializable
|
||||
):
|
||||
v_new = [elem_type.from_dict(x) for x in v]
|
||||
else:
|
||||
v_new = v
|
||||
elif isinstance(v_type, type) and issubclass(v_type, JSONSerializable):
|
||||
v_new = v_type.from_dict(v)
|
||||
else:
|
||||
v_new = v
|
||||
ret[k] = v_new
|
||||
return cls(**ret) # type: ignore[arg-type]
|
||||
except Exception as e:
|
||||
logger.error("Failed to deserialize %s from dict: %s", cls.__name__, e)
|
||||
raise ValueError(f"Malformed data for {cls.__name__}: {e}") from e
|
||||
|
||||
def to_dict(self) -> OrderedDict[str, Any]:
|
||||
"""
|
||||
Convert the object to a dictionary representation.
|
||||
Will be written to and from using json.dumps and json.loads.
|
||||
"""
|
||||
# get the fields of the dataclass
|
||||
field_list = fields(self)
|
||||
# filter out the _ fields
|
||||
field_list = [field for field in field_list if not field.name.startswith("_")]
|
||||
# ensure the fields are sorted for consistent serialization
|
||||
field_list.sort(key=lambda x: x.name)
|
||||
ret: OrderedDict[str, Any] = OrderedDict()
|
||||
for field_obj in field_list:
|
||||
field_val = getattr(self, field_obj.name)
|
||||
if isinstance(field_val, JSONSerializable):
|
||||
if field_val._is_leaf:
|
||||
ret[field_obj.name] = str(field_val)
|
||||
else:
|
||||
ret[field_obj.name] = field_val.to_dict()
|
||||
elif isinstance(field_val, list):
|
||||
if len(field_val) == 0:
|
||||
ret[field_obj.name] = []
|
||||
elif isinstance(field_val[0], JSONSerializable):
|
||||
if field_val[0]._is_leaf:
|
||||
ret[field_obj.name] = [str(x) for x in field_val]
|
||||
else:
|
||||
ret[field_obj.name] = [x.to_dict() for x in field_val]
|
||||
else:
|
||||
ret[field_obj.name] = field_val
|
||||
elif isinstance(field_val, OrderedDict):
|
||||
tmp: OrderedDict[Any, Any] = OrderedDict()
|
||||
for k, v in field_val.items():
|
||||
if isinstance(v, JSONSerializable):
|
||||
if v._is_leaf:
|
||||
new_v: Any = str(v)
|
||||
else:
|
||||
new_v = v.to_dict()
|
||||
else:
|
||||
new_v = v
|
||||
if isinstance(k, JSONSerializable):
|
||||
if k._is_leaf:
|
||||
new_k: Any = str(k)
|
||||
else:
|
||||
new_k = k.to_dict()
|
||||
else:
|
||||
new_k = k
|
||||
tmp[new_k] = new_v
|
||||
ret[field_obj.name] = tmp
|
||||
else:
|
||||
ret[field_obj.name] = field_val
|
||||
return ret
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""
|
||||
Return a string representation of the object.
|
||||
"""
|
||||
return json.dumps(self.to_dict())
|
||||
|
||||
@classmethod
|
||||
def parse(cls, string: str) -> Self:
|
||||
"""
|
||||
Parse the string representaiton of the object. Only reqiured for leaf nodes.
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
f"String representation not implemented for base {cls.__name__}"
|
||||
)
|
||||
|
||||
|
||||
@dataclass(kw_only=True)
|
||||
class TritonGEMMConfig(JSONSerializable):
|
||||
_is_leaf: bool = True
|
||||
name: str
|
||||
grid: int
|
||||
block_m: int
|
||||
block_n: int
|
||||
block_k: int
|
||||
group_m: int
|
||||
num_stages: int
|
||||
num_warps: int
|
||||
EVEN_K: bool = False
|
||||
ALLOW_TF32: bool = False
|
||||
USE_FAST_ACCUM: bool = False
|
||||
ACC_TYPE: str = "tl.float32"
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return hash(
|
||||
(
|
||||
self.name,
|
||||
self.grid,
|
||||
self.block_m,
|
||||
self.block_n,
|
||||
self.block_k,
|
||||
self.group_m,
|
||||
self.num_stages,
|
||||
self.num_warps,
|
||||
self.EVEN_K,
|
||||
self.ALLOW_TF32,
|
||||
self.USE_FAST_ACCUM,
|
||||
self.ACC_TYPE,
|
||||
)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def parse(cls, string: str) -> Self:
|
||||
d = json.loads(string, object_pairs_hook=OrderedDict)
|
||||
# validate types, yay python :P
|
||||
if "name" not in d:
|
||||
raise KeyError("Missing required field: name")
|
||||
if not isinstance(d["name"], str):
|
||||
raise TypeError(f"name must be a string, got {type(d['name'])}")
|
||||
if "grid" not in d:
|
||||
raise KeyError("Missing required field: grid")
|
||||
if not isinstance(d["grid"], int):
|
||||
raise TypeError(f"grid must be an int, got {type(d['grid'])}")
|
||||
if "block_m" not in d:
|
||||
raise KeyError("Missing required field: block_m")
|
||||
if not isinstance(d["block_m"], int):
|
||||
raise TypeError(f"block_m must be an int, got {type(d['block_m'])}")
|
||||
if "block_n" not in d:
|
||||
raise KeyError("Missing required field: block_n")
|
||||
if not isinstance(d["block_n"], int):
|
||||
raise TypeError(f"block_n must be an int, got {type(d['block_n'])}")
|
||||
if "block_k" not in d:
|
||||
raise KeyError("Missing required field: block_k")
|
||||
if not isinstance(d["block_k"], int):
|
||||
raise TypeError(f"block_k must be an int, got {type(d['block_k'])}")
|
||||
if "group_m" not in d:
|
||||
raise KeyError("Missing required field: group_m")
|
||||
if not isinstance(d["group_m"], int):
|
||||
raise TypeError(f"group_m must be an int, got {type(d['group_m'])}")
|
||||
if "num_stages" not in d:
|
||||
raise KeyError("Missing required field: num_stages")
|
||||
if not isinstance(d["num_stages"], int):
|
||||
raise TypeError(f"num_stages must be an int, got {type(d['num_stages'])}")
|
||||
if "num_warps" not in d:
|
||||
raise KeyError("Missing required field: num_warps")
|
||||
if not isinstance(d["num_warps"], int):
|
||||
raise TypeError(f"num_warps must be an int, got {type(d['num_warps'])}")
|
||||
if "EVEN_K" in d and not isinstance(d["EVEN_K"], bool):
|
||||
raise TypeError(f"EVEN_K must be a bool, got {type(d['EVEN_K'])}")
|
||||
if "ALLOW_TF32" in d and not isinstance(d["ALLOW_TF32"], bool):
|
||||
raise TypeError(f"ALLOW_TF32 must be a bool, got {type(d['ALLOW_TF32'])}")
|
||||
if "USE_FAST_ACCUM" in d and not isinstance(d["USE_FAST_ACCUM"], bool):
|
||||
raise TypeError(
|
||||
f"USE_FAST_ACCUM must be a bool, got {type(d['USE_FAST_ACCUM'])}"
|
||||
)
|
||||
if "ACC_TYPE" in d and not isinstance(d["ACC_TYPE"], str):
|
||||
raise TypeError(f"ACC_TYPE must be a string, got {type(d['ACC_TYPE'])}")
|
||||
return cls(**d)
|
||||
|
||||
|
||||
@dataclass(kw_only=True)
|
||||
class MMProblem(JSONSerializable):
|
||||
_is_leaf: bool = True
|
||||
B: int
|
||||
M: int
|
||||
M_dtype: torch.dtype
|
||||
N: int
|
||||
K: int
|
||||
K_dtype: torch.dtype
|
||||
out_dtype: torch.dtype
|
||||
out_size: tuple[int, int, int]
|
||||
out_stride: tuple[int, int, int]
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return hash(
|
||||
(
|
||||
self.B,
|
||||
self.M,
|
||||
self.M_dtype,
|
||||
self.N,
|
||||
self.K_dtype,
|
||||
self.K,
|
||||
self.out_dtype,
|
||||
self.out_size,
|
||||
self.out_stride,
|
||||
)
|
||||
)
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""
|
||||
Return a string representation of the object.
|
||||
"""
|
||||
d = asdict(self)
|
||||
d["M_dtype"] = str(d["M_dtype"]).split(".")[-1]
|
||||
d["K_dtype"] = str(d["K_dtype"]).split(".")[-1]
|
||||
d["out_dtype"] = str(d["out_dtype"]).split(".")[-1]
|
||||
d["out_size"] = list(d["out_size"])
|
||||
d["out_stride"] = list(d["out_stride"])
|
||||
d = OrderedDict((k, v) for k, v in d.items() if not k.startswith("_"))
|
||||
return json.dumps(d)
|
||||
|
||||
@classmethod
|
||||
def parse(cls, string: str) -> Self:
|
||||
d = json.loads(string, object_pairs_hook=OrderedDict)
|
||||
# validate types, yay python :P
|
||||
if "B" not in d:
|
||||
raise KeyError("Missing required field: B")
|
||||
if not isinstance(d["B"], int):
|
||||
raise TypeError(f"B must be an int, got {type(d['B'])}")
|
||||
if "M" not in d:
|
||||
raise KeyError("Missing required field: M")
|
||||
if not isinstance(d["M"], int):
|
||||
raise TypeError(f"M must be an int, got {type(d['M'])}")
|
||||
if "N" not in d:
|
||||
raise KeyError("Missing required field: N")
|
||||
if not isinstance(d["N"], int):
|
||||
raise TypeError(f"N must be an int, got {type(d['N'])}")
|
||||
if "K" not in d:
|
||||
raise KeyError("Missing required field: K")
|
||||
if not isinstance(d["K"], int):
|
||||
raise TypeError(f"K must be an int, got {type(d['K'])}")
|
||||
if "M_dtype" not in d:
|
||||
raise KeyError("Missing required field: M_dtype")
|
||||
if not isinstance(d["M_dtype"], str):
|
||||
raise TypeError(f"M_dtype must be a string, got {type(d['M_dtype'])}")
|
||||
if "K_dtype" not in d:
|
||||
raise KeyError("Missing required field: K_dtype")
|
||||
if not isinstance(d["K_dtype"], str):
|
||||
raise TypeError(f"K_dtype must be a string, got {type(d['K_dtype'])}")
|
||||
if "out_dtype" not in d:
|
||||
raise KeyError("Missing required field: out_dtype")
|
||||
if not isinstance(d["out_dtype"], str):
|
||||
raise TypeError(f"out_dtype must be a string, got {type(d['out_dtype'])}")
|
||||
if "out_size" not in d:
|
||||
raise KeyError("Missing required field: out_size")
|
||||
if not isinstance(d["out_size"], list):
|
||||
raise TypeError(f"out_size must be a list, got {type(d['out_size'])}")
|
||||
if "out_stride" not in d:
|
||||
raise KeyError("Missing required field: out_stride")
|
||||
if not isinstance(d["out_stride"], list):
|
||||
raise TypeError(f"out_stride must be a list, got {type(d['out_stride'])}")
|
||||
|
||||
# Validate torch dtype strings
|
||||
try:
|
||||
d["M_dtype"] = getattr(torch, d["M_dtype"])
|
||||
except AttributeError:
|
||||
raise ValueError(f"Invalid torch dtype: {d['M_dtype']}") from None
|
||||
try:
|
||||
d["K_dtype"] = getattr(torch, d["K_dtype"])
|
||||
except AttributeError:
|
||||
raise ValueError(f"Invalid torch dtype: {d['K_dtype']}") from None
|
||||
try:
|
||||
d["out_dtype"] = getattr(torch, d["out_dtype"])
|
||||
except AttributeError:
|
||||
raise ValueError(f"Invalid torch dtype: {d['out_dtype']}") from None
|
||||
|
||||
d["out_size"] = tuple(d["out_size"])
|
||||
d["out_stride"] = tuple(d["out_stride"])
|
||||
return cls(**d)
|
||||
|
||||
|
||||
@dataclass(kw_only=True)
|
||||
class Solution(JSONSerializable):
|
||||
# like mm or addmm
|
||||
name: str
|
||||
# mapping
|
||||
config: list[TritonGEMMConfig]
|
||||
|
||||
|
||||
@dataclass(kw_only=True)
|
||||
class Operation(JSONSerializable):
|
||||
name: str
|
||||
solution: OrderedDict[MMProblem, Solution]
|
||||
|
||||
|
||||
@dataclass(kw_only=True)
|
||||
class Hardware(JSONSerializable):
|
||||
# like gfx942:sramecc+:xnack-
|
||||
operation: OrderedDict[str, Operation]
|
||||
|
||||
|
||||
@dataclass(kw_only=True)
|
||||
class Table(JSONSerializable):
|
||||
hardware: OrderedDict[str, Hardware]
|
||||
_set_cache: OrderedDict[
|
||||
tuple[str, str, MMProblem], OrderedSet[TritonGEMMConfig]
|
||||
] = field(default_factory=OrderedDict)
|
||||
|
||||
def serialize(self) -> str:
|
||||
foo = self.to_dict()
|
||||
return json.dumps(foo, indent=2)
|
||||
|
||||
@classmethod
|
||||
def deserialize(cls, s: str) -> Optional[Self]:
|
||||
try:
|
||||
return cls.from_dict(json.loads(s, object_pairs_hook=OrderedDict))
|
||||
except (json.JSONDecodeError, TypeError, ValueError) as e:
|
||||
logger.error("Failed to deserialize table: %s", e)
|
||||
return None
|
||||
|
||||
def lookup(
|
||||
self, hardware: str, op_name: str, problem: MMProblem
|
||||
) -> Optional[list[TritonGEMMConfig]]:
|
||||
"""
|
||||
Lookup the best TritonGEMMConfig for a given problem.
|
||||
"""
|
||||
if hardware not in self.hardware:
|
||||
return None
|
||||
tmp = self.hardware[hardware].operation
|
||||
if op_name not in tmp:
|
||||
return None
|
||||
tmp = tmp[op_name].solution
|
||||
if problem not in tmp:
|
||||
return None
|
||||
return tmp[problem].config
|
||||
|
||||
def lookup_set(
|
||||
self, hardware: str, op_name: str, problem: MMProblem
|
||||
) -> Optional[OrderedSet[TritonGEMMConfig]]:
|
||||
"""
|
||||
Easier and faster to check membership in a set, but cache the sets for runtime.
|
||||
"""
|
||||
if (hardware, op_name, problem) in self._set_cache:
|
||||
return self._set_cache[(hardware, op_name, problem)]
|
||||
problem_list = self.lookup(hardware, op_name, problem)
|
||||
problem_set = OrderedSet(problem_list) if problem_list is not None else None
|
||||
if problem_set is None:
|
||||
return None
|
||||
self._set_cache[(hardware, op_name, problem)] = problem_set
|
||||
return problem_set
|
||||
|
||||
def filter(
|
||||
self,
|
||||
hardware: str,
|
||||
op_name: str,
|
||||
problem: MMProblem,
|
||||
to_filter: list[TritonGEMMConfig],
|
||||
) -> Optional[list[TritonGEMMConfig]]:
|
||||
"""
|
||||
Filter a list of TritonGEMMConfig for a given problem.
|
||||
"""
|
||||
|
||||
problem_set = self.lookup_set(hardware, op_name, problem)
|
||||
if problem_set is None:
|
||||
return None
|
||||
ret = [x for x in to_filter if x in problem_set]
|
||||
if len(ret) == 0:
|
||||
return None
|
||||
return ret
|
||||
|
||||
|
||||
def convert_triton_configs_to_gemm_configs(
|
||||
triton_configs: list["TritonConfig"], name_prefix: str = "triton_config"
|
||||
) -> list[TritonGEMMConfig]:
|
||||
"""
|
||||
Convert a list of triton.runtime.autotuner.Config objects to TritonGEMMConfig objects.
|
||||
Args:
|
||||
triton_configs: List of triton.runtime.autotuner.Config objects
|
||||
name_prefix: Prefix for generated config names (default: "triton_config")
|
||||
Returns:
|
||||
List of TritonGEMMConfig objects
|
||||
"""
|
||||
gemm_configs = []
|
||||
|
||||
for i, config in enumerate(triton_configs):
|
||||
# Extract kwargs which contain the block sizes
|
||||
kwargs = getattr(config, "kwargs", {})
|
||||
|
||||
# Handle case where kwargs is None
|
||||
if kwargs is None:
|
||||
kwargs = {}
|
||||
|
||||
# Extract required parameters from kwargs
|
||||
block_m = kwargs.get("BLOCK_M", 64) # Default fallback values
|
||||
block_n = kwargs.get("BLOCK_N", 64)
|
||||
block_k = kwargs.get("BLOCK_K", 32)
|
||||
group_m = kwargs.get("GROUP_M", 8)
|
||||
|
||||
# Extract other parameters directly from config object
|
||||
num_stages = getattr(config, "num_stages", 2)
|
||||
num_warps = getattr(config, "num_warps", 4)
|
||||
|
||||
# Extract optional parameters with defaults
|
||||
even_k = kwargs.get("EVEN_K", False)
|
||||
allow_tf32 = kwargs.get("ALLOW_TF32", False)
|
||||
use_fast_accum = kwargs.get("USE_FAST_ACCUM", False)
|
||||
acc_type = kwargs.get("ACC_TYPE", "tl.float32")
|
||||
|
||||
# Generate a unique name for this config
|
||||
config_name = f"{name_prefix}_{i}"
|
||||
|
||||
# Create TritonGEMMConfig object
|
||||
gemm_config = TritonGEMMConfig(
|
||||
name=config_name,
|
||||
grid=1, # Default grid value, can be adjusted based on requirements
|
||||
block_m=block_m,
|
||||
block_n=block_n,
|
||||
block_k=block_k,
|
||||
group_m=group_m,
|
||||
num_stages=num_stages,
|
||||
num_warps=num_warps,
|
||||
EVEN_K=even_k,
|
||||
ALLOW_TF32=allow_tf32,
|
||||
USE_FAST_ACCUM=use_fast_accum,
|
||||
ACC_TYPE=acc_type,
|
||||
)
|
||||
|
||||
gemm_configs.append(gemm_config)
|
||||
|
||||
return gemm_configs
|
||||
|
||||
|
||||
@lru_cache
|
||||
def get_table(path: str) -> Optional[Table]:
|
||||
"""Load a table from a file path."""
|
||||
try:
|
||||
with open(path) as f:
|
||||
return Table.deserialize(f.read())
|
||||
except OSError as e:
|
||||
logger.error("Failed to read table from %s: %s", path, e)
|
||||
return None
|
||||
|
||||
|
||||
def get_table_safe(path: str) -> Optional[Table]:
|
||||
"""Safely load a table from a file path without caching."""
|
||||
try:
|
||||
with open(path) as f:
|
||||
return Table.deserialize(f.read())
|
||||
except OSError as e:
|
||||
logger.error("Failed to read table from %s: %s", path, e)
|
||||
return None
|
||||
383
torch/_inductor/models/mm_kernel_prediction_model.py
Normal file
383
torch/_inductor/models/mm_kernel_prediction_model.py
Normal file
@ -0,0 +1,383 @@
|
||||
"""
|
||||
Neural network model for predicting triton kernel performance.
|
||||
|
||||
This module provides functionality to load and use a pre-trained neural network
|
||||
for predicting the performance of triton kernels.
|
||||
"""
|
||||
|
||||
import copy
|
||||
import os
|
||||
import time
|
||||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd # type: ignore[import-untyped]
|
||||
from pyre_extensions import assert_is_instance # type: ignore[import-untyped]
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch._inductor.kernel_lut import TritonGEMMConfig
|
||||
from torch.optim.lr_scheduler import StepLR
|
||||
|
||||
|
||||
# Default model path - can be overridden by environment variable
|
||||
import os
|
||||
script_dir = os.path.dirname(__file__)
|
||||
DEFAULT_MODEL_PATH = os.path.join(os.path.dirname(__file__), "mm_model.pt2")
|
||||
MODEL_PATH = os.environ.get("TRITON_KERNEL_SELECTION_MODEL_PATH", DEFAULT_MODEL_PATH)
|
||||
import logging
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class NeuralNetwork(nn.Module):
|
||||
"""
|
||||
Multilayer perceptron with a single output.
|
||||
|
||||
It is designed for modeling runtime when there is a constant overhead of
|
||||
`kernel_overhead` and the non-overhead runtime tends to be easier to model
|
||||
on a log scale (e.g. doubling a dimension involved in a matrix
|
||||
multiplication results in runtime roughly doubling.)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
n_inputs: int,
|
||||
hidden_layer_widths: Sequence[int],
|
||||
kernel_overhead: float = 0.00541,
|
||||
) -> None:
|
||||
"""
|
||||
Args:
|
||||
n_inputs: Number of inputs
|
||||
hidden_layer_widths: Hidden layer widths
|
||||
kernel_overhead: Overhead of the kernel, assumed to be constant. The
|
||||
default of 0.00541 is the lowest runtime seen in Triton H100 data.
|
||||
"""
|
||||
super().__init__()
|
||||
self.n_inputs = n_inputs
|
||||
self.kernel_overhead = kernel_overhead
|
||||
self.log_kernel_overhead: float = torch.log(
|
||||
torch.tensor(kernel_overhead)
|
||||
).item()
|
||||
all_layer_widths = list(hidden_layer_widths) + [1]
|
||||
all_input_widths = [n_inputs] + list(hidden_layer_widths)
|
||||
layers: list[nn.Module] = []
|
||||
for n_in, n_out in zip(all_input_widths, all_layer_widths, strict=True):
|
||||
layers.append(nn.Linear(n_in, n_out))
|
||||
layers.append(nn.BatchNorm1d(n_out))
|
||||
layers.append(nn.ReLU())
|
||||
|
||||
self.linear_relu_stack = nn.Sequential(*layers[:-2])
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Predict as log(exp(inputs) + self.kernel_overhead).
|
||||
|
||||
Works well for predicting log(runtime) when runtime contains a constant
|
||||
overhead of `kernel_overhead`. (The log specification means that this
|
||||
wouldn't be trivially modeled with a bias term.)
|
||||
|
||||
Probably could have fit the overhead rather than hard-coding it by
|
||||
having `self.kernel_overhead` be a tunable parameter or by having exp
|
||||
and log layers.
|
||||
"""
|
||||
log_base_pred = self.linear_relu_stack(x)
|
||||
log_overhead_tsr = torch.full_like(
|
||||
input=log_base_pred, fill_value=self.log_kernel_overhead
|
||||
)
|
||||
return torch.logsumexp(
|
||||
torch.stack([log_base_pred, log_overhead_tsr], dim=-1), dim=-1
|
||||
)
|
||||
|
||||
|
||||
def get_nn_x(
|
||||
df: pd.DataFrame, mean: torch.Tensor | None = None, std: torch.Tensor | None = None
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""Standardize the data and convert it to a tensor."""
|
||||
x_df = df[
|
||||
[
|
||||
"dtype_size",
|
||||
"dim_m",
|
||||
"dim_n",
|
||||
"dim_k",
|
||||
"total_gb",
|
||||
"total_gflop",
|
||||
"flops_per_byte",
|
||||
"config_block_k",
|
||||
"config_block_m",
|
||||
"config_block_n",
|
||||
"config_num_stages",
|
||||
"config_num_warps",
|
||||
]
|
||||
].copy()
|
||||
for col in x_df.columns:
|
||||
x_df[col] = np.log(x_df[col])
|
||||
|
||||
x_tens = torch.from_numpy(x_df.astype(float).to_numpy())
|
||||
if mean is None:
|
||||
mean = torch.from_numpy(assert_is_instance(x_df.mean(), pd.Series).to_numpy())
|
||||
if std is None:
|
||||
std = torch.from_numpy(assert_is_instance(x_df.std(), pd.Series).to_numpy())
|
||||
x_tens -= mean
|
||||
x_tens /= std
|
||||
return x_tens.to(torch.float32), mean, std
|
||||
|
||||
|
||||
def get_total_gb_feature(df: pd.DataFrame) -> pd.Series:
|
||||
"""
|
||||
Calculate the total gigabytes feature from the dataframe.
|
||||
|
||||
Args:
|
||||
df: DataFrame containing the necessary columns for calculation
|
||||
|
||||
Returns:
|
||||
Series containing the calculated total gigabytes
|
||||
"""
|
||||
# Calculate memory access in bytes
|
||||
m, n, k = df["dim_m"], df["dim_n"], df["dim_k"]
|
||||
dtype_size = df["dtype_size"] / 8 # Convert bits to bytes
|
||||
|
||||
# A: m×k, B: k×n, C: m×n
|
||||
return ((m * k + k * n + m * n) * dtype_size) / 1e9 # Convert to GB
|
||||
|
||||
|
||||
def get_total_gflop_feature(df: pd.DataFrame) -> pd.Series:
|
||||
"""
|
||||
Calculate the total gigaflops feature from the dataframe.
|
||||
|
||||
Args:
|
||||
df: DataFrame containing the necessary columns for calculation
|
||||
|
||||
Returns:
|
||||
Series containing the calculated total gigaflops
|
||||
"""
|
||||
# For matrix multiplication, flops = 2 * m * n * k
|
||||
m, n, k = df["dim_m"], df["dim_n"], df["dim_k"]
|
||||
return (2 * m * n * k) / 1e9 # Convert to GFLOP
|
||||
|
||||
|
||||
class ModelWrapper:
|
||||
"""
|
||||
Wrapper for the neural network model that handles encoding inputs and decoding outputs.
|
||||
|
||||
This class provides methods to prepare inputs for the model and interpret its outputs,
|
||||
handling the necessary standardization and feature engineering.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize the model wrapper with the pre-trained model and standardization parameters."""
|
||||
start_time = time.time()
|
||||
self.model = NeuralNetwork(
|
||||
n_inputs=12, hidden_layer_widths=[2**8 for _ in range(6)]
|
||||
)
|
||||
self.model = torch.export.load(MODEL_PATH)
|
||||
end_time = time.time()
|
||||
|
||||
log.info("NN Kernel Prediction Model loaded.")
|
||||
log.info("Took: %s seconds", end_time - start_time)
|
||||
|
||||
# Mean values for standardizing input features
|
||||
self.mean_for_standardization = torch.tensor(
|
||||
[
|
||||
2.78275084,
|
||||
8.23996746,
|
||||
7.27791873,
|
||||
7.92035942,
|
||||
-2.39558163,
|
||||
3.40679233,
|
||||
5.80237395,
|
||||
3.95781827,
|
||||
4.19478321,
|
||||
4.19098234,
|
||||
0.9045909,
|
||||
1.28331208,
|
||||
]
|
||||
)
|
||||
|
||||
# Standard deviation values for standardizing input features
|
||||
self.std_for_standardization = torch.tensor(
|
||||
[
|
||||
0.08322756,
|
||||
2.31893439,
|
||||
1.65605574,
|
||||
2.15447078,
|
||||
2.19682881,
|
||||
2.99600806,
|
||||
1.24328795,
|
||||
0.92352521,
|
||||
0.93849802,
|
||||
0.93872011,
|
||||
0.57455891,
|
||||
0.5837217,
|
||||
]
|
||||
)
|
||||
|
||||
def vec(
|
||||
self, m: int, n: int, k: int, dsize: int, config: Any
|
||||
) -> tuple[int, int, int, int, int, int, int, int, int]:
|
||||
"""
|
||||
Convert matrix multiplication parameters and config to a feature vector.
|
||||
|
||||
Args:
|
||||
m: First dimension of matrix multiplication
|
||||
n: Second dimension of matrix multiplication
|
||||
k: Third dimension of matrix multiplication
|
||||
dsize: Data size in bits (e.g., 16 for float16, 32 for float32)
|
||||
config: Configuration object containing kernel parameters
|
||||
|
||||
Returns:
|
||||
Tuple containing the extracted features
|
||||
"""
|
||||
kwargs = config.all_kwargs()
|
||||
|
||||
return (
|
||||
int(m),
|
||||
int(n),
|
||||
int(k),
|
||||
int(dsize),
|
||||
int(kwargs["BLOCK_M"]),
|
||||
int(kwargs["BLOCK_N"]),
|
||||
int(kwargs["BLOCK_K"]),
|
||||
int(kwargs["num_stages"]),
|
||||
int(kwargs["num_warps"]),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def vec_params(
|
||||
m: int, n: int, k: int, dsize: int, params: TritonGEMMConfig
|
||||
) -> tuple[int, int, int, int, int, int, int, int, int]:
|
||||
"""
|
||||
Convert matrix multiplication parameters and config to a feature vector.
|
||||
|
||||
Args:
|
||||
m: First dimension of matrix multiplication
|
||||
n: Second dimension of matrix multiplication
|
||||
k: Third dimension of matrix multiplication
|
||||
dsize: Data size in bits (e.g., 16 for float16, 32 for float32)
|
||||
config: Configuration object containing kernel parameters
|
||||
|
||||
Returns:
|
||||
Tuple containing the extracted features
|
||||
"""
|
||||
|
||||
return (
|
||||
int(m),
|
||||
int(n),
|
||||
int(k),
|
||||
int(dsize),
|
||||
int(params.block_m),
|
||||
int(params.block_n),
|
||||
int(params.block_k),
|
||||
int(params.num_stages),
|
||||
int(params.num_warps),
|
||||
)
|
||||
|
||||
def encode(
|
||||
self, m: int, n: int, k: int, dtype: torch.dtype, configs: list[Any]
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Encode the matrix multiplication parameters and configs as input tensors for the model.
|
||||
|
||||
Args:
|
||||
m: First dimension of matrix multiplication
|
||||
n: Second dimension of matrix multiplication
|
||||
k: Third dimension of matrix multiplication
|
||||
dtype: Data type of the matrices
|
||||
configs: List of configuration objects
|
||||
|
||||
Returns:
|
||||
Tensor containing the encoded inputs ready for the model
|
||||
|
||||
Raises:
|
||||
ValueError: If the dtype is not supported
|
||||
"""
|
||||
# Determine data size based on dtype
|
||||
if dtype == torch.bfloat16 or dtype == torch.float16:
|
||||
dsize = 16
|
||||
elif dtype == torch.float32:
|
||||
dsize = 32
|
||||
else:
|
||||
raise ValueError(f"Unsupported dtype: {dtype}. Add support for this dtype.")
|
||||
|
||||
# Create feature dataframe
|
||||
df = pd.DataFrame(
|
||||
columns=[
|
||||
"dim_m",
|
||||
"dim_n",
|
||||
"dim_k",
|
||||
"dtype_size",
|
||||
"config_block_m",
|
||||
"config_block_n",
|
||||
"config_block_k",
|
||||
"config_num_stages",
|
||||
"config_num_warps",
|
||||
],
|
||||
data=[self.vec(m, n, k, dsize, config) for config in configs],
|
||||
)
|
||||
|
||||
# Calculate derived features
|
||||
df["total_gb"] = get_total_gb_feature(df=df).astype(np.float32)
|
||||
df["total_gflop"] = get_total_gflop_feature(df=df).astype(np.float32)
|
||||
df["flops_per_byte"] = df["total_gflop"] / df["total_gb"]
|
||||
|
||||
# Reorder columns to match expected model input
|
||||
df = df[
|
||||
[
|
||||
"dtype_size",
|
||||
"dim_m",
|
||||
"dim_n",
|
||||
"dim_k",
|
||||
"total_gb",
|
||||
"total_gflop",
|
||||
"flops_per_byte",
|
||||
"config_block_k",
|
||||
"config_block_m",
|
||||
"config_block_n",
|
||||
"config_num_stages",
|
||||
"config_num_warps",
|
||||
]
|
||||
]
|
||||
|
||||
# Standardize the input
|
||||
inp, _, _ = get_nn_x(
|
||||
df=df, mean=self.mean_for_standardization, std=self.std_for_standardization
|
||||
)
|
||||
|
||||
return inp
|
||||
|
||||
def inference(self, inp_tensor: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Run inference on the model with the given input tensor.
|
||||
|
||||
Args:
|
||||
inp_tensor: Input tensor for the model
|
||||
|
||||
Returns:
|
||||
Output tensor from the model
|
||||
"""
|
||||
with torch.no_grad():
|
||||
breakpoint()
|
||||
return self.model.forward(inp_tensor)
|
||||
|
||||
def decode(self, ret_tensor: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Decode the model output tensor.
|
||||
|
||||
Args:
|
||||
ret_tensor: Output tensor from the model
|
||||
|
||||
Returns:
|
||||
Decoded tensor representing runtime predictions
|
||||
"""
|
||||
return ret_tensor
|
||||
|
||||
|
||||
# Create a singleton instance of the model wrapper
|
||||
import functools
|
||||
|
||||
|
||||
@functools.lru_cache
|
||||
def get_model() -> ModelWrapper:
|
||||
return ModelWrapper()
|
||||
@ -3044,6 +3044,9 @@ class AlgorithmSelectorCache(PersistentCache):
|
||||
def add_feedback_saver(self, fn: FeedbackFunction):
|
||||
self.feedback_saver_fns.append(fn)
|
||||
|
||||
def clear_feedback_savers(self):
|
||||
self.feedback_saver_fns = []
|
||||
|
||||
|
||||
_ALGORITHM_SELECTOR_CACHE: Optional[AlgorithmSelectorCache] = None
|
||||
|
||||
@ -3071,6 +3074,11 @@ def add_feedback_saver(
|
||||
if _ALGORITHM_SELECTOR_CACHE is None:
|
||||
_ALGORITHM_SELECTOR_CACHE = AlgorithmSelectorCache()
|
||||
_ALGORITHM_SELECTOR_CACHE.add_feedback_saver(fn)
|
||||
def clear_feedback_savers():
|
||||
global _ALGORITHM_SELECTOR_CACHE
|
||||
if _ALGORITHM_SELECTOR_CACHE is None:
|
||||
_ALGORITHM_SELECTOR_CACHE = AlgorithmSelectorCache()
|
||||
_ALGORITHM_SELECTOR_CACHE.clear_feedback_savers()
|
||||
|
||||
|
||||
def realize_inputs(*args):
|
||||
|
||||
@ -19,7 +19,7 @@ if TYPE_CHECKING:
|
||||
from triton import Config as TritonConfig
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class BaseConfig:
|
||||
"""
|
||||
Base Gemm configuration used for most backends (CPU, CUDA)
|
||||
@ -32,7 +32,7 @@ class BaseConfig:
|
||||
num_warps: int
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class GemmConfig(BaseConfig):
|
||||
"""
|
||||
Gemm configuration used for most backends (CPU, CUDA)
|
||||
@ -44,7 +44,7 @@ class GemmConfig(BaseConfig):
|
||||
ConvConfig = BaseConfig
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class ROCmGemmConfig(GemmConfig):
|
||||
"""
|
||||
ROCm subclass for GEMMs, with AMD backend specific tuneable kernargs
|
||||
@ -55,7 +55,7 @@ class ROCmGemmConfig(GemmConfig):
|
||||
kpack: int = 2
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class ROCmConvConfig(ConvConfig):
|
||||
"""
|
||||
ROCm subclass for Conv, with AMD backend specific tuneable kernargs
|
||||
@ -85,16 +85,92 @@ class BaseHeuristicSingleton(type):
|
||||
return cls._instances[cls]
|
||||
|
||||
|
||||
import os
|
||||
|
||||
|
||||
class BaseConfigHeuristic(metaclass=BaseHeuristicSingleton):
|
||||
"""
|
||||
Base class for mm_configs, device specific triton kernels config inherit from here
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
# List of dictionaries to store the kernel configs. Configs that evaluate to true
|
||||
# will be utilised on the target platform. The configs are as follows:
|
||||
# (BLOCK_M, BLOCK_N, BLOCK_K, num_stages, num_warps)
|
||||
self.mm_configs: list[BaseConfig] = [
|
||||
@property
|
||||
def mm_configs(self) -> list[BaseConfig]:
|
||||
# if (
|
||||
# os.environ.get("TORCHINDUCTOR_NEW_CONFIGS", "0") == "1"
|
||||
# or config.new_configs
|
||||
# ):
|
||||
# return [
|
||||
# GemmConfig(128, 16, 128, 5, 8), # new
|
||||
# GemmConfig(64, 16, 128, 4, 4), #new
|
||||
# GemmConfig(64, 128, 128, 4, 8),
|
||||
# GemmConfig(128, 32, 32, 4, 4),
|
||||
# GemmConfig(64, 128, 64, 4, 4), # new
|
||||
# GemmConfig(64, 128, 128, 3, 4), # new
|
||||
# GemmConfig(128, 16, 128, 5, 8), # new
|
||||
# # GemmConfig(64, 128, 256, 4, 8),
|
||||
# GemmConfig(64, 64, 128, 5, 4),
|
||||
# GemmConfig(128, 32, 16, 5, 2),
|
||||
# GemmConfig(32, 128, 128, 4, 4),
|
||||
# GemmConfig(64, 64, 32, 5, 4),
|
||||
# GemmConfig(64, 128, 128, 3, 8),
|
||||
# GemmConfig(32, 64, 64, 3, 4),
|
||||
# # GemmConfig(64, 128, 256, 3, 8),
|
||||
# GemmConfig(64, 16, 32, 4, 2),
|
||||
# GemmConfig(128, 64, 64, 5, 4),
|
||||
# # GemmConfig(32, 128, 256, 4, 8),
|
||||
# GemmConfig(128, 32, 32, 5, 4),
|
||||
# GemmConfig(256, 32, 16, 3, 2),
|
||||
# GemmConfig(16, 128, 64, 1, 4),
|
||||
# GemmConfig(256, 64, 16, 1, 4),
|
||||
# GemmConfig(64, 128, 64, 4, 4),
|
||||
# ]
|
||||
# # return [
|
||||
# # GemmConfig(64, 128, 128, 4, 8),
|
||||
# # GemmConfig(128, 32, 32, 4, 4),
|
||||
# # GemmConfig(64, 128, 256, 4, 8),
|
||||
# # GemmConfig(64, 64, 128, 5, 4),
|
||||
# # GemmConfig(128, 32, 16, 5, 2),
|
||||
# # #GemmConfig(32, 256, 128, 4, 8),
|
||||
# # GemmConfig(32, 128, 128, 4, 4),
|
||||
# # GemmConfig(64, 64, 32, 5, 4),
|
||||
# # GemmConfig(64, 128, 128, 3, 8),
|
||||
# # GemmConfig(32, 64, 64, 3, 4),
|
||||
# # GemmConfig(64, 128, 256, 3, 8),
|
||||
# # GemmConfig(64, 16, 32, 4, 2),
|
||||
# # GemmConfig(128, 64, 64, 5, 4),
|
||||
# # GemmConfig(32, 128, 256, 4, 8),
|
||||
# # GemmConfig(128, 32, 32, 5, 4),
|
||||
# # # GemmConfig(64, 128, 128, 5, 4),
|
||||
# # GemmConfig(256, 32, 16, 3, 2),
|
||||
# # GemmConfig(16, 128, 64, 1, 4),
|
||||
# # GemmConfig(256, 64, 16, 1, 4),
|
||||
# # GemmConfig(64, 128, 64, 4, 4),
|
||||
# # ]
|
||||
# # return [
|
||||
# # # GemmConfig(16, 16, 128, 5, 1),
|
||||
# # # GemmConfig(16, 16, 256, 4, 1),
|
||||
# # GemmConfig(64, 16, 128, 4, 4),
|
||||
# # GemmConfig(64, 16, 256, 4, 4),
|
||||
# # GemmConfig(64, 32, 128, 4, 4),
|
||||
# # GemmConfig(64, 32, 128, 5, 8),
|
||||
# # # GemmConfig(63, 32, 256, 1, 8),
|
||||
# # GemmConfig(64, 64, 128, 4, 4),
|
||||
# # GemmConfig(64, 128, 64, 4, 4),
|
||||
# # GemmConfig(64, 128, 128, 3, 4),
|
||||
# # GemmConfig(128, 16, 128, 5, 8),
|
||||
# # GemmConfig(128, 128, 32, 5, 8),
|
||||
# # GemmConfig(128, 128, 64, 3, 4),
|
||||
# # GemmConfig(128, 128, 64, 3, 8),
|
||||
# # GemmConfig(128, 128, 64, 4, 4),
|
||||
# # GemmConfig(128, 128, 64, 4, 8),
|
||||
# # GemmConfig(128, 256, 32, 5, 8),
|
||||
# # GemmConfig(128, 256, 64, 3, 8),
|
||||
# # GemmConfig(128, 256, 64, 4, 8),
|
||||
# # #GemmConfig(128, 256, 64, 5, 8),
|
||||
# # GemmConfig(256, 128, 32, 5, 8)
|
||||
# # ]
|
||||
# else:
|
||||
return [
|
||||
GemmConfig(32, 32, 16, 1, 2),
|
||||
GemmConfig(32, 32, 128, 2, 4),
|
||||
GemmConfig(32, 64, 32, 5, 8),
|
||||
@ -115,6 +191,35 @@ class BaseConfigHeuristic(metaclass=BaseHeuristicSingleton):
|
||||
GemmConfig(128, 128, 64, 3, 4),
|
||||
GemmConfig(128, 128, 64, 5, 8),
|
||||
]
|
||||
# TODO more results on compile time overhead
|
||||
# TODO compare to exhaustive results
|
||||
# TODO ask about cudagraph benchmarking
|
||||
|
||||
def __init__(self) -> None:
|
||||
# List of dictionaries to store the kernel configs. Configs that evaluate to true
|
||||
# will be utilised on the target platform. The configs are as follows:
|
||||
# (BLOCK_M, BLOCK_N, BLOCK_K, num_stages, num_warps)
|
||||
# self.mm_configs: list[BaseConfig] = [
|
||||
# GemmConfig(32, 32, 16, 1, 2),
|
||||
# GemmConfig(32, 32, 128, 2, 4),
|
||||
# GemmConfig(32, 64, 32, 5, 8),
|
||||
# GemmConfig(64, 32, 32, 5, 8),
|
||||
# GemmConfig(64, 32, 128, 5, 4),
|
||||
# GemmConfig(64, 64, 16, 2, 4),
|
||||
# GemmConfig(64, 64, 32, 2, 4),
|
||||
# GemmConfig(64, 64, 64, 3, 8),
|
||||
# GemmConfig(64, 64, 128, 5, 4),
|
||||
# GemmConfig(64, 128, 32, 3, 4),
|
||||
# GemmConfig(64, 128, 32, 4, 8),
|
||||
# GemmConfig(64, 128, 64, 3, 4),
|
||||
# GemmConfig(64, 128, 128, 4, 4),
|
||||
# GemmConfig(128, 64, 32, 3, 4),
|
||||
# GemmConfig(128, 64, 32, 4, 8),
|
||||
# GemmConfig(128, 128, 32, 2, 8),
|
||||
# GemmConfig(128, 128, 32, 3, 4),
|
||||
# GemmConfig(128, 128, 64, 3, 4),
|
||||
# GemmConfig(128, 128, 64, 5, 8),
|
||||
# ]
|
||||
|
||||
# Exhaustive search for mm configs
|
||||
self.exhaustive_configs: list[BaseConfig] = [
|
||||
|
||||
Reference in New Issue
Block a user