mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-10-20 21:14:14 +08:00 
			
		
		
		
	Custom FX pass for inductor's backend registration (#154841)
This PR is related to RFC #153532. It is an extension to Inductor's backend registration interface to allow to register custom FX passes by the backend. Pull Request resolved: https://github.com/pytorch/pytorch/pull/154841 Approved by: https://github.com/jansel Co-authored-by: Jason Ansel <jansel@jansel.net>
This commit is contained in:
		
				
					committed by
					
						 PyTorch MergeBot
						PyTorch MergeBot
					
				
			
			
				
	
			
			
			
						parent
						
							c6b4f98625
						
					
				
				
					commit
					e694280d12
				
			| @ -27,7 +27,11 @@ from torch._inductor.codecache import ( | ||||
|     TensorMetadata, | ||||
|     TensorMetadataAndValues, | ||||
| ) | ||||
| from torch._inductor.custom_graph_pass import CustomGraphPass, get_hash_for_files | ||||
| from torch._inductor.custom_graph_pass import ( | ||||
|     CustomGraphModulePass, | ||||
|     CustomGraphPass, | ||||
|     get_hash_for_files, | ||||
| ) | ||||
| from torch._inductor.graph import GraphLowering | ||||
| from torch._inductor.mock_cache import global_stats, PatchCaches, Stats | ||||
| from torch._inductor.runtime.runtime_utils import cache_dir | ||||
| @ -53,6 +57,7 @@ from torch.testing._internal.inductor_utils import ( | ||||
|     HAS_GPU, | ||||
|     HAS_MULTIGPU, | ||||
|     HAS_TRITON, | ||||
|     patch_inductor_backend, | ||||
|     requires_gpu, | ||||
|     requires_triton, | ||||
| ) | ||||
| @ -2183,6 +2188,42 @@ class TestFxGraphCacheHashing(TestCase): | ||||
|                 pickler.dumps(details3), | ||||
|             ) | ||||
|  | ||||
|     def test_hash_custom_backend_pass(self): | ||||
|         """ | ||||
|         Test CustomGraphModulePass usage. | ||||
|         """ | ||||
|  | ||||
|         class TestCustomGraphModulePass(CustomGraphModulePass): | ||||
|             def __init__(self): | ||||
|                 self._uuid = None | ||||
|  | ||||
|             def __call__(self, gm: torch.fx.GraphModule) -> None: | ||||
|                 return None | ||||
|  | ||||
|             def uuid(self) -> Optional[Union[bytes, str]]: | ||||
|                 return self._uuid | ||||
|  | ||||
|         custom_pass = TestCustomGraphModulePass() | ||||
|         with patch_inductor_backend("cpu", custom_pass=custom_pass): | ||||
|             custom_pass._uuid = "1" | ||||
|             details1 = FxGraphHashDetails(None, [], {}, []) | ||||
|             details2 = FxGraphHashDetails(None, [], {}, []) | ||||
|  | ||||
|             custom_pass._uuid = "2" | ||||
|             details3 = FxGraphHashDetails(None, [], {}, []) | ||||
|  | ||||
|             gm = torch.fx.GraphModule({}, torch.fx.Graph()) | ||||
|             pickler = FxGraphCachePickler(gm) | ||||
|  | ||||
|             self.assertEqual( | ||||
|                 pickler.dumps(details1), | ||||
|                 pickler.dumps(details2), | ||||
|             ) | ||||
|             self.assertNotEqual( | ||||
|                 pickler.dumps(details1), | ||||
|                 pickler.dumps(details3), | ||||
|             ) | ||||
|  | ||||
|     def test_bypass_unsupported(self): | ||||
|         """ | ||||
|         Test _reduce_unsupported | ||||
|  | ||||
| @ -8,12 +8,17 @@ import torch._inductor.pattern_matcher as pattern_matcher | ||||
| import torch.fx as fx | ||||
| from torch._dynamo.utils import counters | ||||
| from torch._inductor import config | ||||
| from torch._inductor.custom_graph_pass import CustomGraphPass, get_hash_for_files | ||||
| from torch._inductor.codegen.common import get_custom_backend_pass_for_device | ||||
| from torch._inductor.custom_graph_pass import ( | ||||
|     CustomGraphModulePass, | ||||
|     CustomGraphPass, | ||||
|     get_hash_for_files, | ||||
| ) | ||||
| from torch._inductor.lowering import lowerings as L | ||||
| from torch._inductor.pattern_matcher import Arg, CallFunction, PatternMatcherPass | ||||
| from torch._inductor.test_case import run_tests, TestCase | ||||
| from torch.testing._internal.common_utils import IS_LINUX | ||||
| from torch.testing._internal.inductor_utils import HAS_CPU | ||||
| from torch.testing._internal.inductor_utils import HAS_CPU, patch_inductor_backend | ||||
|  | ||||
|  | ||||
| @config.patch({"freezing": True}) | ||||
| @ -264,6 +269,35 @@ class TestPostGradCustomPrePostPass(TestCustomPassBase): | ||||
|  | ||||
|         inner_test() | ||||
|  | ||||
|     def test_custom_backend_pass(self): | ||||
|         class CustomBackendPass(CustomGraphModulePass): | ||||
|             def __init__(self, existing_pass: CustomGraphModulePass = None): | ||||
|                 super().__init__() | ||||
|                 self.existing_pass = existing_pass | ||||
|  | ||||
|             def __call__(self, gm: fx.GraphModule) -> None: | ||||
|                 if self.existing_pass: | ||||
|                     self.existing_pass(gm) | ||||
|  | ||||
|                 change_cos_pass(gm.graph) | ||||
|  | ||||
|             def uuid(self) -> bytes: | ||||
|                 return get_hash_for_files((__file__,)) | ||||
|  | ||||
|         custom_backend_pass = CustomBackendPass( | ||||
|             get_custom_backend_pass_for_device("cpu") | ||||
|         ) | ||||
|         with patch_inductor_backend("cpu", custom_pass=custom_backend_pass): | ||||
|  | ||||
|             def g(x): | ||||
|                 return x.sin().sin().sin() | ||||
|  | ||||
|             def f(x): | ||||
|                 return x.cos().cos().cos() | ||||
|  | ||||
|             x = torch.randn(8, dtype=torch.float32) | ||||
|             torch.testing.assert_close(torch.compile(f)(x), g(x)) | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     if IS_LINUX and HAS_CPU and torch.backends.mkldnn.is_available(): | ||||
|  | ||||
| @ -12,8 +12,6 @@ import torch | ||||
| import torch.library | ||||
| from torch._dynamo.testing import CompileCounterWithBackend, make_test_cls_with_patches | ||||
| from torch._inductor import metrics | ||||
| from torch._inductor.codegen.common import device_codegens, register_backend_for_device | ||||
| from torch._inductor.codegen.cpp import CppScheduling | ||||
| from torch._inductor.codegen.wrapper import PythonWrapperCodegen | ||||
| from torch._inductor.test_case import TestCase | ||||
| from torch._inductor.utils import run_and_get_code | ||||
| @ -34,7 +32,12 @@ from torch.testing._internal.common_utils import ( | ||||
|     TEST_WITH_ASAN, | ||||
|     TEST_WITH_ROCM, | ||||
| ) | ||||
| from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_CPU, HAS_GPU | ||||
| from torch.testing._internal.inductor_utils import ( | ||||
|     GPU_TYPE, | ||||
|     HAS_CPU, | ||||
|     HAS_GPU, | ||||
|     patch_inductor_backend, | ||||
| ) | ||||
|  | ||||
|  | ||||
| # Make the helper files in test/ importable | ||||
| @ -932,23 +935,13 @@ class TestInductorDynamic(TestCase): | ||||
|                 _test_wrapper_codegen_statically_known_int_or_none_in_context() | ||||
|                 return super().generate(is_inference, *args, **kwargs) | ||||
|  | ||||
|         if "cpu" not in device_codegens: | ||||
|             register_backend_for_device("cpu", CppScheduling, PythonWrapperCodegen) | ||||
|         orig_cpu_codegens = device_codegens["cpu"] | ||||
|         try: | ||||
|             register_backend_for_device( | ||||
|                 "cpu", orig_cpu_codegens.scheduling, TestWrapperCodegen | ||||
|             ) | ||||
|         with patch_inductor_backend("cpu", python_wrapper_codegen=TestWrapperCodegen): | ||||
|             # Compile each of the functions above, with an example input | ||||
|             # that has 5 in the first dimension, but is marked as dynamic | ||||
|  | ||||
|             torch.compile(backend="inductor", dynamic=None)(fn_1)(_x) | ||||
|             torch.compile(backend="inductor", dynamic=None)(fn_2)(_x) | ||||
|             torch.compile(backend="inductor", dynamic=None)(fn_3)(_x) | ||||
|         finally: | ||||
|             register_backend_for_device( | ||||
|                 "cpu", orig_cpu_codegens.scheduling, orig_cpu_codegens.wrapper_codegen | ||||
|             ) | ||||
|  | ||||
|     @torch._dynamo.config.patch(capture_scalar_outputs=True) | ||||
|     def test_item_unbacked_stride_nobreak(self, device): | ||||
|  | ||||
| @ -51,6 +51,7 @@ from torch import SymInt, Tensor | ||||
| from torch._dynamo.exc import SkipFrame | ||||
| from torch._dynamo.utils import CompileEventLogger, counters, dynamo_timed | ||||
| from torch._inductor import config, exc, metrics | ||||
| from torch._inductor.codegen.common import custom_backend_passes | ||||
| from torch._inductor.codegen.cuda import cuda_env | ||||
| from torch._inductor.codegen.rocm.compile_command import ( | ||||
|     rocm_compile_command, | ||||
| @ -72,7 +73,11 @@ from torch._inductor.cpp_builder import ( | ||||
|     normalize_path_separator, | ||||
| ) | ||||
| from torch._inductor.cpu_vec_isa import pick_vec_isa | ||||
| from torch._inductor.custom_graph_pass import CustomGraphPass, CustomGraphPassType | ||||
| from torch._inductor.custom_graph_pass import ( | ||||
|     CustomGraphModulePass, | ||||
|     CustomGraphPass, | ||||
|     CustomGraphPassType, | ||||
| ) | ||||
| from torch._inductor.freezing_utils import has_frozen_params, is_frozen_param | ||||
| from torch._inductor.runtime.compile_tasks import _reload_python_module | ||||
| from torch._inductor.runtime.runtime_utils import cache_dir, default_cache_dir | ||||
| @ -891,12 +896,16 @@ class FxGraphHashDetails: | ||||
|             config.post_grad_custom_post_pass | ||||
|         ) | ||||
|  | ||||
|         self.custom_backend_passes = tuple( | ||||
|             map(self._get_custom_pass_detail, custom_backend_passes.values()) | ||||
|         ) | ||||
|  | ||||
|     def _get_custom_pass_detail( | ||||
|         self, custom_pass: CustomGraphPassType | ||||
|         self, custom_pass: Union[CustomGraphPassType, CustomGraphModulePass] | ||||
|     ) -> Optional[Any]: | ||||
|         if not custom_pass: | ||||
|             return None | ||||
|         assert isinstance(custom_pass, CustomGraphPass) | ||||
|         assert isinstance(custom_pass, (CustomGraphPass, CustomGraphModulePass)) | ||||
|         return custom_pass.uuid() | ||||
|  | ||||
|  | ||||
|  | ||||
| @ -65,6 +65,7 @@ if TYPE_CHECKING: | ||||
|  | ||||
|     from torch.fx import GraphModule | ||||
|  | ||||
|     from ..custom_graph_pass import CustomGraphModulePass | ||||
|     from ..ir import Buffer, ChoiceCaller, FixedLayout, IRNode | ||||
|     from ..loop_body import LoopBody | ||||
|     from ..scheduler import BaseScheduling, Scheduler, SchedulerNode | ||||
| @ -351,6 +352,7 @@ class DeviceOpOverrides: | ||||
|  | ||||
|  | ||||
| device_op_overrides_dict: dict[str, DeviceOpOverrides] = {} | ||||
| custom_backend_passes: dict[str, Optional[CustomGraphModulePass]] = {} | ||||
|  | ||||
|  | ||||
| # The code generated by Inductor consists of two main parts: kernel code and wrapper code. | ||||
| @ -379,10 +381,12 @@ def register_backend_for_device( | ||||
|     device_scheduling: SchedulingConstructor, | ||||
|     device_wrapper_codegen: WrapperConstructor, | ||||
|     device_cpp_wrapper_codegen: Optional[WrapperConstructor] = None, | ||||
|     device_custom_pass: Optional[CustomGraphModulePass] = None, | ||||
| ) -> None: | ||||
|     device_codegens[device] = DeviceCodegen( | ||||
|         device_scheduling, device_wrapper_codegen, device_cpp_wrapper_codegen | ||||
|     ) | ||||
|     custom_backend_passes[device] = device_custom_pass | ||||
|  | ||||
|  | ||||
| class BackendFeature(Enum): | ||||
| @ -441,6 +445,10 @@ def get_wrapper_codegen_for_device( | ||||
|     return None | ||||
|  | ||||
|  | ||||
| def get_custom_backend_pass_for_device(device: str) -> Optional[CustomGraphModulePass]: | ||||
|     return custom_backend_passes[device] if device in custom_backend_passes else None | ||||
|  | ||||
|  | ||||
| @functools.lru_cache(None) | ||||
| def init_backend_registration() -> None: | ||||
|     from .cpp import CppScheduling | ||||
|  | ||||
| @ -76,6 +76,7 @@ from torch._inductor.utils import ( | ||||
|     BoxedBool, | ||||
|     count_tangents, | ||||
|     fresh_inductor_cache, | ||||
|     get_all_devices, | ||||
|     InputType, | ||||
|     is_gpu, | ||||
|     should_assume_input_aligned, | ||||
| @ -1901,22 +1902,6 @@ def get_cpp_wrapper_config() -> dict[str, object]: | ||||
|     } | ||||
|  | ||||
|  | ||||
| def get_all_devices(gm: torch.fx.GraphModule) -> OrderedSet[torch.device]: | ||||
|     placeholder_nodes = gm.graph.find_nodes(op="placeholder") | ||||
|     input_devices: OrderedSet[torch.device] = OrderedSet( | ||||
|         node.meta["val"].device | ||||
|         for node in placeholder_nodes | ||||
|         if isinstance(node.meta.get("val"), torch.Tensor) | ||||
|     ) | ||||
|  | ||||
|     out_devices: OrderedSet[torch.device] = OrderedSet( | ||||
|         arg.meta["val"].device | ||||
|         for arg in output_node(gm).args[0]  # type: ignore[union-attr] | ||||
|         if isinstance(arg, fx.Node) and isinstance(arg.meta.get("val"), torch.Tensor) | ||||
|     ) | ||||
|     return input_devices | out_devices | ||||
|  | ||||
|  | ||||
| def get_cuda_device_context(gm: torch.fx.GraphModule) -> AbstractContextManager[None]: | ||||
|     """ | ||||
|     Returns a cuda device context manager if there is a single device in the graph | ||||
|  | ||||
| @ -53,6 +53,38 @@ class CustomGraphPass(ABC): | ||||
|         """ | ||||
|  | ||||
|  | ||||
| class CustomGraphModulePass(ABC): | ||||
|     """ | ||||
|     Implement this interface for custom Graph passes: | ||||
|  | ||||
|     1) The __call__() method contains the implementation of the custom pass. | ||||
|  | ||||
|     2) The uuid() method enables inductor to cache compiled graphs when your custom | ||||
|     passes are applied. This method can return any identifier as long as it uniquely | ||||
|     identifies your implementation (and can be pickled). The caching logic includes this | ||||
|     identifier in its key calculation, i.e., any new value will effectively invalidate | ||||
|     existing entries. We expect custom passes would typically depend purely on the | ||||
|     textual reprensentation of the implementation. In that case, we recommend using the | ||||
|     'get_hash_for_files' helper below to compute a unique hash from the contents of a | ||||
|     static list of source files, i.e., the source(s) containing the custom pass | ||||
|     implementation. That approach ensures that any change to the implementation will | ||||
|     mean a new uuid. | ||||
|     """ | ||||
|  | ||||
|     @abstractmethod | ||||
|     def __call__(self, gm: torch.fx.GraphModule) -> None: | ||||
|         """ | ||||
|         Implementation of the custom pass. | ||||
|         """ | ||||
|  | ||||
|     @abstractmethod | ||||
|     def uuid(self) -> Optional[Any]: | ||||
|         """ | ||||
|         Return an ID to uniquely identify your custom pass implementation. Return None | ||||
|         to skip inductor code caching entirely. | ||||
|         """ | ||||
|  | ||||
|  | ||||
| CustomGraphPassType: TypeAlias = Optional[ | ||||
|     Union[CustomGraphPass, Callable[[torch.fx.graph.Graph], None]] | ||||
| ] | ||||
|  | ||||
| @ -22,6 +22,7 @@ from torch.fx.experimental.symbolic_shapes import statically_known_true, sym_eq | ||||
| from torch.utils._ordered_set import OrderedSet | ||||
|  | ||||
| from .. import config, ir, pattern_matcher | ||||
| from ..codegen.common import custom_backend_passes | ||||
| from ..comms import remove_fsdp2_unsharded_param_graph_input_usage | ||||
| from ..fx_utils import FakeTensorUpdater, get_fake_args_kwargs, get_node_storage | ||||
| from ..lowering import lowerings as L | ||||
| @ -48,6 +49,7 @@ from ..pattern_matcher import ( | ||||
| ) | ||||
| from ..utils import ( | ||||
|     decode_device, | ||||
|     get_all_devices, | ||||
|     get_gpu_type, | ||||
|     is_gpu, | ||||
|     is_pointwise_use, | ||||
| @ -182,6 +184,13 @@ def post_grad_passes(gm: torch.fx.GraphModule, is_inference: bool): | ||||
|  | ||||
|     fake_tensor_updater.incremental_update() | ||||
|  | ||||
|     for device, custom_backend_pass in custom_backend_passes.items(): | ||||
|         if custom_backend_pass is not None: | ||||
|             gm_devices = [d.type for d in get_all_devices(gm)] | ||||
|             if device in gm_devices: | ||||
|                 pass_name = "custom_backend_passes_" + device | ||||
|                 GraphTransformObserver(gm, pass_name).apply_gm_pass(custom_backend_pass) | ||||
|  | ||||
|     # Keep these last, since they introduces mutation. Look at | ||||
|     # ./fx_passes/README.md for a discussion of mutation invariants. | ||||
|     GraphTransformObserver(gm, "reinplace_inplaceable_ops").apply_graph_pass( | ||||
|  | ||||
| @ -994,6 +994,25 @@ def output_node(gm: torch.fx.GraphModule) -> Node: | ||||
|     return last_node | ||||
|  | ||||
|  | ||||
| def get_all_devices(gm: torch.fx.GraphModule) -> OrderedSet[torch.device]: | ||||
|     placeholder_nodes = gm.graph.find_nodes(op="placeholder") | ||||
|     input_devices: OrderedSet[torch.device] = OrderedSet( | ||||
|         node.meta["val"].device | ||||
|         for node in placeholder_nodes | ||||
|         if isinstance(node.meta.get("val"), torch.Tensor) | ||||
|     ) | ||||
|  | ||||
|     out_arg = output_node(gm).args[0]  # type: ignore[union-attr] | ||||
|     out_args = out_arg if isinstance(out_arg, tuple) else (out_arg,) | ||||
|     out_devices: OrderedSet[torch.device] = OrderedSet( | ||||
|         arg.meta["val"].device | ||||
|         for arg in out_args | ||||
|         if isinstance(arg, torch.fx.Node) | ||||
|         and isinstance(arg.meta.get("val"), torch.Tensor) | ||||
|     ) | ||||
|     return input_devices | out_devices | ||||
|  | ||||
|  | ||||
| _registered_caches: list[Any] = [] | ||||
|  | ||||
|  | ||||
|  | ||||
| @ -14,6 +14,15 @@ from torch.fx.experimental.proxy_tensor import make_fx | ||||
| from torch._inductor.graph import GraphLowering | ||||
| from torch._inductor.compile_fx import shape_env_from_inputs | ||||
| from torch._inductor.codecache import CppCodeCache | ||||
| from torch._inductor.custom_graph_pass import CustomGraphModulePass | ||||
| from torch._inductor.codegen.common import ( | ||||
|     get_custom_backend_pass_for_device, | ||||
|     get_scheduling_for_device, | ||||
|     get_wrapper_codegen_for_device, | ||||
|     init_backend_registration, | ||||
|     register_backend_for_device | ||||
| ) | ||||
| from torch._inductor.codegen.wrapper import PythonWrapperCodegen | ||||
| from torch._inductor.utils import get_gpu_shared_memory, is_big_gpu | ||||
| from torch._inductor.utils import GPU_TYPES, get_gpu_type, is_gpu | ||||
| from torch.utils._triton import has_triton | ||||
| @ -290,3 +299,41 @@ def _quantize_rowwise(x: Tensor, float8_dtype: torch.dtype): | ||||
|     x_fp8 = _to_fp8_saturated(x * scale, float8_dtype) | ||||
|     inverse_scale = scale.reciprocal() | ||||
|     return x_fp8, inverse_scale | ||||
|  | ||||
| @contextlib.contextmanager | ||||
| def patch_inductor_backend( | ||||
|     device: str, | ||||
|     python_wrapper_codegen: PythonWrapperCodegen = None, | ||||
|     custom_pass: CustomGraphModulePass = None | ||||
| ): | ||||
|     """ | ||||
|     Patch the inductor backend for a specific device. | ||||
|     """ | ||||
|     # Make sure the backend is already registered | ||||
|     init_backend_registration() | ||||
|  | ||||
|     # Get the original registration parameters | ||||
|     original_scheduling = get_scheduling_for_device(device) | ||||
|     original_python_wrapper = get_wrapper_codegen_for_device(device, False) | ||||
|     original_cpp_wrapper = get_wrapper_codegen_for_device(device, True) | ||||
|     original_custom_pass = get_custom_backend_pass_for_device(device) | ||||
|  | ||||
|     try: | ||||
|         # Register modified backend for the device | ||||
|         register_backend_for_device( | ||||
|             device, | ||||
|             original_scheduling, | ||||
|             python_wrapper_codegen if python_wrapper_codegen is not None else original_python_wrapper, | ||||
|             original_cpp_wrapper, | ||||
|             custom_pass if custom_pass is not None else original_custom_pass | ||||
|         ) | ||||
|         yield | ||||
|     finally: | ||||
|         # Restore the original backend | ||||
|         register_backend_for_device( | ||||
|             device, | ||||
|             original_scheduling, | ||||
|             original_python_wrapper, | ||||
|             original_cpp_wrapper, | ||||
|             original_custom_pass | ||||
|         ) | ||||
|  | ||||
		Reference in New Issue
	
	Block a user