mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Select Algorithm clear feedback savers (#161654)
Add `clear_feedback_savers` and tests for the feedback functionality. Pull Request resolved: https://github.com/pytorch/pytorch/pull/161654 Approved by: https://github.com/masnesral
This commit is contained in:
committed by
PyTorch MergeBot
parent
95516ad7e6
commit
d2d4a3c539
@ -32,7 +32,9 @@ from torch._inductor.ir import Buffer, ChoiceCaller, FixedLayout, InputBuffer
|
||||
from torch._inductor.kernel.mm_plus_mm import aten_mm_plus_mm
|
||||
from torch._inductor.kernel_inputs import MMKernelInputs
|
||||
from torch._inductor.select_algorithm import (
|
||||
add_feedback_saver,
|
||||
AlgorithmSelectorCache,
|
||||
clear_feedback_savers,
|
||||
TritonTemplate,
|
||||
TritonTemplateCaller,
|
||||
)
|
||||
@ -2250,6 +2252,118 @@ class TestTuningProcessPool(TestCase):
|
||||
|
||||
tuning_pool.shutdown()
|
||||
|
||||
def test_add_feedback_saver(self):
|
||||
"""Test that add_feedback_saver correctly adds feedback functions."""
|
||||
from torch._inductor.select_algorithm import get_algorithm_selector_cache
|
||||
|
||||
# Clear any existing feedback savers
|
||||
clear_feedback_savers()
|
||||
|
||||
# Create a simple feedback saver function
|
||||
feedback_calls = []
|
||||
|
||||
def simple_feedback_saver(timings, name, input_nodes, choices, profiled_time):
|
||||
feedback_calls.append(
|
||||
{
|
||||
"name": name,
|
||||
"num_choices": len(choices),
|
||||
"num_timings": len(timings),
|
||||
"has_profiled_time": profiled_time is not None,
|
||||
}
|
||||
)
|
||||
|
||||
# Add the feedback saver
|
||||
add_feedback_saver(simple_feedback_saver)
|
||||
|
||||
# Get the global cache and verify the function was added
|
||||
cache = get_algorithm_selector_cache()
|
||||
self.assertEqual(len(cache.feedback_saver_fns), 1)
|
||||
self.assertEqual(cache.feedback_saver_fns[0], simple_feedback_saver)
|
||||
|
||||
# Test that we can add multiple feedback savers
|
||||
def another_feedback_saver(timings, name, input_nodes, choices, profiled_time):
|
||||
pass
|
||||
|
||||
add_feedback_saver(another_feedback_saver)
|
||||
self.assertEqual(len(cache.feedback_saver_fns), 2)
|
||||
|
||||
# Clean up
|
||||
clear_feedback_savers()
|
||||
|
||||
def test_clear_feedback_savers(self):
|
||||
"""Test that clear_feedback_savers removes all feedback functions."""
|
||||
from torch._inductor.select_algorithm import get_algorithm_selector_cache
|
||||
|
||||
# Add some feedback savers first
|
||||
def feedback_saver1(timings, name, input_nodes, choices, profiled_time):
|
||||
pass
|
||||
|
||||
def feedback_saver2(timings, name, input_nodes, choices, profiled_time):
|
||||
pass
|
||||
|
||||
add_feedback_saver(feedback_saver1)
|
||||
add_feedback_saver(feedback_saver2)
|
||||
|
||||
# Verify they were added
|
||||
cache = get_algorithm_selector_cache()
|
||||
self.assertEqual(len(cache.feedback_saver_fns), 2)
|
||||
|
||||
# Clear all feedback savers
|
||||
clear_feedback_savers()
|
||||
|
||||
# Verify they were cleared
|
||||
self.assertEqual(len(cache.feedback_saver_fns), 0)
|
||||
|
||||
def test_feedback_saver_integration(self):
|
||||
"""Test that feedback savers are actually called during autotuning."""
|
||||
# Clear any existing feedback savers
|
||||
clear_feedback_savers()
|
||||
|
||||
feedback_calls = []
|
||||
|
||||
def test_feedback_saver(timings, name, input_nodes, choices, profiled_time):
|
||||
# Store information about the call for verification
|
||||
feedback_calls.append(
|
||||
{
|
||||
"name": name,
|
||||
"num_choices": len(choices),
|
||||
"num_timings": len(timings),
|
||||
"input_node_count": len(input_nodes),
|
||||
}
|
||||
)
|
||||
|
||||
# Add our test feedback saver
|
||||
add_feedback_saver(test_feedback_saver)
|
||||
|
||||
# Create a simple matrix multiplication that will trigger autotuning
|
||||
def mm(a, b):
|
||||
return a @ b
|
||||
|
||||
a = torch.randn(32, 32, device=GPU_TYPE)
|
||||
b = torch.randn(32, 32, device=GPU_TYPE)
|
||||
|
||||
with config.patch(
|
||||
{"max_autotune": True, "max_autotune_gemm_backends": "TRITON"}
|
||||
):
|
||||
torch.compile(mm)(a, b)
|
||||
|
||||
# Verify that our feedback saver was called
|
||||
self.assertGreater(
|
||||
len(feedback_calls), 0, "Feedback saver should have been called"
|
||||
)
|
||||
|
||||
# Verify the structure of the feedback call
|
||||
call = feedback_calls[0]
|
||||
self.assertIn("name", call)
|
||||
self.assertIn("num_choices", call)
|
||||
self.assertIn("num_timings", call)
|
||||
self.assertIn("input_node_count", call)
|
||||
self.assertGreater(call["num_choices"], 0)
|
||||
self.assertEqual(call["input_node_count"], 2) # Two input matrices
|
||||
|
||||
# Clean up
|
||||
clear_feedback_savers()
|
||||
|
||||
|
||||
@instantiate_parametrized_tests
|
||||
class TestPrologueFusion(TestCase):
|
||||
|
@ -3358,6 +3358,9 @@ class AlgorithmSelectorCache(PersistentCache):
|
||||
def add_feedback_saver(self, fn: FeedbackFunction):
|
||||
self.feedback_saver_fns.append(fn)
|
||||
|
||||
def clear_feedback_savers(self):
|
||||
self.feedback_saver_fns = []
|
||||
|
||||
def add_preprocessing_fn(self, fn: PreprocessingFunction):
|
||||
self.preprocessing_fns.append(fn)
|
||||
|
||||
@ -3405,6 +3408,12 @@ def add_feedback_saver(
|
||||
cache.add_feedback_saver(fn)
|
||||
|
||||
|
||||
def clear_feedback_savers():
|
||||
"""Clear all feedback saver functions."""
|
||||
cache = get_algorithm_selector_cache()
|
||||
cache.clear_feedback_savers()
|
||||
|
||||
|
||||
def add_preprocessing_fn(
|
||||
fn: PreprocessingFunction,
|
||||
):
|
||||
|
Reference in New Issue
Block a user