mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
AutoHeuristic: Support ranking/pruning choices (#131705)
This PR adds support in train_decision if one wants to learn a heuristic for ranking. The main idea is that the user has to provide a number of choices the heuristic should return. I added a way to prune the learned decision tree such that it always returns the number of choices provided by the user. Pull Request resolved: https://github.com/pytorch/pytorch/pull/131705 Approved by: https://github.com/eellison
This commit is contained in:
committed by
PyTorch MergeBot
parent
929d2f8253
commit
add0f0085c
262
torchgen/_autoheuristic/ah_tree.py
Normal file
262
torchgen/_autoheuristic/ah_tree.py
Normal file
@ -0,0 +1,262 @@
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
from sklearn.tree import _tree # type: ignore[import-untyped]
|
||||
|
||||
|
||||
class DecisionTreeNode:
|
||||
def __init__(
|
||||
self,
|
||||
feature: Optional[str] = None,
|
||||
threshold: Optional[float] = None,
|
||||
left: Optional["DecisionTreeNode"] = None,
|
||||
right: Optional["DecisionTreeNode"] = None,
|
||||
class_probs: Any = None,
|
||||
num_samples: int = 0,
|
||||
node_id: int = 0,
|
||||
) -> None:
|
||||
self.feature = feature
|
||||
self.threshold = threshold
|
||||
self.left = left
|
||||
self.right = right
|
||||
self.class_probs = class_probs
|
||||
self.num_samples = num_samples
|
||||
self.id = node_id
|
||||
|
||||
def is_leaf(self) -> bool:
|
||||
return self.left is None or self.right is None
|
||||
|
||||
|
||||
class DecisionTree:
|
||||
"""
|
||||
Custom decision tree implementation that mimics some of the sklearn API.
|
||||
The purpose of this class it to be able to perform transformations, such as custom pruning, which
|
||||
does not seem to be easy with sklearn.
|
||||
"""
|
||||
|
||||
def __init__(self, sklearn_tree: Any, feature_names: List[str]) -> None:
|
||||
self.feature_names = feature_names
|
||||
self.root = self._convert_sklearn_tree(sklearn_tree.tree_)
|
||||
self.classes_: List[str] = sklearn_tree.classes_
|
||||
|
||||
def _convert_sklearn_tree(
|
||||
self, sklearn_tree: Any, node_id: int = 0
|
||||
) -> DecisionTreeNode:
|
||||
class_probs = sklearn_tree.value[node_id][0]
|
||||
num_samples = sklearn_tree.n_node_samples[node_id]
|
||||
if sklearn_tree.feature[node_id] != _tree.TREE_UNDEFINED:
|
||||
feature_index = sklearn_tree.feature[node_id]
|
||||
feature = self.feature_names[feature_index]
|
||||
left = self._convert_sklearn_tree(
|
||||
sklearn_tree, sklearn_tree.children_left[node_id]
|
||||
)
|
||||
right = self._convert_sklearn_tree(
|
||||
sklearn_tree, sklearn_tree.children_right[node_id]
|
||||
)
|
||||
return DecisionTreeNode(
|
||||
feature=feature,
|
||||
threshold=sklearn_tree.threshold[node_id],
|
||||
left=left,
|
||||
right=right,
|
||||
class_probs=class_probs,
|
||||
num_samples=num_samples,
|
||||
node_id=node_id,
|
||||
)
|
||||
else:
|
||||
return DecisionTreeNode(
|
||||
class_probs=class_probs, num_samples=num_samples, node_id=node_id
|
||||
)
|
||||
|
||||
def prune(self, df: Any, target_col: str, k: int) -> None:
|
||||
self.root = self._prune_tree(self.root, df, target_col, k)
|
||||
|
||||
def _prune_tree(
|
||||
self, node: DecisionTreeNode, df: Any, target_col: str, k: int
|
||||
) -> DecisionTreeNode:
|
||||
if node.is_leaf():
|
||||
return node
|
||||
|
||||
left_df = df[df[node.feature] <= node.threshold]
|
||||
right_df = df[df[node.feature] > node.threshold]
|
||||
|
||||
# number of unique classes in the left and right subtrees
|
||||
left_counts = left_df[target_col].nunique()
|
||||
right_counts = right_df[target_col].nunique()
|
||||
|
||||
# for ranking, we want to ensure that we return at least k classes, so if we have less than k classes in the
|
||||
# left or right subtree, we remove the split and make this node a leaf node
|
||||
if left_counts < k or right_counts < k:
|
||||
return DecisionTreeNode(class_probs=node.class_probs)
|
||||
|
||||
assert node.left is not None, "expected left child to exist"
|
||||
node.left = self._prune_tree(node.left, left_df, target_col, k)
|
||||
assert node.right is not None, "expected right child to exist"
|
||||
node.right = self._prune_tree(node.right, right_df, target_col, k)
|
||||
|
||||
return node
|
||||
|
||||
def to_dot(self) -> str:
|
||||
dot = "digraph DecisionTree {\n"
|
||||
dot += ' node [fontname="helvetica"];\n'
|
||||
dot += ' edge [fontname="helvetica"];\n'
|
||||
dot += self._node_to_dot(self.root)
|
||||
dot += "}"
|
||||
return dot
|
||||
|
||||
def _node_to_dot(
|
||||
self, node: DecisionTreeNode, parent_id: int = 0, edge_label: str = ""
|
||||
) -> str:
|
||||
if node is None:
|
||||
return ""
|
||||
|
||||
node_id = id(node)
|
||||
|
||||
# Format class_probs array with line breaks
|
||||
class_probs_str = self._format_class_probs_array(
|
||||
node.class_probs, node.num_samples
|
||||
)
|
||||
|
||||
if node.is_leaf():
|
||||
label = class_probs_str
|
||||
shape = "box"
|
||||
else:
|
||||
feature_name = f"{node.feature}"
|
||||
label = f"{feature_name} <= {node.threshold:.2f}\\n{class_probs_str}"
|
||||
shape = "oval"
|
||||
|
||||
dot = f' {node_id} [label="{label}", shape={shape}];\n'
|
||||
|
||||
if parent_id != 0:
|
||||
dot += f' {parent_id} -> {node_id} [label="{edge_label}"];\n'
|
||||
|
||||
if not node.is_leaf():
|
||||
assert node.left is not None, "expected left child to exist"
|
||||
dot += self._node_to_dot(node.left, node_id, "<=")
|
||||
assert node.right is not None, "expected right child to exist"
|
||||
dot += self._node_to_dot(node.right, node_id, ">")
|
||||
|
||||
return dot
|
||||
|
||||
def _format_class_prob(self, num: float) -> str:
|
||||
if num == 0:
|
||||
return "0"
|
||||
return f"{num:.2f}"
|
||||
|
||||
def _format_class_probs_array(
|
||||
self, class_probs: Any, num_samples: int, max_per_line: int = 5
|
||||
) -> str:
|
||||
# add line breaks to avoid very long lines
|
||||
flat_class_probs = class_probs.flatten()
|
||||
formatted = [self._format_class_prob(v) for v in flat_class_probs]
|
||||
lines = [
|
||||
formatted[i : i + max_per_line]
|
||||
for i in range(0, len(formatted), max_per_line)
|
||||
]
|
||||
return f"num_samples={num_samples}\\n" + "\\n".join(
|
||||
[", ".join(line) for line in lines]
|
||||
)
|
||||
|
||||
def predict(self, X: Any) -> Any:
|
||||
predictions = [self._predict_single(x) for _, x in X.iterrows()]
|
||||
return np.array(predictions)
|
||||
|
||||
def predict_proba(self, X: Any) -> Any:
|
||||
return np.array([self._predict_proba_single(x) for _, x in X.iterrows()])
|
||||
|
||||
def _get_leaf(self, X: Any) -> DecisionTreeNode:
|
||||
node = self.root
|
||||
while not node.is_leaf():
|
||||
if X[node.feature] <= node.threshold:
|
||||
assert node.left is not None, "expected left child to exist"
|
||||
node = node.left
|
||||
else:
|
||||
assert node.right is not None, "expected right child to exist"
|
||||
node = node.right
|
||||
return node
|
||||
|
||||
def _predict_single(self, x: Any) -> str:
|
||||
node = self._get_leaf(x)
|
||||
# map index to class name
|
||||
return self.classes_[np.argmax(node.class_probs)]
|
||||
|
||||
def _predict_proba_single(self, x: Any) -> Any:
|
||||
node = self._get_leaf(x)
|
||||
return node.class_probs
|
||||
|
||||
def apply(self, X: Any) -> Any:
|
||||
ids = [self._apply_single(x) for _, x in X.iterrows()]
|
||||
return np.array(ids)
|
||||
|
||||
def _apply_single(self, x: Any) -> int:
|
||||
node = self._get_leaf(x)
|
||||
return node.id
|
||||
|
||||
def codegen(
|
||||
self,
|
||||
dummy_col_2_col_val: Dict[str, Tuple[str, Any]],
|
||||
lines: List[str],
|
||||
unsafe_leaves: List[int],
|
||||
) -> None:
|
||||
# generates python code for the decision tree
|
||||
def codegen_node(node: DecisionTreeNode, depth: int) -> None:
|
||||
indent = " " * (depth + 1)
|
||||
if node.is_leaf():
|
||||
lines.append(handle_leaf(node, indent, unsafe_leaves))
|
||||
else:
|
||||
name = node.feature
|
||||
threshold = node.threshold
|
||||
if name in dummy_col_2_col_val:
|
||||
(orig_name, value) = dummy_col_2_col_val[name]
|
||||
predicate = f"{indent}if str(context.get_value('{orig_name}')) != '{value}':"
|
||||
assert (
|
||||
threshold == 0.5
|
||||
), f"expected threshold to be 0.5 but is {threshold}"
|
||||
else:
|
||||
predicate = (
|
||||
f"{indent}if context.get_value('{name}') <= {threshold}:"
|
||||
)
|
||||
lines.append(predicate)
|
||||
assert node.left is not None, "expected left child to exist"
|
||||
codegen_node(node.left, depth + 1)
|
||||
lines.append(f"{indent}else:")
|
||||
assert node.right is not None, "expected right child to exist"
|
||||
codegen_node(node.right, depth + 1)
|
||||
|
||||
def handle_leaf(
|
||||
node: DecisionTreeNode, indent: str, unsafe_leaves: List[int]
|
||||
) -> str:
|
||||
"""
|
||||
This generates the code for a leaf node in the decision tree. If the leaf is unsafe, the learned heuristic
|
||||
will return "unsure" (i.e. None).
|
||||
"""
|
||||
if node.id in unsafe_leaves:
|
||||
return f"{indent}return None"
|
||||
class_probas = node.class_probs
|
||||
return f"{indent}return {best_probas_and_indices(class_probas)}"
|
||||
|
||||
def best_probas_and_indices(class_probas: Any) -> str:
|
||||
"""
|
||||
Given a list of tuples (proba, idx), this function returns a string in which the tuples are
|
||||
sorted by proba in descending order. E.g.:
|
||||
Given class_probas=[(0.3, 0), (0.5, 1), (0.2, 2)]
|
||||
this function returns
|
||||
"[(0.5, 1), (0.3, 0), (0.2, 2)]"
|
||||
"""
|
||||
# we generate a list of tuples (proba, idx) sorted by proba in descending order
|
||||
# idx is the index of a choice
|
||||
# we only generate a tuple if proba > 0
|
||||
probas_indices_sorted = sorted(
|
||||
[
|
||||
(proba, index)
|
||||
for index, proba in enumerate(class_probas)
|
||||
if proba > 0
|
||||
],
|
||||
key=lambda x: x[0],
|
||||
reverse=True,
|
||||
)
|
||||
probas_indices_sorted_str = ", ".join(
|
||||
f"({value:.3f}, {index})" for value, index in probas_indices_sorted
|
||||
)
|
||||
return f"[{probas_indices_sorted_str}]"
|
||||
|
||||
codegen_node(self.root, 1)
|
@ -2,7 +2,6 @@
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import sys
|
||||
import warnings
|
||||
|
||||
import pandas as pd # type: ignore[import-untyped]
|
||||
@ -64,6 +63,15 @@ class AHTrain:
|
||||
action="store_true",
|
||||
help="Export heuristic to graphviz dot.",
|
||||
)
|
||||
self.parser.add_argument(
|
||||
"--ranking",
|
||||
type=int,
|
||||
default=None,
|
||||
help="""
|
||||
Makes AutoHeuristic learn a heuristic that ranks choices instead of predicting a single choice.
|
||||
The argument is the number of choices the heuristic will provide.
|
||||
""",
|
||||
)
|
||||
|
||||
def parse_args(self):
|
||||
return self.parser.parse_args()
|
||||
@ -87,6 +95,7 @@ class AHTrain:
|
||||
self.args.nrows,
|
||||
self.args.heuristic_name,
|
||||
self.args.save_dot,
|
||||
self.args.ranking is not None,
|
||||
)
|
||||
|
||||
def filter_df(self, df):
|
||||
@ -138,9 +147,6 @@ class AHTrain:
|
||||
and str(metadata.device_capa) == "{device_capa}"
|
||||
)"""
|
||||
|
||||
def handle_leaf(self, tree_, node, indent, unsafe_leaves):
|
||||
pass
|
||||
|
||||
def codegen_boilerplate(
|
||||
self, heuristic_name, opt_name, threshold, shared_memory, device_capa, dt
|
||||
):
|
||||
@ -149,63 +155,7 @@ class AHTrain:
|
||||
def gen_predict_fn_def(self):
|
||||
pass
|
||||
|
||||
def dt_to_python(
|
||||
self,
|
||||
dt,
|
||||
metadata,
|
||||
feature_names,
|
||||
dummy_col_2_col_val,
|
||||
heuristic_name,
|
||||
threshold,
|
||||
unsafe_leaves=None,
|
||||
):
|
||||
tree_ = dt.tree_
|
||||
feature_name = [
|
||||
feature_names[i] if i != -1 else "undefined!" for i in tree_.feature
|
||||
]
|
||||
|
||||
lines = []
|
||||
device_capa = metadata["device_capa"]
|
||||
device_capa_str = f"({device_capa[0]}, {device_capa[1]})"
|
||||
opt_name = metadata["name"]
|
||||
lines.append(
|
||||
self.codegen_boilerplate(
|
||||
heuristic_name,
|
||||
opt_name,
|
||||
threshold,
|
||||
metadata["shared_memory"],
|
||||
device_capa_str,
|
||||
dt,
|
||||
)
|
||||
)
|
||||
fn_def = f"\n {self.gen_predict_fn_def()}"
|
||||
lines.append(fn_def)
|
||||
|
||||
def dt_to_python(node, depth):
|
||||
indent = " " * (depth + 1)
|
||||
false_predicate = ""
|
||||
if tree_.feature[node] != -2:
|
||||
name = feature_name[node]
|
||||
threshold = tree_.threshold[node]
|
||||
if name in dummy_col_2_col_val:
|
||||
(orig_name, value) = dummy_col_2_col_val[name]
|
||||
predicate = f"{indent}if str(context.get_value('{orig_name}')) != '{value}':"
|
||||
if threshold != 0.5:
|
||||
print(f"expected threshold to be 0.5 but is {threshold}")
|
||||
sys.exit(1)
|
||||
else:
|
||||
predicate = (
|
||||
f"{indent}if context.get_value('{name}') <= {threshold}:"
|
||||
)
|
||||
lines.append(predicate)
|
||||
dt_to_python(tree_.children_left[node], depth + 1)
|
||||
lines.append(f"{indent}else:")
|
||||
dt_to_python(tree_.children_right[node], depth + 1)
|
||||
else:
|
||||
lines.append(self.handle_leaf(tree_, node, indent, unsafe_leaves))
|
||||
|
||||
dt_to_python(0, 1)
|
||||
|
||||
def write_heuristic_to_file(self, lines, heuristic_name):
|
||||
output_file = (
|
||||
f"../../../torch/_inductor/autoheuristic/artifacts/_{heuristic_name}.py"
|
||||
)
|
||||
|
@ -16,6 +16,7 @@ from dataclasses import dataclass
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd # type: ignore[import-untyped]
|
||||
from ah_tree import DecisionTree
|
||||
from scipy.stats import gmean
|
||||
from sklearn.model_selection import train_test_split
|
||||
from sklearn.tree import DecisionTreeClassifier
|
||||
@ -102,8 +103,21 @@ class AHTrainDecisionTree(AHTrain):
|
||||
leaf_ids = model.apply(df[feature_columns])
|
||||
return predictions, proba, leaf_ids
|
||||
|
||||
def ranking_num_choices(self):
|
||||
# if the heuristic is used for ranking, this function returns the number
|
||||
# of choices that the heuristic will return
|
||||
if self.args.ranking is None:
|
||||
return 5
|
||||
return self.args.ranking
|
||||
|
||||
def train_and_evaluate_models(
|
||||
self, datasets, max_depths, min_samples_leafs, criterion_list, feature_columns
|
||||
self,
|
||||
datasets,
|
||||
max_depths,
|
||||
min_samples_leafs,
|
||||
criterion_list,
|
||||
feature_columns,
|
||||
ranking=False,
|
||||
):
|
||||
"""
|
||||
Does a grid search over max_depths, min_samples_leafs, and criterion_list and returns the best model.
|
||||
@ -131,7 +145,20 @@ class AHTrainDecisionTree(AHTrain):
|
||||
)
|
||||
df_train = datasets["train"]
|
||||
df_val = datasets["val"]
|
||||
model.fit(df_train[feature_columns], df_train["winner"])
|
||||
if ranking:
|
||||
model.fit(
|
||||
df_train[feature_columns],
|
||||
df_train["winner"],
|
||||
sample_weight=df_train["relative_performance"],
|
||||
)
|
||||
else:
|
||||
model.fit(df_train[feature_columns], df_train["winner"])
|
||||
|
||||
model = DecisionTree(model, feature_columns)
|
||||
|
||||
if ranking:
|
||||
model.prune(df_train, "winner", k=self.ranking_num_choices())
|
||||
|
||||
unsafe_leaves = self.get_unsafe_leaves(model, df_train, feature_columns)
|
||||
predictions, proba, leaf_ids = self.predict(model, df_val, feature_columns)
|
||||
|
||||
@ -145,11 +172,18 @@ class AHTrainDecisionTree(AHTrain):
|
||||
wrong_pct=wrong_pct,
|
||||
unsafe_leaves=unsafe_leaves,
|
||||
leaf_ids=leaf_ids,
|
||||
k=self.ranking_num_choices(),
|
||||
ranking=ranking,
|
||||
)
|
||||
safe_proba = evaluator.get_safe_proba()
|
||||
print(f"safe_proba={safe_proba}")
|
||||
|
||||
def eval(name, df):
|
||||
if ranking:
|
||||
# when ranking is enabled, we duplicate each input for each choice that
|
||||
# is almost as good as the best choice
|
||||
# we do not want to evaluate the same input multiple times, so we remove duplicates here
|
||||
df = df[df["winner"] == df["actual_winner"]]
|
||||
predictions, proba, leaf_ids = self.predict(model, df, feature_columns)
|
||||
evaluator = DecisionEvaluator(
|
||||
self,
|
||||
@ -161,6 +195,8 @@ class AHTrainDecisionTree(AHTrain):
|
||||
threshold=safe_proba,
|
||||
unsafe_leaves=unsafe_leaves,
|
||||
leaf_ids=leaf_ids,
|
||||
k=self.ranking_num_choices(),
|
||||
ranking=ranking,
|
||||
)
|
||||
return evaluator.get_results()
|
||||
|
||||
@ -202,7 +238,7 @@ class AHTrainDecisionTree(AHTrain):
|
||||
"""
|
||||
return (0.15, 0.15)
|
||||
|
||||
def prepare_datasets(self, df, other_datasets, cat_feature2cats):
|
||||
def prepare_datasets(self, df, other_datasets, cat_feature2cats, ranking=False):
|
||||
"""
|
||||
Splits the dataframe into train, val, and test sets.
|
||||
Also adds other datasets, specified by the user, to the train set.
|
||||
@ -219,24 +255,16 @@ class AHTrainDecisionTree(AHTrain):
|
||||
df_train_val, test_size=val_size / train_val_size, random_state=42
|
||||
)
|
||||
datasets = {"train": df_train, "val": df_val, "test": df_test}
|
||||
self.add_real_datasets(datasets, other_datasets, cat_feature2cats)
|
||||
self.add_real_datasets(datasets, other_datasets, cat_feature2cats, ranking)
|
||||
return datasets
|
||||
|
||||
def export_to_dot(self, best_model, df, feature_columns):
|
||||
"""
|
||||
Export a learned decision tree to a dot file.
|
||||
"""
|
||||
from sklearn import tree
|
||||
|
||||
tree.export_graphviz(
|
||||
best_model,
|
||||
out_file="best_model.dot",
|
||||
feature_names=df[feature_columns].columns,
|
||||
class_names=[str(c) for c in best_model.classes_],
|
||||
filled=True,
|
||||
rounded=True,
|
||||
special_characters=True,
|
||||
)
|
||||
dot_str = best_model.to_dot()
|
||||
with open("best_model.dot", "w") as f:
|
||||
f.write(dot_str)
|
||||
|
||||
def get_feature_columns(self, df):
|
||||
"""
|
||||
@ -250,20 +278,36 @@ class AHTrainDecisionTree(AHTrain):
|
||||
"avail_choices",
|
||||
"choice2time",
|
||||
"index",
|
||||
"actual_winner",
|
||||
"relative_performance",
|
||||
]
|
||||
feature_columns = [col for col in df.columns if col not in exclude_columns]
|
||||
return feature_columns
|
||||
|
||||
def main(self, log_path, other_datasets, nrows, heuristic_name, save_dot=False):
|
||||
def add_training_data(self, df_train, datasets):
|
||||
return datasets["train"]
|
||||
|
||||
def main(
|
||||
self,
|
||||
log_path,
|
||||
other_datasets,
|
||||
nrows,
|
||||
heuristic_name,
|
||||
save_dot=False,
|
||||
ranking=False,
|
||||
):
|
||||
"""
|
||||
Main function that trains a decision tree and generates a heuristic.
|
||||
"""
|
||||
# TODO: Enable apply_filters
|
||||
(df, choices, cat_feature2cats, dummy_col_2_col_val, metadata) = self.get_df(
|
||||
log_path, nrows=nrows, apply_filters=False
|
||||
log_path, nrows=nrows, apply_filters=False, add_near_best=ranking
|
||||
)
|
||||
print(df["winner"].value_counts())
|
||||
datasets = self.prepare_datasets(df, other_datasets, cat_feature2cats)
|
||||
datasets = self.prepare_datasets(df, other_datasets, cat_feature2cats, ranking)
|
||||
df_train = self.add_training_data(datasets["train"], datasets)
|
||||
datasets["train"] = df_train
|
||||
|
||||
feature_columns = self.get_feature_columns(df)
|
||||
grid_search_values = self.get_grid_search_values()
|
||||
max_depths = grid_search_values["max_depth"]
|
||||
@ -275,28 +319,44 @@ class AHTrainDecisionTree(AHTrain):
|
||||
best_model_safe_proba,
|
||||
unsafe_leaves,
|
||||
) = self.train_and_evaluate_models(
|
||||
datasets, max_depths, min_samples_leafs, criterion_list, feature_columns
|
||||
datasets,
|
||||
max_depths,
|
||||
min_samples_leafs,
|
||||
criterion_list,
|
||||
feature_columns,
|
||||
ranking=ranking,
|
||||
)
|
||||
|
||||
if ranking:
|
||||
columns_to_keep = [
|
||||
"set",
|
||||
"total",
|
||||
"top_k_correct",
|
||||
"top_k_wrong",
|
||||
"top_k_unsure",
|
||||
"wrong_max_spdup_k",
|
||||
"wrong_gman_spdup_k",
|
||||
]
|
||||
results_df = results_df[columns_to_keep]
|
||||
# prints results for all models and datasets
|
||||
print(results_df.to_string())
|
||||
|
||||
# prints results grouped by dataset
|
||||
for set_name in results_df["set"].unique():
|
||||
dataset_results = results_df[results_df["set"] == set_name]
|
||||
dataset_results = dataset_results.sort_values(by="correct")
|
||||
print(dataset_results.to_string() + "\n")
|
||||
if not ranking:
|
||||
# prints results grouped by dataset
|
||||
for set_name in results_df["set"].unique():
|
||||
dataset_results = results_df[results_df["set"] == set_name]
|
||||
dataset_results = dataset_results.sort_values(by="correct")
|
||||
print(dataset_results.to_string() + "\n")
|
||||
|
||||
if best_model is not None:
|
||||
if save_dot:
|
||||
self.export_to_dot(best_model, df, feature_columns)
|
||||
self.dt_to_python(
|
||||
self.codegen(
|
||||
best_model,
|
||||
metadata,
|
||||
feature_columns,
|
||||
dummy_col_2_col_val,
|
||||
heuristic_name,
|
||||
best_model_safe_proba,
|
||||
dummy_col_2_col_val,
|
||||
unsafe_leaves,
|
||||
)
|
||||
else:
|
||||
@ -304,7 +364,14 @@ class AHTrainDecisionTree(AHTrain):
|
||||
"All learned models have too many wrong predictions, so no heuristic was generated"
|
||||
)
|
||||
|
||||
def get_df(self, log_path, cat_feature2cats=None, nrows=None, apply_filters=False):
|
||||
def get_df(
|
||||
self,
|
||||
log_path,
|
||||
cat_feature2cats=None,
|
||||
nrows=None,
|
||||
apply_filters=False,
|
||||
add_near_best=False,
|
||||
):
|
||||
"""
|
||||
Parses the log file and processes the data into a dataframe that can be used for training.
|
||||
"""
|
||||
@ -314,14 +381,19 @@ class AHTrainDecisionTree(AHTrain):
|
||||
|
||||
def calculate_stats(group):
|
||||
count = len(group)
|
||||
mean = group["feedback"].mean()
|
||||
std = group["feedback"].std()
|
||||
relative_std = (std / mean) * 100 if mean != 0 else np.inf
|
||||
has_inf = np.isinf(group["feedback"]).any()
|
||||
if has_inf:
|
||||
relative_std = np.inf
|
||||
median = np.inf
|
||||
else:
|
||||
mean = group["feedback"].mean()
|
||||
std = group["feedback"].std()
|
||||
relative_std = (std / mean) * 100 if mean != 0 else np.inf
|
||||
median = group["feedback"].median()
|
||||
if relative_std > 5:
|
||||
times = group["feedback"].tolist()
|
||||
times_str = ", ".join([f"{t:.3f}" for t in sorted(times)])
|
||||
log.debug("High relative std: %f. times=%s", relative_std, times_str)
|
||||
median = group["feedback"].median()
|
||||
return pd.Series(
|
||||
{
|
||||
"count": count,
|
||||
@ -385,6 +457,28 @@ class AHTrainDecisionTree(AHTrain):
|
||||
.reset_index()
|
||||
)
|
||||
|
||||
def add_near_best_configs(df):
|
||||
new_rows = []
|
||||
|
||||
for index, row in df.iterrows():
|
||||
dictionary = json.loads(row["choice2time"])
|
||||
min_value = min(dictionary.values())
|
||||
|
||||
for key, value in dictionary.items():
|
||||
new_row = row.copy()
|
||||
relative_performance = min_value / value
|
||||
new_row["relative_performance"] = relative_performance
|
||||
if relative_performance is None or relative_performance is np.inf:
|
||||
breakpoint()
|
||||
new_row["actual_winner"] = row["winner"]
|
||||
new_row["winner"] = key
|
||||
if relative_performance >= 0.95:
|
||||
new_rows.append(new_row)
|
||||
|
||||
return pd.DataFrame(new_rows).reset_index(drop=True)
|
||||
|
||||
if add_near_best:
|
||||
results = add_near_best_configs(results)
|
||||
(results, added_categorical_features) = self.add_new_features(results)
|
||||
categorical_features += added_categorical_features
|
||||
|
||||
@ -409,27 +503,6 @@ class AHTrainDecisionTree(AHTrain):
|
||||
indent = " " * num_spaces
|
||||
return "\n".join([f"{indent}self.choices.append('{c}')" for c in classes])
|
||||
|
||||
def best_probas_and_indices(self, class_probas):
|
||||
"""
|
||||
Given a list of tuples (proba, idx), this function returns a string in which the tuples are sorted by proba in
|
||||
descending order. E.g.:
|
||||
Given class_probas=[(0.3, 0), (0.5, 1), (0.2, 2)]
|
||||
this function returns
|
||||
"[(0.5, 1), (0.3, 0), (0.2, 2)]"
|
||||
"""
|
||||
# we generate a list of tuples (proba, idx) sorted by proba in descending order
|
||||
# idx is the index of a choice
|
||||
# we only generate a tuple if proba > 0
|
||||
probas_indices_sorted = sorted(
|
||||
[(proba, index) for index, proba in enumerate(class_probas) if proba > 0],
|
||||
key=lambda x: x[0],
|
||||
reverse=True,
|
||||
)
|
||||
probas_indices_sorted_str = ", ".join(
|
||||
f"({value:.3f}, {index})" for value, index in probas_indices_sorted
|
||||
)
|
||||
return f"[{probas_indices_sorted_str}]"
|
||||
|
||||
def get_default_config(self, row):
|
||||
"""
|
||||
Returns the default config for a given sample. The default config could for example be the config that is
|
||||
@ -438,17 +511,6 @@ class AHTrainDecisionTree(AHTrain):
|
||||
"""
|
||||
return None
|
||||
|
||||
def handle_leaf(self, tree_, node, indent, unsafe_leaves):
|
||||
"""
|
||||
This generates the code for a leaf node in the decision tree. If the leaf is unsafe, the learned heuristic
|
||||
will return "unsure" (i.e. None).
|
||||
"""
|
||||
if node in unsafe_leaves:
|
||||
return f"{indent}return None"
|
||||
leaf_num_samples = tree_.n_node_samples[node]
|
||||
class_probas = tree_.value[node][0]
|
||||
return f"{indent}return {self.best_probas_and_indices(class_probas)}"
|
||||
|
||||
def gen_predict_fn_def(self):
|
||||
"""
|
||||
Generates the definition of the predict function.
|
||||
@ -456,7 +518,7 @@ class AHTrainDecisionTree(AHTrain):
|
||||
return "def get_best_choices(self, context: AHContext) -> Optional[List[Tuple[float, int]]]:"
|
||||
|
||||
def codegen_boilerplate(
|
||||
self, heuristic_name, opt_name, threshold, shared_memory, device_capa, dt
|
||||
self, heuristic_name, opt_name, threshold, shared_memory, device_capa, classes
|
||||
):
|
||||
"""
|
||||
Generates the boilerplate code for the generated heuristic. This includes things like imports, class definition,
|
||||
@ -496,23 +558,56 @@ class {heuristic_name}(LearnedHeuristicDecision):
|
||||
return None
|
||||
|
||||
def fill_choices(self) -> None:
|
||||
{self.gen_classes(dt.classes_, num_spaces=8)}
|
||||
{self.gen_classes(classes, num_spaces=8)}
|
||||
|
||||
def get_name(self) -> str:
|
||||
return '{opt_name}'"""
|
||||
return boiler_plate
|
||||
|
||||
def add_real_datasets(self, datasets, other_datasets, cat_feature2cats):
|
||||
def add_real_datasets(
|
||||
self, datasets, other_datasets, cat_feature2cats, ranking=False
|
||||
):
|
||||
"""
|
||||
Adds datasets specified by the user to the datasets dictionary.
|
||||
"""
|
||||
if other_datasets:
|
||||
for name, path in other_datasets:
|
||||
(df_other, choices, _, _, _) = self.get_df(
|
||||
path, cat_feature2cats=cat_feature2cats, apply_filters=False
|
||||
path,
|
||||
cat_feature2cats=cat_feature2cats,
|
||||
apply_filters=False,
|
||||
add_near_best=ranking,
|
||||
)
|
||||
datasets[name] = df_other
|
||||
|
||||
def codegen(
|
||||
self,
|
||||
tree,
|
||||
metadata,
|
||||
heuristic_name,
|
||||
threshold,
|
||||
dummy_col_2_col_val,
|
||||
unsafe_leaves,
|
||||
):
|
||||
lines = []
|
||||
device_capa = metadata["device_capa"]
|
||||
device_capa_str = f"({device_capa[0]}, {device_capa[1]})"
|
||||
opt_name = metadata["name"]
|
||||
lines.append(
|
||||
self.codegen_boilerplate(
|
||||
heuristic_name,
|
||||
opt_name,
|
||||
threshold,
|
||||
metadata["shared_memory"],
|
||||
device_capa_str,
|
||||
tree.classes_,
|
||||
)
|
||||
)
|
||||
fn_def = f"\n {self.gen_predict_fn_def()}"
|
||||
lines.append(fn_def)
|
||||
tree.codegen(dummy_col_2_col_val, lines, unsafe_leaves)
|
||||
self.write_heuristic_to_file(lines, heuristic_name)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AccuracyMetrics:
|
||||
@ -552,6 +647,8 @@ class WrongSpeedupMetrics:
|
||||
class RankingMetrics:
|
||||
# Number of predictions where best choice is in top k choices
|
||||
num_correct: int
|
||||
# Number of predictions where best choice is not in top k choices
|
||||
num_wrong: int
|
||||
# Maximum speedup of best choice over best choice in top k (this tells us how much better the best choice, which
|
||||
# is not in top k, is over the best choice in top k)
|
||||
max_speedup: float
|
||||
@ -563,6 +660,7 @@ class RankingMetrics:
|
||||
def to_map(self):
|
||||
return {
|
||||
"top_k_correct": self.num_correct,
|
||||
"top_k_wrong": self.num_wrong,
|
||||
"wrong_max_speedup_k": self.max_speedup,
|
||||
"wrong_gmean_speedup_k": self.gmean_speedup,
|
||||
"top_k_unsure": self.unsure,
|
||||
@ -618,9 +716,10 @@ class DecisionEvaluator:
|
||||
probas,
|
||||
wrong_pct=0.01,
|
||||
threshold=0.0,
|
||||
k=3,
|
||||
k=10,
|
||||
unsafe_leaves=None,
|
||||
leaf_ids=None,
|
||||
ranking=False,
|
||||
) -> None:
|
||||
self.train = train
|
||||
self.model = model
|
||||
@ -632,6 +731,7 @@ class DecisionEvaluator:
|
||||
self.k = k
|
||||
self.unsafe_leaves = unsafe_leaves
|
||||
self.leaf_ids = leaf_ids
|
||||
self.ranking = ranking
|
||||
|
||||
self.num_correct = 0
|
||||
self.num_wrong = 0
|
||||
@ -639,6 +739,7 @@ class DecisionEvaluator:
|
||||
self.wrong_probas = []
|
||||
self.speedups_wrong = []
|
||||
self.num_correct_top_k = 0
|
||||
self.num_wrong_top_k = 0
|
||||
self.wrong_speedups_top_k = []
|
||||
self.top_k_unsure = 0
|
||||
self.num_non_default_predictions = 0
|
||||
@ -718,6 +819,7 @@ class DecisionEvaluator:
|
||||
if min_time is not None:
|
||||
speedup = min_time / best_time
|
||||
self.wrong_speedups_top_k.append(speedup)
|
||||
self.num_wrong_top_k += 1
|
||||
else:
|
||||
self.top_k_unsure += 1
|
||||
# TODO (AlnisM): print more info (input and choices)
|
||||
@ -743,7 +845,7 @@ class DecisionEvaluator:
|
||||
Custom evaluation function that evaluates a learned decision tree.
|
||||
"""
|
||||
|
||||
y_true = self.df["winner"]
|
||||
y_true = self.df["actual_winner"] if self.ranking else self.df["winner"]
|
||||
i = 0
|
||||
for pred, true, prob, leaf_id in zip(
|
||||
self.predictions, y_true, self.probas, self.leaf_ids
|
||||
@ -790,6 +892,7 @@ class DecisionEvaluator:
|
||||
wrongSpeedupMetrics = WrongSpeedupMetrics(max_speedup, gmean_speedup)
|
||||
rankingMetrics = RankingMetrics(
|
||||
self.num_correct_top_k,
|
||||
self.num_wrong_top_k,
|
||||
max_speedup_top_k,
|
||||
gmean_speedup_top_k,
|
||||
self.top_k_unsure,
|
||||
|
@ -34,7 +34,15 @@ class AHTrainRegressionTree(AHTrain):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def main(self, log_path, other_datasets, nrows, heuristic_name, save_dot=False):
|
||||
def main(
|
||||
self,
|
||||
log_path,
|
||||
other_datasets,
|
||||
nrows,
|
||||
heuristic_name,
|
||||
save_dot=False,
|
||||
ranking=False,
|
||||
):
|
||||
"""
|
||||
Main function that trains a decision tree and generates a heuristic.
|
||||
"""
|
||||
@ -357,6 +365,65 @@ class AHTrainRegressionTree(AHTrain):
|
||||
"wrong_max_ratio": wrong_max_ratio,
|
||||
}
|
||||
|
||||
def dt_to_python(
|
||||
self,
|
||||
dt,
|
||||
metadata,
|
||||
feature_names,
|
||||
dummy_col_2_col_val,
|
||||
heuristic_name,
|
||||
threshold,
|
||||
unsafe_leaves=None,
|
||||
):
|
||||
tree_ = dt.tree_
|
||||
feature_name = [
|
||||
feature_names[i] if i != -1 else "undefined!" for i in tree_.feature
|
||||
]
|
||||
|
||||
lines = []
|
||||
device_capa = metadata["device_capa"]
|
||||
device_capa_str = f"({device_capa[0]}, {device_capa[1]})"
|
||||
opt_name = metadata["name"]
|
||||
lines.append(
|
||||
self.codegen_boilerplate(
|
||||
heuristic_name,
|
||||
opt_name,
|
||||
threshold,
|
||||
metadata["shared_memory"],
|
||||
device_capa_str,
|
||||
dt,
|
||||
)
|
||||
)
|
||||
fn_def = f"\n {self.gen_predict_fn_def()}"
|
||||
lines.append(fn_def)
|
||||
|
||||
def dt_to_python(node, depth):
|
||||
indent = " " * (depth + 1)
|
||||
false_predicate = ""
|
||||
if tree_.feature[node] != -2:
|
||||
name = feature_name[node]
|
||||
threshold = tree_.threshold[node]
|
||||
if name in dummy_col_2_col_val:
|
||||
(orig_name, value) = dummy_col_2_col_val[name]
|
||||
predicate = f"{indent}if str(context.get_value('{orig_name}')) != '{value}':"
|
||||
assert (
|
||||
threshold == 0.5
|
||||
), f"expected threshold to be 0.5 but is {threshold}"
|
||||
else:
|
||||
predicate = (
|
||||
f"{indent}if context.get_value('{name}') <= {threshold}:"
|
||||
)
|
||||
lines.append(predicate)
|
||||
dt_to_python(tree_.children_left[node], depth + 1)
|
||||
lines.append(f"{indent}else:")
|
||||
dt_to_python(tree_.children_right[node], depth + 1)
|
||||
else:
|
||||
lines.append(self.handle_leaf(tree_, node, indent, unsafe_leaves))
|
||||
|
||||
dt_to_python(0, 1)
|
||||
|
||||
self.write_heuristic_to_file(lines, heuristic_name)
|
||||
|
||||
def handle_leaf(self, tree_, node, indent, unsafe_leaves):
|
||||
"""
|
||||
Generates the code for a leaf node. This is just the value predicted by the regression tree.
|
||||
@ -368,7 +435,7 @@ class AHTrainRegressionTree(AHTrain):
|
||||
return "def predict(self, context: AHContext) -> float:"
|
||||
|
||||
def codegen_boilerplate(
|
||||
self, heuristic_name, opt_name, threshold, shared_memory, device_capa, dt
|
||||
self, heuristic_name, opt_name, threshold, shared_memory, device_capa, classes
|
||||
):
|
||||
"""
|
||||
Generates the boilerplate code for the generated heuristic. This includes things like imports, class definition,
|
||||
|
Reference in New Issue
Block a user