Files
pytorch/torch/_inductor/autoheuristic/autoheuristic.py
2025-01-19 01:22:47 +00:00

316 lines
12 KiB
Python

import json
import os
from functools import partial
from typing import Any, Callable, Optional
import torch
from torch._inductor.autoheuristic.autoheuristic_utils import (
AHContext,
AHMetadata,
AHOperation,
Choice,
CHOICE_COL,
Feedback,
FEEDBACK_COL,
get_metadata_str_from_log,
)
from torch._inductor.autoheuristic.learned_heuristic_controller import (
LearnedHeuristicController,
)
from torch._inductor.ir import ChoiceCaller
from torch._inductor.runtime.runtime_utils import cache_dir
from torch._inductor.utils import get_gpu_shared_memory
class LocalFeedback:
"""
To be able to collect data for a choice, a function providing feedback given a choice has to be provided.
LocalFeedback can be used when AutoHeuristic should immediately run the function to collect feedback for each choice
(see pad_mm.py, where the autotuning happens locally, for an example).
"""
def __init__(self, feedback_fn: Callable[[Choice], Feedback]) -> None:
self.feedback_fn = feedback_fn
def __call__(self, choice: Choice) -> Feedback:
return self.feedback_fn(choice)
class InconsistentMetadata(Exception):
"""
Exception that is thrown when AutoHeuristic tries to log data to a file where the metadata stored in the file does
not match the metadata it would store if the file didn't exist.
"""
class AutoHeuristic:
"""
AutoHeuristic is a framework that allows one to collect data, learn a heuristic (i.e. a regression tree) and
generate the heuristic to code. This class allows one to collect data. The collected data can then be used to train
a heuristic (see torchgen/autoheuristic/).
"""
collected_feedback: dict[Choice, Feedback]
def __init__(
self,
fallback: Callable[[], Choice],
choices: list[Choice],
feedback: Optional[LocalFeedback],
context: AHContext,
name: str,
augment_context: Optional[list[AHOperation]] = None,
precondition: Optional[Callable[[AHMetadata, AHContext], bool]] = None,
) -> None:
"""
Initializes an instance of the AutoHeuristic class.
Args:
fallback: A callable that returns a Choice when the heuristic is unsure which choice to make, or
AutoHeuristic is in data collection mode.
choices: A list of possible choices the heuristic can make.
feedback: An instance of LocalFeedback that provides feedback for a given choice.
context: Context to store with each choice and feedback.
name: A string that identifies the heuristic.
augment_context: An optional list of AHOperation instances that augment the context.
precondition: A callable that returns a boolean indicating whether AutoHeuristic should run.
"""
self.fallback = fallback
self.choices = choices
self.feedback = feedback
self.context = context
self.name = name
self.collected_feedback = {}
self.augment_context = augment_context
self.metadata = AHMetadata(
get_gpu_shared_memory(),
torch.cuda.get_device_capability(),
self.choices,
self.name,
)
self.precondition = precondition
if not self.satisfies_precondition():
return
if torch._inductor.config.autoheuristic_log_path == "DEFAULT":
self.log_path = self.get_default_log_path()
else:
self.log_path = torch._inductor.config.autoheuristic_log_path
if torch._inductor.config.collect_autoheuristic(self.name):
if self.feedback is not None:
for choice in self.choices:
feedback_val = self.feedback(choice)
self.save_data(choice, feedback_val)
def satisfies_precondition(self) -> bool:
return self.precondition is None or self.precondition(
self.metadata, self.context
)
def get_choice(self) -> Choice:
"""
Returns the chosen option based on the value of autoheuristic_use.
If self.name is one of the comma separated strings in autoheuristic_use,
it queries a learned heuristic to make a decision. Otherwise, it returns the fallback option.
"""
if not self.satisfies_precondition():
return self.fallback()
if torch._inductor.config.use_autoheuristic(self.name):
if self.augment_context is not None:
self.context.apply_operations(self.augment_context)
controller = LearnedHeuristicController(
self.metadata,
self.context,
)
decision = controller.get_decision()
if decision not in self.choices:
# TODO(AlnisM): We might want to allow this in the future
return self.fallback()
if decision is not None:
return decision
return self.fallback()
def get_top_k_choices(
self, top_k: int, always_included: Optional[list[str]] = None
) -> Optional[list[Choice]]:
if not self.satisfies_precondition():
return None
if torch._inductor.config.use_autoheuristic(self.name):
if self.augment_context is not None:
self.context.apply_operations(self.augment_context)
controller = LearnedHeuristicController(
self.metadata,
self.context,
)
choices = controller.get_decisions_ranked(top_k)
if choices is None:
return None
if always_included is not None:
for choice in always_included:
if choice not in choices:
choices.append(choice)
return choices
return None
def get_collected_feedback(self, choice: Choice) -> Any:
return self.collected_feedback.get(choice, None)
@staticmethod
def get_device_identifier() -> str:
# a heuristic might work well for one GPU, but not for another
# we store the collected data per GPU model and learn a heuristic per GPU model
# TODO(AlnisM): just using the device name for now, but the same GPU model can have different names
device_name = torch.cuda.get_device_name().replace(" ", "_")
return device_name
def get_default_log_path(self) -> str:
device_name = self.get_device_identifier()
path = f"{cache_dir()}/autoheuristic/{device_name}/"
os.makedirs(path, exist_ok=True)
path += f"{self.name}.txt"
return path
def serialize_metadata(self) -> str:
metadata_dict = self.metadata.to_dict()
(
num_features,
cat_features,
) = self.context.get_numerical_and_categorical_features()
metadata_dict["numerical_features"] = num_features
metadata_dict["categorical_features"] = cat_features
return json.dumps(metadata_dict)
def save_data(self, choice: Choice, feedback_val: Feedback) -> None:
self.collected_feedback[choice] = feedback_val
log_path = self.log_path
lines = []
log_exists = os.path.exists(log_path)
if log_exists:
# if log already exists, make sure it is consistent
metadata = self.serialize_metadata()
existing_metadata = get_metadata_str_from_log(self.log_path)
if existing_metadata != metadata:
raise InconsistentMetadata(
"Given metadata does not match existing metadata"
)
else:
lines.append(self.serialize_metadata())
feature_header = self.context.get_feature_names_csv()
header = feature_header + "," + CHOICE_COL + "," + FEEDBACK_COL
lines.append(header)
line = ""
feature_values = self.context.get_feature_values_csv()
line += feature_values + "," + choice + "," + str(feedback_val)
lines.append(line)
with open(log_path, "a") as f:
f.write("\n".join(lines) + "\n")
class AutoHeuristicSelectAlgorithm(AutoHeuristic):
"""
AutoHeuristicSelectAlgorithm is a subclass of AutoHeuristic that allows one to collect data and learn a heuristic
when one wants to use AutoHeuristic for kernel choice selection.
"""
def __init__(
self,
fallback: Callable[[], Optional[ChoiceCaller]],
choices: list[ChoiceCaller],
input_nodes: list[Any],
context: AHContext,
name: str,
augment_context: Optional[list[AHOperation]] = None,
precondition: Optional[Callable[[AHMetadata, AHContext], bool]] = None,
) -> None:
"""
The arguments choices, input_nodes and name have to match the ones used in the call to
autotune_select_algorithm(), e.g. if the following call is made
autotune_select_algorithm(name, choices, input_nodes, layout), the same name, choices and input_nodes
have to be used here.
"""
self.input_nodes = input_nodes
self.choicestr2choice: dict[str, ChoiceCaller] = {}
for choice in choices:
self.choicestr2choice[choice.autoheuristic_id()] = choice
choices_str = list(self.choicestr2choice.keys())
def fallback_str() -> str:
fallback_choice = fallback()
if fallback_choice is None:
# TODO: Find a nicer way to handle this
return "unsure"
return fallback_choice.autoheuristic_id()
super().__init__(
fallback_str,
choices_str,
None,
context,
name,
augment_context,
precondition,
)
if (
torch._inductor.config.collect_autoheuristic(self.name)
and self.satisfies_precondition()
):
self.register_global_feedback(input_nodes, choices)
def register_global_feedback(
self, input_nodes: list[Any], choices: list[ChoiceCaller]
) -> None:
"""
Registers a callback in select_algorithm, which is called with the timing of each choice.
"""
from torch._inductor.select_algorithm import (
add_feedback_saver,
create_inputs_key,
create_precompile_key,
)
def store_global_feedback(
ah_inputs_key: str,
ah_precompile_key: str,
timings: dict[ChoiceCaller, float],
name: str,
input_nodes: list[Any],
choices: list[ChoiceCaller],
) -> None:
current_inputs_key = create_inputs_key(input_nodes)
if current_inputs_key != ah_inputs_key:
return
current_precompile_key = create_precompile_key(
name, current_inputs_key, choices
)
if current_precompile_key != ah_precompile_key:
return
for choice, time in timings.items():
self.save_data(choice.autoheuristic_id(), time)
inputs_key = create_inputs_key(input_nodes)
precompile_key = create_precompile_key(self.name, inputs_key, choices)
feedback_saver = partial(store_global_feedback, inputs_key, precompile_key)
add_feedback_saver(feedback_saver)
def get_choice_caller(self) -> Optional[ChoiceCaller]:
choice = self.get_choice()
return self.choicestr2choice.get(choice, None)
def get_top_k_choices_caller(
self, top_k: int, always_included: Optional[list[str]] = None
) -> Optional[list[ChoiceCaller]]:
choices = self.get_top_k_choices(top_k, always_included)
if choices is None:
return None
return [self.choicestr2choice[choice] for choice in choices]