[inductor] Allow backends to register their own custom config object (#158254)

An out of tree backend can have its own configuration options that the user can enable to control inductor compilation. These config options need to be taken into account when calculating the key that is used to determine cache miss / hits. This PR allows out of tree backends to specify a custom config module that has the same type as `torch._inductor.config` that can be used to control codegen (in addition to the default config), and will be used when creating the cache key.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/158254
Approved by: https://github.com/eellison
This commit is contained in:
Mwiza Kunda
2025-07-23 15:56:06 +00:00
committed by PyTorch MergeBot
parent 7d296d5c19
commit d3d9bc1c31
5 changed files with 101 additions and 3 deletions

View File

@ -0,0 +1,15 @@
# Owner(s): ["module: inductor"]
# This module is used in test_codecache.py to verify the correctness
# of FXGraphHashDetails when a custom inductor backend registers its own
# config object
import sys
from torch.utils._config_module import install_config_module
enable_optimisation: bool = False
# adds patch, save_config, etc
install_config_module(sys.modules[__name__])

View File

@ -66,6 +66,12 @@ from torch.testing._internal.inductor_utils import (
from torch.testing._internal.triton_utils import requires_cuda
try:
from . import custom_inductor_config
except ImportError:
import custom_inductor_config
if HAS_TRITON:
import triton # @manual
@ -2463,6 +2469,50 @@ class TestFxGraphCacheHashing(TestCase):
pickler.dumps(details3),
)
def test_hash_custom_backend_config(self):
"""
Test cache correctness when a custom inductor codegen config
is installed
"""
with patch_inductor_backend(
"cpu", custom_backend_config=custom_inductor_config
):
gm = torch.fx.GraphModule({}, torch.fx.Graph())
pickler = FxGraphCachePickler(gm)
details1 = FxGraphHashDetails(None, [], {}, [])
details2 = FxGraphHashDetails(None, [], {}, [])
self.assertEqual(pickler.dumps(details1), pickler.dumps(details2))
custom_inductor_config.enable_optimisation = True
details3 = FxGraphHashDetails(None, [], {}, [])
self.assertNotEqual(pickler.dumps(details2), pickler.dumps(details3))
torch._dynamo.reset()
counters.clear()
custom_inductor_config.enable_optimisation = False
x = torch.zeros(32)
y = torch.zeros(32)
compiled_fn = torch.compile(torch.add)
compiled_fn(x, y)
self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 1)
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0)
torch._dynamo.reset()
counters.clear()
compiled_fn(x, y)
self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 0)
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 1)
torch._dynamo.reset()
counters.clear()
# Changing the custom config should trigger a recompilation
custom_inductor_config.enable_optimisation = True
compiled_fn(x, y)
self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 1)
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0)
def test_bypass_unsupported(self):
"""
Test _reduce_unsupported

View File

@ -52,6 +52,7 @@ from torch._dynamo.exc import SkipFrame
from torch._dynamo.utils import CompileEventLogger, counters, dynamo_timed
from torch._inductor import config, exc, metrics
from torch._inductor.codegen.common import (
custom_backend_codegen_configs,
custom_backend_passes,
init_backend_registration,
)
@ -854,6 +855,13 @@ class FxGraphHashDetails:
map(self._get_custom_pass_detail, custom_backend_passes.values())
)
# Save custom inductor codegen configs
self.custom_backend_codegen_configs = {
device: custom_config.save_config_portable(ignore_private_configs=False)
for device, custom_config in custom_backend_codegen_configs.items()
if custom_config is not None
}
# This is mainly added to handle these two inductor configs, which are (unfortunately)
# sometimes cache safe:
# - _pre_fusion_custom_pass

View File

@ -34,6 +34,7 @@ import torch
import torch.fx
from torch._prims_common import ELEMENTWISE_TYPE_PROMOTION_KIND
from torch.utils import _pytree as pytree
from torch.utils._config_module import ConfigModule
from torch.utils._ordered_set import OrderedSet
from torch.utils._sympy.numbers import int_oo
from torch.utils._sympy.printers import PythonPrinter as _PythonPrinter
@ -367,6 +368,7 @@ class DeviceOpOverrides:
device_op_overrides_dict: dict[str, DeviceOpOverrides] = {}
custom_backend_passes: dict[str, Optional[CustomGraphModulePass]] = {}
custom_backend_codegen_configs: dict[str, Optional[ConfigModule]] = {}
# The code generated by Inductor consists of two main parts: kernel code and wrapper code.
@ -396,11 +398,20 @@ def register_backend_for_device(
device_wrapper_codegen: WrapperConstructor,
device_cpp_wrapper_codegen: Optional[WrapperConstructor] = None,
device_custom_pass: Optional[CustomGraphModulePass] = None,
device_custom_config: Optional[ConfigModule] = None,
) -> None:
device_codegens[device] = DeviceCodegen(
device_scheduling, device_wrapper_codegen, device_cpp_wrapper_codegen
)
custom_backend_passes[device] = device_custom_pass
if device_custom_config:
assert (
isinstance(device_custom_config, ConfigModule)
and device_custom_config is not config
), (
f"{device_custom_config=} cannot be the same as the default inductor config {config=}"
)
custom_backend_codegen_configs[device] = device_custom_config
class BackendFeature(Enum):
@ -463,6 +474,14 @@ def get_custom_backend_pass_for_device(device: str) -> Optional[CustomGraphModul
return custom_backend_passes[device] if device in custom_backend_passes else None
def get_custom_backend_config_for_device(device: str) -> Optional[ConfigModule]:
return (
custom_backend_codegen_configs[device]
if device in custom_backend_codegen_configs
else None
)
@functools.cache
def init_backend_registration() -> None:
from .cpp import CppScheduling

View File

@ -16,6 +16,7 @@ from torch._inductor.compile_fx import shape_env_from_inputs
from torch._inductor.codecache import CppCodeCache
from torch._inductor.custom_graph_pass import CustomGraphModulePass
from torch._inductor.codegen.common import (
get_custom_backend_config_for_device,
get_custom_backend_pass_for_device,
get_scheduling_for_device,
get_wrapper_codegen_for_device,
@ -27,6 +28,7 @@ from torch._inductor.utils import get_gpu_shared_memory, is_big_gpu
from torch._inductor.utils import GPU_TYPES, get_gpu_type, is_gpu
from torch.utils._helion import has_helion
from torch.utils._triton import has_triton
from torch.utils._config_module import ConfigModule
from torch.testing._internal.common_device_type import (
get_desired_device_type_test_bases,
)
@ -308,7 +310,8 @@ def _quantize_rowwise(x: Tensor, float8_dtype: torch.dtype):
def patch_inductor_backend(
device: str,
python_wrapper_codegen: PythonWrapperCodegen = None,
custom_pass: CustomGraphModulePass = None
custom_pass: CustomGraphModulePass = None,
custom_backend_config: ConfigModule = None
):
"""
Patch the inductor backend for a specific device.
@ -321,6 +324,7 @@ def patch_inductor_backend(
original_python_wrapper = get_wrapper_codegen_for_device(device, False)
original_cpp_wrapper = get_wrapper_codegen_for_device(device, True)
original_custom_pass = get_custom_backend_pass_for_device(device)
original_custom_backend_config = get_custom_backend_config_for_device(device)
try:
# Register modified backend for the device
@ -329,7 +333,8 @@ def patch_inductor_backend(
original_scheduling,
python_wrapper_codegen if python_wrapper_codegen is not None else original_python_wrapper,
original_cpp_wrapper,
custom_pass if custom_pass is not None else original_custom_pass
custom_pass if custom_pass is not None else original_custom_pass,
custom_backend_config if custom_backend_config is not None else original_custom_backend_config
)
yield
finally:
@ -339,5 +344,6 @@ def patch_inductor_backend(
original_scheduling,
original_python_wrapper,
original_cpp_wrapper,
original_custom_pass
original_custom_pass,
original_custom_backend_config
)