mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[AOTI-FX] Support registering custom FX backends (#162317)
# Feature Currently, `torch._inductor.compile_aot` always uses the `WrapperFxCodegen` class. In contrast, Python and C++ codegen allow users to register custom backends. This PR brings that feature to FX codegen. # Test plan Added a CI test registering a custom FX backend. Pull Request resolved: https://github.com/pytorch/pytorch/pull/162317 Approved by: https://github.com/jansel
This commit is contained in:
committed by
PyTorch MergeBot
parent
0ff8eabf13
commit
9aedb3cd87
@ -20,6 +20,7 @@ from torch._inductor import config
|
||||
from torch._inductor.codegen.common import register_backend_for_device
|
||||
from torch._inductor.codegen.cpp import CppScheduling
|
||||
from torch._inductor.codegen.triton import TritonScheduling
|
||||
from torch._inductor.codegen.wrapper import PythonWrapperCodegen
|
||||
from torch._inductor.codegen.wrapper_fxir import FxConverter, WrapperFxCodegen
|
||||
from torch._inductor.test_case import TestCase as InductorTestCase
|
||||
from torch.export import Dim
|
||||
@ -783,6 +784,43 @@ class AOTFxirTestCase(InductorTestCase):
|
||||
strict=True,
|
||||
)
|
||||
|
||||
def test_custom_backend(self):
|
||||
"""
|
||||
Test registering a custom FX backend.
|
||||
"""
|
||||
called = False
|
||||
|
||||
class CustomWrapperCodegen(WrapperFxCodegen):
|
||||
def compile_graph(self, gm):
|
||||
"""
|
||||
Simply records whether this override was called.
|
||||
"""
|
||||
nonlocal called
|
||||
called = True
|
||||
return super().compile_graph(gm)
|
||||
|
||||
class M(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
return x + 1
|
||||
|
||||
# Register a custom FX backend.
|
||||
custom_backend = common.DeviceCodegen(
|
||||
TritonScheduling,
|
||||
PythonWrapperCodegen,
|
||||
fx_wrapper_codegen=CustomWrapperCodegen,
|
||||
)
|
||||
with unittest.mock.patch.dict(
|
||||
common.device_codegens, {self.device: custom_backend}
|
||||
):
|
||||
# The backend should not have been called yet.
|
||||
self.assertFalse(called)
|
||||
|
||||
inp = (torch.randn(8, device=self.device),)
|
||||
self.check(M().to(self.device), inp)
|
||||
|
||||
# Now the backend should have been called.
|
||||
self.assertTrue(called)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._inductor.test_case import run_tests
|
||||
|
@ -308,6 +308,7 @@ class DeviceCodegen:
|
||||
scheduling: SchedulingConstructor
|
||||
wrapper_codegen: WrapperConstructor
|
||||
cpp_wrapper_codegen: Optional[WrapperConstructor] = None
|
||||
fx_wrapper_codegen: Optional[WrapperConstructor] = None
|
||||
|
||||
|
||||
KernelArgType = Union[WorkspaceArg, TensorArg, SizeArg, TMADescriptorArg, ConstexprArg]
|
||||
@ -402,11 +403,15 @@ def register_backend_for_device(
|
||||
device_scheduling: SchedulingConstructor,
|
||||
device_wrapper_codegen: WrapperConstructor,
|
||||
device_cpp_wrapper_codegen: Optional[WrapperConstructor] = None,
|
||||
device_fx_wrapper_codegen: Optional[WrapperConstructor] = None,
|
||||
device_custom_pass: Optional[CustomGraphModulePass] = None,
|
||||
device_custom_config: Optional[ConfigModule] = None,
|
||||
) -> None:
|
||||
device_codegens[device] = DeviceCodegen(
|
||||
device_scheduling, device_wrapper_codegen, device_cpp_wrapper_codegen
|
||||
device_scheduling,
|
||||
device_wrapper_codegen,
|
||||
device_cpp_wrapper_codegen,
|
||||
device_fx_wrapper_codegen,
|
||||
)
|
||||
custom_backend_passes[device] = device_custom_pass
|
||||
if device_custom_config:
|
||||
@ -468,9 +473,7 @@ def get_wrapper_codegen_for_device(
|
||||
if device in device_codegens:
|
||||
wrapper_codegen_obj: DeviceCodegen = device_codegens[device]
|
||||
if fx_wrapper:
|
||||
from .wrapper_fxir import WrapperFxCodegen
|
||||
|
||||
return WrapperFxCodegen
|
||||
return wrapper_codegen_obj.fx_wrapper_codegen
|
||||
elif cpp_wrapper:
|
||||
return wrapper_codegen_obj.cpp_wrapper_codegen
|
||||
else:
|
||||
@ -507,6 +510,7 @@ def init_backend_registration() -> None:
|
||||
from .python_wrapper_mtia import PythonWrapperMtia
|
||||
from .triton import TritonScheduling
|
||||
from .wrapper import PythonWrapperCodegen
|
||||
from .wrapper_fxir import WrapperFxCodegen
|
||||
|
||||
if get_scheduling_for_device("cpu") is None:
|
||||
cpu_backends = {
|
||||
@ -521,6 +525,7 @@ def init_backend_registration() -> None:
|
||||
CppWrapperCpuArrayRef
|
||||
if config.aot_inductor.allow_stack_allocation
|
||||
else CppWrapperCpu,
|
||||
WrapperFxCodegen,
|
||||
)
|
||||
|
||||
if get_scheduling_for_device("cuda") is None:
|
||||
@ -534,6 +539,7 @@ def init_backend_registration() -> None:
|
||||
lambda scheduling: cuda_backends[config.cuda_backend](scheduling),
|
||||
PythonWrapperCodegen,
|
||||
CppWrapperGpu,
|
||||
WrapperFxCodegen,
|
||||
)
|
||||
|
||||
if get_scheduling_for_device("xpu") is None:
|
||||
@ -542,6 +548,7 @@ def init_backend_registration() -> None:
|
||||
TritonScheduling,
|
||||
PythonWrapperCodegen,
|
||||
CppWrapperGpu,
|
||||
WrapperFxCodegen,
|
||||
)
|
||||
|
||||
if get_scheduling_for_device("mps") is None:
|
||||
@ -550,6 +557,7 @@ def init_backend_registration() -> None:
|
||||
MetalScheduling,
|
||||
PythonWrapperCodegen,
|
||||
CppWrapperMps,
|
||||
WrapperFxCodegen,
|
||||
)
|
||||
|
||||
if get_scheduling_for_device("mtia") is None:
|
||||
@ -558,6 +566,7 @@ def init_backend_registration() -> None:
|
||||
TritonScheduling,
|
||||
PythonWrapperMtia,
|
||||
CppWrapperGpu,
|
||||
WrapperFxCodegen,
|
||||
)
|
||||
|
||||
private_backend = torch._C._get_privateuse1_backend_name()
|
||||
@ -571,12 +580,14 @@ def init_backend_registration() -> None:
|
||||
device_scheduling = _get_custom_mod_func("Scheduling")
|
||||
wrapper_codegen = _get_custom_mod_func("PythonWrapperCodegen")
|
||||
cpp_wrapper_codegen = _get_custom_mod_func("CppWrapperCodegen")
|
||||
fx_wrapper_codegen = _get_custom_mod_func("WrapperFxCodegen")
|
||||
if device_scheduling and wrapper_codegen and cpp_wrapper_codegen:
|
||||
register_backend_for_device(
|
||||
private_backend,
|
||||
device_scheduling,
|
||||
wrapper_codegen,
|
||||
cpp_wrapper_codegen,
|
||||
fx_wrapper_codegen,
|
||||
)
|
||||
except RuntimeError:
|
||||
pass
|
||||
|
@ -342,6 +342,7 @@ def patch_inductor_backend(
|
||||
original_scheduling = get_scheduling_for_device(device)
|
||||
original_python_wrapper = get_wrapper_codegen_for_device(device, False)
|
||||
original_cpp_wrapper = get_wrapper_codegen_for_device(device, True)
|
||||
original_fx_wrapper = get_wrapper_codegen_for_device(device, fx_wrapper=True)
|
||||
original_custom_pass = get_custom_backend_pass_for_device(device)
|
||||
original_custom_backend_config = get_custom_backend_config_for_device(device)
|
||||
|
||||
@ -352,6 +353,7 @@ def patch_inductor_backend(
|
||||
original_scheduling,
|
||||
python_wrapper_codegen if python_wrapper_codegen is not None else original_python_wrapper,
|
||||
original_cpp_wrapper,
|
||||
original_fx_wrapper,
|
||||
custom_pass if custom_pass is not None else original_custom_pass,
|
||||
custom_backend_config if custom_backend_config is not None else original_custom_backend_config
|
||||
)
|
||||
@ -363,6 +365,7 @@ def patch_inductor_backend(
|
||||
original_scheduling,
|
||||
original_python_wrapper,
|
||||
original_cpp_wrapper,
|
||||
original_fx_wrapper,
|
||||
original_custom_pass,
|
||||
original_custom_backend_config
|
||||
)
|
||||
|
Reference in New Issue
Block a user