From add0f0085c4ede7699dbedfeb10136cf18fb6248 Mon Sep 17 00:00:00 2001 From: Alnis Murtovi Date: Thu, 15 Aug 2024 13:02:51 -0700 Subject: [PATCH] AutoHeuristic: Support ranking/pruning choices (#131705) This PR adds support in train_decision if one wants to learn a heuristic for ranking. The main idea is that the user has to provide a number of choices the heuristic should return. I added a way to prune the learned decision tree such that it always returns the number of choices provided by the user. Pull Request resolved: https://github.com/pytorch/pytorch/pull/131705 Approved by: https://github.com/eellison --- torchgen/_autoheuristic/ah_tree.py | 262 ++++++++++++++++++++ torchgen/_autoheuristic/train.py | 72 +----- torchgen/_autoheuristic/train_decision.py | 243 ++++++++++++------ torchgen/_autoheuristic/train_regression.py | 71 +++++- 4 files changed, 515 insertions(+), 133 deletions(-) create mode 100644 torchgen/_autoheuristic/ah_tree.py diff --git a/torchgen/_autoheuristic/ah_tree.py b/torchgen/_autoheuristic/ah_tree.py new file mode 100644 index 000000000000..3991ffc87f88 --- /dev/null +++ b/torchgen/_autoheuristic/ah_tree.py @@ -0,0 +1,262 @@ +from typing import Any, Dict, List, Optional, Tuple + +import numpy as np +from sklearn.tree import _tree # type: ignore[import-untyped] + + +class DecisionTreeNode: + def __init__( + self, + feature: Optional[str] = None, + threshold: Optional[float] = None, + left: Optional["DecisionTreeNode"] = None, + right: Optional["DecisionTreeNode"] = None, + class_probs: Any = None, + num_samples: int = 0, + node_id: int = 0, + ) -> None: + self.feature = feature + self.threshold = threshold + self.left = left + self.right = right + self.class_probs = class_probs + self.num_samples = num_samples + self.id = node_id + + def is_leaf(self) -> bool: + return self.left is None or self.right is None + + +class DecisionTree: + """ + Custom decision tree implementation that mimics some of the sklearn API. + The purpose of this class it to be able to perform transformations, such as custom pruning, which + does not seem to be easy with sklearn. + """ + + def __init__(self, sklearn_tree: Any, feature_names: List[str]) -> None: + self.feature_names = feature_names + self.root = self._convert_sklearn_tree(sklearn_tree.tree_) + self.classes_: List[str] = sklearn_tree.classes_ + + def _convert_sklearn_tree( + self, sklearn_tree: Any, node_id: int = 0 + ) -> DecisionTreeNode: + class_probs = sklearn_tree.value[node_id][0] + num_samples = sklearn_tree.n_node_samples[node_id] + if sklearn_tree.feature[node_id] != _tree.TREE_UNDEFINED: + feature_index = sklearn_tree.feature[node_id] + feature = self.feature_names[feature_index] + left = self._convert_sklearn_tree( + sklearn_tree, sklearn_tree.children_left[node_id] + ) + right = self._convert_sklearn_tree( + sklearn_tree, sklearn_tree.children_right[node_id] + ) + return DecisionTreeNode( + feature=feature, + threshold=sklearn_tree.threshold[node_id], + left=left, + right=right, + class_probs=class_probs, + num_samples=num_samples, + node_id=node_id, + ) + else: + return DecisionTreeNode( + class_probs=class_probs, num_samples=num_samples, node_id=node_id + ) + + def prune(self, df: Any, target_col: str, k: int) -> None: + self.root = self._prune_tree(self.root, df, target_col, k) + + def _prune_tree( + self, node: DecisionTreeNode, df: Any, target_col: str, k: int + ) -> DecisionTreeNode: + if node.is_leaf(): + return node + + left_df = df[df[node.feature] <= node.threshold] + right_df = df[df[node.feature] > node.threshold] + + # number of unique classes in the left and right subtrees + left_counts = left_df[target_col].nunique() + right_counts = right_df[target_col].nunique() + + # for ranking, we want to ensure that we return at least k classes, so if we have less than k classes in the + # left or right subtree, we remove the split and make this node a leaf node + if left_counts < k or right_counts < k: + return DecisionTreeNode(class_probs=node.class_probs) + + assert node.left is not None, "expected left child to exist" + node.left = self._prune_tree(node.left, left_df, target_col, k) + assert node.right is not None, "expected right child to exist" + node.right = self._prune_tree(node.right, right_df, target_col, k) + + return node + + def to_dot(self) -> str: + dot = "digraph DecisionTree {\n" + dot += ' node [fontname="helvetica"];\n' + dot += ' edge [fontname="helvetica"];\n' + dot += self._node_to_dot(self.root) + dot += "}" + return dot + + def _node_to_dot( + self, node: DecisionTreeNode, parent_id: int = 0, edge_label: str = "" + ) -> str: + if node is None: + return "" + + node_id = id(node) + + # Format class_probs array with line breaks + class_probs_str = self._format_class_probs_array( + node.class_probs, node.num_samples + ) + + if node.is_leaf(): + label = class_probs_str + shape = "box" + else: + feature_name = f"{node.feature}" + label = f"{feature_name} <= {node.threshold:.2f}\\n{class_probs_str}" + shape = "oval" + + dot = f' {node_id} [label="{label}", shape={shape}];\n' + + if parent_id != 0: + dot += f' {parent_id} -> {node_id} [label="{edge_label}"];\n' + + if not node.is_leaf(): + assert node.left is not None, "expected left child to exist" + dot += self._node_to_dot(node.left, node_id, "<=") + assert node.right is not None, "expected right child to exist" + dot += self._node_to_dot(node.right, node_id, ">") + + return dot + + def _format_class_prob(self, num: float) -> str: + if num == 0: + return "0" + return f"{num:.2f}" + + def _format_class_probs_array( + self, class_probs: Any, num_samples: int, max_per_line: int = 5 + ) -> str: + # add line breaks to avoid very long lines + flat_class_probs = class_probs.flatten() + formatted = [self._format_class_prob(v) for v in flat_class_probs] + lines = [ + formatted[i : i + max_per_line] + for i in range(0, len(formatted), max_per_line) + ] + return f"num_samples={num_samples}\\n" + "\\n".join( + [", ".join(line) for line in lines] + ) + + def predict(self, X: Any) -> Any: + predictions = [self._predict_single(x) for _, x in X.iterrows()] + return np.array(predictions) + + def predict_proba(self, X: Any) -> Any: + return np.array([self._predict_proba_single(x) for _, x in X.iterrows()]) + + def _get_leaf(self, X: Any) -> DecisionTreeNode: + node = self.root + while not node.is_leaf(): + if X[node.feature] <= node.threshold: + assert node.left is not None, "expected left child to exist" + node = node.left + else: + assert node.right is not None, "expected right child to exist" + node = node.right + return node + + def _predict_single(self, x: Any) -> str: + node = self._get_leaf(x) + # map index to class name + return self.classes_[np.argmax(node.class_probs)] + + def _predict_proba_single(self, x: Any) -> Any: + node = self._get_leaf(x) + return node.class_probs + + def apply(self, X: Any) -> Any: + ids = [self._apply_single(x) for _, x in X.iterrows()] + return np.array(ids) + + def _apply_single(self, x: Any) -> int: + node = self._get_leaf(x) + return node.id + + def codegen( + self, + dummy_col_2_col_val: Dict[str, Tuple[str, Any]], + lines: List[str], + unsafe_leaves: List[int], + ) -> None: + # generates python code for the decision tree + def codegen_node(node: DecisionTreeNode, depth: int) -> None: + indent = " " * (depth + 1) + if node.is_leaf(): + lines.append(handle_leaf(node, indent, unsafe_leaves)) + else: + name = node.feature + threshold = node.threshold + if name in dummy_col_2_col_val: + (orig_name, value) = dummy_col_2_col_val[name] + predicate = f"{indent}if str(context.get_value('{orig_name}')) != '{value}':" + assert ( + threshold == 0.5 + ), f"expected threshold to be 0.5 but is {threshold}" + else: + predicate = ( + f"{indent}if context.get_value('{name}') <= {threshold}:" + ) + lines.append(predicate) + assert node.left is not None, "expected left child to exist" + codegen_node(node.left, depth + 1) + lines.append(f"{indent}else:") + assert node.right is not None, "expected right child to exist" + codegen_node(node.right, depth + 1) + + def handle_leaf( + node: DecisionTreeNode, indent: str, unsafe_leaves: List[int] + ) -> str: + """ + This generates the code for a leaf node in the decision tree. If the leaf is unsafe, the learned heuristic + will return "unsure" (i.e. None). + """ + if node.id in unsafe_leaves: + return f"{indent}return None" + class_probas = node.class_probs + return f"{indent}return {best_probas_and_indices(class_probas)}" + + def best_probas_and_indices(class_probas: Any) -> str: + """ + Given a list of tuples (proba, idx), this function returns a string in which the tuples are + sorted by proba in descending order. E.g.: + Given class_probas=[(0.3, 0), (0.5, 1), (0.2, 2)] + this function returns + "[(0.5, 1), (0.3, 0), (0.2, 2)]" + """ + # we generate a list of tuples (proba, idx) sorted by proba in descending order + # idx is the index of a choice + # we only generate a tuple if proba > 0 + probas_indices_sorted = sorted( + [ + (proba, index) + for index, proba in enumerate(class_probas) + if proba > 0 + ], + key=lambda x: x[0], + reverse=True, + ) + probas_indices_sorted_str = ", ".join( + f"({value:.3f}, {index})" for value, index in probas_indices_sorted + ) + return f"[{probas_indices_sorted_str}]" + + codegen_node(self.root, 1) diff --git a/torchgen/_autoheuristic/train.py b/torchgen/_autoheuristic/train.py index 78a16c42d79e..4e8dd330a132 100644 --- a/torchgen/_autoheuristic/train.py +++ b/torchgen/_autoheuristic/train.py @@ -2,7 +2,6 @@ import argparse import json -import sys import warnings import pandas as pd # type: ignore[import-untyped] @@ -64,6 +63,15 @@ class AHTrain: action="store_true", help="Export heuristic to graphviz dot.", ) + self.parser.add_argument( + "--ranking", + type=int, + default=None, + help=""" + Makes AutoHeuristic learn a heuristic that ranks choices instead of predicting a single choice. + The argument is the number of choices the heuristic will provide. + """, + ) def parse_args(self): return self.parser.parse_args() @@ -87,6 +95,7 @@ class AHTrain: self.args.nrows, self.args.heuristic_name, self.args.save_dot, + self.args.ranking is not None, ) def filter_df(self, df): @@ -138,9 +147,6 @@ class AHTrain: and str(metadata.device_capa) == "{device_capa}" )""" - def handle_leaf(self, tree_, node, indent, unsafe_leaves): - pass - def codegen_boilerplate( self, heuristic_name, opt_name, threshold, shared_memory, device_capa, dt ): @@ -149,63 +155,7 @@ class AHTrain: def gen_predict_fn_def(self): pass - def dt_to_python( - self, - dt, - metadata, - feature_names, - dummy_col_2_col_val, - heuristic_name, - threshold, - unsafe_leaves=None, - ): - tree_ = dt.tree_ - feature_name = [ - feature_names[i] if i != -1 else "undefined!" for i in tree_.feature - ] - - lines = [] - device_capa = metadata["device_capa"] - device_capa_str = f"({device_capa[0]}, {device_capa[1]})" - opt_name = metadata["name"] - lines.append( - self.codegen_boilerplate( - heuristic_name, - opt_name, - threshold, - metadata["shared_memory"], - device_capa_str, - dt, - ) - ) - fn_def = f"\n {self.gen_predict_fn_def()}" - lines.append(fn_def) - - def dt_to_python(node, depth): - indent = " " * (depth + 1) - false_predicate = "" - if tree_.feature[node] != -2: - name = feature_name[node] - threshold = tree_.threshold[node] - if name in dummy_col_2_col_val: - (orig_name, value) = dummy_col_2_col_val[name] - predicate = f"{indent}if str(context.get_value('{orig_name}')) != '{value}':" - if threshold != 0.5: - print(f"expected threshold to be 0.5 but is {threshold}") - sys.exit(1) - else: - predicate = ( - f"{indent}if context.get_value('{name}') <= {threshold}:" - ) - lines.append(predicate) - dt_to_python(tree_.children_left[node], depth + 1) - lines.append(f"{indent}else:") - dt_to_python(tree_.children_right[node], depth + 1) - else: - lines.append(self.handle_leaf(tree_, node, indent, unsafe_leaves)) - - dt_to_python(0, 1) - + def write_heuristic_to_file(self, lines, heuristic_name): output_file = ( f"../../../torch/_inductor/autoheuristic/artifacts/_{heuristic_name}.py" ) diff --git a/torchgen/_autoheuristic/train_decision.py b/torchgen/_autoheuristic/train_decision.py index f2270d264488..cd0b3ff4ad07 100644 --- a/torchgen/_autoheuristic/train_decision.py +++ b/torchgen/_autoheuristic/train_decision.py @@ -16,6 +16,7 @@ from dataclasses import dataclass import numpy as np import pandas as pd # type: ignore[import-untyped] +from ah_tree import DecisionTree from scipy.stats import gmean from sklearn.model_selection import train_test_split from sklearn.tree import DecisionTreeClassifier @@ -102,8 +103,21 @@ class AHTrainDecisionTree(AHTrain): leaf_ids = model.apply(df[feature_columns]) return predictions, proba, leaf_ids + def ranking_num_choices(self): + # if the heuristic is used for ranking, this function returns the number + # of choices that the heuristic will return + if self.args.ranking is None: + return 5 + return self.args.ranking + def train_and_evaluate_models( - self, datasets, max_depths, min_samples_leafs, criterion_list, feature_columns + self, + datasets, + max_depths, + min_samples_leafs, + criterion_list, + feature_columns, + ranking=False, ): """ Does a grid search over max_depths, min_samples_leafs, and criterion_list and returns the best model. @@ -131,7 +145,20 @@ class AHTrainDecisionTree(AHTrain): ) df_train = datasets["train"] df_val = datasets["val"] - model.fit(df_train[feature_columns], df_train["winner"]) + if ranking: + model.fit( + df_train[feature_columns], + df_train["winner"], + sample_weight=df_train["relative_performance"], + ) + else: + model.fit(df_train[feature_columns], df_train["winner"]) + + model = DecisionTree(model, feature_columns) + + if ranking: + model.prune(df_train, "winner", k=self.ranking_num_choices()) + unsafe_leaves = self.get_unsafe_leaves(model, df_train, feature_columns) predictions, proba, leaf_ids = self.predict(model, df_val, feature_columns) @@ -145,11 +172,18 @@ class AHTrainDecisionTree(AHTrain): wrong_pct=wrong_pct, unsafe_leaves=unsafe_leaves, leaf_ids=leaf_ids, + k=self.ranking_num_choices(), + ranking=ranking, ) safe_proba = evaluator.get_safe_proba() print(f"safe_proba={safe_proba}") def eval(name, df): + if ranking: + # when ranking is enabled, we duplicate each input for each choice that + # is almost as good as the best choice + # we do not want to evaluate the same input multiple times, so we remove duplicates here + df = df[df["winner"] == df["actual_winner"]] predictions, proba, leaf_ids = self.predict(model, df, feature_columns) evaluator = DecisionEvaluator( self, @@ -161,6 +195,8 @@ class AHTrainDecisionTree(AHTrain): threshold=safe_proba, unsafe_leaves=unsafe_leaves, leaf_ids=leaf_ids, + k=self.ranking_num_choices(), + ranking=ranking, ) return evaluator.get_results() @@ -202,7 +238,7 @@ class AHTrainDecisionTree(AHTrain): """ return (0.15, 0.15) - def prepare_datasets(self, df, other_datasets, cat_feature2cats): + def prepare_datasets(self, df, other_datasets, cat_feature2cats, ranking=False): """ Splits the dataframe into train, val, and test sets. Also adds other datasets, specified by the user, to the train set. @@ -219,24 +255,16 @@ class AHTrainDecisionTree(AHTrain): df_train_val, test_size=val_size / train_val_size, random_state=42 ) datasets = {"train": df_train, "val": df_val, "test": df_test} - self.add_real_datasets(datasets, other_datasets, cat_feature2cats) + self.add_real_datasets(datasets, other_datasets, cat_feature2cats, ranking) return datasets def export_to_dot(self, best_model, df, feature_columns): """ Export a learned decision tree to a dot file. """ - from sklearn import tree - - tree.export_graphviz( - best_model, - out_file="best_model.dot", - feature_names=df[feature_columns].columns, - class_names=[str(c) for c in best_model.classes_], - filled=True, - rounded=True, - special_characters=True, - ) + dot_str = best_model.to_dot() + with open("best_model.dot", "w") as f: + f.write(dot_str) def get_feature_columns(self, df): """ @@ -250,20 +278,36 @@ class AHTrainDecisionTree(AHTrain): "avail_choices", "choice2time", "index", + "actual_winner", + "relative_performance", ] feature_columns = [col for col in df.columns if col not in exclude_columns] return feature_columns - def main(self, log_path, other_datasets, nrows, heuristic_name, save_dot=False): + def add_training_data(self, df_train, datasets): + return datasets["train"] + + def main( + self, + log_path, + other_datasets, + nrows, + heuristic_name, + save_dot=False, + ranking=False, + ): """ Main function that trains a decision tree and generates a heuristic. """ # TODO: Enable apply_filters (df, choices, cat_feature2cats, dummy_col_2_col_val, metadata) = self.get_df( - log_path, nrows=nrows, apply_filters=False + log_path, nrows=nrows, apply_filters=False, add_near_best=ranking ) print(df["winner"].value_counts()) - datasets = self.prepare_datasets(df, other_datasets, cat_feature2cats) + datasets = self.prepare_datasets(df, other_datasets, cat_feature2cats, ranking) + df_train = self.add_training_data(datasets["train"], datasets) + datasets["train"] = df_train + feature_columns = self.get_feature_columns(df) grid_search_values = self.get_grid_search_values() max_depths = grid_search_values["max_depth"] @@ -275,28 +319,44 @@ class AHTrainDecisionTree(AHTrain): best_model_safe_proba, unsafe_leaves, ) = self.train_and_evaluate_models( - datasets, max_depths, min_samples_leafs, criterion_list, feature_columns + datasets, + max_depths, + min_samples_leafs, + criterion_list, + feature_columns, + ranking=ranking, ) + if ranking: + columns_to_keep = [ + "set", + "total", + "top_k_correct", + "top_k_wrong", + "top_k_unsure", + "wrong_max_spdup_k", + "wrong_gman_spdup_k", + ] + results_df = results_df[columns_to_keep] # prints results for all models and datasets print(results_df.to_string()) - # prints results grouped by dataset - for set_name in results_df["set"].unique(): - dataset_results = results_df[results_df["set"] == set_name] - dataset_results = dataset_results.sort_values(by="correct") - print(dataset_results.to_string() + "\n") + if not ranking: + # prints results grouped by dataset + for set_name in results_df["set"].unique(): + dataset_results = results_df[results_df["set"] == set_name] + dataset_results = dataset_results.sort_values(by="correct") + print(dataset_results.to_string() + "\n") if best_model is not None: if save_dot: self.export_to_dot(best_model, df, feature_columns) - self.dt_to_python( + self.codegen( best_model, metadata, - feature_columns, - dummy_col_2_col_val, heuristic_name, best_model_safe_proba, + dummy_col_2_col_val, unsafe_leaves, ) else: @@ -304,7 +364,14 @@ class AHTrainDecisionTree(AHTrain): "All learned models have too many wrong predictions, so no heuristic was generated" ) - def get_df(self, log_path, cat_feature2cats=None, nrows=None, apply_filters=False): + def get_df( + self, + log_path, + cat_feature2cats=None, + nrows=None, + apply_filters=False, + add_near_best=False, + ): """ Parses the log file and processes the data into a dataframe that can be used for training. """ @@ -314,14 +381,19 @@ class AHTrainDecisionTree(AHTrain): def calculate_stats(group): count = len(group) - mean = group["feedback"].mean() - std = group["feedback"].std() - relative_std = (std / mean) * 100 if mean != 0 else np.inf + has_inf = np.isinf(group["feedback"]).any() + if has_inf: + relative_std = np.inf + median = np.inf + else: + mean = group["feedback"].mean() + std = group["feedback"].std() + relative_std = (std / mean) * 100 if mean != 0 else np.inf + median = group["feedback"].median() if relative_std > 5: times = group["feedback"].tolist() times_str = ", ".join([f"{t:.3f}" for t in sorted(times)]) log.debug("High relative std: %f. times=%s", relative_std, times_str) - median = group["feedback"].median() return pd.Series( { "count": count, @@ -385,6 +457,28 @@ class AHTrainDecisionTree(AHTrain): .reset_index() ) + def add_near_best_configs(df): + new_rows = [] + + for index, row in df.iterrows(): + dictionary = json.loads(row["choice2time"]) + min_value = min(dictionary.values()) + + for key, value in dictionary.items(): + new_row = row.copy() + relative_performance = min_value / value + new_row["relative_performance"] = relative_performance + if relative_performance is None or relative_performance is np.inf: + breakpoint() + new_row["actual_winner"] = row["winner"] + new_row["winner"] = key + if relative_performance >= 0.95: + new_rows.append(new_row) + + return pd.DataFrame(new_rows).reset_index(drop=True) + + if add_near_best: + results = add_near_best_configs(results) (results, added_categorical_features) = self.add_new_features(results) categorical_features += added_categorical_features @@ -409,27 +503,6 @@ class AHTrainDecisionTree(AHTrain): indent = " " * num_spaces return "\n".join([f"{indent}self.choices.append('{c}')" for c in classes]) - def best_probas_and_indices(self, class_probas): - """ - Given a list of tuples (proba, idx), this function returns a string in which the tuples are sorted by proba in - descending order. E.g.: - Given class_probas=[(0.3, 0), (0.5, 1), (0.2, 2)] - this function returns - "[(0.5, 1), (0.3, 0), (0.2, 2)]" - """ - # we generate a list of tuples (proba, idx) sorted by proba in descending order - # idx is the index of a choice - # we only generate a tuple if proba > 0 - probas_indices_sorted = sorted( - [(proba, index) for index, proba in enumerate(class_probas) if proba > 0], - key=lambda x: x[0], - reverse=True, - ) - probas_indices_sorted_str = ", ".join( - f"({value:.3f}, {index})" for value, index in probas_indices_sorted - ) - return f"[{probas_indices_sorted_str}]" - def get_default_config(self, row): """ Returns the default config for a given sample. The default config could for example be the config that is @@ -438,17 +511,6 @@ class AHTrainDecisionTree(AHTrain): """ return None - def handle_leaf(self, tree_, node, indent, unsafe_leaves): - """ - This generates the code for a leaf node in the decision tree. If the leaf is unsafe, the learned heuristic - will return "unsure" (i.e. None). - """ - if node in unsafe_leaves: - return f"{indent}return None" - leaf_num_samples = tree_.n_node_samples[node] - class_probas = tree_.value[node][0] - return f"{indent}return {self.best_probas_and_indices(class_probas)}" - def gen_predict_fn_def(self): """ Generates the definition of the predict function. @@ -456,7 +518,7 @@ class AHTrainDecisionTree(AHTrain): return "def get_best_choices(self, context: AHContext) -> Optional[List[Tuple[float, int]]]:" def codegen_boilerplate( - self, heuristic_name, opt_name, threshold, shared_memory, device_capa, dt + self, heuristic_name, opt_name, threshold, shared_memory, device_capa, classes ): """ Generates the boilerplate code for the generated heuristic. This includes things like imports, class definition, @@ -496,23 +558,56 @@ class {heuristic_name}(LearnedHeuristicDecision): return None def fill_choices(self) -> None: -{self.gen_classes(dt.classes_, num_spaces=8)} +{self.gen_classes(classes, num_spaces=8)} def get_name(self) -> str: return '{opt_name}'""" return boiler_plate - def add_real_datasets(self, datasets, other_datasets, cat_feature2cats): + def add_real_datasets( + self, datasets, other_datasets, cat_feature2cats, ranking=False + ): """ Adds datasets specified by the user to the datasets dictionary. """ if other_datasets: for name, path in other_datasets: (df_other, choices, _, _, _) = self.get_df( - path, cat_feature2cats=cat_feature2cats, apply_filters=False + path, + cat_feature2cats=cat_feature2cats, + apply_filters=False, + add_near_best=ranking, ) datasets[name] = df_other + def codegen( + self, + tree, + metadata, + heuristic_name, + threshold, + dummy_col_2_col_val, + unsafe_leaves, + ): + lines = [] + device_capa = metadata["device_capa"] + device_capa_str = f"({device_capa[0]}, {device_capa[1]})" + opt_name = metadata["name"] + lines.append( + self.codegen_boilerplate( + heuristic_name, + opt_name, + threshold, + metadata["shared_memory"], + device_capa_str, + tree.classes_, + ) + ) + fn_def = f"\n {self.gen_predict_fn_def()}" + lines.append(fn_def) + tree.codegen(dummy_col_2_col_val, lines, unsafe_leaves) + self.write_heuristic_to_file(lines, heuristic_name) + @dataclass class AccuracyMetrics: @@ -552,6 +647,8 @@ class WrongSpeedupMetrics: class RankingMetrics: # Number of predictions where best choice is in top k choices num_correct: int + # Number of predictions where best choice is not in top k choices + num_wrong: int # Maximum speedup of best choice over best choice in top k (this tells us how much better the best choice, which # is not in top k, is over the best choice in top k) max_speedup: float @@ -563,6 +660,7 @@ class RankingMetrics: def to_map(self): return { "top_k_correct": self.num_correct, + "top_k_wrong": self.num_wrong, "wrong_max_speedup_k": self.max_speedup, "wrong_gmean_speedup_k": self.gmean_speedup, "top_k_unsure": self.unsure, @@ -618,9 +716,10 @@ class DecisionEvaluator: probas, wrong_pct=0.01, threshold=0.0, - k=3, + k=10, unsafe_leaves=None, leaf_ids=None, + ranking=False, ) -> None: self.train = train self.model = model @@ -632,6 +731,7 @@ class DecisionEvaluator: self.k = k self.unsafe_leaves = unsafe_leaves self.leaf_ids = leaf_ids + self.ranking = ranking self.num_correct = 0 self.num_wrong = 0 @@ -639,6 +739,7 @@ class DecisionEvaluator: self.wrong_probas = [] self.speedups_wrong = [] self.num_correct_top_k = 0 + self.num_wrong_top_k = 0 self.wrong_speedups_top_k = [] self.top_k_unsure = 0 self.num_non_default_predictions = 0 @@ -718,6 +819,7 @@ class DecisionEvaluator: if min_time is not None: speedup = min_time / best_time self.wrong_speedups_top_k.append(speedup) + self.num_wrong_top_k += 1 else: self.top_k_unsure += 1 # TODO (AlnisM): print more info (input and choices) @@ -743,7 +845,7 @@ class DecisionEvaluator: Custom evaluation function that evaluates a learned decision tree. """ - y_true = self.df["winner"] + y_true = self.df["actual_winner"] if self.ranking else self.df["winner"] i = 0 for pred, true, prob, leaf_id in zip( self.predictions, y_true, self.probas, self.leaf_ids @@ -790,6 +892,7 @@ class DecisionEvaluator: wrongSpeedupMetrics = WrongSpeedupMetrics(max_speedup, gmean_speedup) rankingMetrics = RankingMetrics( self.num_correct_top_k, + self.num_wrong_top_k, max_speedup_top_k, gmean_speedup_top_k, self.top_k_unsure, diff --git a/torchgen/_autoheuristic/train_regression.py b/torchgen/_autoheuristic/train_regression.py index 024095e3c139..1fc487320425 100644 --- a/torchgen/_autoheuristic/train_regression.py +++ b/torchgen/_autoheuristic/train_regression.py @@ -34,7 +34,15 @@ class AHTrainRegressionTree(AHTrain): def __init__(self): super().__init__() - def main(self, log_path, other_datasets, nrows, heuristic_name, save_dot=False): + def main( + self, + log_path, + other_datasets, + nrows, + heuristic_name, + save_dot=False, + ranking=False, + ): """ Main function that trains a decision tree and generates a heuristic. """ @@ -357,6 +365,65 @@ class AHTrainRegressionTree(AHTrain): "wrong_max_ratio": wrong_max_ratio, } + def dt_to_python( + self, + dt, + metadata, + feature_names, + dummy_col_2_col_val, + heuristic_name, + threshold, + unsafe_leaves=None, + ): + tree_ = dt.tree_ + feature_name = [ + feature_names[i] if i != -1 else "undefined!" for i in tree_.feature + ] + + lines = [] + device_capa = metadata["device_capa"] + device_capa_str = f"({device_capa[0]}, {device_capa[1]})" + opt_name = metadata["name"] + lines.append( + self.codegen_boilerplate( + heuristic_name, + opt_name, + threshold, + metadata["shared_memory"], + device_capa_str, + dt, + ) + ) + fn_def = f"\n {self.gen_predict_fn_def()}" + lines.append(fn_def) + + def dt_to_python(node, depth): + indent = " " * (depth + 1) + false_predicate = "" + if tree_.feature[node] != -2: + name = feature_name[node] + threshold = tree_.threshold[node] + if name in dummy_col_2_col_val: + (orig_name, value) = dummy_col_2_col_val[name] + predicate = f"{indent}if str(context.get_value('{orig_name}')) != '{value}':" + assert ( + threshold == 0.5 + ), f"expected threshold to be 0.5 but is {threshold}" + else: + predicate = ( + f"{indent}if context.get_value('{name}') <= {threshold}:" + ) + lines.append(predicate) + dt_to_python(tree_.children_left[node], depth + 1) + lines.append(f"{indent}else:") + dt_to_python(tree_.children_right[node], depth + 1) + else: + lines.append(self.handle_leaf(tree_, node, indent, unsafe_leaves)) + + dt_to_python(0, 1) + + self.write_heuristic_to_file(lines, heuristic_name) + def handle_leaf(self, tree_, node, indent, unsafe_leaves): """ Generates the code for a leaf node. This is just the value predicted by the regression tree. @@ -368,7 +435,7 @@ class AHTrainRegressionTree(AHTrain): return "def predict(self, context: AHContext) -> float:" def codegen_boilerplate( - self, heuristic_name, opt_name, threshold, shared_memory, device_capa, dt + self, heuristic_name, opt_name, threshold, shared_memory, device_capa, classes ): """ Generates the boilerplate code for the generated heuristic. This includes things like imports, class definition,