mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Autoheuristic: Do not store choices as metadata (#130304)
While for optimizations like pad_mm, there are always only two possible choices, for other decision procedures, like kernel choice selection, the set of "available" choices depends on the input. Instead of storing the choices as metadata, we can instead take a look at all choices for which we have collected data (i.e. `df[CHOICE_COL].unique()`). In this PR, I also try to replace "choice" and "feedback" with global constants CHOICE_COL and FEEDBACK_COL. Pull Request resolved: https://github.com/pytorch/pytorch/pull/130304 Approved by: https://github.com/eellison
This commit is contained in:
committed by
PyTorch MergeBot
parent
4d9f2a6d56
commit
7f1cda1533
@ -108,7 +108,6 @@ class AutoHeuristicTest(TestCase):
|
||||
lines = file.readlines()
|
||||
self.assertTrue('"numerical_features": ["fa"]' in lines[0])
|
||||
self.assertTrue('"categorical_features": []' in lines[0])
|
||||
self.assertTrue('"choices": ["a", "b", "c"]' in lines[0])
|
||||
self.assertTrue(f'"shared_memory": {shared_memory}' in lines[0])
|
||||
self.assertTrue(f'"device_capa": [{fst}, {snd}]' in lines[0])
|
||||
self.assertTrue('"name": "test"' in lines[0])
|
||||
|
@ -4,6 +4,7 @@ from torch._inductor.autoheuristic.autoheuristic_utils import (
|
||||
AHContext,
|
||||
AHMetadata,
|
||||
Choice,
|
||||
CHOICE_COL,
|
||||
)
|
||||
from torch._inductor.autoheuristic.learnedheuristic_interface import LearnedHeuristic
|
||||
|
||||
@ -24,7 +25,7 @@ class PadMMA100(LearnedHeuristic):
|
||||
)
|
||||
|
||||
def get_feedback(self, context: AHContext, choice: Choice) -> float:
|
||||
context.context_dict["choice"] = choice
|
||||
context.context_dict[CHOICE_COL] = choice
|
||||
return self.predict(context)
|
||||
|
||||
def get_speedup_threshold(self) -> float:
|
||||
|
@ -106,7 +106,6 @@ class AHMetadata:
|
||||
return {
|
||||
"shared_memory": self.shared_memory,
|
||||
"device_capa": self.device_capa,
|
||||
"choices": self.choices,
|
||||
"name": self.name,
|
||||
}
|
||||
|
||||
|
@ -11,6 +11,7 @@ from sklearn.model_selection import train_test_split # type: ignore[import-unty
|
||||
from sklearn.tree import DecisionTreeRegressor # type: ignore[import-untyped]
|
||||
|
||||
from torch._inductor.autoheuristic.autoheuristic import deserialize_data
|
||||
from torch._inductor.autoheuristic.autoheuristic_utils import CHOICE_COL, FEEDBACK_COL
|
||||
|
||||
|
||||
# TODO (AlnisM): Fix these warnings
|
||||
@ -83,7 +84,7 @@ class AHTrain:
|
||||
# We will do a grid search over the values
|
||||
max_depths = [5, 10, 13, 15, 17, 20, 23, None]
|
||||
min_samples_leafs = [1, 2, 5, 10]
|
||||
choice_columns = ["choice_" + choice for choice in choices]
|
||||
choice_columns = [f"{CHOICE_COL}_{choice}" for choice in choices]
|
||||
(results_df, best_model, threshold) = self.train_and_evaluate_models(
|
||||
datasets, feature_columns, choice_columns, max_depths, min_samples_leafs
|
||||
)
|
||||
@ -114,7 +115,7 @@ class AHTrain:
|
||||
(df, metadata) = deserialize_data(log_path)
|
||||
numerical_features = metadata["numerical_features"]
|
||||
categorical_features = metadata["categorical_features"]
|
||||
choices = metadata["choices"]
|
||||
choices = df[CHOICE_COL].unique().tolist()
|
||||
features = numerical_features + categorical_features
|
||||
if nrows is not None:
|
||||
df = df.head(nrows)
|
||||
@ -133,10 +134,10 @@ class AHTrain:
|
||||
# Calculate statistics for each input and choice combination
|
||||
def calculate_stats(group):
|
||||
count = len(group)
|
||||
mean = group["feedback"].mean()
|
||||
std = group["feedback"].std()
|
||||
mean = group[FEEDBACK_COL].mean()
|
||||
std = group[FEEDBACK_COL].std()
|
||||
relative_std = (std / mean) * 100 if mean != 0 else np.inf
|
||||
median = group["feedback"].median()
|
||||
median = group[FEEDBACK_COL].median()
|
||||
return pd.Series(
|
||||
{
|
||||
"count": count,
|
||||
@ -146,7 +147,7 @@ class AHTrain:
|
||||
)
|
||||
|
||||
stats = (
|
||||
df.groupby(feature_columns + ["choice"])
|
||||
df.groupby(feature_columns + [CHOICE_COL])
|
||||
.apply(calculate_stats)
|
||||
.reset_index()
|
||||
)
|
||||
@ -167,7 +168,7 @@ class AHTrain:
|
||||
# Compute the winner and ratios for each input
|
||||
def get_winner_and_speedups(group):
|
||||
mean_time = group["median_execution_time"].mean()
|
||||
winner = group.loc[group["median_execution_time"].idxmin(), "choice"]
|
||||
winner = group.loc[group["median_execution_time"].idxmin(), CHOICE_COL]
|
||||
min_time = group["median_execution_time"].min()
|
||||
max_time = group["median_execution_time"].max()
|
||||
|
||||
@ -176,7 +177,7 @@ class AHTrain:
|
||||
group["target"] = mean_time / group["median_execution_time"]
|
||||
|
||||
return group[
|
||||
feature_columns + ["choice", "winner", "speedup", "target"]
|
||||
feature_columns + [CHOICE_COL, "winner", "speedup", "target"]
|
||||
]
|
||||
|
||||
results = (
|
||||
@ -190,7 +191,7 @@ class AHTrain:
|
||||
results = process_data(df, feature_columns, apply_filters)
|
||||
(results, added_categorical_features) = self.add_new_features(results)
|
||||
categorical_features += added_categorical_features
|
||||
categorical_features += ["choice"]
|
||||
categorical_features += [CHOICE_COL]
|
||||
|
||||
# Doing this here because if we create another df for testing purposes
|
||||
# and that other df does not contain all categories for a categorical feature,
|
||||
@ -225,7 +226,7 @@ class AHTrain:
|
||||
feature_columns = [
|
||||
col
|
||||
for col in df.columns
|
||||
if col not in exclude_columns and not col.startswith("choice_")
|
||||
if col not in exclude_columns and not col.startswith(CHOICE_COL + "_")
|
||||
]
|
||||
df["input_id"] = df.groupby(feature_columns).ngroup()
|
||||
|
||||
@ -404,7 +405,7 @@ class AHTrain:
|
||||
):
|
||||
boiler_plate = f"""# flake8: noqa: B950
|
||||
|
||||
from torch._inductor.autoheuristic.autoheuristic_utils import AHContext, AHMetadata, Choice
|
||||
from torch._inductor.autoheuristic.autoheuristic_utils import AHContext, AHMetadata, Choice, CHOICE_COL
|
||||
from torch._inductor.autoheuristic.learnedheuristic_interface import (
|
||||
LearnedHeuristic,
|
||||
)
|
||||
@ -422,7 +423,7 @@ class {heuristic_name}(LearnedHeuristic):
|
||||
)
|
||||
|
||||
def get_feedback(self, context: AHContext, choice: Choice) -> float:
|
||||
context.context_dict["choice"] = choice
|
||||
context.context_dict[CHOICE_COL] = choice
|
||||
return self.predict(context)
|
||||
|
||||
def get_speedup_threshold(self) -> float:
|
||||
|
Reference in New Issue
Block a user