mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
See #145101 for details. Pull Request resolved: https://github.com/pytorch/pytorch/pull/145137 Approved by: https://github.com/bobrenjc93
316 lines
12 KiB
Python
316 lines
12 KiB
Python
import json
|
|
import os
|
|
from functools import partial
|
|
from typing import Any, Callable, Optional
|
|
|
|
import torch
|
|
from torch._inductor.autoheuristic.autoheuristic_utils import (
|
|
AHContext,
|
|
AHMetadata,
|
|
AHOperation,
|
|
Choice,
|
|
CHOICE_COL,
|
|
Feedback,
|
|
FEEDBACK_COL,
|
|
get_metadata_str_from_log,
|
|
)
|
|
from torch._inductor.autoheuristic.learned_heuristic_controller import (
|
|
LearnedHeuristicController,
|
|
)
|
|
from torch._inductor.ir import ChoiceCaller
|
|
from torch._inductor.runtime.runtime_utils import cache_dir
|
|
from torch._inductor.utils import get_gpu_shared_memory
|
|
|
|
|
|
class LocalFeedback:
|
|
"""
|
|
To be able to collect data for a choice, a function providing feedback given a choice has to be provided.
|
|
LocalFeedback can be used when AutoHeuristic should immediately run the function to collect feedback for each choice
|
|
(see pad_mm.py, where the autotuning happens locally, for an example).
|
|
"""
|
|
|
|
def __init__(self, feedback_fn: Callable[[Choice], Feedback]) -> None:
|
|
self.feedback_fn = feedback_fn
|
|
|
|
def __call__(self, choice: Choice) -> Feedback:
|
|
return self.feedback_fn(choice)
|
|
|
|
|
|
class InconsistentMetadata(Exception):
|
|
"""
|
|
Exception that is thrown when AutoHeuristic tries to log data to a file where the metadata stored in the file does
|
|
not match the metadata it would store if the file didn't exist.
|
|
"""
|
|
|
|
|
|
class AutoHeuristic:
|
|
"""
|
|
AutoHeuristic is a framework that allows one to collect data, learn a heuristic (i.e. a regression tree) and
|
|
generate the heuristic to code. This class allows one to collect data. The collected data can then be used to train
|
|
a heuristic (see torchgen/autoheuristic/).
|
|
"""
|
|
|
|
collected_feedback: dict[Choice, Feedback]
|
|
|
|
def __init__(
|
|
self,
|
|
fallback: Callable[[], Choice],
|
|
choices: list[Choice],
|
|
feedback: Optional[LocalFeedback],
|
|
context: AHContext,
|
|
name: str,
|
|
augment_context: Optional[list[AHOperation]] = None,
|
|
precondition: Optional[Callable[[AHMetadata, AHContext], bool]] = None,
|
|
) -> None:
|
|
"""
|
|
Initializes an instance of the AutoHeuristic class.
|
|
|
|
Args:
|
|
fallback: A callable that returns a Choice when the heuristic is unsure which choice to make, or
|
|
AutoHeuristic is in data collection mode.
|
|
choices: A list of possible choices the heuristic can make.
|
|
feedback: An instance of LocalFeedback that provides feedback for a given choice.
|
|
context: Context to store with each choice and feedback.
|
|
name: A string that identifies the heuristic.
|
|
augment_context: An optional list of AHOperation instances that augment the context.
|
|
precondition: A callable that returns a boolean indicating whether AutoHeuristic should run.
|
|
"""
|
|
self.fallback = fallback
|
|
self.choices = choices
|
|
self.feedback = feedback
|
|
self.context = context
|
|
self.name = name
|
|
self.collected_feedback = {}
|
|
self.augment_context = augment_context
|
|
self.metadata = AHMetadata(
|
|
get_gpu_shared_memory(),
|
|
torch.cuda.get_device_capability(),
|
|
self.choices,
|
|
self.name,
|
|
)
|
|
self.precondition = precondition
|
|
|
|
if not self.satisfies_precondition():
|
|
return
|
|
|
|
if torch._inductor.config.autoheuristic_log_path == "DEFAULT":
|
|
self.log_path = self.get_default_log_path()
|
|
else:
|
|
self.log_path = torch._inductor.config.autoheuristic_log_path
|
|
|
|
if torch._inductor.config.collect_autoheuristic(self.name):
|
|
if self.feedback is not None:
|
|
for choice in self.choices:
|
|
feedback_val = self.feedback(choice)
|
|
self.save_data(choice, feedback_val)
|
|
|
|
def satisfies_precondition(self) -> bool:
|
|
return self.precondition is None or self.precondition(
|
|
self.metadata, self.context
|
|
)
|
|
|
|
def get_choice(self) -> Choice:
|
|
"""
|
|
Returns the chosen option based on the value of autoheuristic_use.
|
|
If self.name is one of the comma separated strings in autoheuristic_use,
|
|
it queries a learned heuristic to make a decision. Otherwise, it returns the fallback option.
|
|
"""
|
|
|
|
if not self.satisfies_precondition():
|
|
return self.fallback()
|
|
|
|
if torch._inductor.config.use_autoheuristic(self.name):
|
|
if self.augment_context is not None:
|
|
self.context.apply_operations(self.augment_context)
|
|
controller = LearnedHeuristicController(
|
|
self.metadata,
|
|
self.context,
|
|
)
|
|
decision = controller.get_decision()
|
|
if decision not in self.choices:
|
|
# TODO(AlnisM): We might want to allow this in the future
|
|
return self.fallback()
|
|
if decision is not None:
|
|
return decision
|
|
return self.fallback()
|
|
|
|
def get_top_k_choices(
|
|
self, top_k: int, always_included: Optional[list[str]] = None
|
|
) -> Optional[list[Choice]]:
|
|
if not self.satisfies_precondition():
|
|
return None
|
|
if torch._inductor.config.use_autoheuristic(self.name):
|
|
if self.augment_context is not None:
|
|
self.context.apply_operations(self.augment_context)
|
|
controller = LearnedHeuristicController(
|
|
self.metadata,
|
|
self.context,
|
|
)
|
|
choices = controller.get_decisions_ranked(top_k)
|
|
if choices is None:
|
|
return None
|
|
if always_included is not None:
|
|
for choice in always_included:
|
|
if choice not in choices:
|
|
choices.append(choice)
|
|
return choices
|
|
return None
|
|
|
|
def get_collected_feedback(self, choice: Choice) -> Any:
|
|
return self.collected_feedback.get(choice, None)
|
|
|
|
@staticmethod
|
|
def get_device_identifier() -> str:
|
|
# a heuristic might work well for one GPU, but not for another
|
|
# we store the collected data per GPU model and learn a heuristic per GPU model
|
|
|
|
# TODO(AlnisM): just using the device name for now, but the same GPU model can have different names
|
|
device_name = torch.cuda.get_device_name().replace(" ", "_")
|
|
return device_name
|
|
|
|
def get_default_log_path(self) -> str:
|
|
device_name = self.get_device_identifier()
|
|
path = f"{cache_dir()}/autoheuristic/{device_name}/"
|
|
os.makedirs(path, exist_ok=True)
|
|
path += f"{self.name}.txt"
|
|
return path
|
|
|
|
def serialize_metadata(self) -> str:
|
|
metadata_dict = self.metadata.to_dict()
|
|
(
|
|
num_features,
|
|
cat_features,
|
|
) = self.context.get_numerical_and_categorical_features()
|
|
metadata_dict["numerical_features"] = num_features
|
|
metadata_dict["categorical_features"] = cat_features
|
|
return json.dumps(metadata_dict)
|
|
|
|
def save_data(self, choice: Choice, feedback_val: Feedback) -> None:
|
|
self.collected_feedback[choice] = feedback_val
|
|
log_path = self.log_path
|
|
|
|
lines = []
|
|
log_exists = os.path.exists(log_path)
|
|
if log_exists:
|
|
# if log already exists, make sure it is consistent
|
|
metadata = self.serialize_metadata()
|
|
existing_metadata = get_metadata_str_from_log(self.log_path)
|
|
if existing_metadata != metadata:
|
|
raise InconsistentMetadata(
|
|
"Given metadata does not match existing metadata"
|
|
)
|
|
else:
|
|
lines.append(self.serialize_metadata())
|
|
feature_header = self.context.get_feature_names_csv()
|
|
header = feature_header + "," + CHOICE_COL + "," + FEEDBACK_COL
|
|
lines.append(header)
|
|
|
|
line = ""
|
|
feature_values = self.context.get_feature_values_csv()
|
|
line += feature_values + "," + choice + "," + str(feedback_val)
|
|
lines.append(line)
|
|
|
|
with open(log_path, "a") as f:
|
|
f.write("\n".join(lines) + "\n")
|
|
|
|
|
|
class AutoHeuristicSelectAlgorithm(AutoHeuristic):
|
|
"""
|
|
AutoHeuristicSelectAlgorithm is a subclass of AutoHeuristic that allows one to collect data and learn a heuristic
|
|
when one wants to use AutoHeuristic for kernel choice selection.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
fallback: Callable[[], Optional[ChoiceCaller]],
|
|
choices: list[ChoiceCaller],
|
|
input_nodes: list[Any],
|
|
context: AHContext,
|
|
name: str,
|
|
augment_context: Optional[list[AHOperation]] = None,
|
|
precondition: Optional[Callable[[AHMetadata, AHContext], bool]] = None,
|
|
) -> None:
|
|
"""
|
|
The arguments choices, input_nodes and name have to match the ones used in the call to
|
|
autotune_select_algorithm(), e.g. if the following call is made
|
|
autotune_select_algorithm(name, choices, input_nodes, layout), the same name, choices and input_nodes
|
|
have to be used here.
|
|
"""
|
|
self.input_nodes = input_nodes
|
|
self.choicestr2choice: dict[str, ChoiceCaller] = {}
|
|
for choice in choices:
|
|
self.choicestr2choice[choice.autoheuristic_id()] = choice
|
|
choices_str = list(self.choicestr2choice.keys())
|
|
|
|
def fallback_str() -> str:
|
|
fallback_choice = fallback()
|
|
if fallback_choice is None:
|
|
# TODO: Find a nicer way to handle this
|
|
return "unsure"
|
|
return fallback_choice.autoheuristic_id()
|
|
|
|
super().__init__(
|
|
fallback_str,
|
|
choices_str,
|
|
None,
|
|
context,
|
|
name,
|
|
augment_context,
|
|
precondition,
|
|
)
|
|
|
|
if (
|
|
torch._inductor.config.collect_autoheuristic(self.name)
|
|
and self.satisfies_precondition()
|
|
):
|
|
self.register_global_feedback(input_nodes, choices)
|
|
|
|
def register_global_feedback(
|
|
self, input_nodes: list[Any], choices: list[ChoiceCaller]
|
|
) -> None:
|
|
"""
|
|
Registers a callback in select_algorithm, which is called with the timing of each choice.
|
|
"""
|
|
|
|
from torch._inductor.select_algorithm import (
|
|
add_feedback_saver,
|
|
create_inputs_key,
|
|
create_precompile_key,
|
|
)
|
|
|
|
def store_global_feedback(
|
|
ah_inputs_key: str,
|
|
ah_precompile_key: str,
|
|
timings: dict[ChoiceCaller, float],
|
|
name: str,
|
|
input_nodes: list[Any],
|
|
choices: list[ChoiceCaller],
|
|
) -> None:
|
|
current_inputs_key = create_inputs_key(input_nodes)
|
|
if current_inputs_key != ah_inputs_key:
|
|
return
|
|
current_precompile_key = create_precompile_key(
|
|
name, current_inputs_key, choices
|
|
)
|
|
if current_precompile_key != ah_precompile_key:
|
|
return
|
|
for choice, time in timings.items():
|
|
self.save_data(choice.autoheuristic_id(), time)
|
|
|
|
inputs_key = create_inputs_key(input_nodes)
|
|
precompile_key = create_precompile_key(self.name, inputs_key, choices)
|
|
feedback_saver = partial(store_global_feedback, inputs_key, precompile_key)
|
|
add_feedback_saver(feedback_saver)
|
|
|
|
def get_choice_caller(self) -> Optional[ChoiceCaller]:
|
|
choice = self.get_choice()
|
|
return self.choicestr2choice.get(choice, None)
|
|
|
|
def get_top_k_choices_caller(
|
|
self, top_k: int, always_included: Optional[list[str]] = None
|
|
) -> Optional[list[ChoiceCaller]]:
|
|
choices = self.get_top_k_choices(top_k, always_included)
|
|
if choices is None:
|
|
return None
|
|
return [self.choicestr2choice[choice] for choice in choices]
|