Autoheuristic: Do not store choices as metadata (#130304)

While for optimizations like pad_mm, there are always only two possible choices, for other decision procedures, like kernel choice selection, the set of "available" choices depends on the input. Instead of storing the choices as metadata, we can instead take a look at all choices for which we have collected data (i.e. `df[CHOICE_COL].unique()`).

In this PR, I also try to replace "choice" and "feedback" with global constants CHOICE_COL and FEEDBACK_COL.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/130304
Approved by: https://github.com/eellison
This commit is contained in:
Alnis Murtovi
2024-07-18 21:39:40 +00:00
committed by PyTorch MergeBot
parent 4d9f2a6d56
commit 7f1cda1533
4 changed files with 15 additions and 15 deletions

View File

@ -108,7 +108,6 @@ class AutoHeuristicTest(TestCase):
lines = file.readlines() lines = file.readlines()
self.assertTrue('"numerical_features": ["fa"]' in lines[0]) self.assertTrue('"numerical_features": ["fa"]' in lines[0])
self.assertTrue('"categorical_features": []' in lines[0]) self.assertTrue('"categorical_features": []' in lines[0])
self.assertTrue('"choices": ["a", "b", "c"]' in lines[0])
self.assertTrue(f'"shared_memory": {shared_memory}' in lines[0]) self.assertTrue(f'"shared_memory": {shared_memory}' in lines[0])
self.assertTrue(f'"device_capa": [{fst}, {snd}]' in lines[0]) self.assertTrue(f'"device_capa": [{fst}, {snd}]' in lines[0])
self.assertTrue('"name": "test"' in lines[0]) self.assertTrue('"name": "test"' in lines[0])

View File

@ -4,6 +4,7 @@ from torch._inductor.autoheuristic.autoheuristic_utils import (
AHContext, AHContext,
AHMetadata, AHMetadata,
Choice, Choice,
CHOICE_COL,
) )
from torch._inductor.autoheuristic.learnedheuristic_interface import LearnedHeuristic from torch._inductor.autoheuristic.learnedheuristic_interface import LearnedHeuristic
@ -24,7 +25,7 @@ class PadMMA100(LearnedHeuristic):
) )
def get_feedback(self, context: AHContext, choice: Choice) -> float: def get_feedback(self, context: AHContext, choice: Choice) -> float:
context.context_dict["choice"] = choice context.context_dict[CHOICE_COL] = choice
return self.predict(context) return self.predict(context)
def get_speedup_threshold(self) -> float: def get_speedup_threshold(self) -> float:

View File

@ -106,7 +106,6 @@ class AHMetadata:
return { return {
"shared_memory": self.shared_memory, "shared_memory": self.shared_memory,
"device_capa": self.device_capa, "device_capa": self.device_capa,
"choices": self.choices,
"name": self.name, "name": self.name,
} }

View File

@ -11,6 +11,7 @@ from sklearn.model_selection import train_test_split # type: ignore[import-unty
from sklearn.tree import DecisionTreeRegressor # type: ignore[import-untyped] from sklearn.tree import DecisionTreeRegressor # type: ignore[import-untyped]
from torch._inductor.autoheuristic.autoheuristic import deserialize_data from torch._inductor.autoheuristic.autoheuristic import deserialize_data
from torch._inductor.autoheuristic.autoheuristic_utils import CHOICE_COL, FEEDBACK_COL
# TODO (AlnisM): Fix these warnings # TODO (AlnisM): Fix these warnings
@ -83,7 +84,7 @@ class AHTrain:
# We will do a grid search over the values # We will do a grid search over the values
max_depths = [5, 10, 13, 15, 17, 20, 23, None] max_depths = [5, 10, 13, 15, 17, 20, 23, None]
min_samples_leafs = [1, 2, 5, 10] min_samples_leafs = [1, 2, 5, 10]
choice_columns = ["choice_" + choice for choice in choices] choice_columns = [f"{CHOICE_COL}_{choice}" for choice in choices]
(results_df, best_model, threshold) = self.train_and_evaluate_models( (results_df, best_model, threshold) = self.train_and_evaluate_models(
datasets, feature_columns, choice_columns, max_depths, min_samples_leafs datasets, feature_columns, choice_columns, max_depths, min_samples_leafs
) )
@ -114,7 +115,7 @@ class AHTrain:
(df, metadata) = deserialize_data(log_path) (df, metadata) = deserialize_data(log_path)
numerical_features = metadata["numerical_features"] numerical_features = metadata["numerical_features"]
categorical_features = metadata["categorical_features"] categorical_features = metadata["categorical_features"]
choices = metadata["choices"] choices = df[CHOICE_COL].unique().tolist()
features = numerical_features + categorical_features features = numerical_features + categorical_features
if nrows is not None: if nrows is not None:
df = df.head(nrows) df = df.head(nrows)
@ -133,10 +134,10 @@ class AHTrain:
# Calculate statistics for each input and choice combination # Calculate statistics for each input and choice combination
def calculate_stats(group): def calculate_stats(group):
count = len(group) count = len(group)
mean = group["feedback"].mean() mean = group[FEEDBACK_COL].mean()
std = group["feedback"].std() std = group[FEEDBACK_COL].std()
relative_std = (std / mean) * 100 if mean != 0 else np.inf relative_std = (std / mean) * 100 if mean != 0 else np.inf
median = group["feedback"].median() median = group[FEEDBACK_COL].median()
return pd.Series( return pd.Series(
{ {
"count": count, "count": count,
@ -146,7 +147,7 @@ class AHTrain:
) )
stats = ( stats = (
df.groupby(feature_columns + ["choice"]) df.groupby(feature_columns + [CHOICE_COL])
.apply(calculate_stats) .apply(calculate_stats)
.reset_index() .reset_index()
) )
@ -167,7 +168,7 @@ class AHTrain:
# Compute the winner and ratios for each input # Compute the winner and ratios for each input
def get_winner_and_speedups(group): def get_winner_and_speedups(group):
mean_time = group["median_execution_time"].mean() mean_time = group["median_execution_time"].mean()
winner = group.loc[group["median_execution_time"].idxmin(), "choice"] winner = group.loc[group["median_execution_time"].idxmin(), CHOICE_COL]
min_time = group["median_execution_time"].min() min_time = group["median_execution_time"].min()
max_time = group["median_execution_time"].max() max_time = group["median_execution_time"].max()
@ -176,7 +177,7 @@ class AHTrain:
group["target"] = mean_time / group["median_execution_time"] group["target"] = mean_time / group["median_execution_time"]
return group[ return group[
feature_columns + ["choice", "winner", "speedup", "target"] feature_columns + [CHOICE_COL, "winner", "speedup", "target"]
] ]
results = ( results = (
@ -190,7 +191,7 @@ class AHTrain:
results = process_data(df, feature_columns, apply_filters) results = process_data(df, feature_columns, apply_filters)
(results, added_categorical_features) = self.add_new_features(results) (results, added_categorical_features) = self.add_new_features(results)
categorical_features += added_categorical_features categorical_features += added_categorical_features
categorical_features += ["choice"] categorical_features += [CHOICE_COL]
# Doing this here because if we create another df for testing purposes # Doing this here because if we create another df for testing purposes
# and that other df does not contain all categories for a categorical feature, # and that other df does not contain all categories for a categorical feature,
@ -225,7 +226,7 @@ class AHTrain:
feature_columns = [ feature_columns = [
col col
for col in df.columns for col in df.columns
if col not in exclude_columns and not col.startswith("choice_") if col not in exclude_columns and not col.startswith(CHOICE_COL + "_")
] ]
df["input_id"] = df.groupby(feature_columns).ngroup() df["input_id"] = df.groupby(feature_columns).ngroup()
@ -404,7 +405,7 @@ class AHTrain:
): ):
boiler_plate = f"""# flake8: noqa: B950 boiler_plate = f"""# flake8: noqa: B950
from torch._inductor.autoheuristic.autoheuristic_utils import AHContext, AHMetadata, Choice from torch._inductor.autoheuristic.autoheuristic_utils import AHContext, AHMetadata, Choice, CHOICE_COL
from torch._inductor.autoheuristic.learnedheuristic_interface import ( from torch._inductor.autoheuristic.learnedheuristic_interface import (
LearnedHeuristic, LearnedHeuristic,
) )
@ -422,7 +423,7 @@ class {heuristic_name}(LearnedHeuristic):
) )
def get_feedback(self, context: AHContext, choice: Choice) -> float: def get_feedback(self, context: AHContext, choice: Choice) -> float:
context.context_dict["choice"] = choice context.context_dict[CHOICE_COL] = choice
return self.predict(context) return self.predict(context)
def get_speedup_threshold(self) -> float: def get_speedup_threshold(self) -> float: