Files
pytorch/test/inductor/test_autoheuristic.py
Alnis Murtovi 7f1cda1533 Autoheuristic: Do not store choices as metadata (#130304)
While for optimizations like pad_mm, there are always only two possible choices, for other decision procedures, like kernel choice selection, the set of "available" choices depends on the input. Instead of storing the choices as metadata, we can instead take a look at all choices for which we have collected data (i.e. `df[CHOICE_COL].unique()`).

In this PR, I also try to replace "choice" and "feedback" with global constants CHOICE_COL and FEEDBACK_COL.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/130304
Approved by: https://github.com/eellison
2024-07-18 21:39:42 +00:00

137 lines
5.3 KiB
Python

# Owner(s): ["module: inductor"]
import os
import unittest
import torch
import torch._inductor.config as inductor_config
from torch._inductor.autoheuristic.autoheuristic import (
AHContext,
AutoHeuristic,
LocalFeedback,
)
from torch._inductor.runtime.runtime_utils import cache_dir
from torch._inductor.test_case import run_tests, TestCase
from torch._inductor.utils import get_gpu_shared_memory
from torch.testing._internal.inductor_utils import HAS_CUDA, IS_A100, IS_H100
class AutoHeuristicTest(TestCase):
def count_lines_in_file(self, file_path):
with open(file_path) as file:
line_count = sum(1 for line in file)
return line_count
def run_mm(self):
def f(a, b):
return torch.mm(a, b)
cf = torch.compile(f)
a = torch.randn(2047, 2048, device="cuda", dtype=torch.float16)
b = torch.randn(2048, 2048, device="cuda", dtype=torch.float16)
cf(a, b)
def get_path_to_autoheuristic_log(self, name):
device_name = AutoHeuristic.get_device_identifier()
path = cache_dir() + "/autoheuristic/" + device_name + "/" + name + ".txt"
return path
def test_autoheuristic_pad_mm_default(self):
# this test ensures that data is not collected for pad_mm when autoheuristic config is set to its default value
self.run_mm()
self.assertFalse(os.path.exists(self.get_path_to_autoheuristic_log("pad_mm")))
@inductor_config.patch(autoheuristic_collect="foo")
def test_autoheuristic_pad_mm_off(self):
# this test ensures that data is not collected for pad_mm when autoheuristic_collect does not contain "pad_mm"
self.run_mm()
self.assertFalse(os.path.exists(self.get_path_to_autoheuristic_log("pad_mm")))
def assert_autoheuristic_collected_data(self):
self.run_mm()
device_name = AutoHeuristic.get_device_identifier()
path = self.get_path_to_autoheuristic_log("pad_mm")
self.assertTrue(os.path.exists(path))
num_lines = self.count_lines_in_file(path)
# 1 line for metadata, 1 line for header, 1 line per choice (orig, padded)
self.assertEqual(num_lines, 4)
@inductor_config.patch(autoheuristic_collect="pad_mm")
def test_autoheuristic_pad_mm_collect_data(self):
# this test ensures that data is collected for pad_mm when autoheuristic_collect="pad_mm"
self.assert_autoheuristic_collected_data()
@inductor_config.patch(autoheuristic_collect="foo,pad_mm")
def test_autoheuristic_pad_mm_collect_data2(self):
# this test ensures that data is collected for "pad_mm" when autoheuristic_collect contains "pad_mm"
self.assert_autoheuristic_collected_data()
@inductor_config.patch(autoheuristic_collect="test")
def test_autoheuristic(self):
# test basic functionality of autoheuristic
def fallback():
return "fallback"
choices = ["a", "b", "c"]
def feedback_fn(choice):
if choice == "a":
return 1
elif choice == "b":
return 2
elif choice == "c":
return 3
else:
raise RuntimeError("unexpected choice")
feedback = LocalFeedback(feedback_fn)
context = AHContext()
context.add_feature("fa", 5)
name = "test"
autoheuristic = AutoHeuristic(fallback, choices, feedback, context, name)
# when autoheuristic is configured to only collect data, we always return fallback
self.assertEqual(autoheuristic.get_choice(), "fallback")
self.assertEqual(autoheuristic.get_collected_feedback("a"), 1)
self.assertEqual(autoheuristic.get_collected_feedback("b"), 2)
self.assertEqual(autoheuristic.get_collected_feedback("c"), 3)
path = self.get_path_to_autoheuristic_log(name)
self.assertTrue(os.path.exists(path))
num_lines = self.count_lines_in_file(path)
self.assertEqual(num_lines, 5)
shared_memory = get_gpu_shared_memory()
(fst, snd) = torch.cuda.get_device_capability()
with open(path) as file:
lines = file.readlines()
self.assertTrue('"numerical_features": ["fa"]' in lines[0])
self.assertTrue('"categorical_features": []' in lines[0])
self.assertTrue(f'"shared_memory": {shared_memory}' in lines[0])
self.assertTrue(f'"device_capa": [{fst}, {snd}]' in lines[0])
self.assertTrue('"name": "test"' in lines[0])
self.assertEqual("fa,choice,feedback", lines[1].rstrip())
self.assertEqual("5,a,1", lines[2].rstrip())
self.assertEqual("5,b,2", lines[3].rstrip())
self.assertEqual("5,c,3", lines[4].rstrip())
@unittest.skipIf(not IS_A100, "heuristic only run on A100")
@inductor_config.patch(autoheuristic_use="pad_mm")
def test_autoheuristic_a100(self):
# Make sure heuristic does not break anything
# TODO (AlnisM): Find a way to check whether heuristic is used
self.run_mm()
@unittest.skipIf(not IS_H100, "heuristic only run on H100")
@inductor_config.patch(autoheuristic_use="pad_mm")
def test_autoheuristic_h100(self):
# Make sure heuristic does not break anything
# TODO (AlnisM): Find a way to check whether heuristic is used
self.run_mm()
if __name__ == "__main__":
if HAS_CUDA:
run_tests()