mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Autoheuristic: add config options for specifying optimizations to collect data for and use heuristics (#130245)
Previously, it was only possible to collect data or use a heuristic regardless of where autoheuristic is used. This PR makes it possible to collect data for some optimizations while using a learned heuristic for other optimizations. Pull Request resolved: https://github.com/pytorch/pytorch/pull/130245 Approved by: https://github.com/shunting314
This commit is contained in:
committed by
PyTorch MergeBot
parent
051971ab32
commit
d818c3319f
@ -38,19 +38,17 @@ class AutoHeuristicTest(TestCase):
|
||||
return path
|
||||
|
||||
def test_autoheuristic_pad_mm_default(self):
|
||||
# this test ensure that data is not collected for pad_mm when autoheuristic_mode is set to its default value ("OFF")
|
||||
# this test ensures that data is not collected for pad_mm when autoheuristic config is set to its default value
|
||||
self.run_mm()
|
||||
self.assertFalse(os.path.exists(self.get_path_to_autoheuristic_log("pad_mm")))
|
||||
|
||||
@inductor_config.patch(autoheuristic_mode="OFF")
|
||||
@inductor_config.patch(autoheuristic_collect="foo")
|
||||
def test_autoheuristic_pad_mm_off(self):
|
||||
# this test ensure that data is not collected for pad_mm when autoheuristic_mode="OFF"
|
||||
# this test ensures that data is not collected for pad_mm when autoheuristic_collect does not contain "pad_mm"
|
||||
self.run_mm()
|
||||
self.assertFalse(os.path.exists(self.get_path_to_autoheuristic_log("pad_mm")))
|
||||
|
||||
@inductor_config.patch(autoheuristic_mode="COLLECT_DATA")
|
||||
def test_autoheuristic_pad_mm_collect_data(self):
|
||||
# this test ensure that data is collected for pad_mm when autoheuristic_mode="COLLECT_DATA"
|
||||
def assert_autoheuristic_collected_data(self):
|
||||
self.run_mm()
|
||||
device_name = AutoHeuristic.get_device_identifier()
|
||||
path = self.get_path_to_autoheuristic_log("pad_mm")
|
||||
@ -60,7 +58,17 @@ class AutoHeuristicTest(TestCase):
|
||||
# 1 line for metadata, 1 line for header, 1 line per choice (orig, padded)
|
||||
self.assertEqual(num_lines, 4)
|
||||
|
||||
@inductor_config.patch(autoheuristic_mode="COLLECT_DATA")
|
||||
@inductor_config.patch(autoheuristic_collect="pad_mm")
|
||||
def test_autoheuristic_pad_mm_collect_data(self):
|
||||
# this test ensures that data is collected for pad_mm when autoheuristic_collect="pad_mm"
|
||||
self.assert_autoheuristic_collected_data()
|
||||
|
||||
@inductor_config.patch(autoheuristic_collect="foo,pad_mm")
|
||||
def test_autoheuristic_pad_mm_collect_data2(self):
|
||||
# this test ensures that data is collected for "pad_mm" when autoheuristic_collect contains "pad_mm"
|
||||
self.assert_autoheuristic_collected_data()
|
||||
|
||||
@inductor_config.patch(autoheuristic_collect="test")
|
||||
def test_autoheuristic(self):
|
||||
# test basic functionality of autoheuristic
|
||||
def fallback():
|
||||
@ -84,7 +92,7 @@ class AutoHeuristicTest(TestCase):
|
||||
name = "test"
|
||||
autoheuristic = AutoHeuristic(fallback, choices, feedback, context, name)
|
||||
|
||||
# when autoheuristic_mode is COLLECT_DATA, we always return fallback
|
||||
# when autoheuristic is configured to only collect data, we always return fallback
|
||||
self.assertEqual(autoheuristic.get_choice(), "fallback")
|
||||
self.assertEqual(autoheuristic.get_collected_feedback("a"), 1)
|
||||
self.assertEqual(autoheuristic.get_collected_feedback("b"), 2)
|
||||
@ -112,14 +120,14 @@ class AutoHeuristicTest(TestCase):
|
||||
self.assertEqual("5,c,3", lines[4].rstrip())
|
||||
|
||||
@unittest.skipIf(not IS_A100, "heuristic only run on A100")
|
||||
@inductor_config.patch(autoheuristic_mode="USE_HEURISTIC")
|
||||
@inductor_config.patch(autoheuristic_use="pad_mm")
|
||||
def test_autoheuristic_a100(self):
|
||||
# Make sure heuristic does not break anything
|
||||
# TODO (AlnisM): Find a way to check whether heuristic is used
|
||||
self.run_mm()
|
||||
|
||||
@unittest.skipIf(not IS_H100, "heuristic only run on H100")
|
||||
@inductor_config.patch(autoheuristic_mode="USE_HEURISTIC")
|
||||
@inductor_config.patch(autoheuristic_use="pad_mm")
|
||||
def test_autoheuristic_h100(self):
|
||||
# Make sure heuristic does not break anything
|
||||
# TODO (AlnisM): Find a way to check whether heuristic is used
|
||||
|
@ -139,9 +139,7 @@ class AutoHeuristic:
|
||||
else:
|
||||
self.log_path = torch._inductor.config.autoheuristic_log_path
|
||||
|
||||
# TODO(AlnisM): Allow something like AUTOHEURISTIC_MODE="collect:pad_mm,foo,bar"
|
||||
# to be able to collect data only for specific heuristics
|
||||
if torch._inductor.config.autoheuristic_mode == "COLLECT_DATA" and isinstance(
|
||||
if torch._inductor.config.collect_autoheuristic(self.name) and isinstance(
|
||||
self.feedback, LocalFeedback
|
||||
):
|
||||
for choice in self.choices:
|
||||
@ -155,19 +153,15 @@ class AutoHeuristic:
|
||||
|
||||
def get_choice(self) -> Choice:
|
||||
"""
|
||||
Returns the chosen option based on the autoheuristic mode.
|
||||
|
||||
If the mode is "USE_HEURISTIC", it queries a learned heuristic to make a decision.
|
||||
If the mode is not "USE_HEURISTIC", it falls back to the self.fallback() method.
|
||||
|
||||
Returns:
|
||||
Choice: The chosen option.
|
||||
Returns the chosen option based on the value of autoheuristic_use.
|
||||
If self.name is one of the comma separated strings in autoheuristic_use,
|
||||
it queries a learned heuristic to make a decision. Otherwise, it returns the fallback option.
|
||||
"""
|
||||
|
||||
if not self.satisfies_precondition():
|
||||
return self.fallback()
|
||||
|
||||
if torch._inductor.config.autoheuristic_mode == "USE_HEURISTIC":
|
||||
if torch._inductor.config.use_autoheuristic(self.name):
|
||||
if self.augment_context is not None:
|
||||
self.context.apply_operations(self.augment_context)
|
||||
controller = LearnedHeuristicController(
|
||||
|
@ -324,11 +324,24 @@ coordinate_descent_search_radius = int(
|
||||
)
|
||||
|
||||
# AutoHeuristic is a framework that allows one to collect data from autotuning, use the data to learn a heuristic, and
|
||||
# generate the learned heursitic to code which is shipped with the compiler. For now, this is only enabled for pad_mm.
|
||||
# If set to "OFF", this will not run AutoHeuristic.
|
||||
# If set to "COLLECT_DATA", this will store data about the inputs and autotuning results.
|
||||
# If set to "USE_HEURISTIC", this will use the learned heuristic to make a choice in pad_mm.
|
||||
autoheuristic_mode = os.environ.get("TORCHINDUCTOR_AUTOHEURISTIC_MODE", "OFF")
|
||||
# generate the learned heursitic to code which is shipped with the compiler
|
||||
# Specify a list of comma separated optimizations to collect data for
|
||||
autoheuristic_collect = os.environ.get("TORCHINDUCTOR_AUTOHEURISTIC_COLLECT", "")
|
||||
# Specify a list of comma separated optimizations to use learned heuristics for
|
||||
autoheuristic_use = os.environ.get("TORCHINDUCTOR_AUTOHEURISTIC_USE", "")
|
||||
|
||||
|
||||
def run_autoheuristic(name):
|
||||
return collect_autoheuristic(name) or use_autoheuristic(name)
|
||||
|
||||
|
||||
def collect_autoheuristic(name):
|
||||
return name in torch._inductor.config.autoheuristic_collect.split(",")
|
||||
|
||||
|
||||
def use_autoheuristic(name):
|
||||
return name in torch._inductor.config.autoheuristic_use.split(",")
|
||||
|
||||
|
||||
# If set to "DEFAULT", this will use the default log path specified in autoheuristic.py.
|
||||
# If set to another path, autoheuristic will instead log results to the given path.
|
||||
|
@ -514,7 +514,7 @@ def should_pad_bench(
|
||||
fn()
|
||||
|
||||
if (
|
||||
torch._inductor.config.autoheuristic_mode != "OFF"
|
||||
torch._inductor.config.run_autoheuristic("pad_mm")
|
||||
and op is torch.ops.aten.mm
|
||||
):
|
||||
ah_should_pad = run_autoheuristic(
|
||||
@ -638,7 +638,7 @@ def run_autoheuristic(
|
||||
choice2should_pad = {orig_choice: False, pad_choice: True, "autotune": None}
|
||||
ah_should_pad = choice2should_pad.get(choice, None)
|
||||
|
||||
if torch._inductor.config.autoheuristic_mode == "COLLECT_DATA":
|
||||
if torch._inductor.config.collect_autoheuristic(name):
|
||||
ah_ori_time = autoheuristic.get_collected_feedback(orig_choice)
|
||||
ah_pad_time = autoheuristic.get_collected_feedback(pad_choice)
|
||||
|
||||
|
@ -187,10 +187,9 @@ if __name__ == "__main__":
|
||||
help="torch.cuda.set_device(device) will be used",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--autoheuristic-mode",
|
||||
type=str,
|
||||
default="COLLECT_DATA",
|
||||
help="COLLECT_DATA to collect Data. USE_HEURISTIC to test heuristic.",
|
||||
"--use-heuristic",
|
||||
action="store_true",
|
||||
help="Use learned heuristic instead of collecting data.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-o",
|
||||
@ -206,7 +205,10 @@ if __name__ == "__main__":
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
torch._inductor.config.autoheuristic_mode = args.autoheuristic_mode
|
||||
if args.use_heuristic:
|
||||
torch._inductor.config.autoheuristic_use = "pad_mm"
|
||||
else:
|
||||
torch._inductor.config.autoheuristic_collect = "pad_mm"
|
||||
torch._inductor.config.autoheuristic_log_path = args.o
|
||||
if args.device is not None:
|
||||
torch.cuda.set_device(args.device)
|
||||
|
Reference in New Issue
Block a user