From d3d9bc1c312cb8415d504a7af5682e75a97d3541 Mon Sep 17 00:00:00 2001 From: Mwiza Kunda Date: Wed, 23 Jul 2025 15:56:06 +0000 Subject: [PATCH] [inductor] Allow backends to register their own custom config object (#158254) An out of tree backend can have its own configuration options that the user can enable to control inductor compilation. These config options need to be taken into account when calculating the key that is used to determine cache miss / hits. This PR allows out of tree backends to specify a custom config module that has the same type as `torch._inductor.config` that can be used to control codegen (in addition to the default config), and will be used when creating the cache key. Pull Request resolved: https://github.com/pytorch/pytorch/pull/158254 Approved by: https://github.com/eellison --- test/inductor/custom_inductor_config.py | 15 +++++++ test/inductor/test_codecache.py | 50 +++++++++++++++++++++++ torch/_inductor/codecache.py | 8 ++++ torch/_inductor/codegen/common.py | 19 +++++++++ torch/testing/_internal/inductor_utils.py | 12 ++++-- 5 files changed, 101 insertions(+), 3 deletions(-) create mode 100644 test/inductor/custom_inductor_config.py diff --git a/test/inductor/custom_inductor_config.py b/test/inductor/custom_inductor_config.py new file mode 100644 index 000000000000..e29430728f94 --- /dev/null +++ b/test/inductor/custom_inductor_config.py @@ -0,0 +1,15 @@ +# Owner(s): ["module: inductor"] + +# This module is used in test_codecache.py to verify the correctness +# of FXGraphHashDetails when a custom inductor backend registers its own +# config object + +import sys + +from torch.utils._config_module import install_config_module + + +enable_optimisation: bool = False + +# adds patch, save_config, etc +install_config_module(sys.modules[__name__]) diff --git a/test/inductor/test_codecache.py b/test/inductor/test_codecache.py index 51af64153500..93545ed93cc3 100644 --- a/test/inductor/test_codecache.py +++ b/test/inductor/test_codecache.py @@ -66,6 +66,12 @@ from torch.testing._internal.inductor_utils import ( from torch.testing._internal.triton_utils import requires_cuda +try: + from . import custom_inductor_config +except ImportError: + import custom_inductor_config + + if HAS_TRITON: import triton # @manual @@ -2463,6 +2469,50 @@ class TestFxGraphCacheHashing(TestCase): pickler.dumps(details3), ) + def test_hash_custom_backend_config(self): + """ + Test cache correctness when a custom inductor codegen config + is installed + """ + with patch_inductor_backend( + "cpu", custom_backend_config=custom_inductor_config + ): + gm = torch.fx.GraphModule({}, torch.fx.Graph()) + pickler = FxGraphCachePickler(gm) + details1 = FxGraphHashDetails(None, [], {}, []) + details2 = FxGraphHashDetails(None, [], {}, []) + self.assertEqual(pickler.dumps(details1), pickler.dumps(details2)) + + custom_inductor_config.enable_optimisation = True + details3 = FxGraphHashDetails(None, [], {}, []) + self.assertNotEqual(pickler.dumps(details2), pickler.dumps(details3)) + + torch._dynamo.reset() + counters.clear() + + custom_inductor_config.enable_optimisation = False + x = torch.zeros(32) + y = torch.zeros(32) + compiled_fn = torch.compile(torch.add) + + compiled_fn(x, y) + self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 1) + self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0) + torch._dynamo.reset() + counters.clear() + + compiled_fn(x, y) + self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 0) + self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 1) + torch._dynamo.reset() + counters.clear() + + # Changing the custom config should trigger a recompilation + custom_inductor_config.enable_optimisation = True + compiled_fn(x, y) + self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 1) + self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0) + def test_bypass_unsupported(self): """ Test _reduce_unsupported diff --git a/torch/_inductor/codecache.py b/torch/_inductor/codecache.py index c8b23aded15c..442d36e0d117 100644 --- a/torch/_inductor/codecache.py +++ b/torch/_inductor/codecache.py @@ -52,6 +52,7 @@ from torch._dynamo.exc import SkipFrame from torch._dynamo.utils import CompileEventLogger, counters, dynamo_timed from torch._inductor import config, exc, metrics from torch._inductor.codegen.common import ( + custom_backend_codegen_configs, custom_backend_passes, init_backend_registration, ) @@ -854,6 +855,13 @@ class FxGraphHashDetails: map(self._get_custom_pass_detail, custom_backend_passes.values()) ) + # Save custom inductor codegen configs + self.custom_backend_codegen_configs = { + device: custom_config.save_config_portable(ignore_private_configs=False) + for device, custom_config in custom_backend_codegen_configs.items() + if custom_config is not None + } + # This is mainly added to handle these two inductor configs, which are (unfortunately) # sometimes cache safe: # - _pre_fusion_custom_pass diff --git a/torch/_inductor/codegen/common.py b/torch/_inductor/codegen/common.py index 828050d6da14..92ee9e28be74 100644 --- a/torch/_inductor/codegen/common.py +++ b/torch/_inductor/codegen/common.py @@ -34,6 +34,7 @@ import torch import torch.fx from torch._prims_common import ELEMENTWISE_TYPE_PROMOTION_KIND from torch.utils import _pytree as pytree +from torch.utils._config_module import ConfigModule from torch.utils._ordered_set import OrderedSet from torch.utils._sympy.numbers import int_oo from torch.utils._sympy.printers import PythonPrinter as _PythonPrinter @@ -367,6 +368,7 @@ class DeviceOpOverrides: device_op_overrides_dict: dict[str, DeviceOpOverrides] = {} custom_backend_passes: dict[str, Optional[CustomGraphModulePass]] = {} +custom_backend_codegen_configs: dict[str, Optional[ConfigModule]] = {} # The code generated by Inductor consists of two main parts: kernel code and wrapper code. @@ -396,11 +398,20 @@ def register_backend_for_device( device_wrapper_codegen: WrapperConstructor, device_cpp_wrapper_codegen: Optional[WrapperConstructor] = None, device_custom_pass: Optional[CustomGraphModulePass] = None, + device_custom_config: Optional[ConfigModule] = None, ) -> None: device_codegens[device] = DeviceCodegen( device_scheduling, device_wrapper_codegen, device_cpp_wrapper_codegen ) custom_backend_passes[device] = device_custom_pass + if device_custom_config: + assert ( + isinstance(device_custom_config, ConfigModule) + and device_custom_config is not config + ), ( + f"{device_custom_config=} cannot be the same as the default inductor config {config=}" + ) + custom_backend_codegen_configs[device] = device_custom_config class BackendFeature(Enum): @@ -463,6 +474,14 @@ def get_custom_backend_pass_for_device(device: str) -> Optional[CustomGraphModul return custom_backend_passes[device] if device in custom_backend_passes else None +def get_custom_backend_config_for_device(device: str) -> Optional[ConfigModule]: + return ( + custom_backend_codegen_configs[device] + if device in custom_backend_codegen_configs + else None + ) + + @functools.cache def init_backend_registration() -> None: from .cpp import CppScheduling diff --git a/torch/testing/_internal/inductor_utils.py b/torch/testing/_internal/inductor_utils.py index 91a4aaa5728a..8a521d56f5f8 100644 --- a/torch/testing/_internal/inductor_utils.py +++ b/torch/testing/_internal/inductor_utils.py @@ -16,6 +16,7 @@ from torch._inductor.compile_fx import shape_env_from_inputs from torch._inductor.codecache import CppCodeCache from torch._inductor.custom_graph_pass import CustomGraphModulePass from torch._inductor.codegen.common import ( + get_custom_backend_config_for_device, get_custom_backend_pass_for_device, get_scheduling_for_device, get_wrapper_codegen_for_device, @@ -27,6 +28,7 @@ from torch._inductor.utils import get_gpu_shared_memory, is_big_gpu from torch._inductor.utils import GPU_TYPES, get_gpu_type, is_gpu from torch.utils._helion import has_helion from torch.utils._triton import has_triton +from torch.utils._config_module import ConfigModule from torch.testing._internal.common_device_type import ( get_desired_device_type_test_bases, ) @@ -308,7 +310,8 @@ def _quantize_rowwise(x: Tensor, float8_dtype: torch.dtype): def patch_inductor_backend( device: str, python_wrapper_codegen: PythonWrapperCodegen = None, - custom_pass: CustomGraphModulePass = None + custom_pass: CustomGraphModulePass = None, + custom_backend_config: ConfigModule = None ): """ Patch the inductor backend for a specific device. @@ -321,6 +324,7 @@ def patch_inductor_backend( original_python_wrapper = get_wrapper_codegen_for_device(device, False) original_cpp_wrapper = get_wrapper_codegen_for_device(device, True) original_custom_pass = get_custom_backend_pass_for_device(device) + original_custom_backend_config = get_custom_backend_config_for_device(device) try: # Register modified backend for the device @@ -329,7 +333,8 @@ def patch_inductor_backend( original_scheduling, python_wrapper_codegen if python_wrapper_codegen is not None else original_python_wrapper, original_cpp_wrapper, - custom_pass if custom_pass is not None else original_custom_pass + custom_pass if custom_pass is not None else original_custom_pass, + custom_backend_config if custom_backend_config is not None else original_custom_backend_config ) yield finally: @@ -339,5 +344,6 @@ def patch_inductor_backend( original_scheduling, original_python_wrapper, original_cpp_wrapper, - original_custom_pass + original_custom_pass, + original_custom_backend_config )