mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
This PR adds support in train_decision if one wants to learn a heuristic for ranking. The main idea is that the user has to provide a number of choices the heuristic should return. I added a way to prune the learned decision tree such that it always returns the number of choices provided by the user. Pull Request resolved: https://github.com/pytorch/pytorch/pull/131705 Approved by: https://github.com/eellison
918 lines
32 KiB
Python
918 lines
32 KiB
Python
# mypy: ignore-errors
|
|
|
|
import itertools
|
|
import json
|
|
import logging
|
|
import math
|
|
import warnings
|
|
|
|
|
|
warnings.filterwarnings(
|
|
"ignore",
|
|
message="The behavior of DataFrame concatenation with empty or all-NA entries is deprecated",
|
|
)
|
|
|
|
from dataclasses import dataclass
|
|
|
|
import numpy as np
|
|
import pandas as pd # type: ignore[import-untyped]
|
|
from ah_tree import DecisionTree
|
|
from scipy.stats import gmean
|
|
from sklearn.model_selection import train_test_split
|
|
from sklearn.tree import DecisionTreeClassifier
|
|
from train import AHTrain
|
|
|
|
|
|
log = logging.getLogger(__name__)
|
|
DEBUG = True
|
|
if DEBUG:
|
|
ch = logging.StreamHandler()
|
|
ch.setLevel(logging.DEBUG)
|
|
formatter = logging.Formatter(
|
|
"%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S"
|
|
)
|
|
ch.setFormatter(formatter)
|
|
log.addHandler(ch)
|
|
|
|
|
|
class AHTrainDecisionTree(AHTrain):
|
|
def __init__(self):
|
|
super().__init__()
|
|
|
|
def is_unsafe_leaf(self, row, predicted_config, choice2time):
|
|
"""
|
|
Can be overridden by subclasses to define their own logic for deciding when a leaf is unsafe. Returns a sample
|
|
that landed in the leaf, the choice predicted by the tree, and a dictionary that maps each choice to the
|
|
execution time. One can for example decide to mark a leaf as unsafe if the predicted choice is 2x slower
|
|
than the fastest choice.
|
|
If a leaf is unsafe, the learned heuristic will always return 'unsure' if an input lands in that leaf.
|
|
"""
|
|
|
|
return False
|
|
|
|
def get_unsafe_leaves(self, model, df, feature_columns):
|
|
"""
|
|
Given a trained decision tree, and a dataframe containing the training data, returns a list of unsafe leaves.
|
|
"""
|
|
X = df[feature_columns]
|
|
y = df["winner"]
|
|
leaf_ids = model.apply(X)
|
|
unique_leaves = np.unique(leaf_ids)
|
|
|
|
unsafe_leaves = []
|
|
# Iterate over each leaf
|
|
for leaf in unique_leaves:
|
|
leaf_mask = leaf_ids == leaf
|
|
# Get samples that land in this leaf
|
|
leaf_X = X[leaf_mask]
|
|
|
|
predicted_config = model.predict(leaf_X.iloc[[0]])[0]
|
|
|
|
# For each sample, check if we should mark the leaf as unsafe
|
|
for idx, row in leaf_X.iterrows():
|
|
choice2time = json.loads(df.loc[idx, "choice2time"])
|
|
if self.is_unsafe_leaf(row, predicted_config, choice2time):
|
|
unsafe_leaves.append(leaf)
|
|
break
|
|
return unsafe_leaves
|
|
|
|
def get_allowed_wrong_prediction_pct(self):
|
|
"""
|
|
This is used to determine a threshold for when a learned heuristic returns 'unsure'.
|
|
If this function returns 0.01, we will set the probability required for the decision tree to return a decision
|
|
such that at most 1% of the predictions will be wrong on the validation set.
|
|
"""
|
|
return 0.01
|
|
|
|
def get_grid_search_values(self):
|
|
"""
|
|
Standard values for grid search. Can be overriden.
|
|
"""
|
|
return {
|
|
"max_depth": [5, 6, 7],
|
|
"min_samples_leaf": [1, 5, 10, 0.01, 0.05, 0.02],
|
|
"criterion": ["gini", "entropy"],
|
|
}
|
|
|
|
def predict(self, model, df, feature_columns):
|
|
"""
|
|
Returns the predictions, probabilities, and leaf ids for a given dataframe.
|
|
"""
|
|
predictions = model.predict(df[feature_columns])
|
|
proba = model.predict_proba(df[feature_columns])
|
|
leaf_ids = model.apply(df[feature_columns])
|
|
return predictions, proba, leaf_ids
|
|
|
|
def ranking_num_choices(self):
|
|
# if the heuristic is used for ranking, this function returns the number
|
|
# of choices that the heuristic will return
|
|
if self.args.ranking is None:
|
|
return 5
|
|
return self.args.ranking
|
|
|
|
def train_and_evaluate_models(
|
|
self,
|
|
datasets,
|
|
max_depths,
|
|
min_samples_leafs,
|
|
criterion_list,
|
|
feature_columns,
|
|
ranking=False,
|
|
):
|
|
"""
|
|
Does a grid search over max_depths, min_samples_leafs, and criterion_list and returns the best model.
|
|
"""
|
|
|
|
results = []
|
|
best_model = None
|
|
best_model_safe_proba = 0
|
|
best_model_num_correct = 0
|
|
best_model_num_wrong = 0
|
|
best_model_unsafe_leaves = []
|
|
columns = ["set", "crit", "max_depth", "min_samples_leaf"]
|
|
metrics_columns = []
|
|
for max_depth, min_samples_leaf, criterion in itertools.product(
|
|
max_depths, min_samples_leafs, criterion_list
|
|
):
|
|
print(
|
|
f"max_depth={max_depth} min_samples_leaf={min_samples_leaf} criterion={criterion}"
|
|
)
|
|
model = DecisionTreeClassifier(
|
|
max_depth=max_depth,
|
|
min_samples_leaf=min_samples_leaf,
|
|
criterion=criterion,
|
|
random_state=42,
|
|
)
|
|
df_train = datasets["train"]
|
|
df_val = datasets["val"]
|
|
if ranking:
|
|
model.fit(
|
|
df_train[feature_columns],
|
|
df_train["winner"],
|
|
sample_weight=df_train["relative_performance"],
|
|
)
|
|
else:
|
|
model.fit(df_train[feature_columns], df_train["winner"])
|
|
|
|
model = DecisionTree(model, feature_columns)
|
|
|
|
if ranking:
|
|
model.prune(df_train, "winner", k=self.ranking_num_choices())
|
|
|
|
unsafe_leaves = self.get_unsafe_leaves(model, df_train, feature_columns)
|
|
predictions, proba, leaf_ids = self.predict(model, df_val, feature_columns)
|
|
|
|
wrong_pct = self.get_allowed_wrong_prediction_pct()
|
|
evaluator = DecisionEvaluator(
|
|
self,
|
|
model,
|
|
predictions,
|
|
df_val,
|
|
proba,
|
|
wrong_pct=wrong_pct,
|
|
unsafe_leaves=unsafe_leaves,
|
|
leaf_ids=leaf_ids,
|
|
k=self.ranking_num_choices(),
|
|
ranking=ranking,
|
|
)
|
|
safe_proba = evaluator.get_safe_proba()
|
|
print(f"safe_proba={safe_proba}")
|
|
|
|
def eval(name, df):
|
|
if ranking:
|
|
# when ranking is enabled, we duplicate each input for each choice that
|
|
# is almost as good as the best choice
|
|
# we do not want to evaluate the same input multiple times, so we remove duplicates here
|
|
df = df[df["winner"] == df["actual_winner"]]
|
|
predictions, proba, leaf_ids = self.predict(model, df, feature_columns)
|
|
evaluator = DecisionEvaluator(
|
|
self,
|
|
model,
|
|
predictions,
|
|
df,
|
|
proba,
|
|
wrong_pct=wrong_pct,
|
|
threshold=safe_proba,
|
|
unsafe_leaves=unsafe_leaves,
|
|
leaf_ids=leaf_ids,
|
|
k=self.ranking_num_choices(),
|
|
ranking=ranking,
|
|
)
|
|
return evaluator.get_results()
|
|
|
|
for dataset_name, dataset in datasets.items():
|
|
eval_result: EvalResults = eval(dataset_name, dataset)
|
|
eval_result_metrics = eval_result.to_map()
|
|
if dataset_name == "val":
|
|
num_correct = eval_result.accuracy.num_correct
|
|
num_wrong = eval_result.accuracy.num_wrong
|
|
num_total = eval_result.accuracy.total
|
|
if num_wrong <= num_total * wrong_pct:
|
|
if num_correct > best_model_num_correct:
|
|
print(
|
|
f"new best model with {num_correct} correct and {num_wrong} wrong"
|
|
)
|
|
best_model = model
|
|
best_model_num_correct = num_correct
|
|
best_model_num_wrong = num_wrong
|
|
best_model_safe_proba = safe_proba
|
|
best_model_unsafe_leaves = unsafe_leaves
|
|
|
|
result = (dataset_name, criterion, max_depth, min_samples_leaf)
|
|
result += tuple(eval_result_metrics.values())
|
|
results.append(result)
|
|
if len(metrics_columns) == 0:
|
|
metrics_columns = list(eval_result_metrics.keys())
|
|
columns += metrics_columns
|
|
|
|
return (
|
|
pd.DataFrame(results, columns=columns),
|
|
best_model,
|
|
best_model_safe_proba,
|
|
best_model_unsafe_leaves,
|
|
)
|
|
|
|
def get_test_and_val_size(self):
|
|
"""
|
|
Returns the size of the test and validation sets.
|
|
"""
|
|
return (0.15, 0.15)
|
|
|
|
def prepare_datasets(self, df, other_datasets, cat_feature2cats, ranking=False):
|
|
"""
|
|
Splits the dataframe into train, val, and test sets.
|
|
Also adds other datasets, specified by the user, to the train set.
|
|
"""
|
|
test_size, val_size = self.get_test_and_val_size()
|
|
# Split into train+val and test
|
|
df_train_val, df_test = train_test_split(
|
|
df, test_size=test_size, random_state=42
|
|
)
|
|
|
|
# Split train+val inputs into train and val
|
|
train_val_size = 1 - test_size
|
|
df_train, df_val = train_test_split(
|
|
df_train_val, test_size=val_size / train_val_size, random_state=42
|
|
)
|
|
datasets = {"train": df_train, "val": df_val, "test": df_test}
|
|
self.add_real_datasets(datasets, other_datasets, cat_feature2cats, ranking)
|
|
return datasets
|
|
|
|
def export_to_dot(self, best_model, df, feature_columns):
|
|
"""
|
|
Export a learned decision tree to a dot file.
|
|
"""
|
|
dot_str = best_model.to_dot()
|
|
with open("best_model.dot", "w") as f:
|
|
f.write(dot_str)
|
|
|
|
def get_feature_columns(self, df):
|
|
"""
|
|
The dataframe contains columns that are not features, such as 'winner', 'speedup' that are only used for
|
|
debugging purposes. This function returns the columns that are actually features.
|
|
"""
|
|
exclude_columns = [
|
|
"speedup",
|
|
"winner",
|
|
"target",
|
|
"avail_choices",
|
|
"choice2time",
|
|
"index",
|
|
"actual_winner",
|
|
"relative_performance",
|
|
]
|
|
feature_columns = [col for col in df.columns if col not in exclude_columns]
|
|
return feature_columns
|
|
|
|
def add_training_data(self, df_train, datasets):
|
|
return datasets["train"]
|
|
|
|
def main(
|
|
self,
|
|
log_path,
|
|
other_datasets,
|
|
nrows,
|
|
heuristic_name,
|
|
save_dot=False,
|
|
ranking=False,
|
|
):
|
|
"""
|
|
Main function that trains a decision tree and generates a heuristic.
|
|
"""
|
|
# TODO: Enable apply_filters
|
|
(df, choices, cat_feature2cats, dummy_col_2_col_val, metadata) = self.get_df(
|
|
log_path, nrows=nrows, apply_filters=False, add_near_best=ranking
|
|
)
|
|
print(df["winner"].value_counts())
|
|
datasets = self.prepare_datasets(df, other_datasets, cat_feature2cats, ranking)
|
|
df_train = self.add_training_data(datasets["train"], datasets)
|
|
datasets["train"] = df_train
|
|
|
|
feature_columns = self.get_feature_columns(df)
|
|
grid_search_values = self.get_grid_search_values()
|
|
max_depths = grid_search_values["max_depth"]
|
|
min_samples_leafs = grid_search_values["min_samples_leaf"]
|
|
criterion_list = grid_search_values["criterion"]
|
|
(
|
|
results_df,
|
|
best_model,
|
|
best_model_safe_proba,
|
|
unsafe_leaves,
|
|
) = self.train_and_evaluate_models(
|
|
datasets,
|
|
max_depths,
|
|
min_samples_leafs,
|
|
criterion_list,
|
|
feature_columns,
|
|
ranking=ranking,
|
|
)
|
|
|
|
if ranking:
|
|
columns_to_keep = [
|
|
"set",
|
|
"total",
|
|
"top_k_correct",
|
|
"top_k_wrong",
|
|
"top_k_unsure",
|
|
"wrong_max_spdup_k",
|
|
"wrong_gman_spdup_k",
|
|
]
|
|
results_df = results_df[columns_to_keep]
|
|
# prints results for all models and datasets
|
|
print(results_df.to_string())
|
|
|
|
if not ranking:
|
|
# prints results grouped by dataset
|
|
for set_name in results_df["set"].unique():
|
|
dataset_results = results_df[results_df["set"] == set_name]
|
|
dataset_results = dataset_results.sort_values(by="correct")
|
|
print(dataset_results.to_string() + "\n")
|
|
|
|
if best_model is not None:
|
|
if save_dot:
|
|
self.export_to_dot(best_model, df, feature_columns)
|
|
self.codegen(
|
|
best_model,
|
|
metadata,
|
|
heuristic_name,
|
|
best_model_safe_proba,
|
|
dummy_col_2_col_val,
|
|
unsafe_leaves,
|
|
)
|
|
else:
|
|
print(
|
|
"All learned models have too many wrong predictions, so no heuristic was generated"
|
|
)
|
|
|
|
def get_df(
|
|
self,
|
|
log_path,
|
|
cat_feature2cats=None,
|
|
nrows=None,
|
|
apply_filters=False,
|
|
add_near_best=False,
|
|
):
|
|
"""
|
|
Parses the log file and processes the data into a dataframe that can be used for training.
|
|
"""
|
|
(df, metadata, features, categorical_features, choices) = self.parse_log(
|
|
log_path, nrows
|
|
)
|
|
|
|
def calculate_stats(group):
|
|
count = len(group)
|
|
has_inf = np.isinf(group["feedback"]).any()
|
|
if has_inf:
|
|
relative_std = np.inf
|
|
median = np.inf
|
|
else:
|
|
mean = group["feedback"].mean()
|
|
std = group["feedback"].std()
|
|
relative_std = (std / mean) * 100 if mean != 0 else np.inf
|
|
median = group["feedback"].median()
|
|
if relative_std > 5:
|
|
times = group["feedback"].tolist()
|
|
times_str = ", ".join([f"{t:.3f}" for t in sorted(times)])
|
|
log.debug("High relative std: %f. times=%s", relative_std, times_str)
|
|
return pd.Series(
|
|
{
|
|
"count": count,
|
|
"relative_std": relative_std,
|
|
"median_execution_time": median,
|
|
}
|
|
)
|
|
|
|
feature_columns = features
|
|
stats = (
|
|
df.groupby(feature_columns + ["choice"], as_index=False)
|
|
.apply(calculate_stats, include_groups=False)
|
|
.reset_index()
|
|
)
|
|
|
|
# TODO: We have to be careful with removing certain choices, because if we e.g. remove the winner, the
|
|
# heuristic will end up learning wrong things. But, execution times with high variance are also bad
|
|
if apply_filters:
|
|
# Filter out inputs with less than 3 measurements or high relative std
|
|
valid_stats = stats[(stats["count"] >= 3) & (stats["relative_std"] <= 5)]
|
|
# Group by input features and count how many valid choices we have for each input
|
|
valid_inputs = valid_stats.groupby(feature_columns).filter(
|
|
lambda x: len(x) >= 2
|
|
)
|
|
else:
|
|
valid_inputs = stats
|
|
|
|
# Compute the winner and speedup for each valid input
|
|
def get_winner_and_speedup(group):
|
|
assert len(group) >= 2, "Need at least 2 choices"
|
|
|
|
sorted_group = group.sort_values("median_execution_time")
|
|
winner = sorted_group.iloc[0]["choice"]
|
|
winning_time = sorted_group.iloc[0]["median_execution_time"]
|
|
second_best_time = sorted_group.iloc[1]["median_execution_time"]
|
|
speedup = second_best_time / winning_time
|
|
unique_choices = group["choice"].unique()
|
|
|
|
choice2time = {}
|
|
for row in group.itertuples():
|
|
choice2time[row.choice] = row.median_execution_time
|
|
|
|
assert len(unique_choices) == len(
|
|
group
|
|
), f"len(unique_choices) != len(group): {len(unique_choices)} != {len(group)}"
|
|
|
|
return pd.Series(
|
|
{
|
|
"winner": winner,
|
|
"speedup": speedup,
|
|
"avail_choices": unique_choices,
|
|
"choice2time": json.dumps(choice2time),
|
|
}
|
|
)
|
|
|
|
results = (
|
|
valid_inputs.groupby(feature_columns, as_index=False)
|
|
.filter(lambda x: len(x) >= 2)
|
|
.groupby(feature_columns, as_index=False)
|
|
.apply(get_winner_and_speedup, include_groups=False)
|
|
.reset_index()
|
|
)
|
|
|
|
def add_near_best_configs(df):
|
|
new_rows = []
|
|
|
|
for index, row in df.iterrows():
|
|
dictionary = json.loads(row["choice2time"])
|
|
min_value = min(dictionary.values())
|
|
|
|
for key, value in dictionary.items():
|
|
new_row = row.copy()
|
|
relative_performance = min_value / value
|
|
new_row["relative_performance"] = relative_performance
|
|
if relative_performance is None or relative_performance is np.inf:
|
|
breakpoint()
|
|
new_row["actual_winner"] = row["winner"]
|
|
new_row["winner"] = key
|
|
if relative_performance >= 0.95:
|
|
new_rows.append(new_row)
|
|
|
|
return pd.DataFrame(new_rows).reset_index(drop=True)
|
|
|
|
if add_near_best:
|
|
results = add_near_best_configs(results)
|
|
(results, added_categorical_features) = self.add_new_features(results)
|
|
categorical_features += added_categorical_features
|
|
|
|
(
|
|
results,
|
|
cat_feature2cats,
|
|
dummy_col_2_col_val,
|
|
) = self.handle_categorical_features(
|
|
cat_feature2cats, categorical_features, results
|
|
)
|
|
return (results, choices, cat_feature2cats, dummy_col_2_col_val, metadata)
|
|
|
|
def gen_classes(self, classes, num_spaces):
|
|
"""
|
|
If classes=['choice1', 'choice2', 'choice3'], then this function returns
|
|
the following string:
|
|
self.choices.append('choice1')
|
|
self.choices.append('choice2')
|
|
self.choices.append('choice3')
|
|
Used in the generated heuristic to map the index of a choice to its name.
|
|
"""
|
|
indent = " " * num_spaces
|
|
return "\n".join([f"{indent}self.choices.append('{c}')" for c in classes])
|
|
|
|
def get_default_config(self, row):
|
|
"""
|
|
Returns the default config for a given sample. The default config could for example be the config that is
|
|
the chosen by a current handwritten heuristic. This can for example be used in get_unsafe_leaf to
|
|
compare the predicted config with the default config.
|
|
"""
|
|
return None
|
|
|
|
def gen_predict_fn_def(self):
|
|
"""
|
|
Generates the definition of the predict function.
|
|
"""
|
|
return "def get_best_choices(self, context: AHContext) -> Optional[List[Tuple[float, int]]]:"
|
|
|
|
def codegen_boilerplate(
|
|
self, heuristic_name, opt_name, threshold, shared_memory, device_capa, classes
|
|
):
|
|
"""
|
|
Generates the boilerplate code for the generated heuristic. This includes things like imports, class definition,
|
|
etc.
|
|
"""
|
|
|
|
boiler_plate = f"""# flake8: noqa: B950
|
|
# fmt: off
|
|
# This file was generated by AutoHeuristic. Do not modify it manually!
|
|
# To regenerate this file, take a look at the steps in the README.md file inside torchgen/_autoheuristic/{opt_name}/
|
|
from typing import List, Optional, Tuple
|
|
|
|
from torch._inductor.autoheuristic.autoheuristic_utils import (
|
|
AHContext,
|
|
AHMetadata,
|
|
Choice,
|
|
)
|
|
from torch._inductor.autoheuristic.learnedheuristic_interface import (
|
|
LearnedHeuristicDecision,
|
|
)
|
|
|
|
|
|
class {heuristic_name}(LearnedHeuristicDecision):
|
|
|
|
def __init__(self) -> None:
|
|
self.choices: List[Choice] = []
|
|
self.fill_choices()
|
|
|
|
{self.gen_precondition(opt_name, shared_memory, device_capa)}
|
|
|
|
def get_confidence_threshold(self) -> float:
|
|
return {threshold}
|
|
|
|
def get_choice(self, idx: int) -> Optional[str]:
|
|
if idx < len(self.choices):
|
|
return self.choices[idx]
|
|
return None
|
|
|
|
def fill_choices(self) -> None:
|
|
{self.gen_classes(classes, num_spaces=8)}
|
|
|
|
def get_name(self) -> str:
|
|
return '{opt_name}'"""
|
|
return boiler_plate
|
|
|
|
def add_real_datasets(
|
|
self, datasets, other_datasets, cat_feature2cats, ranking=False
|
|
):
|
|
"""
|
|
Adds datasets specified by the user to the datasets dictionary.
|
|
"""
|
|
if other_datasets:
|
|
for name, path in other_datasets:
|
|
(df_other, choices, _, _, _) = self.get_df(
|
|
path,
|
|
cat_feature2cats=cat_feature2cats,
|
|
apply_filters=False,
|
|
add_near_best=ranking,
|
|
)
|
|
datasets[name] = df_other
|
|
|
|
def codegen(
|
|
self,
|
|
tree,
|
|
metadata,
|
|
heuristic_name,
|
|
threshold,
|
|
dummy_col_2_col_val,
|
|
unsafe_leaves,
|
|
):
|
|
lines = []
|
|
device_capa = metadata["device_capa"]
|
|
device_capa_str = f"({device_capa[0]}, {device_capa[1]})"
|
|
opt_name = metadata["name"]
|
|
lines.append(
|
|
self.codegen_boilerplate(
|
|
heuristic_name,
|
|
opt_name,
|
|
threshold,
|
|
metadata["shared_memory"],
|
|
device_capa_str,
|
|
tree.classes_,
|
|
)
|
|
)
|
|
fn_def = f"\n {self.gen_predict_fn_def()}"
|
|
lines.append(fn_def)
|
|
tree.codegen(dummy_col_2_col_val, lines, unsafe_leaves)
|
|
self.write_heuristic_to_file(lines, heuristic_name)
|
|
|
|
|
|
@dataclass
|
|
class AccuracyMetrics:
|
|
# Number of correct predictions
|
|
num_correct: int
|
|
# Number of wrong predictions
|
|
num_wrong: int
|
|
# Number of predictions where model is unsure
|
|
num_unsure: int
|
|
# Total number of predictions
|
|
total: int
|
|
|
|
def to_map(self):
|
|
return {
|
|
"correct": self.num_correct,
|
|
"wrong": self.num_wrong,
|
|
"unsure": self.num_unsure,
|
|
"total": self.total,
|
|
}
|
|
|
|
|
|
@dataclass
|
|
class WrongSpeedupMetrics:
|
|
# If the model predicted the wrong choice, this is the maximum speedup of the best choice over the predicted choice
|
|
max_speedup: float
|
|
# For all wrong predictions, this is the geometric mean of the speedups of the best choices over the predicted choices
|
|
gmean_speedup: float
|
|
|
|
def to_map(self):
|
|
return {
|
|
"wrong_max_speedup": self.max_speedup,
|
|
"wrong_gmean_speedup": self.gmean_speedup,
|
|
}
|
|
|
|
|
|
@dataclass
|
|
class RankingMetrics:
|
|
# Number of predictions where best choice is in top k choices
|
|
num_correct: int
|
|
# Number of predictions where best choice is not in top k choices
|
|
num_wrong: int
|
|
# Maximum speedup of best choice over best choice in top k (this tells us how much better the best choice, which
|
|
# is not in top k, is over the best choice in top k)
|
|
max_speedup: float
|
|
# Geometric mean of speedups of best choice over best choice in top k
|
|
gmean_speedup: float
|
|
# Number of predictions where model is unsure
|
|
unsure: int
|
|
|
|
def to_map(self):
|
|
return {
|
|
"top_k_correct": self.num_correct,
|
|
"top_k_wrong": self.num_wrong,
|
|
"wrong_max_speedup_k": self.max_speedup,
|
|
"wrong_gmean_speedup_k": self.gmean_speedup,
|
|
"top_k_unsure": self.unsure,
|
|
}
|
|
|
|
|
|
@dataclass
|
|
class DefaultComparisonMetrics:
|
|
# Maximum speedup of predicted choice over default choice
|
|
max_speedup: float
|
|
# Geometric mean of speedups of predicted choices over default choices
|
|
gmean_speedup: float
|
|
# Maximum speedup of default choice over predicted choice
|
|
max_slowdown: float
|
|
# Number of predictions where the predicted choice is not the default choice
|
|
non_default_predictions: int
|
|
# Number of predictions where the default choice is better than the predicted choice
|
|
default_better: bool
|
|
|
|
def to_map(self):
|
|
return {
|
|
"max_speedup_over_default": self.max_speedup,
|
|
"gmean_speedup_over_default": self.gmean_speedup,
|
|
"max_speedup_default_over_heuristic": self.max_slowdown,
|
|
"non_default_predictions": self.non_default_predictions,
|
|
"default_better": self.default_better,
|
|
}
|
|
|
|
|
|
@dataclass
|
|
class EvalResults:
|
|
accuracy: AccuracyMetrics
|
|
speedup: WrongSpeedupMetrics
|
|
ranking: RankingMetrics
|
|
default_comparison: DefaultComparisonMetrics
|
|
|
|
def to_map(self):
|
|
return {
|
|
**self.accuracy.to_map(),
|
|
**self.speedup.to_map(),
|
|
**self.ranking.to_map(),
|
|
**self.default_comparison.to_map(),
|
|
}
|
|
|
|
|
|
class DecisionEvaluator:
|
|
def __init__(
|
|
self,
|
|
train,
|
|
model,
|
|
predictions,
|
|
df,
|
|
probas,
|
|
wrong_pct=0.01,
|
|
threshold=0.0,
|
|
k=10,
|
|
unsafe_leaves=None,
|
|
leaf_ids=None,
|
|
ranking=False,
|
|
) -> None:
|
|
self.train = train
|
|
self.model = model
|
|
self.predictions = predictions
|
|
self.df = df
|
|
self.probas = probas
|
|
self.wrong_pct = wrong_pct
|
|
self.threshold = threshold
|
|
self.k = k
|
|
self.unsafe_leaves = unsafe_leaves
|
|
self.leaf_ids = leaf_ids
|
|
self.ranking = ranking
|
|
|
|
self.num_correct = 0
|
|
self.num_wrong = 0
|
|
self.num_unsure = 0
|
|
self.wrong_probas = []
|
|
self.speedups_wrong = []
|
|
self.num_correct_top_k = 0
|
|
self.num_wrong_top_k = 0
|
|
self.wrong_speedups_top_k = []
|
|
self.top_k_unsure = 0
|
|
self.num_non_default_predictions = 0
|
|
self.speedups_over_default = []
|
|
self.num_default_better = 0
|
|
|
|
def compute_speedup_over_default(self, default_config, pred, i, predicted_time):
|
|
if default_config is not None:
|
|
if pred != default_config:
|
|
self.num_non_default_predictions += 1
|
|
default_time = self.get_time(self.df.iloc[i], default_config)
|
|
# TODO: We should keep track of how often this happens
|
|
if default_time is not None and not math.isinf(default_time):
|
|
speedup_over_default = default_time / predicted_time
|
|
if speedup_over_default < 1:
|
|
self.num_default_better += 1
|
|
self.speedups_over_default.append(speedup_over_default)
|
|
else:
|
|
log.debug(
|
|
"cannot compute speedup over default because default_time=%d",
|
|
default_time,
|
|
)
|
|
|
|
def get_time(self, row, choice):
|
|
choices_feedback = json.loads(row["choice2time"])
|
|
return choices_feedback.get(choice, None)
|
|
|
|
def top_k_classes(self, model, probas, k, avail_choices):
|
|
# Get classes and their corresponding probabilities
|
|
classes = model.classes_
|
|
class_proba_pairs = list(zip(classes, probas))
|
|
|
|
# Sort by probability (descending) and filter out zero probabilities
|
|
sorted_classes = [
|
|
c
|
|
for c, p in sorted(zip(classes, probas), key=lambda x: x[1], reverse=True)
|
|
if p > 0 and c in avail_choices
|
|
]
|
|
|
|
# Return top k choices
|
|
return sorted_classes[:k]
|
|
|
|
def eval_prediction(
|
|
self, avail_choices, leaf_id, pred, true, prob, threshold, default_config, i
|
|
):
|
|
predicted_time = self.get_time(self.df.iloc[i], pred)
|
|
max_prob = max(prob)
|
|
if (
|
|
leaf_id in self.unsafe_leaves
|
|
or pred not in avail_choices
|
|
or (max_prob != 1.0 and max_prob <= threshold)
|
|
):
|
|
self.num_unsure += 1
|
|
self.speedups_over_default.append(1.0)
|
|
elif pred == true:
|
|
self.compute_speedup_over_default(default_config, pred, i, predicted_time)
|
|
self.num_correct += 1
|
|
else:
|
|
self.compute_speedup_over_default(default_config, pred, i, predicted_time)
|
|
self.num_wrong += 1
|
|
self.wrong_probas.append(max_prob)
|
|
best_time = self.get_time(self.df.iloc[i], true)
|
|
wrong_speedup = predicted_time / best_time
|
|
self.speedups_wrong.append(wrong_speedup)
|
|
|
|
def eval_ranking_prediction(self, true, top_k_choices, i):
|
|
if true in top_k_choices:
|
|
self.num_correct_top_k += 1
|
|
else:
|
|
top_k_choices_times = []
|
|
for choice in top_k_choices:
|
|
time = self.get_time(self.df.iloc[i], choice)
|
|
if time is not None:
|
|
top_k_choices_times.append(time)
|
|
best_time = self.get_time(self.df.iloc[i], true)
|
|
min_time = min(top_k_choices_times, default=None)
|
|
if min_time is not None:
|
|
speedup = min_time / best_time
|
|
self.wrong_speedups_top_k.append(speedup)
|
|
self.num_wrong_top_k += 1
|
|
else:
|
|
self.top_k_unsure += 1
|
|
# TODO (AlnisM): print more info (input and choices)
|
|
log.debug(
|
|
"All top k choices have no time which means all top k are unavailable"
|
|
)
|
|
|
|
def get_safe_proba(self):
|
|
return self.get_results(return_safe_proba=True)
|
|
|
|
def compute_safe_proba(self, num_predictions, wrong_probas, wrong_pct):
|
|
wrong_probas.sort()
|
|
num_wrong = len(wrong_probas)
|
|
allowed_wrong = int(num_predictions * wrong_pct)
|
|
if allowed_wrong >= num_wrong:
|
|
return 0.0
|
|
too_many_wrong = num_wrong - allowed_wrong
|
|
idx = min(too_many_wrong, len(wrong_probas) - 1)
|
|
return wrong_probas[idx]
|
|
|
|
def get_results(self, return_safe_proba=False) -> EvalResults:
|
|
"""
|
|
Custom evaluation function that evaluates a learned decision tree.
|
|
"""
|
|
|
|
y_true = self.df["actual_winner"] if self.ranking else self.df["winner"]
|
|
i = 0
|
|
for pred, true, prob, leaf_id in zip(
|
|
self.predictions, y_true, self.probas, self.leaf_ids
|
|
):
|
|
avail_choices = self.df["avail_choices"].iloc[i]
|
|
top_k_choices = self.top_k_classes(
|
|
self.model, prob, k=self.k, avail_choices=avail_choices
|
|
)
|
|
assert (
|
|
true in avail_choices
|
|
), f"Best choice {true} not in available choices {avail_choices}"
|
|
default_config = self.train.get_default_config(self.df.iloc[i])
|
|
self.eval_prediction(
|
|
avail_choices,
|
|
leaf_id,
|
|
pred,
|
|
true,
|
|
prob,
|
|
self.threshold,
|
|
default_config,
|
|
i,
|
|
)
|
|
self.eval_ranking_prediction(true, top_k_choices, i)
|
|
i += 1
|
|
|
|
total = len(self.predictions)
|
|
if return_safe_proba:
|
|
return self.compute_safe_proba(total, self.wrong_probas, self.wrong_pct)
|
|
|
|
def safe_gmean(x):
|
|
return gmean(x) if x else 0
|
|
|
|
max_speedup = max(self.speedups_wrong, default=0)
|
|
gmean_speedup = safe_gmean(self.speedups_wrong)
|
|
max_speedup_top_k = max(self.wrong_speedups_top_k, default=0)
|
|
gmean_speedup_top_k = safe_gmean(self.wrong_speedups_top_k)
|
|
max_speedup_over_default = max(self.speedups_over_default, default=0)
|
|
gmean_speedup_over_default = safe_gmean(self.speedups_over_default)
|
|
max_slowdown_over_default = min(self.speedups_over_default, default=0)
|
|
|
|
accuracyMetrics = AccuracyMetrics(
|
|
self.num_correct, self.num_wrong, self.num_unsure, total
|
|
)
|
|
wrongSpeedupMetrics = WrongSpeedupMetrics(max_speedup, gmean_speedup)
|
|
rankingMetrics = RankingMetrics(
|
|
self.num_correct_top_k,
|
|
self.num_wrong_top_k,
|
|
max_speedup_top_k,
|
|
gmean_speedup_top_k,
|
|
self.top_k_unsure,
|
|
)
|
|
defaultComparisonMetrics = DefaultComparisonMetrics(
|
|
max_speedup_over_default,
|
|
gmean_speedup_over_default,
|
|
max_slowdown_over_default,
|
|
self.num_non_default_predictions,
|
|
self.num_default_better,
|
|
)
|
|
return EvalResults(
|
|
accuracyMetrics,
|
|
wrongSpeedupMetrics,
|
|
rankingMetrics,
|
|
defaultComparisonMetrics,
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
train = AHTrainDecisionTree()
|
|
train.generate_heuristic()
|