[inductor][heuristics registry] missing heuristic is not an error anymore, cross device heuristics (#161767)

# why

- not having a heuristic is an error but should not crash, just provide 0 configs
- some heuristics are cross device type
- cleaner to be explicit about being cross device type than having to
  enumerate every possible device type

# what

- on registration, supply device_type=None (explicitly) to say this
  heuristic is cross device
- test to guard the heuristics hierarchies

# testing

```
python3 -bb -m pytest test/inductor/test_template_heuristics_registry.py
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/161767
Approved by: https://github.com/PaulZhang12
This commit is contained in:
Ruben Rodriguez Buchillon
2025-08-29 12:20:16 -07:00
committed by PyTorch MergeBot
parent 037f3bd475
commit 45eccf414f
3 changed files with 232 additions and 21 deletions

View File

@ -0,0 +1,171 @@
# Owner(s): ["module: inductor"]
from torch._inductor.template_heuristics.base import TemplateConfigHeuristics
from torch._inductor.template_heuristics.registry import (
_TEMPLATE_HEURISTIC_REGISTRY,
clear_registry,
get_template_heuristic,
register_template_heuristic,
)
from torch._inductor.test_case import run_tests, TestCase
class TestTemplateHeuristicsRegistry(TestCase):
def setUp(self):
super().setUp()
# Save original registry state
self.original_registry = _TEMPLATE_HEURISTIC_REGISTRY.copy()
clear_registry() # Test heuristic classes using the decorator registration
def tearDown(self):
# Restore original registry
clear_registry()
_TEMPLATE_HEURISTIC_REGISTRY.update(self.original_registry)
super().tearDown()
def test_register_class(self):
"""Test basic registration of a heuristic class."""
# Clear registry for this isolated test
clear_registry()
@register_template_heuristic("test_mm", "cuda")
class TestHeuristic(TemplateConfigHeuristics):
pass
# Verify registration
key = ("test_mm", "cuda", None)
self.assertIn(key, _TEMPLATE_HEURISTIC_REGISTRY)
self.assertEqual(_TEMPLATE_HEURISTIC_REGISTRY[key], TestHeuristic)
def test_assertion_existing_class(self):
@register_template_heuristic("triton::mm", "cuda")
class _CrossOpHeuristic(TemplateConfigHeuristics):
"""(template, device, None) - Cross-op for specific device"""
"""Test that registered class can be retrieved."""
# The _CrossOpHeuristic is registered at module level for ("mm", "cuda", None)
# Test retrieval - it should match for any op on cuda device
heuristic = get_template_heuristic("triton::mm", "cuda", "bmm")
self.assertIsInstance(heuristic, _CrossOpHeuristic)
def test_hierarchy_lookup(self):
"""Test complete hierarchy: (template, device, op) -> (template, None, None)"""
@register_template_heuristic("triton::mm", "cuda", op_name="scaled_mm")
class _MostSpecificHeuristic(TemplateConfigHeuristics):
"""(template, device, op) - Most specific"""
@register_template_heuristic("triton::mm", None, op_name="scaled_mm")
class _CrossDeviceHeuristic(TemplateConfigHeuristics):
"""(template, None, op) - Cross-device for specific op"""
@register_template_heuristic("triton::mm", "cuda")
class _CrossOpHeuristic(TemplateConfigHeuristics):
"""(template, device, None) - Cross-op for specific device"""
@register_template_heuristic("triton::mm", None)
class _MostGeneralHeuristic(TemplateConfigHeuristics):
"""(template, None, None) - Most general"""
# All classes are already registered via decorators:
# _MostSpecificHeuristic: ("mm", "cuda", "scaled_mm") - Most specific
# _CrossDeviceHeuristic: ("mm", None, "scaled_mm") - Cross-device for specific op
# _CrossOpHeuristic: ("mm", "cuda", None) - Cross-op for specific device
# _MostGeneralHeuristic: ("mm", None, None) - Most general
# Test 1: Exact match - should get most specific
heuristic = get_template_heuristic("triton::mm", "cuda", "scaled_mm")
self.assertIsInstance(heuristic, _MostSpecificHeuristic)
# Test 2: Different device, same op - should get cross-device
heuristic = get_template_heuristic("triton::mm", "xpu", "scaled_mm")
self.assertIsInstance(heuristic, _CrossDeviceHeuristic)
# Test 3: Same device, different op - should get cross-op
heuristic = get_template_heuristic("triton::mm", "cuda", "bmm")
self.assertIsInstance(heuristic, _CrossOpHeuristic)
# Test 4: Different device and op - should get most general
heuristic = get_template_heuristic("triton::mm", "xpu", "bmm")
self.assertIsInstance(heuristic, _MostGeneralHeuristic)
def test_partial_hierarchy_scenarios(self):
"""Test hierarchy behavior with partial registrations"""
# Scenario 1: Register partial hierarchy using decorators
@register_template_heuristic("triton::tma", None, op_name="scaled_tma")
class _TestCrossDeviceHeuristic(TemplateConfigHeuristics):
pass
@register_template_heuristic("triton::tma", None)
class _TestGeneralHeuristic(TemplateConfigHeuristics):
pass
# Should get cross-device for matching op, regardless of device
heuristic = get_template_heuristic("triton::tma", "cuda", "scaled_tma")
self.assertIsInstance(heuristic, _TestCrossDeviceHeuristic)
# Should fallback to general for different op
heuristic = get_template_heuristic("triton::tma", "cuda", "scaled_mm")
self.assertIsInstance(heuristic, _TestGeneralHeuristic)
# Scenario 2: Only specific device exists
@register_template_heuristic("triton::bmm", "cuda")
class _TestDeviceSpecificHeuristic(TemplateConfigHeuristics):
pass
# Should get device-specific for cuda
heuristic = get_template_heuristic("triton::bmm", "cuda", "any_op")
self.assertIsInstance(heuristic, _TestDeviceSpecificHeuristic)
# Should return fallback instance for other devices (no specific heuristic registered)
heuristic = get_template_heuristic("triton::bmm", "xpu", "any_op")
self.assertIsInstance(heuristic, TemplateConfigHeuristics)
# Make sure it's not the registered specific heuristic
self.assertNotIsInstance(heuristic, _TestDeviceSpecificHeuristic)
# Scenario 3: Only most general exists
@register_template_heuristic("triton::mm", None)
class _TestMostGeneralHeuristic(TemplateConfigHeuristics):
pass
# Should always get general regardless of device/op
heuristic = get_template_heuristic("triton::mm", "cuda", "scaled_addmm")
self.assertIsInstance(heuristic, _TestMostGeneralHeuristic)
heuristic = get_template_heuristic("triton::mm", "xpu", "regular_addmm")
self.assertIsInstance(heuristic, _TestMostGeneralHeuristic)
def test_fallback_behavior(self):
"""Test that fallback TemplateConfigHeuristics is returned when no heuristic is found"""
# Test 1: Get fallback for unregistered template
heuristic = get_template_heuristic("unknown_template", "cuda", "unknown_op")
self.assertIsInstance(heuristic, TemplateConfigHeuristics)
# Make sure it's the base class and not a subclass
self.assertEqual(type(heuristic), TemplateConfigHeuristics)
# Test 2: Verify fallback instances are NOT cached (new instance each time)
heuristic2 = get_template_heuristic("unknown_template", "cuda", "unknown_op")
self.assertIsInstance(heuristic2, TemplateConfigHeuristics)
self.assertEqual(type(heuristic2), TemplateConfigHeuristics)
# Should be different instances (not cached)
self.assertIsNot(heuristic, heuristic2)
# Test 3: After registering a heuristic, should get the registered one instead
@register_template_heuristic("unknown_template", "cuda", op_name="unknown_op")
class _NewlyRegisteredHeuristic(TemplateConfigHeuristics):
pass
# Now should get the registered heuristic, not the fallback
heuristic3 = get_template_heuristic("unknown_template", "cuda", "unknown_op")
self.assertIsInstance(heuristic3, _NewlyRegisteredHeuristic)
self.assertNotEqual(type(heuristic3), TemplateConfigHeuristics)
# Test 4: Verify registered instances ARE cached (same instance each time)
heuristic4 = get_template_heuristic("unknown_template", "cuda", "unknown_op")
self.assertIsInstance(heuristic4, _NewlyRegisteredHeuristic)
self.assertIs(heuristic3, heuristic4) # Should be same cached instance
if __name__ == "__main__":
run_tests()

View File

@ -20,9 +20,15 @@ if TYPE_CHECKING:
from ..ir import Layout
# on CUDA, we don't support hip for decompose_k yet
@register_template_heuristic(
"decompose_k", "cuda", register=torch.version.hip is None, op_name="mm"
)
# TODO(coconutruben): enable decompose k on AMD by removing the register bool
# and benchmarking it for performance and stability
# TODO(coconutruben): enable decompose k on other devices (xpu, cpu, mps, mtia)
# by either adding specific register_template_heuristic tags, or setting the
# device to None (enabled on all devices)
class DecomposeKConfigHeuristics(TemplateConfigHeuristics):
def get_template_configs(
self,

View File

@ -10,8 +10,7 @@ from __future__ import annotations
import contextlib
import logging
from functools import cache
from typing import Any, Optional, TYPE_CHECKING
from typing import Any, Optional, TYPE_CHECKING, Union
from .base import TemplateConfigHeuristics
@ -21,14 +20,19 @@ if TYPE_CHECKING:
# Module-wide registry for template heuristics
_TEMPLATE_HEURISTIC_REGISTRY: dict[tuple[str, ...], type[TemplateConfigHeuristics]] = {}
_TEMPLATE_HEURISTIC_REGISTRY: dict[
tuple[Union[str, None], ...], type[TemplateConfigHeuristics]
] = {}
# Manual cache for successful lookups only (fallback instances are not cached)
_HEURISTIC_CACHE: dict[tuple[str, str, str], TemplateConfigHeuristics] = {}
log = logging.getLogger(__name__)
def register_template_heuristic(
template_name: str,
device_type: str,
device_type: Union[str, None],
register: bool = True,
op_name: Optional[str] = None,
) -> Any:
@ -38,6 +42,7 @@ def register_template_heuristic(
Args:
template_name: Name of the template (e.g., "mm", "bmm", "scaled_mm")
device_type: Device type ("cuda", "cpu", "xpu")
Set this to None to indicate that the heuristic is applicable to all device types.
register: Whether to register this heuristic. Caller should pass the condition directly.
op_name: Name of the operator (e.g., "mm", "bmm", "scaled_mm"). This is optional
and is only used when a template uses different heuristics for different ops
@ -55,9 +60,7 @@ def register_template_heuristic(
cls: type[TemplateConfigHeuristics],
) -> type[TemplateConfigHeuristics]:
if register:
key: tuple[str, ...] = (device_type, template_name)
if op_name is not None:
key = (device_type, template_name, op_name)
key: tuple[Union[str, None], ...] = (template_name, device_type, op_name)
_TEMPLATE_HEURISTIC_REGISTRY[key] = cls
log.info(
f"Registered template heuristic: {cls.__name__} for '{template_name=}', '{device_type=}', '{op_name=}'" # noqa: G004
@ -67,7 +70,6 @@ def register_template_heuristic(
return decorator
@cache
def get_template_heuristic(
template_name: str, device_type: str, op_name: str
) -> TemplateConfigHeuristics:
@ -77,15 +79,27 @@ def get_template_heuristic(
Args:
template_name: Name of the template (e.g., "mm", "bmm", "scaled_mm")
device_type: Device type ("cuda", "cpu", "xpu")
op_name: Name of the operator (e.g., "mm", "bmm", "scaled_mm")
Returns:
Template heuristic instance.
Raises:
ValueError: If no heuristic is found for the given combination.
Template heuristic instance. If no specific heuristic is found,
returns a fallback TemplateConfigHeuristics() instance (uncached).
"""
# First check the more specific key
keys = [(device_type, template_name, op_name), (device_type, template_name)]
# Check cache first
cache_key = (template_name, device_type, op_name)
if cache_key in _HEURISTIC_CACHE:
return _HEURISTIC_CACHE[cache_key]
keys = [
# everything is specified
(template_name, device_type, op_name),
# heuristic is valid across all devices
(template_name, None, op_name),
# heuristic is valid across all ops for that device
(template_name, device_type, None),
# heuristic is always valid for that template
(template_name, None, None),
]
# Look up in registry
heuristic_class = None
@ -93,13 +107,33 @@ def get_template_heuristic(
if key in _TEMPLATE_HEURISTIC_REGISTRY:
heuristic_class = _TEMPLATE_HEURISTIC_REGISTRY[key]
break
if heuristic_class is None:
raise ValueError(
f"No template heuristic found for '{template_name=}', "
f"'{device_type=}', '{op_name=}'. "
f"Available combinations: {list(_TEMPLATE_HEURISTIC_REGISTRY.keys())}"
# Log error and return fallback instance (uncached)
log.error(
"No template heuristic found - template_name=%s, device_type=%s, op_name=%s. "
"Available combinations: %s. Using fallback TemplateConfigHeuristics instance.",
template_name,
device_type,
op_name,
list(_TEMPLATE_HEURISTIC_REGISTRY.keys()),
)
return heuristic_class()
return TemplateConfigHeuristics()
# Cache successful lookup and return
instance = heuristic_class()
_HEURISTIC_CACHE[cache_key] = instance
return instance
def clear_registry() -> None:
"""
Clear all registered template heuristics.
This is primarily useful for testing purposes to ensure a clean state.
"""
_TEMPLATE_HEURISTIC_REGISTRY.clear()
_HEURISTIC_CACHE.clear()
@contextlib.contextmanager
@ -120,7 +154,7 @@ def override_template_heuristics(
# Save original entries to restore later
original_entries = {}
new_keys = []
get_template_heuristic.cache_clear()
_HEURISTIC_CACHE.clear()
try:
for template_name, op_name in template_op_pairs:
assert op_name is not None
@ -138,4 +172,4 @@ def override_template_heuristics(
_TEMPLATE_HEURISTIC_REGISTRY.pop(key, None)
if key in original_entries:
_TEMPLATE_HEURISTIC_REGISTRY[key] = original_entries[key]
get_template_heuristic.cache_clear()
_HEURISTIC_CACHE.clear()