AutoHeuristic: Support ranking/pruning choices (#131705)

This PR adds support in train_decision if one wants to learn a heuristic for ranking. The main idea is that the user has to provide a number of choices the heuristic should return. I added a way to prune the learned decision tree such that it always returns the number of choices provided by the user.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/131705
Approved by: https://github.com/eellison
This commit is contained in:
Alnis Murtovi
2024-08-15 13:02:51 -07:00
committed by PyTorch MergeBot
parent 929d2f8253
commit add0f0085c
4 changed files with 515 additions and 133 deletions

View File

@ -0,0 +1,262 @@
from typing import Any, Dict, List, Optional, Tuple
import numpy as np
from sklearn.tree import _tree # type: ignore[import-untyped]
class DecisionTreeNode:
def __init__(
self,
feature: Optional[str] = None,
threshold: Optional[float] = None,
left: Optional["DecisionTreeNode"] = None,
right: Optional["DecisionTreeNode"] = None,
class_probs: Any = None,
num_samples: int = 0,
node_id: int = 0,
) -> None:
self.feature = feature
self.threshold = threshold
self.left = left
self.right = right
self.class_probs = class_probs
self.num_samples = num_samples
self.id = node_id
def is_leaf(self) -> bool:
return self.left is None or self.right is None
class DecisionTree:
"""
Custom decision tree implementation that mimics some of the sklearn API.
The purpose of this class it to be able to perform transformations, such as custom pruning, which
does not seem to be easy with sklearn.
"""
def __init__(self, sklearn_tree: Any, feature_names: List[str]) -> None:
self.feature_names = feature_names
self.root = self._convert_sklearn_tree(sklearn_tree.tree_)
self.classes_: List[str] = sklearn_tree.classes_
def _convert_sklearn_tree(
self, sklearn_tree: Any, node_id: int = 0
) -> DecisionTreeNode:
class_probs = sklearn_tree.value[node_id][0]
num_samples = sklearn_tree.n_node_samples[node_id]
if sklearn_tree.feature[node_id] != _tree.TREE_UNDEFINED:
feature_index = sklearn_tree.feature[node_id]
feature = self.feature_names[feature_index]
left = self._convert_sklearn_tree(
sklearn_tree, sklearn_tree.children_left[node_id]
)
right = self._convert_sklearn_tree(
sklearn_tree, sklearn_tree.children_right[node_id]
)
return DecisionTreeNode(
feature=feature,
threshold=sklearn_tree.threshold[node_id],
left=left,
right=right,
class_probs=class_probs,
num_samples=num_samples,
node_id=node_id,
)
else:
return DecisionTreeNode(
class_probs=class_probs, num_samples=num_samples, node_id=node_id
)
def prune(self, df: Any, target_col: str, k: int) -> None:
self.root = self._prune_tree(self.root, df, target_col, k)
def _prune_tree(
self, node: DecisionTreeNode, df: Any, target_col: str, k: int
) -> DecisionTreeNode:
if node.is_leaf():
return node
left_df = df[df[node.feature] <= node.threshold]
right_df = df[df[node.feature] > node.threshold]
# number of unique classes in the left and right subtrees
left_counts = left_df[target_col].nunique()
right_counts = right_df[target_col].nunique()
# for ranking, we want to ensure that we return at least k classes, so if we have less than k classes in the
# left or right subtree, we remove the split and make this node a leaf node
if left_counts < k or right_counts < k:
return DecisionTreeNode(class_probs=node.class_probs)
assert node.left is not None, "expected left child to exist"
node.left = self._prune_tree(node.left, left_df, target_col, k)
assert node.right is not None, "expected right child to exist"
node.right = self._prune_tree(node.right, right_df, target_col, k)
return node
def to_dot(self) -> str:
dot = "digraph DecisionTree {\n"
dot += ' node [fontname="helvetica"];\n'
dot += ' edge [fontname="helvetica"];\n'
dot += self._node_to_dot(self.root)
dot += "}"
return dot
def _node_to_dot(
self, node: DecisionTreeNode, parent_id: int = 0, edge_label: str = ""
) -> str:
if node is None:
return ""
node_id = id(node)
# Format class_probs array with line breaks
class_probs_str = self._format_class_probs_array(
node.class_probs, node.num_samples
)
if node.is_leaf():
label = class_probs_str
shape = "box"
else:
feature_name = f"{node.feature}"
label = f"{feature_name} <= {node.threshold:.2f}\\n{class_probs_str}"
shape = "oval"
dot = f' {node_id} [label="{label}", shape={shape}];\n'
if parent_id != 0:
dot += f' {parent_id} -> {node_id} [label="{edge_label}"];\n'
if not node.is_leaf():
assert node.left is not None, "expected left child to exist"
dot += self._node_to_dot(node.left, node_id, "<=")
assert node.right is not None, "expected right child to exist"
dot += self._node_to_dot(node.right, node_id, ">")
return dot
def _format_class_prob(self, num: float) -> str:
if num == 0:
return "0"
return f"{num:.2f}"
def _format_class_probs_array(
self, class_probs: Any, num_samples: int, max_per_line: int = 5
) -> str:
# add line breaks to avoid very long lines
flat_class_probs = class_probs.flatten()
formatted = [self._format_class_prob(v) for v in flat_class_probs]
lines = [
formatted[i : i + max_per_line]
for i in range(0, len(formatted), max_per_line)
]
return f"num_samples={num_samples}\\n" + "\\n".join(
[", ".join(line) for line in lines]
)
def predict(self, X: Any) -> Any:
predictions = [self._predict_single(x) for _, x in X.iterrows()]
return np.array(predictions)
def predict_proba(self, X: Any) -> Any:
return np.array([self._predict_proba_single(x) for _, x in X.iterrows()])
def _get_leaf(self, X: Any) -> DecisionTreeNode:
node = self.root
while not node.is_leaf():
if X[node.feature] <= node.threshold:
assert node.left is not None, "expected left child to exist"
node = node.left
else:
assert node.right is not None, "expected right child to exist"
node = node.right
return node
def _predict_single(self, x: Any) -> str:
node = self._get_leaf(x)
# map index to class name
return self.classes_[np.argmax(node.class_probs)]
def _predict_proba_single(self, x: Any) -> Any:
node = self._get_leaf(x)
return node.class_probs
def apply(self, X: Any) -> Any:
ids = [self._apply_single(x) for _, x in X.iterrows()]
return np.array(ids)
def _apply_single(self, x: Any) -> int:
node = self._get_leaf(x)
return node.id
def codegen(
self,
dummy_col_2_col_val: Dict[str, Tuple[str, Any]],
lines: List[str],
unsafe_leaves: List[int],
) -> None:
# generates python code for the decision tree
def codegen_node(node: DecisionTreeNode, depth: int) -> None:
indent = " " * (depth + 1)
if node.is_leaf():
lines.append(handle_leaf(node, indent, unsafe_leaves))
else:
name = node.feature
threshold = node.threshold
if name in dummy_col_2_col_val:
(orig_name, value) = dummy_col_2_col_val[name]
predicate = f"{indent}if str(context.get_value('{orig_name}')) != '{value}':"
assert (
threshold == 0.5
), f"expected threshold to be 0.5 but is {threshold}"
else:
predicate = (
f"{indent}if context.get_value('{name}') <= {threshold}:"
)
lines.append(predicate)
assert node.left is not None, "expected left child to exist"
codegen_node(node.left, depth + 1)
lines.append(f"{indent}else:")
assert node.right is not None, "expected right child to exist"
codegen_node(node.right, depth + 1)
def handle_leaf(
node: DecisionTreeNode, indent: str, unsafe_leaves: List[int]
) -> str:
"""
This generates the code for a leaf node in the decision tree. If the leaf is unsafe, the learned heuristic
will return "unsure" (i.e. None).
"""
if node.id in unsafe_leaves:
return f"{indent}return None"
class_probas = node.class_probs
return f"{indent}return {best_probas_and_indices(class_probas)}"
def best_probas_and_indices(class_probas: Any) -> str:
"""
Given a list of tuples (proba, idx), this function returns a string in which the tuples are
sorted by proba in descending order. E.g.:
Given class_probas=[(0.3, 0), (0.5, 1), (0.2, 2)]
this function returns
"[(0.5, 1), (0.3, 0), (0.2, 2)]"
"""
# we generate a list of tuples (proba, idx) sorted by proba in descending order
# idx is the index of a choice
# we only generate a tuple if proba > 0
probas_indices_sorted = sorted(
[
(proba, index)
for index, proba in enumerate(class_probas)
if proba > 0
],
key=lambda x: x[0],
reverse=True,
)
probas_indices_sorted_str = ", ".join(
f"({value:.3f}, {index})" for value, index in probas_indices_sorted
)
return f"[{probas_indices_sorted_str}]"
codegen_node(self.root, 1)

View File

@ -2,7 +2,6 @@
import argparse
import json
import sys
import warnings
import pandas as pd # type: ignore[import-untyped]
@ -64,6 +63,15 @@ class AHTrain:
action="store_true",
help="Export heuristic to graphviz dot.",
)
self.parser.add_argument(
"--ranking",
type=int,
default=None,
help="""
Makes AutoHeuristic learn a heuristic that ranks choices instead of predicting a single choice.
The argument is the number of choices the heuristic will provide.
""",
)
def parse_args(self):
return self.parser.parse_args()
@ -87,6 +95,7 @@ class AHTrain:
self.args.nrows,
self.args.heuristic_name,
self.args.save_dot,
self.args.ranking is not None,
)
def filter_df(self, df):
@ -138,9 +147,6 @@ class AHTrain:
and str(metadata.device_capa) == "{device_capa}"
)"""
def handle_leaf(self, tree_, node, indent, unsafe_leaves):
pass
def codegen_boilerplate(
self, heuristic_name, opt_name, threshold, shared_memory, device_capa, dt
):
@ -149,63 +155,7 @@ class AHTrain:
def gen_predict_fn_def(self):
pass
def dt_to_python(
self,
dt,
metadata,
feature_names,
dummy_col_2_col_val,
heuristic_name,
threshold,
unsafe_leaves=None,
):
tree_ = dt.tree_
feature_name = [
feature_names[i] if i != -1 else "undefined!" for i in tree_.feature
]
lines = []
device_capa = metadata["device_capa"]
device_capa_str = f"({device_capa[0]}, {device_capa[1]})"
opt_name = metadata["name"]
lines.append(
self.codegen_boilerplate(
heuristic_name,
opt_name,
threshold,
metadata["shared_memory"],
device_capa_str,
dt,
)
)
fn_def = f"\n {self.gen_predict_fn_def()}"
lines.append(fn_def)
def dt_to_python(node, depth):
indent = " " * (depth + 1)
false_predicate = ""
if tree_.feature[node] != -2:
name = feature_name[node]
threshold = tree_.threshold[node]
if name in dummy_col_2_col_val:
(orig_name, value) = dummy_col_2_col_val[name]
predicate = f"{indent}if str(context.get_value('{orig_name}')) != '{value}':"
if threshold != 0.5:
print(f"expected threshold to be 0.5 but is {threshold}")
sys.exit(1)
else:
predicate = (
f"{indent}if context.get_value('{name}') <= {threshold}:"
)
lines.append(predicate)
dt_to_python(tree_.children_left[node], depth + 1)
lines.append(f"{indent}else:")
dt_to_python(tree_.children_right[node], depth + 1)
else:
lines.append(self.handle_leaf(tree_, node, indent, unsafe_leaves))
dt_to_python(0, 1)
def write_heuristic_to_file(self, lines, heuristic_name):
output_file = (
f"../../../torch/_inductor/autoheuristic/artifacts/_{heuristic_name}.py"
)

View File

@ -16,6 +16,7 @@ from dataclasses import dataclass
import numpy as np
import pandas as pd # type: ignore[import-untyped]
from ah_tree import DecisionTree
from scipy.stats import gmean
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier
@ -102,8 +103,21 @@ class AHTrainDecisionTree(AHTrain):
leaf_ids = model.apply(df[feature_columns])
return predictions, proba, leaf_ids
def ranking_num_choices(self):
# if the heuristic is used for ranking, this function returns the number
# of choices that the heuristic will return
if self.args.ranking is None:
return 5
return self.args.ranking
def train_and_evaluate_models(
self, datasets, max_depths, min_samples_leafs, criterion_list, feature_columns
self,
datasets,
max_depths,
min_samples_leafs,
criterion_list,
feature_columns,
ranking=False,
):
"""
Does a grid search over max_depths, min_samples_leafs, and criterion_list and returns the best model.
@ -131,7 +145,20 @@ class AHTrainDecisionTree(AHTrain):
)
df_train = datasets["train"]
df_val = datasets["val"]
model.fit(df_train[feature_columns], df_train["winner"])
if ranking:
model.fit(
df_train[feature_columns],
df_train["winner"],
sample_weight=df_train["relative_performance"],
)
else:
model.fit(df_train[feature_columns], df_train["winner"])
model = DecisionTree(model, feature_columns)
if ranking:
model.prune(df_train, "winner", k=self.ranking_num_choices())
unsafe_leaves = self.get_unsafe_leaves(model, df_train, feature_columns)
predictions, proba, leaf_ids = self.predict(model, df_val, feature_columns)
@ -145,11 +172,18 @@ class AHTrainDecisionTree(AHTrain):
wrong_pct=wrong_pct,
unsafe_leaves=unsafe_leaves,
leaf_ids=leaf_ids,
k=self.ranking_num_choices(),
ranking=ranking,
)
safe_proba = evaluator.get_safe_proba()
print(f"safe_proba={safe_proba}")
def eval(name, df):
if ranking:
# when ranking is enabled, we duplicate each input for each choice that
# is almost as good as the best choice
# we do not want to evaluate the same input multiple times, so we remove duplicates here
df = df[df["winner"] == df["actual_winner"]]
predictions, proba, leaf_ids = self.predict(model, df, feature_columns)
evaluator = DecisionEvaluator(
self,
@ -161,6 +195,8 @@ class AHTrainDecisionTree(AHTrain):
threshold=safe_proba,
unsafe_leaves=unsafe_leaves,
leaf_ids=leaf_ids,
k=self.ranking_num_choices(),
ranking=ranking,
)
return evaluator.get_results()
@ -202,7 +238,7 @@ class AHTrainDecisionTree(AHTrain):
"""
return (0.15, 0.15)
def prepare_datasets(self, df, other_datasets, cat_feature2cats):
def prepare_datasets(self, df, other_datasets, cat_feature2cats, ranking=False):
"""
Splits the dataframe into train, val, and test sets.
Also adds other datasets, specified by the user, to the train set.
@ -219,24 +255,16 @@ class AHTrainDecisionTree(AHTrain):
df_train_val, test_size=val_size / train_val_size, random_state=42
)
datasets = {"train": df_train, "val": df_val, "test": df_test}
self.add_real_datasets(datasets, other_datasets, cat_feature2cats)
self.add_real_datasets(datasets, other_datasets, cat_feature2cats, ranking)
return datasets
def export_to_dot(self, best_model, df, feature_columns):
"""
Export a learned decision tree to a dot file.
"""
from sklearn import tree
tree.export_graphviz(
best_model,
out_file="best_model.dot",
feature_names=df[feature_columns].columns,
class_names=[str(c) for c in best_model.classes_],
filled=True,
rounded=True,
special_characters=True,
)
dot_str = best_model.to_dot()
with open("best_model.dot", "w") as f:
f.write(dot_str)
def get_feature_columns(self, df):
"""
@ -250,20 +278,36 @@ class AHTrainDecisionTree(AHTrain):
"avail_choices",
"choice2time",
"index",
"actual_winner",
"relative_performance",
]
feature_columns = [col for col in df.columns if col not in exclude_columns]
return feature_columns
def main(self, log_path, other_datasets, nrows, heuristic_name, save_dot=False):
def add_training_data(self, df_train, datasets):
return datasets["train"]
def main(
self,
log_path,
other_datasets,
nrows,
heuristic_name,
save_dot=False,
ranking=False,
):
"""
Main function that trains a decision tree and generates a heuristic.
"""
# TODO: Enable apply_filters
(df, choices, cat_feature2cats, dummy_col_2_col_val, metadata) = self.get_df(
log_path, nrows=nrows, apply_filters=False
log_path, nrows=nrows, apply_filters=False, add_near_best=ranking
)
print(df["winner"].value_counts())
datasets = self.prepare_datasets(df, other_datasets, cat_feature2cats)
datasets = self.prepare_datasets(df, other_datasets, cat_feature2cats, ranking)
df_train = self.add_training_data(datasets["train"], datasets)
datasets["train"] = df_train
feature_columns = self.get_feature_columns(df)
grid_search_values = self.get_grid_search_values()
max_depths = grid_search_values["max_depth"]
@ -275,28 +319,44 @@ class AHTrainDecisionTree(AHTrain):
best_model_safe_proba,
unsafe_leaves,
) = self.train_and_evaluate_models(
datasets, max_depths, min_samples_leafs, criterion_list, feature_columns
datasets,
max_depths,
min_samples_leafs,
criterion_list,
feature_columns,
ranking=ranking,
)
if ranking:
columns_to_keep = [
"set",
"total",
"top_k_correct",
"top_k_wrong",
"top_k_unsure",
"wrong_max_spdup_k",
"wrong_gman_spdup_k",
]
results_df = results_df[columns_to_keep]
# prints results for all models and datasets
print(results_df.to_string())
# prints results grouped by dataset
for set_name in results_df["set"].unique():
dataset_results = results_df[results_df["set"] == set_name]
dataset_results = dataset_results.sort_values(by="correct")
print(dataset_results.to_string() + "\n")
if not ranking:
# prints results grouped by dataset
for set_name in results_df["set"].unique():
dataset_results = results_df[results_df["set"] == set_name]
dataset_results = dataset_results.sort_values(by="correct")
print(dataset_results.to_string() + "\n")
if best_model is not None:
if save_dot:
self.export_to_dot(best_model, df, feature_columns)
self.dt_to_python(
self.codegen(
best_model,
metadata,
feature_columns,
dummy_col_2_col_val,
heuristic_name,
best_model_safe_proba,
dummy_col_2_col_val,
unsafe_leaves,
)
else:
@ -304,7 +364,14 @@ class AHTrainDecisionTree(AHTrain):
"All learned models have too many wrong predictions, so no heuristic was generated"
)
def get_df(self, log_path, cat_feature2cats=None, nrows=None, apply_filters=False):
def get_df(
self,
log_path,
cat_feature2cats=None,
nrows=None,
apply_filters=False,
add_near_best=False,
):
"""
Parses the log file and processes the data into a dataframe that can be used for training.
"""
@ -314,14 +381,19 @@ class AHTrainDecisionTree(AHTrain):
def calculate_stats(group):
count = len(group)
mean = group["feedback"].mean()
std = group["feedback"].std()
relative_std = (std / mean) * 100 if mean != 0 else np.inf
has_inf = np.isinf(group["feedback"]).any()
if has_inf:
relative_std = np.inf
median = np.inf
else:
mean = group["feedback"].mean()
std = group["feedback"].std()
relative_std = (std / mean) * 100 if mean != 0 else np.inf
median = group["feedback"].median()
if relative_std > 5:
times = group["feedback"].tolist()
times_str = ", ".join([f"{t:.3f}" for t in sorted(times)])
log.debug("High relative std: %f. times=%s", relative_std, times_str)
median = group["feedback"].median()
return pd.Series(
{
"count": count,
@ -385,6 +457,28 @@ class AHTrainDecisionTree(AHTrain):
.reset_index()
)
def add_near_best_configs(df):
new_rows = []
for index, row in df.iterrows():
dictionary = json.loads(row["choice2time"])
min_value = min(dictionary.values())
for key, value in dictionary.items():
new_row = row.copy()
relative_performance = min_value / value
new_row["relative_performance"] = relative_performance
if relative_performance is None or relative_performance is np.inf:
breakpoint()
new_row["actual_winner"] = row["winner"]
new_row["winner"] = key
if relative_performance >= 0.95:
new_rows.append(new_row)
return pd.DataFrame(new_rows).reset_index(drop=True)
if add_near_best:
results = add_near_best_configs(results)
(results, added_categorical_features) = self.add_new_features(results)
categorical_features += added_categorical_features
@ -409,27 +503,6 @@ class AHTrainDecisionTree(AHTrain):
indent = " " * num_spaces
return "\n".join([f"{indent}self.choices.append('{c}')" for c in classes])
def best_probas_and_indices(self, class_probas):
"""
Given a list of tuples (proba, idx), this function returns a string in which the tuples are sorted by proba in
descending order. E.g.:
Given class_probas=[(0.3, 0), (0.5, 1), (0.2, 2)]
this function returns
"[(0.5, 1), (0.3, 0), (0.2, 2)]"
"""
# we generate a list of tuples (proba, idx) sorted by proba in descending order
# idx is the index of a choice
# we only generate a tuple if proba > 0
probas_indices_sorted = sorted(
[(proba, index) for index, proba in enumerate(class_probas) if proba > 0],
key=lambda x: x[0],
reverse=True,
)
probas_indices_sorted_str = ", ".join(
f"({value:.3f}, {index})" for value, index in probas_indices_sorted
)
return f"[{probas_indices_sorted_str}]"
def get_default_config(self, row):
"""
Returns the default config for a given sample. The default config could for example be the config that is
@ -438,17 +511,6 @@ class AHTrainDecisionTree(AHTrain):
"""
return None
def handle_leaf(self, tree_, node, indent, unsafe_leaves):
"""
This generates the code for a leaf node in the decision tree. If the leaf is unsafe, the learned heuristic
will return "unsure" (i.e. None).
"""
if node in unsafe_leaves:
return f"{indent}return None"
leaf_num_samples = tree_.n_node_samples[node]
class_probas = tree_.value[node][0]
return f"{indent}return {self.best_probas_and_indices(class_probas)}"
def gen_predict_fn_def(self):
"""
Generates the definition of the predict function.
@ -456,7 +518,7 @@ class AHTrainDecisionTree(AHTrain):
return "def get_best_choices(self, context: AHContext) -> Optional[List[Tuple[float, int]]]:"
def codegen_boilerplate(
self, heuristic_name, opt_name, threshold, shared_memory, device_capa, dt
self, heuristic_name, opt_name, threshold, shared_memory, device_capa, classes
):
"""
Generates the boilerplate code for the generated heuristic. This includes things like imports, class definition,
@ -496,23 +558,56 @@ class {heuristic_name}(LearnedHeuristicDecision):
return None
def fill_choices(self) -> None:
{self.gen_classes(dt.classes_, num_spaces=8)}
{self.gen_classes(classes, num_spaces=8)}
def get_name(self) -> str:
return '{opt_name}'"""
return boiler_plate
def add_real_datasets(self, datasets, other_datasets, cat_feature2cats):
def add_real_datasets(
self, datasets, other_datasets, cat_feature2cats, ranking=False
):
"""
Adds datasets specified by the user to the datasets dictionary.
"""
if other_datasets:
for name, path in other_datasets:
(df_other, choices, _, _, _) = self.get_df(
path, cat_feature2cats=cat_feature2cats, apply_filters=False
path,
cat_feature2cats=cat_feature2cats,
apply_filters=False,
add_near_best=ranking,
)
datasets[name] = df_other
def codegen(
self,
tree,
metadata,
heuristic_name,
threshold,
dummy_col_2_col_val,
unsafe_leaves,
):
lines = []
device_capa = metadata["device_capa"]
device_capa_str = f"({device_capa[0]}, {device_capa[1]})"
opt_name = metadata["name"]
lines.append(
self.codegen_boilerplate(
heuristic_name,
opt_name,
threshold,
metadata["shared_memory"],
device_capa_str,
tree.classes_,
)
)
fn_def = f"\n {self.gen_predict_fn_def()}"
lines.append(fn_def)
tree.codegen(dummy_col_2_col_val, lines, unsafe_leaves)
self.write_heuristic_to_file(lines, heuristic_name)
@dataclass
class AccuracyMetrics:
@ -552,6 +647,8 @@ class WrongSpeedupMetrics:
class RankingMetrics:
# Number of predictions where best choice is in top k choices
num_correct: int
# Number of predictions where best choice is not in top k choices
num_wrong: int
# Maximum speedup of best choice over best choice in top k (this tells us how much better the best choice, which
# is not in top k, is over the best choice in top k)
max_speedup: float
@ -563,6 +660,7 @@ class RankingMetrics:
def to_map(self):
return {
"top_k_correct": self.num_correct,
"top_k_wrong": self.num_wrong,
"wrong_max_speedup_k": self.max_speedup,
"wrong_gmean_speedup_k": self.gmean_speedup,
"top_k_unsure": self.unsure,
@ -618,9 +716,10 @@ class DecisionEvaluator:
probas,
wrong_pct=0.01,
threshold=0.0,
k=3,
k=10,
unsafe_leaves=None,
leaf_ids=None,
ranking=False,
) -> None:
self.train = train
self.model = model
@ -632,6 +731,7 @@ class DecisionEvaluator:
self.k = k
self.unsafe_leaves = unsafe_leaves
self.leaf_ids = leaf_ids
self.ranking = ranking
self.num_correct = 0
self.num_wrong = 0
@ -639,6 +739,7 @@ class DecisionEvaluator:
self.wrong_probas = []
self.speedups_wrong = []
self.num_correct_top_k = 0
self.num_wrong_top_k = 0
self.wrong_speedups_top_k = []
self.top_k_unsure = 0
self.num_non_default_predictions = 0
@ -718,6 +819,7 @@ class DecisionEvaluator:
if min_time is not None:
speedup = min_time / best_time
self.wrong_speedups_top_k.append(speedup)
self.num_wrong_top_k += 1
else:
self.top_k_unsure += 1
# TODO (AlnisM): print more info (input and choices)
@ -743,7 +845,7 @@ class DecisionEvaluator:
Custom evaluation function that evaluates a learned decision tree.
"""
y_true = self.df["winner"]
y_true = self.df["actual_winner"] if self.ranking else self.df["winner"]
i = 0
for pred, true, prob, leaf_id in zip(
self.predictions, y_true, self.probas, self.leaf_ids
@ -790,6 +892,7 @@ class DecisionEvaluator:
wrongSpeedupMetrics = WrongSpeedupMetrics(max_speedup, gmean_speedup)
rankingMetrics = RankingMetrics(
self.num_correct_top_k,
self.num_wrong_top_k,
max_speedup_top_k,
gmean_speedup_top_k,
self.top_k_unsure,

View File

@ -34,7 +34,15 @@ class AHTrainRegressionTree(AHTrain):
def __init__(self):
super().__init__()
def main(self, log_path, other_datasets, nrows, heuristic_name, save_dot=False):
def main(
self,
log_path,
other_datasets,
nrows,
heuristic_name,
save_dot=False,
ranking=False,
):
"""
Main function that trains a decision tree and generates a heuristic.
"""
@ -357,6 +365,65 @@ class AHTrainRegressionTree(AHTrain):
"wrong_max_ratio": wrong_max_ratio,
}
def dt_to_python(
self,
dt,
metadata,
feature_names,
dummy_col_2_col_val,
heuristic_name,
threshold,
unsafe_leaves=None,
):
tree_ = dt.tree_
feature_name = [
feature_names[i] if i != -1 else "undefined!" for i in tree_.feature
]
lines = []
device_capa = metadata["device_capa"]
device_capa_str = f"({device_capa[0]}, {device_capa[1]})"
opt_name = metadata["name"]
lines.append(
self.codegen_boilerplate(
heuristic_name,
opt_name,
threshold,
metadata["shared_memory"],
device_capa_str,
dt,
)
)
fn_def = f"\n {self.gen_predict_fn_def()}"
lines.append(fn_def)
def dt_to_python(node, depth):
indent = " " * (depth + 1)
false_predicate = ""
if tree_.feature[node] != -2:
name = feature_name[node]
threshold = tree_.threshold[node]
if name in dummy_col_2_col_val:
(orig_name, value) = dummy_col_2_col_val[name]
predicate = f"{indent}if str(context.get_value('{orig_name}')) != '{value}':"
assert (
threshold == 0.5
), f"expected threshold to be 0.5 but is {threshold}"
else:
predicate = (
f"{indent}if context.get_value('{name}') <= {threshold}:"
)
lines.append(predicate)
dt_to_python(tree_.children_left[node], depth + 1)
lines.append(f"{indent}else:")
dt_to_python(tree_.children_right[node], depth + 1)
else:
lines.append(self.handle_leaf(tree_, node, indent, unsafe_leaves))
dt_to_python(0, 1)
self.write_heuristic_to_file(lines, heuristic_name)
def handle_leaf(self, tree_, node, indent, unsafe_leaves):
"""
Generates the code for a leaf node. This is just the value predicted by the regression tree.
@ -368,7 +435,7 @@ class AHTrainRegressionTree(AHTrain):
return "def predict(self, context: AHContext) -> float:"
def codegen_boilerplate(
self, heuristic_name, opt_name, threshold, shared_memory, device_capa, dt
self, heuristic_name, opt_name, threshold, shared_memory, device_capa, classes
):
"""
Generates the boilerplate code for the generated heuristic. This includes things like imports, class definition,