[inductor] Allow backends to register their own custom config object (#158254)

An out of tree backend can have its own configuration options that the user can enable to control inductor compilation. These config options need to be taken into account when calculating the key that is used to determine cache miss / hits. This PR allows out of tree backends to specify a custom config module that has the same type as `torch._inductor.config` that can be used to control codegen (in addition to the default config), and will be used when creating the cache key.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/158254
Approved by: https://github.com/eellison
This commit is contained in:
Mwiza Kunda
2025-07-23 15:56:06 +00:00
committed by PyTorch MergeBot
parent 7d296d5c19
commit d3d9bc1c31
5 changed files with 101 additions and 3 deletions

View File

@ -34,6 +34,7 @@ import torch
import torch.fx
from torch._prims_common import ELEMENTWISE_TYPE_PROMOTION_KIND
from torch.utils import _pytree as pytree
from torch.utils._config_module import ConfigModule
from torch.utils._ordered_set import OrderedSet
from torch.utils._sympy.numbers import int_oo
from torch.utils._sympy.printers import PythonPrinter as _PythonPrinter
@ -367,6 +368,7 @@ class DeviceOpOverrides:
device_op_overrides_dict: dict[str, DeviceOpOverrides] = {}
custom_backend_passes: dict[str, Optional[CustomGraphModulePass]] = {}
custom_backend_codegen_configs: dict[str, Optional[ConfigModule]] = {}
# The code generated by Inductor consists of two main parts: kernel code and wrapper code.
@ -396,11 +398,20 @@ def register_backend_for_device(
device_wrapper_codegen: WrapperConstructor,
device_cpp_wrapper_codegen: Optional[WrapperConstructor] = None,
device_custom_pass: Optional[CustomGraphModulePass] = None,
device_custom_config: Optional[ConfigModule] = None,
) -> None:
device_codegens[device] = DeviceCodegen(
device_scheduling, device_wrapper_codegen, device_cpp_wrapper_codegen
)
custom_backend_passes[device] = device_custom_pass
if device_custom_config:
assert (
isinstance(device_custom_config, ConfigModule)
and device_custom_config is not config
), (
f"{device_custom_config=} cannot be the same as the default inductor config {config=}"
)
custom_backend_codegen_configs[device] = device_custom_config
class BackendFeature(Enum):
@ -463,6 +474,14 @@ def get_custom_backend_pass_for_device(device: str) -> Optional[CustomGraphModul
return custom_backend_passes[device] if device in custom_backend_passes else None
def get_custom_backend_config_for_device(device: str) -> Optional[ConfigModule]:
return (
custom_backend_codegen_configs[device]
if device in custom_backend_codegen_configs
else None
)
@functools.cache
def init_backend_registration() -> None:
from .cpp import CppScheduling