Custom FX pass for inductor's backend registration (#154841)

This PR is related to RFC #153532. It is an extension to Inductor's backend registration interface to allow to register custom FX passes by the backend.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/154841
Approved by: https://github.com/jansel

Co-authored-by: Jason Ansel <jansel@jansel.net>
This commit is contained in:
Marcin Pioch
2025-06-06 06:49:44 +00:00
committed by PyTorch MergeBot
parent c6b4f98625
commit e694280d12
10 changed files with 213 additions and 36 deletions

View File

@ -27,7 +27,11 @@ from torch._inductor.codecache import (
TensorMetadata,
TensorMetadataAndValues,
)
from torch._inductor.custom_graph_pass import CustomGraphPass, get_hash_for_files
from torch._inductor.custom_graph_pass import (
CustomGraphModulePass,
CustomGraphPass,
get_hash_for_files,
)
from torch._inductor.graph import GraphLowering
from torch._inductor.mock_cache import global_stats, PatchCaches, Stats
from torch._inductor.runtime.runtime_utils import cache_dir
@ -53,6 +57,7 @@ from torch.testing._internal.inductor_utils import (
HAS_GPU,
HAS_MULTIGPU,
HAS_TRITON,
patch_inductor_backend,
requires_gpu,
requires_triton,
)
@ -2183,6 +2188,42 @@ class TestFxGraphCacheHashing(TestCase):
pickler.dumps(details3),
)
def test_hash_custom_backend_pass(self):
"""
Test CustomGraphModulePass usage.
"""
class TestCustomGraphModulePass(CustomGraphModulePass):
def __init__(self):
self._uuid = None
def __call__(self, gm: torch.fx.GraphModule) -> None:
return None
def uuid(self) -> Optional[Union[bytes, str]]:
return self._uuid
custom_pass = TestCustomGraphModulePass()
with patch_inductor_backend("cpu", custom_pass=custom_pass):
custom_pass._uuid = "1"
details1 = FxGraphHashDetails(None, [], {}, [])
details2 = FxGraphHashDetails(None, [], {}, [])
custom_pass._uuid = "2"
details3 = FxGraphHashDetails(None, [], {}, [])
gm = torch.fx.GraphModule({}, torch.fx.Graph())
pickler = FxGraphCachePickler(gm)
self.assertEqual(
pickler.dumps(details1),
pickler.dumps(details2),
)
self.assertNotEqual(
pickler.dumps(details1),
pickler.dumps(details3),
)
def test_bypass_unsupported(self):
"""
Test _reduce_unsupported

View File

@ -8,12 +8,17 @@ import torch._inductor.pattern_matcher as pattern_matcher
import torch.fx as fx
from torch._dynamo.utils import counters
from torch._inductor import config
from torch._inductor.custom_graph_pass import CustomGraphPass, get_hash_for_files
from torch._inductor.codegen.common import get_custom_backend_pass_for_device
from torch._inductor.custom_graph_pass import (
CustomGraphModulePass,
CustomGraphPass,
get_hash_for_files,
)
from torch._inductor.lowering import lowerings as L
from torch._inductor.pattern_matcher import Arg, CallFunction, PatternMatcherPass
from torch._inductor.test_case import run_tests, TestCase
from torch.testing._internal.common_utils import IS_LINUX
from torch.testing._internal.inductor_utils import HAS_CPU
from torch.testing._internal.inductor_utils import HAS_CPU, patch_inductor_backend
@config.patch({"freezing": True})
@ -264,6 +269,35 @@ class TestPostGradCustomPrePostPass(TestCustomPassBase):
inner_test()
def test_custom_backend_pass(self):
class CustomBackendPass(CustomGraphModulePass):
def __init__(self, existing_pass: CustomGraphModulePass = None):
super().__init__()
self.existing_pass = existing_pass
def __call__(self, gm: fx.GraphModule) -> None:
if self.existing_pass:
self.existing_pass(gm)
change_cos_pass(gm.graph)
def uuid(self) -> bytes:
return get_hash_for_files((__file__,))
custom_backend_pass = CustomBackendPass(
get_custom_backend_pass_for_device("cpu")
)
with patch_inductor_backend("cpu", custom_pass=custom_backend_pass):
def g(x):
return x.sin().sin().sin()
def f(x):
return x.cos().cos().cos()
x = torch.randn(8, dtype=torch.float32)
torch.testing.assert_close(torch.compile(f)(x), g(x))
if __name__ == "__main__":
if IS_LINUX and HAS_CPU and torch.backends.mkldnn.is_available():

View File

@ -12,8 +12,6 @@ import torch
import torch.library
from torch._dynamo.testing import CompileCounterWithBackend, make_test_cls_with_patches
from torch._inductor import metrics
from torch._inductor.codegen.common import device_codegens, register_backend_for_device
from torch._inductor.codegen.cpp import CppScheduling
from torch._inductor.codegen.wrapper import PythonWrapperCodegen
from torch._inductor.test_case import TestCase
from torch._inductor.utils import run_and_get_code
@ -34,7 +32,12 @@ from torch.testing._internal.common_utils import (
TEST_WITH_ASAN,
TEST_WITH_ROCM,
)
from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_CPU, HAS_GPU
from torch.testing._internal.inductor_utils import (
GPU_TYPE,
HAS_CPU,
HAS_GPU,
patch_inductor_backend,
)
# Make the helper files in test/ importable
@ -932,23 +935,13 @@ class TestInductorDynamic(TestCase):
_test_wrapper_codegen_statically_known_int_or_none_in_context()
return super().generate(is_inference, *args, **kwargs)
if "cpu" not in device_codegens:
register_backend_for_device("cpu", CppScheduling, PythonWrapperCodegen)
orig_cpu_codegens = device_codegens["cpu"]
try:
register_backend_for_device(
"cpu", orig_cpu_codegens.scheduling, TestWrapperCodegen
)
with patch_inductor_backend("cpu", python_wrapper_codegen=TestWrapperCodegen):
# Compile each of the functions above, with an example input
# that has 5 in the first dimension, but is marked as dynamic
torch.compile(backend="inductor", dynamic=None)(fn_1)(_x)
torch.compile(backend="inductor", dynamic=None)(fn_2)(_x)
torch.compile(backend="inductor", dynamic=None)(fn_3)(_x)
finally:
register_backend_for_device(
"cpu", orig_cpu_codegens.scheduling, orig_cpu_codegens.wrapper_codegen
)
@torch._dynamo.config.patch(capture_scalar_outputs=True)
def test_item_unbacked_stride_nobreak(self, device):

View File

@ -51,6 +51,7 @@ from torch import SymInt, Tensor
from torch._dynamo.exc import SkipFrame
from torch._dynamo.utils import CompileEventLogger, counters, dynamo_timed
from torch._inductor import config, exc, metrics
from torch._inductor.codegen.common import custom_backend_passes
from torch._inductor.codegen.cuda import cuda_env
from torch._inductor.codegen.rocm.compile_command import (
rocm_compile_command,
@ -72,7 +73,11 @@ from torch._inductor.cpp_builder import (
normalize_path_separator,
)
from torch._inductor.cpu_vec_isa import pick_vec_isa
from torch._inductor.custom_graph_pass import CustomGraphPass, CustomGraphPassType
from torch._inductor.custom_graph_pass import (
CustomGraphModulePass,
CustomGraphPass,
CustomGraphPassType,
)
from torch._inductor.freezing_utils import has_frozen_params, is_frozen_param
from torch._inductor.runtime.compile_tasks import _reload_python_module
from torch._inductor.runtime.runtime_utils import cache_dir, default_cache_dir
@ -891,12 +896,16 @@ class FxGraphHashDetails:
config.post_grad_custom_post_pass
)
self.custom_backend_passes = tuple(
map(self._get_custom_pass_detail, custom_backend_passes.values())
)
def _get_custom_pass_detail(
self, custom_pass: CustomGraphPassType
self, custom_pass: Union[CustomGraphPassType, CustomGraphModulePass]
) -> Optional[Any]:
if not custom_pass:
return None
assert isinstance(custom_pass, CustomGraphPass)
assert isinstance(custom_pass, (CustomGraphPass, CustomGraphModulePass))
return custom_pass.uuid()

View File

@ -65,6 +65,7 @@ if TYPE_CHECKING:
from torch.fx import GraphModule
from ..custom_graph_pass import CustomGraphModulePass
from ..ir import Buffer, ChoiceCaller, FixedLayout, IRNode
from ..loop_body import LoopBody
from ..scheduler import BaseScheduling, Scheduler, SchedulerNode
@ -351,6 +352,7 @@ class DeviceOpOverrides:
device_op_overrides_dict: dict[str, DeviceOpOverrides] = {}
custom_backend_passes: dict[str, Optional[CustomGraphModulePass]] = {}
# The code generated by Inductor consists of two main parts: kernel code and wrapper code.
@ -379,10 +381,12 @@ def register_backend_for_device(
device_scheduling: SchedulingConstructor,
device_wrapper_codegen: WrapperConstructor,
device_cpp_wrapper_codegen: Optional[WrapperConstructor] = None,
device_custom_pass: Optional[CustomGraphModulePass] = None,
) -> None:
device_codegens[device] = DeviceCodegen(
device_scheduling, device_wrapper_codegen, device_cpp_wrapper_codegen
)
custom_backend_passes[device] = device_custom_pass
class BackendFeature(Enum):
@ -441,6 +445,10 @@ def get_wrapper_codegen_for_device(
return None
def get_custom_backend_pass_for_device(device: str) -> Optional[CustomGraphModulePass]:
return custom_backend_passes[device] if device in custom_backend_passes else None
@functools.lru_cache(None)
def init_backend_registration() -> None:
from .cpp import CppScheduling

View File

@ -76,6 +76,7 @@ from torch._inductor.utils import (
BoxedBool,
count_tangents,
fresh_inductor_cache,
get_all_devices,
InputType,
is_gpu,
should_assume_input_aligned,
@ -1901,22 +1902,6 @@ def get_cpp_wrapper_config() -> dict[str, object]:
}
def get_all_devices(gm: torch.fx.GraphModule) -> OrderedSet[torch.device]:
placeholder_nodes = gm.graph.find_nodes(op="placeholder")
input_devices: OrderedSet[torch.device] = OrderedSet(
node.meta["val"].device
for node in placeholder_nodes
if isinstance(node.meta.get("val"), torch.Tensor)
)
out_devices: OrderedSet[torch.device] = OrderedSet(
arg.meta["val"].device
for arg in output_node(gm).args[0] # type: ignore[union-attr]
if isinstance(arg, fx.Node) and isinstance(arg.meta.get("val"), torch.Tensor)
)
return input_devices | out_devices
def get_cuda_device_context(gm: torch.fx.GraphModule) -> AbstractContextManager[None]:
"""
Returns a cuda device context manager if there is a single device in the graph

View File

@ -53,6 +53,38 @@ class CustomGraphPass(ABC):
"""
class CustomGraphModulePass(ABC):
"""
Implement this interface for custom Graph passes:
1) The __call__() method contains the implementation of the custom pass.
2) The uuid() method enables inductor to cache compiled graphs when your custom
passes are applied. This method can return any identifier as long as it uniquely
identifies your implementation (and can be pickled). The caching logic includes this
identifier in its key calculation, i.e., any new value will effectively invalidate
existing entries. We expect custom passes would typically depend purely on the
textual reprensentation of the implementation. In that case, we recommend using the
'get_hash_for_files' helper below to compute a unique hash from the contents of a
static list of source files, i.e., the source(s) containing the custom pass
implementation. That approach ensures that any change to the implementation will
mean a new uuid.
"""
@abstractmethod
def __call__(self, gm: torch.fx.GraphModule) -> None:
"""
Implementation of the custom pass.
"""
@abstractmethod
def uuid(self) -> Optional[Any]:
"""
Return an ID to uniquely identify your custom pass implementation. Return None
to skip inductor code caching entirely.
"""
CustomGraphPassType: TypeAlias = Optional[
Union[CustomGraphPass, Callable[[torch.fx.graph.Graph], None]]
]

View File

@ -22,6 +22,7 @@ from torch.fx.experimental.symbolic_shapes import statically_known_true, sym_eq
from torch.utils._ordered_set import OrderedSet
from .. import config, ir, pattern_matcher
from ..codegen.common import custom_backend_passes
from ..comms import remove_fsdp2_unsharded_param_graph_input_usage
from ..fx_utils import FakeTensorUpdater, get_fake_args_kwargs, get_node_storage
from ..lowering import lowerings as L
@ -48,6 +49,7 @@ from ..pattern_matcher import (
)
from ..utils import (
decode_device,
get_all_devices,
get_gpu_type,
is_gpu,
is_pointwise_use,
@ -182,6 +184,13 @@ def post_grad_passes(gm: torch.fx.GraphModule, is_inference: bool):
fake_tensor_updater.incremental_update()
for device, custom_backend_pass in custom_backend_passes.items():
if custom_backend_pass is not None:
gm_devices = [d.type for d in get_all_devices(gm)]
if device in gm_devices:
pass_name = "custom_backend_passes_" + device
GraphTransformObserver(gm, pass_name).apply_gm_pass(custom_backend_pass)
# Keep these last, since they introduces mutation. Look at
# ./fx_passes/README.md for a discussion of mutation invariants.
GraphTransformObserver(gm, "reinplace_inplaceable_ops").apply_graph_pass(

View File

@ -994,6 +994,25 @@ def output_node(gm: torch.fx.GraphModule) -> Node:
return last_node
def get_all_devices(gm: torch.fx.GraphModule) -> OrderedSet[torch.device]:
placeholder_nodes = gm.graph.find_nodes(op="placeholder")
input_devices: OrderedSet[torch.device] = OrderedSet(
node.meta["val"].device
for node in placeholder_nodes
if isinstance(node.meta.get("val"), torch.Tensor)
)
out_arg = output_node(gm).args[0] # type: ignore[union-attr]
out_args = out_arg if isinstance(out_arg, tuple) else (out_arg,)
out_devices: OrderedSet[torch.device] = OrderedSet(
arg.meta["val"].device
for arg in out_args
if isinstance(arg, torch.fx.Node)
and isinstance(arg.meta.get("val"), torch.Tensor)
)
return input_devices | out_devices
_registered_caches: list[Any] = []

View File

@ -14,6 +14,15 @@ from torch.fx.experimental.proxy_tensor import make_fx
from torch._inductor.graph import GraphLowering
from torch._inductor.compile_fx import shape_env_from_inputs
from torch._inductor.codecache import CppCodeCache
from torch._inductor.custom_graph_pass import CustomGraphModulePass
from torch._inductor.codegen.common import (
get_custom_backend_pass_for_device,
get_scheduling_for_device,
get_wrapper_codegen_for_device,
init_backend_registration,
register_backend_for_device
)
from torch._inductor.codegen.wrapper import PythonWrapperCodegen
from torch._inductor.utils import get_gpu_shared_memory, is_big_gpu
from torch._inductor.utils import GPU_TYPES, get_gpu_type, is_gpu
from torch.utils._triton import has_triton
@ -290,3 +299,41 @@ def _quantize_rowwise(x: Tensor, float8_dtype: torch.dtype):
x_fp8 = _to_fp8_saturated(x * scale, float8_dtype)
inverse_scale = scale.reciprocal()
return x_fp8, inverse_scale
@contextlib.contextmanager
def patch_inductor_backend(
device: str,
python_wrapper_codegen: PythonWrapperCodegen = None,
custom_pass: CustomGraphModulePass = None
):
"""
Patch the inductor backend for a specific device.
"""
# Make sure the backend is already registered
init_backend_registration()
# Get the original registration parameters
original_scheduling = get_scheduling_for_device(device)
original_python_wrapper = get_wrapper_codegen_for_device(device, False)
original_cpp_wrapper = get_wrapper_codegen_for_device(device, True)
original_custom_pass = get_custom_backend_pass_for_device(device)
try:
# Register modified backend for the device
register_backend_for_device(
device,
original_scheduling,
python_wrapper_codegen if python_wrapper_codegen is not None else original_python_wrapper,
original_cpp_wrapper,
custom_pass if custom_pass is not None else original_custom_pass
)
yield
finally:
# Restore the original backend
register_backend_for_device(
device,
original_scheduling,
original_python_wrapper,
original_cpp_wrapper,
original_custom_pass
)