mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[inductor][heuristics registry] missing heuristic is not an error anymore, cross device heuristics (#161767)
# why - not having a heuristic is an error but should not crash, just provide 0 configs - some heuristics are cross device type - cleaner to be explicit about being cross device type than having to enumerate every possible device type # what - on registration, supply device_type=None (explicitly) to say this heuristic is cross device - test to guard the heuristics hierarchies # testing ``` python3 -bb -m pytest test/inductor/test_template_heuristics_registry.py ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/161767 Approved by: https://github.com/PaulZhang12
This commit is contained in:
committed by
PyTorch MergeBot
parent
037f3bd475
commit
45eccf414f
171
test/inductor/test_template_heuristics_registry.py
Normal file
171
test/inductor/test_template_heuristics_registry.py
Normal file
@ -0,0 +1,171 @@
|
||||
# Owner(s): ["module: inductor"]
|
||||
from torch._inductor.template_heuristics.base import TemplateConfigHeuristics
|
||||
from torch._inductor.template_heuristics.registry import (
|
||||
_TEMPLATE_HEURISTIC_REGISTRY,
|
||||
clear_registry,
|
||||
get_template_heuristic,
|
||||
register_template_heuristic,
|
||||
)
|
||||
from torch._inductor.test_case import run_tests, TestCase
|
||||
|
||||
|
||||
class TestTemplateHeuristicsRegistry(TestCase):
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
# Save original registry state
|
||||
self.original_registry = _TEMPLATE_HEURISTIC_REGISTRY.copy()
|
||||
clear_registry() # Test heuristic classes using the decorator registration
|
||||
|
||||
def tearDown(self):
|
||||
# Restore original registry
|
||||
clear_registry()
|
||||
_TEMPLATE_HEURISTIC_REGISTRY.update(self.original_registry)
|
||||
super().tearDown()
|
||||
|
||||
def test_register_class(self):
|
||||
"""Test basic registration of a heuristic class."""
|
||||
# Clear registry for this isolated test
|
||||
clear_registry()
|
||||
|
||||
@register_template_heuristic("test_mm", "cuda")
|
||||
class TestHeuristic(TemplateConfigHeuristics):
|
||||
pass
|
||||
|
||||
# Verify registration
|
||||
key = ("test_mm", "cuda", None)
|
||||
self.assertIn(key, _TEMPLATE_HEURISTIC_REGISTRY)
|
||||
self.assertEqual(_TEMPLATE_HEURISTIC_REGISTRY[key], TestHeuristic)
|
||||
|
||||
def test_assertion_existing_class(self):
|
||||
@register_template_heuristic("triton::mm", "cuda")
|
||||
class _CrossOpHeuristic(TemplateConfigHeuristics):
|
||||
"""(template, device, None) - Cross-op for specific device"""
|
||||
|
||||
"""Test that registered class can be retrieved."""
|
||||
# The _CrossOpHeuristic is registered at module level for ("mm", "cuda", None)
|
||||
# Test retrieval - it should match for any op on cuda device
|
||||
heuristic = get_template_heuristic("triton::mm", "cuda", "bmm")
|
||||
self.assertIsInstance(heuristic, _CrossOpHeuristic)
|
||||
|
||||
def test_hierarchy_lookup(self):
|
||||
"""Test complete hierarchy: (template, device, op) -> (template, None, None)"""
|
||||
|
||||
@register_template_heuristic("triton::mm", "cuda", op_name="scaled_mm")
|
||||
class _MostSpecificHeuristic(TemplateConfigHeuristics):
|
||||
"""(template, device, op) - Most specific"""
|
||||
|
||||
@register_template_heuristic("triton::mm", None, op_name="scaled_mm")
|
||||
class _CrossDeviceHeuristic(TemplateConfigHeuristics):
|
||||
"""(template, None, op) - Cross-device for specific op"""
|
||||
|
||||
@register_template_heuristic("triton::mm", "cuda")
|
||||
class _CrossOpHeuristic(TemplateConfigHeuristics):
|
||||
"""(template, device, None) - Cross-op for specific device"""
|
||||
|
||||
@register_template_heuristic("triton::mm", None)
|
||||
class _MostGeneralHeuristic(TemplateConfigHeuristics):
|
||||
"""(template, None, None) - Most general"""
|
||||
|
||||
# All classes are already registered via decorators:
|
||||
# _MostSpecificHeuristic: ("mm", "cuda", "scaled_mm") - Most specific
|
||||
# _CrossDeviceHeuristic: ("mm", None, "scaled_mm") - Cross-device for specific op
|
||||
# _CrossOpHeuristic: ("mm", "cuda", None) - Cross-op for specific device
|
||||
# _MostGeneralHeuristic: ("mm", None, None) - Most general
|
||||
|
||||
# Test 1: Exact match - should get most specific
|
||||
heuristic = get_template_heuristic("triton::mm", "cuda", "scaled_mm")
|
||||
self.assertIsInstance(heuristic, _MostSpecificHeuristic)
|
||||
|
||||
# Test 2: Different device, same op - should get cross-device
|
||||
heuristic = get_template_heuristic("triton::mm", "xpu", "scaled_mm")
|
||||
self.assertIsInstance(heuristic, _CrossDeviceHeuristic)
|
||||
|
||||
# Test 3: Same device, different op - should get cross-op
|
||||
heuristic = get_template_heuristic("triton::mm", "cuda", "bmm")
|
||||
self.assertIsInstance(heuristic, _CrossOpHeuristic)
|
||||
|
||||
# Test 4: Different device and op - should get most general
|
||||
heuristic = get_template_heuristic("triton::mm", "xpu", "bmm")
|
||||
self.assertIsInstance(heuristic, _MostGeneralHeuristic)
|
||||
|
||||
def test_partial_hierarchy_scenarios(self):
|
||||
"""Test hierarchy behavior with partial registrations"""
|
||||
|
||||
# Scenario 1: Register partial hierarchy using decorators
|
||||
@register_template_heuristic("triton::tma", None, op_name="scaled_tma")
|
||||
class _TestCrossDeviceHeuristic(TemplateConfigHeuristics):
|
||||
pass
|
||||
|
||||
@register_template_heuristic("triton::tma", None)
|
||||
class _TestGeneralHeuristic(TemplateConfigHeuristics):
|
||||
pass
|
||||
|
||||
# Should get cross-device for matching op, regardless of device
|
||||
heuristic = get_template_heuristic("triton::tma", "cuda", "scaled_tma")
|
||||
self.assertIsInstance(heuristic, _TestCrossDeviceHeuristic)
|
||||
|
||||
# Should fallback to general for different op
|
||||
heuristic = get_template_heuristic("triton::tma", "cuda", "scaled_mm")
|
||||
self.assertIsInstance(heuristic, _TestGeneralHeuristic)
|
||||
|
||||
# Scenario 2: Only specific device exists
|
||||
@register_template_heuristic("triton::bmm", "cuda")
|
||||
class _TestDeviceSpecificHeuristic(TemplateConfigHeuristics):
|
||||
pass
|
||||
|
||||
# Should get device-specific for cuda
|
||||
heuristic = get_template_heuristic("triton::bmm", "cuda", "any_op")
|
||||
self.assertIsInstance(heuristic, _TestDeviceSpecificHeuristic)
|
||||
|
||||
# Should return fallback instance for other devices (no specific heuristic registered)
|
||||
heuristic = get_template_heuristic("triton::bmm", "xpu", "any_op")
|
||||
self.assertIsInstance(heuristic, TemplateConfigHeuristics)
|
||||
# Make sure it's not the registered specific heuristic
|
||||
self.assertNotIsInstance(heuristic, _TestDeviceSpecificHeuristic)
|
||||
|
||||
# Scenario 3: Only most general exists
|
||||
@register_template_heuristic("triton::mm", None)
|
||||
class _TestMostGeneralHeuristic(TemplateConfigHeuristics):
|
||||
pass
|
||||
|
||||
# Should always get general regardless of device/op
|
||||
heuristic = get_template_heuristic("triton::mm", "cuda", "scaled_addmm")
|
||||
self.assertIsInstance(heuristic, _TestMostGeneralHeuristic)
|
||||
|
||||
heuristic = get_template_heuristic("triton::mm", "xpu", "regular_addmm")
|
||||
self.assertIsInstance(heuristic, _TestMostGeneralHeuristic)
|
||||
|
||||
def test_fallback_behavior(self):
|
||||
"""Test that fallback TemplateConfigHeuristics is returned when no heuristic is found"""
|
||||
|
||||
# Test 1: Get fallback for unregistered template
|
||||
heuristic = get_template_heuristic("unknown_template", "cuda", "unknown_op")
|
||||
self.assertIsInstance(heuristic, TemplateConfigHeuristics)
|
||||
# Make sure it's the base class and not a subclass
|
||||
self.assertEqual(type(heuristic), TemplateConfigHeuristics)
|
||||
|
||||
# Test 2: Verify fallback instances are NOT cached (new instance each time)
|
||||
heuristic2 = get_template_heuristic("unknown_template", "cuda", "unknown_op")
|
||||
self.assertIsInstance(heuristic2, TemplateConfigHeuristics)
|
||||
self.assertEqual(type(heuristic2), TemplateConfigHeuristics)
|
||||
# Should be different instances (not cached)
|
||||
self.assertIsNot(heuristic, heuristic2)
|
||||
|
||||
# Test 3: After registering a heuristic, should get the registered one instead
|
||||
@register_template_heuristic("unknown_template", "cuda", op_name="unknown_op")
|
||||
class _NewlyRegisteredHeuristic(TemplateConfigHeuristics):
|
||||
pass
|
||||
|
||||
# Now should get the registered heuristic, not the fallback
|
||||
heuristic3 = get_template_heuristic("unknown_template", "cuda", "unknown_op")
|
||||
self.assertIsInstance(heuristic3, _NewlyRegisteredHeuristic)
|
||||
self.assertNotEqual(type(heuristic3), TemplateConfigHeuristics)
|
||||
|
||||
# Test 4: Verify registered instances ARE cached (same instance each time)
|
||||
heuristic4 = get_template_heuristic("unknown_template", "cuda", "unknown_op")
|
||||
self.assertIsInstance(heuristic4, _NewlyRegisteredHeuristic)
|
||||
self.assertIs(heuristic3, heuristic4) # Should be same cached instance
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
@ -20,9 +20,15 @@ if TYPE_CHECKING:
|
||||
from ..ir import Layout
|
||||
|
||||
|
||||
# on CUDA, we don't support hip for decompose_k yet
|
||||
@register_template_heuristic(
|
||||
"decompose_k", "cuda", register=torch.version.hip is None, op_name="mm"
|
||||
)
|
||||
# TODO(coconutruben): enable decompose k on AMD by removing the register bool
|
||||
# and benchmarking it for performance and stability
|
||||
# TODO(coconutruben): enable decompose k on other devices (xpu, cpu, mps, mtia)
|
||||
# by either adding specific register_template_heuristic tags, or setting the
|
||||
# device to None (enabled on all devices)
|
||||
class DecomposeKConfigHeuristics(TemplateConfigHeuristics):
|
||||
def get_template_configs(
|
||||
self,
|
||||
|
@ -10,8 +10,7 @@ from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import logging
|
||||
from functools import cache
|
||||
from typing import Any, Optional, TYPE_CHECKING
|
||||
from typing import Any, Optional, TYPE_CHECKING, Union
|
||||
|
||||
from .base import TemplateConfigHeuristics
|
||||
|
||||
@ -21,14 +20,19 @@ if TYPE_CHECKING:
|
||||
|
||||
|
||||
# Module-wide registry for template heuristics
|
||||
_TEMPLATE_HEURISTIC_REGISTRY: dict[tuple[str, ...], type[TemplateConfigHeuristics]] = {}
|
||||
_TEMPLATE_HEURISTIC_REGISTRY: dict[
|
||||
tuple[Union[str, None], ...], type[TemplateConfigHeuristics]
|
||||
] = {}
|
||||
|
||||
# Manual cache for successful lookups only (fallback instances are not cached)
|
||||
_HEURISTIC_CACHE: dict[tuple[str, str, str], TemplateConfigHeuristics] = {}
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def register_template_heuristic(
|
||||
template_name: str,
|
||||
device_type: str,
|
||||
device_type: Union[str, None],
|
||||
register: bool = True,
|
||||
op_name: Optional[str] = None,
|
||||
) -> Any:
|
||||
@ -38,6 +42,7 @@ def register_template_heuristic(
|
||||
Args:
|
||||
template_name: Name of the template (e.g., "mm", "bmm", "scaled_mm")
|
||||
device_type: Device type ("cuda", "cpu", "xpu")
|
||||
Set this to None to indicate that the heuristic is applicable to all device types.
|
||||
register: Whether to register this heuristic. Caller should pass the condition directly.
|
||||
op_name: Name of the operator (e.g., "mm", "bmm", "scaled_mm"). This is optional
|
||||
and is only used when a template uses different heuristics for different ops
|
||||
@ -55,9 +60,7 @@ def register_template_heuristic(
|
||||
cls: type[TemplateConfigHeuristics],
|
||||
) -> type[TemplateConfigHeuristics]:
|
||||
if register:
|
||||
key: tuple[str, ...] = (device_type, template_name)
|
||||
if op_name is not None:
|
||||
key = (device_type, template_name, op_name)
|
||||
key: tuple[Union[str, None], ...] = (template_name, device_type, op_name)
|
||||
_TEMPLATE_HEURISTIC_REGISTRY[key] = cls
|
||||
log.info(
|
||||
f"Registered template heuristic: {cls.__name__} for '{template_name=}', '{device_type=}', '{op_name=}'" # noqa: G004
|
||||
@ -67,7 +70,6 @@ def register_template_heuristic(
|
||||
return decorator
|
||||
|
||||
|
||||
@cache
|
||||
def get_template_heuristic(
|
||||
template_name: str, device_type: str, op_name: str
|
||||
) -> TemplateConfigHeuristics:
|
||||
@ -77,15 +79,27 @@ def get_template_heuristic(
|
||||
Args:
|
||||
template_name: Name of the template (e.g., "mm", "bmm", "scaled_mm")
|
||||
device_type: Device type ("cuda", "cpu", "xpu")
|
||||
op_name: Name of the operator (e.g., "mm", "bmm", "scaled_mm")
|
||||
|
||||
Returns:
|
||||
Template heuristic instance.
|
||||
|
||||
Raises:
|
||||
ValueError: If no heuristic is found for the given combination.
|
||||
Template heuristic instance. If no specific heuristic is found,
|
||||
returns a fallback TemplateConfigHeuristics() instance (uncached).
|
||||
"""
|
||||
# First check the more specific key
|
||||
keys = [(device_type, template_name, op_name), (device_type, template_name)]
|
||||
# Check cache first
|
||||
cache_key = (template_name, device_type, op_name)
|
||||
if cache_key in _HEURISTIC_CACHE:
|
||||
return _HEURISTIC_CACHE[cache_key]
|
||||
|
||||
keys = [
|
||||
# everything is specified
|
||||
(template_name, device_type, op_name),
|
||||
# heuristic is valid across all devices
|
||||
(template_name, None, op_name),
|
||||
# heuristic is valid across all ops for that device
|
||||
(template_name, device_type, None),
|
||||
# heuristic is always valid for that template
|
||||
(template_name, None, None),
|
||||
]
|
||||
|
||||
# Look up in registry
|
||||
heuristic_class = None
|
||||
@ -93,13 +107,33 @@ def get_template_heuristic(
|
||||
if key in _TEMPLATE_HEURISTIC_REGISTRY:
|
||||
heuristic_class = _TEMPLATE_HEURISTIC_REGISTRY[key]
|
||||
break
|
||||
|
||||
if heuristic_class is None:
|
||||
raise ValueError(
|
||||
f"No template heuristic found for '{template_name=}', "
|
||||
f"'{device_type=}', '{op_name=}'. "
|
||||
f"Available combinations: {list(_TEMPLATE_HEURISTIC_REGISTRY.keys())}"
|
||||
# Log error and return fallback instance (uncached)
|
||||
log.error(
|
||||
"No template heuristic found - template_name=%s, device_type=%s, op_name=%s. "
|
||||
"Available combinations: %s. Using fallback TemplateConfigHeuristics instance.",
|
||||
template_name,
|
||||
device_type,
|
||||
op_name,
|
||||
list(_TEMPLATE_HEURISTIC_REGISTRY.keys()),
|
||||
)
|
||||
return heuristic_class()
|
||||
return TemplateConfigHeuristics()
|
||||
|
||||
# Cache successful lookup and return
|
||||
instance = heuristic_class()
|
||||
_HEURISTIC_CACHE[cache_key] = instance
|
||||
return instance
|
||||
|
||||
|
||||
def clear_registry() -> None:
|
||||
"""
|
||||
Clear all registered template heuristics.
|
||||
|
||||
This is primarily useful for testing purposes to ensure a clean state.
|
||||
"""
|
||||
_TEMPLATE_HEURISTIC_REGISTRY.clear()
|
||||
_HEURISTIC_CACHE.clear()
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
@ -120,7 +154,7 @@ def override_template_heuristics(
|
||||
# Save original entries to restore later
|
||||
original_entries = {}
|
||||
new_keys = []
|
||||
get_template_heuristic.cache_clear()
|
||||
_HEURISTIC_CACHE.clear()
|
||||
try:
|
||||
for template_name, op_name in template_op_pairs:
|
||||
assert op_name is not None
|
||||
@ -138,4 +172,4 @@ def override_template_heuristics(
|
||||
_TEMPLATE_HEURISTIC_REGISTRY.pop(key, None)
|
||||
if key in original_entries:
|
||||
_TEMPLATE_HEURISTIC_REGISTRY[key] = original_entries[key]
|
||||
get_template_heuristic.cache_clear()
|
||||
_HEURISTIC_CACHE.clear()
|
||||
|
Reference in New Issue
Block a user