Autoheuristic: add config options for specifying optimizations to collect data for and use heuristics (#130245)

Previously, it was only possible to collect data or use a heuristic regardless of where autoheuristic is used. This PR makes it possible to collect data for some optimizations while using a learned heuristic for other optimizations.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/130245
Approved by: https://github.com/shunting314
This commit is contained in:
Alnis Murtovi
2024-07-18 01:04:34 +00:00
committed by PyTorch MergeBot
parent 051971ab32
commit d818c3319f
5 changed files with 50 additions and 33 deletions

View File

@ -38,19 +38,17 @@ class AutoHeuristicTest(TestCase):
return path
def test_autoheuristic_pad_mm_default(self):
# this test ensure that data is not collected for pad_mm when autoheuristic_mode is set to its default value ("OFF")
# this test ensures that data is not collected for pad_mm when autoheuristic config is set to its default value
self.run_mm()
self.assertFalse(os.path.exists(self.get_path_to_autoheuristic_log("pad_mm")))
@inductor_config.patch(autoheuristic_mode="OFF")
@inductor_config.patch(autoheuristic_collect="foo")
def test_autoheuristic_pad_mm_off(self):
# this test ensure that data is not collected for pad_mm when autoheuristic_mode="OFF"
# this test ensures that data is not collected for pad_mm when autoheuristic_collect does not contain "pad_mm"
self.run_mm()
self.assertFalse(os.path.exists(self.get_path_to_autoheuristic_log("pad_mm")))
@inductor_config.patch(autoheuristic_mode="COLLECT_DATA")
def test_autoheuristic_pad_mm_collect_data(self):
# this test ensure that data is collected for pad_mm when autoheuristic_mode="COLLECT_DATA"
def assert_autoheuristic_collected_data(self):
self.run_mm()
device_name = AutoHeuristic.get_device_identifier()
path = self.get_path_to_autoheuristic_log("pad_mm")
@ -60,7 +58,17 @@ class AutoHeuristicTest(TestCase):
# 1 line for metadata, 1 line for header, 1 line per choice (orig, padded)
self.assertEqual(num_lines, 4)
@inductor_config.patch(autoheuristic_mode="COLLECT_DATA")
@inductor_config.patch(autoheuristic_collect="pad_mm")
def test_autoheuristic_pad_mm_collect_data(self):
# this test ensures that data is collected for pad_mm when autoheuristic_collect="pad_mm"
self.assert_autoheuristic_collected_data()
@inductor_config.patch(autoheuristic_collect="foo,pad_mm")
def test_autoheuristic_pad_mm_collect_data2(self):
# this test ensures that data is collected for "pad_mm" when autoheuristic_collect contains "pad_mm"
self.assert_autoheuristic_collected_data()
@inductor_config.patch(autoheuristic_collect="test")
def test_autoheuristic(self):
# test basic functionality of autoheuristic
def fallback():
@ -84,7 +92,7 @@ class AutoHeuristicTest(TestCase):
name = "test"
autoheuristic = AutoHeuristic(fallback, choices, feedback, context, name)
# when autoheuristic_mode is COLLECT_DATA, we always return fallback
# when autoheuristic is configured to only collect data, we always return fallback
self.assertEqual(autoheuristic.get_choice(), "fallback")
self.assertEqual(autoheuristic.get_collected_feedback("a"), 1)
self.assertEqual(autoheuristic.get_collected_feedback("b"), 2)
@ -112,14 +120,14 @@ class AutoHeuristicTest(TestCase):
self.assertEqual("5,c,3", lines[4].rstrip())
@unittest.skipIf(not IS_A100, "heuristic only run on A100")
@inductor_config.patch(autoheuristic_mode="USE_HEURISTIC")
@inductor_config.patch(autoheuristic_use="pad_mm")
def test_autoheuristic_a100(self):
# Make sure heuristic does not break anything
# TODO (AlnisM): Find a way to check whether heuristic is used
self.run_mm()
@unittest.skipIf(not IS_H100, "heuristic only run on H100")
@inductor_config.patch(autoheuristic_mode="USE_HEURISTIC")
@inductor_config.patch(autoheuristic_use="pad_mm")
def test_autoheuristic_h100(self):
# Make sure heuristic does not break anything
# TODO (AlnisM): Find a way to check whether heuristic is used

View File

@ -139,9 +139,7 @@ class AutoHeuristic:
else:
self.log_path = torch._inductor.config.autoheuristic_log_path
# TODO(AlnisM): Allow something like AUTOHEURISTIC_MODE="collect:pad_mm,foo,bar"
# to be able to collect data only for specific heuristics
if torch._inductor.config.autoheuristic_mode == "COLLECT_DATA" and isinstance(
if torch._inductor.config.collect_autoheuristic(self.name) and isinstance(
self.feedback, LocalFeedback
):
for choice in self.choices:
@ -155,19 +153,15 @@ class AutoHeuristic:
def get_choice(self) -> Choice:
"""
Returns the chosen option based on the autoheuristic mode.
If the mode is "USE_HEURISTIC", it queries a learned heuristic to make a decision.
If the mode is not "USE_HEURISTIC", it falls back to the self.fallback() method.
Returns:
Choice: The chosen option.
Returns the chosen option based on the value of autoheuristic_use.
If self.name is one of the comma separated strings in autoheuristic_use,
it queries a learned heuristic to make a decision. Otherwise, it returns the fallback option.
"""
if not self.satisfies_precondition():
return self.fallback()
if torch._inductor.config.autoheuristic_mode == "USE_HEURISTIC":
if torch._inductor.config.use_autoheuristic(self.name):
if self.augment_context is not None:
self.context.apply_operations(self.augment_context)
controller = LearnedHeuristicController(

View File

@ -324,11 +324,24 @@ coordinate_descent_search_radius = int(
)
# AutoHeuristic is a framework that allows one to collect data from autotuning, use the data to learn a heuristic, and
# generate the learned heursitic to code which is shipped with the compiler. For now, this is only enabled for pad_mm.
# If set to "OFF", this will not run AutoHeuristic.
# If set to "COLLECT_DATA", this will store data about the inputs and autotuning results.
# If set to "USE_HEURISTIC", this will use the learned heuristic to make a choice in pad_mm.
autoheuristic_mode = os.environ.get("TORCHINDUCTOR_AUTOHEURISTIC_MODE", "OFF")
# generate the learned heursitic to code which is shipped with the compiler
# Specify a list of comma separated optimizations to collect data for
autoheuristic_collect = os.environ.get("TORCHINDUCTOR_AUTOHEURISTIC_COLLECT", "")
# Specify a list of comma separated optimizations to use learned heuristics for
autoheuristic_use = os.environ.get("TORCHINDUCTOR_AUTOHEURISTIC_USE", "")
def run_autoheuristic(name):
return collect_autoheuristic(name) or use_autoheuristic(name)
def collect_autoheuristic(name):
return name in torch._inductor.config.autoheuristic_collect.split(",")
def use_autoheuristic(name):
return name in torch._inductor.config.autoheuristic_use.split(",")
# If set to "DEFAULT", this will use the default log path specified in autoheuristic.py.
# If set to another path, autoheuristic will instead log results to the given path.

View File

@ -514,7 +514,7 @@ def should_pad_bench(
fn()
if (
torch._inductor.config.autoheuristic_mode != "OFF"
torch._inductor.config.run_autoheuristic("pad_mm")
and op is torch.ops.aten.mm
):
ah_should_pad = run_autoheuristic(
@ -638,7 +638,7 @@ def run_autoheuristic(
choice2should_pad = {orig_choice: False, pad_choice: True, "autotune": None}
ah_should_pad = choice2should_pad.get(choice, None)
if torch._inductor.config.autoheuristic_mode == "COLLECT_DATA":
if torch._inductor.config.collect_autoheuristic(name):
ah_ori_time = autoheuristic.get_collected_feedback(orig_choice)
ah_pad_time = autoheuristic.get_collected_feedback(pad_choice)

View File

@ -187,10 +187,9 @@ if __name__ == "__main__":
help="torch.cuda.set_device(device) will be used",
)
parser.add_argument(
"--autoheuristic-mode",
type=str,
default="COLLECT_DATA",
help="COLLECT_DATA to collect Data. USE_HEURISTIC to test heuristic.",
"--use-heuristic",
action="store_true",
help="Use learned heuristic instead of collecting data.",
)
parser.add_argument(
"-o",
@ -206,7 +205,10 @@ if __name__ == "__main__":
)
args = parser.parse_args()
torch._inductor.config.autoheuristic_mode = args.autoheuristic_mode
if args.use_heuristic:
torch._inductor.config.autoheuristic_use = "pad_mm"
else:
torch._inductor.config.autoheuristic_collect = "pad_mm"
torch._inductor.config.autoheuristic_log_path = args.o
if args.device is not None:
torch.cuda.set_device(args.device)