Select Algorithm clear feedback savers (#161654)

Add `clear_feedback_savers` and tests for the feedback functionality.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/161654
Approved by: https://github.com/masnesral
This commit is contained in:
Gabriel Ferns
2025-08-28 06:56:03 +00:00
committed by PyTorch MergeBot
parent 95516ad7e6
commit d2d4a3c539
2 changed files with 123 additions and 0 deletions

View File

@ -32,7 +32,9 @@ from torch._inductor.ir import Buffer, ChoiceCaller, FixedLayout, InputBuffer
from torch._inductor.kernel.mm_plus_mm import aten_mm_plus_mm
from torch._inductor.kernel_inputs import MMKernelInputs
from torch._inductor.select_algorithm import (
add_feedback_saver,
AlgorithmSelectorCache,
clear_feedback_savers,
TritonTemplate,
TritonTemplateCaller,
)
@ -2250,6 +2252,118 @@ class TestTuningProcessPool(TestCase):
tuning_pool.shutdown()
def test_add_feedback_saver(self):
"""Test that add_feedback_saver correctly adds feedback functions."""
from torch._inductor.select_algorithm import get_algorithm_selector_cache
# Clear any existing feedback savers
clear_feedback_savers()
# Create a simple feedback saver function
feedback_calls = []
def simple_feedback_saver(timings, name, input_nodes, choices, profiled_time):
feedback_calls.append(
{
"name": name,
"num_choices": len(choices),
"num_timings": len(timings),
"has_profiled_time": profiled_time is not None,
}
)
# Add the feedback saver
add_feedback_saver(simple_feedback_saver)
# Get the global cache and verify the function was added
cache = get_algorithm_selector_cache()
self.assertEqual(len(cache.feedback_saver_fns), 1)
self.assertEqual(cache.feedback_saver_fns[0], simple_feedback_saver)
# Test that we can add multiple feedback savers
def another_feedback_saver(timings, name, input_nodes, choices, profiled_time):
pass
add_feedback_saver(another_feedback_saver)
self.assertEqual(len(cache.feedback_saver_fns), 2)
# Clean up
clear_feedback_savers()
def test_clear_feedback_savers(self):
"""Test that clear_feedback_savers removes all feedback functions."""
from torch._inductor.select_algorithm import get_algorithm_selector_cache
# Add some feedback savers first
def feedback_saver1(timings, name, input_nodes, choices, profiled_time):
pass
def feedback_saver2(timings, name, input_nodes, choices, profiled_time):
pass
add_feedback_saver(feedback_saver1)
add_feedback_saver(feedback_saver2)
# Verify they were added
cache = get_algorithm_selector_cache()
self.assertEqual(len(cache.feedback_saver_fns), 2)
# Clear all feedback savers
clear_feedback_savers()
# Verify they were cleared
self.assertEqual(len(cache.feedback_saver_fns), 0)
def test_feedback_saver_integration(self):
"""Test that feedback savers are actually called during autotuning."""
# Clear any existing feedback savers
clear_feedback_savers()
feedback_calls = []
def test_feedback_saver(timings, name, input_nodes, choices, profiled_time):
# Store information about the call for verification
feedback_calls.append(
{
"name": name,
"num_choices": len(choices),
"num_timings": len(timings),
"input_node_count": len(input_nodes),
}
)
# Add our test feedback saver
add_feedback_saver(test_feedback_saver)
# Create a simple matrix multiplication that will trigger autotuning
def mm(a, b):
return a @ b
a = torch.randn(32, 32, device=GPU_TYPE)
b = torch.randn(32, 32, device=GPU_TYPE)
with config.patch(
{"max_autotune": True, "max_autotune_gemm_backends": "TRITON"}
):
torch.compile(mm)(a, b)
# Verify that our feedback saver was called
self.assertGreater(
len(feedback_calls), 0, "Feedback saver should have been called"
)
# Verify the structure of the feedback call
call = feedback_calls[0]
self.assertIn("name", call)
self.assertIn("num_choices", call)
self.assertIn("num_timings", call)
self.assertIn("input_node_count", call)
self.assertGreater(call["num_choices"], 0)
self.assertEqual(call["input_node_count"], 2) # Two input matrices
# Clean up
clear_feedback_savers()
@instantiate_parametrized_tests
class TestPrologueFusion(TestCase):

View File

@ -3358,6 +3358,9 @@ class AlgorithmSelectorCache(PersistentCache):
def add_feedback_saver(self, fn: FeedbackFunction):
self.feedback_saver_fns.append(fn)
def clear_feedback_savers(self):
self.feedback_saver_fns = []
def add_preprocessing_fn(self, fn: PreprocessingFunction):
self.preprocessing_fns.append(fn)
@ -3405,6 +3408,12 @@ def add_feedback_saver(
cache.add_feedback_saver(fn)
def clear_feedback_savers():
"""Clear all feedback saver functions."""
cache = get_algorithm_selector_cache()
cache.clear_feedback_savers()
def add_preprocessing_fn(
fn: PreprocessingFunction,
):