[AOTI-FX] Support registering custom FX backends (#162317)

# Feature
Currently, `torch._inductor.compile_aot` always uses the `WrapperFxCodegen` class. In contrast, Python and C++ codegen allow users to register custom backends. This PR brings that feature to FX codegen.

# Test plan
Added a CI test registering a custom FX backend.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/162317
Approved by: https://github.com/jansel
This commit is contained in:
Blaine Burton Rister
2025-09-06 07:32:00 +00:00
committed by PyTorch MergeBot
parent 0ff8eabf13
commit 9aedb3cd87
3 changed files with 56 additions and 4 deletions

View File

@ -20,6 +20,7 @@ from torch._inductor import config
from torch._inductor.codegen.common import register_backend_for_device
from torch._inductor.codegen.cpp import CppScheduling
from torch._inductor.codegen.triton import TritonScheduling
from torch._inductor.codegen.wrapper import PythonWrapperCodegen
from torch._inductor.codegen.wrapper_fxir import FxConverter, WrapperFxCodegen
from torch._inductor.test_case import TestCase as InductorTestCase
from torch.export import Dim
@ -783,6 +784,43 @@ class AOTFxirTestCase(InductorTestCase):
strict=True,
)
def test_custom_backend(self):
"""
Test registering a custom FX backend.
"""
called = False
class CustomWrapperCodegen(WrapperFxCodegen):
def compile_graph(self, gm):
"""
Simply records whether this override was called.
"""
nonlocal called
called = True
return super().compile_graph(gm)
class M(torch.nn.Module):
def forward(self, x):
return x + 1
# Register a custom FX backend.
custom_backend = common.DeviceCodegen(
TritonScheduling,
PythonWrapperCodegen,
fx_wrapper_codegen=CustomWrapperCodegen,
)
with unittest.mock.patch.dict(
common.device_codegens, {self.device: custom_backend}
):
# The backend should not have been called yet.
self.assertFalse(called)
inp = (torch.randn(8, device=self.device),)
self.check(M().to(self.device), inp)
# Now the backend should have been called.
self.assertTrue(called)
if __name__ == "__main__":
from torch._inductor.test_case import run_tests

View File

@ -308,6 +308,7 @@ class DeviceCodegen:
scheduling: SchedulingConstructor
wrapper_codegen: WrapperConstructor
cpp_wrapper_codegen: Optional[WrapperConstructor] = None
fx_wrapper_codegen: Optional[WrapperConstructor] = None
KernelArgType = Union[WorkspaceArg, TensorArg, SizeArg, TMADescriptorArg, ConstexprArg]
@ -402,11 +403,15 @@ def register_backend_for_device(
device_scheduling: SchedulingConstructor,
device_wrapper_codegen: WrapperConstructor,
device_cpp_wrapper_codegen: Optional[WrapperConstructor] = None,
device_fx_wrapper_codegen: Optional[WrapperConstructor] = None,
device_custom_pass: Optional[CustomGraphModulePass] = None,
device_custom_config: Optional[ConfigModule] = None,
) -> None:
device_codegens[device] = DeviceCodegen(
device_scheduling, device_wrapper_codegen, device_cpp_wrapper_codegen
device_scheduling,
device_wrapper_codegen,
device_cpp_wrapper_codegen,
device_fx_wrapper_codegen,
)
custom_backend_passes[device] = device_custom_pass
if device_custom_config:
@ -468,9 +473,7 @@ def get_wrapper_codegen_for_device(
if device in device_codegens:
wrapper_codegen_obj: DeviceCodegen = device_codegens[device]
if fx_wrapper:
from .wrapper_fxir import WrapperFxCodegen
return WrapperFxCodegen
return wrapper_codegen_obj.fx_wrapper_codegen
elif cpp_wrapper:
return wrapper_codegen_obj.cpp_wrapper_codegen
else:
@ -507,6 +510,7 @@ def init_backend_registration() -> None:
from .python_wrapper_mtia import PythonWrapperMtia
from .triton import TritonScheduling
from .wrapper import PythonWrapperCodegen
from .wrapper_fxir import WrapperFxCodegen
if get_scheduling_for_device("cpu") is None:
cpu_backends = {
@ -521,6 +525,7 @@ def init_backend_registration() -> None:
CppWrapperCpuArrayRef
if config.aot_inductor.allow_stack_allocation
else CppWrapperCpu,
WrapperFxCodegen,
)
if get_scheduling_for_device("cuda") is None:
@ -534,6 +539,7 @@ def init_backend_registration() -> None:
lambda scheduling: cuda_backends[config.cuda_backend](scheduling),
PythonWrapperCodegen,
CppWrapperGpu,
WrapperFxCodegen,
)
if get_scheduling_for_device("xpu") is None:
@ -542,6 +548,7 @@ def init_backend_registration() -> None:
TritonScheduling,
PythonWrapperCodegen,
CppWrapperGpu,
WrapperFxCodegen,
)
if get_scheduling_for_device("mps") is None:
@ -550,6 +557,7 @@ def init_backend_registration() -> None:
MetalScheduling,
PythonWrapperCodegen,
CppWrapperMps,
WrapperFxCodegen,
)
if get_scheduling_for_device("mtia") is None:
@ -558,6 +566,7 @@ def init_backend_registration() -> None:
TritonScheduling,
PythonWrapperMtia,
CppWrapperGpu,
WrapperFxCodegen,
)
private_backend = torch._C._get_privateuse1_backend_name()
@ -571,12 +580,14 @@ def init_backend_registration() -> None:
device_scheduling = _get_custom_mod_func("Scheduling")
wrapper_codegen = _get_custom_mod_func("PythonWrapperCodegen")
cpp_wrapper_codegen = _get_custom_mod_func("CppWrapperCodegen")
fx_wrapper_codegen = _get_custom_mod_func("WrapperFxCodegen")
if device_scheduling and wrapper_codegen and cpp_wrapper_codegen:
register_backend_for_device(
private_backend,
device_scheduling,
wrapper_codegen,
cpp_wrapper_codegen,
fx_wrapper_codegen,
)
except RuntimeError:
pass

View File

@ -342,6 +342,7 @@ def patch_inductor_backend(
original_scheduling = get_scheduling_for_device(device)
original_python_wrapper = get_wrapper_codegen_for_device(device, False)
original_cpp_wrapper = get_wrapper_codegen_for_device(device, True)
original_fx_wrapper = get_wrapper_codegen_for_device(device, fx_wrapper=True)
original_custom_pass = get_custom_backend_pass_for_device(device)
original_custom_backend_config = get_custom_backend_config_for_device(device)
@ -352,6 +353,7 @@ def patch_inductor_backend(
original_scheduling,
python_wrapper_codegen if python_wrapper_codegen is not None else original_python_wrapper,
original_cpp_wrapper,
original_fx_wrapper,
custom_pass if custom_pass is not None else original_custom_pass,
custom_backend_config if custom_backend_config is not None else original_custom_backend_config
)
@ -363,6 +365,7 @@ def patch_inductor_backend(
original_scheduling,
original_python_wrapper,
original_cpp_wrapper,
original_fx_wrapper,
original_custom_pass,
original_custom_backend_config
)