mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Custom FX pass for inductor's backend registration (#154841)
This PR is related to RFC #153532. It is an extension to Inductor's backend registration interface to allow to register custom FX passes by the backend. Pull Request resolved: https://github.com/pytorch/pytorch/pull/154841 Approved by: https://github.com/jansel Co-authored-by: Jason Ansel <jansel@jansel.net>
This commit is contained in:
committed by
PyTorch MergeBot
parent
c6b4f98625
commit
e694280d12
@ -27,7 +27,11 @@ from torch._inductor.codecache import (
|
||||
TensorMetadata,
|
||||
TensorMetadataAndValues,
|
||||
)
|
||||
from torch._inductor.custom_graph_pass import CustomGraphPass, get_hash_for_files
|
||||
from torch._inductor.custom_graph_pass import (
|
||||
CustomGraphModulePass,
|
||||
CustomGraphPass,
|
||||
get_hash_for_files,
|
||||
)
|
||||
from torch._inductor.graph import GraphLowering
|
||||
from torch._inductor.mock_cache import global_stats, PatchCaches, Stats
|
||||
from torch._inductor.runtime.runtime_utils import cache_dir
|
||||
@ -53,6 +57,7 @@ from torch.testing._internal.inductor_utils import (
|
||||
HAS_GPU,
|
||||
HAS_MULTIGPU,
|
||||
HAS_TRITON,
|
||||
patch_inductor_backend,
|
||||
requires_gpu,
|
||||
requires_triton,
|
||||
)
|
||||
@ -2183,6 +2188,42 @@ class TestFxGraphCacheHashing(TestCase):
|
||||
pickler.dumps(details3),
|
||||
)
|
||||
|
||||
def test_hash_custom_backend_pass(self):
|
||||
"""
|
||||
Test CustomGraphModulePass usage.
|
||||
"""
|
||||
|
||||
class TestCustomGraphModulePass(CustomGraphModulePass):
|
||||
def __init__(self):
|
||||
self._uuid = None
|
||||
|
||||
def __call__(self, gm: torch.fx.GraphModule) -> None:
|
||||
return None
|
||||
|
||||
def uuid(self) -> Optional[Union[bytes, str]]:
|
||||
return self._uuid
|
||||
|
||||
custom_pass = TestCustomGraphModulePass()
|
||||
with patch_inductor_backend("cpu", custom_pass=custom_pass):
|
||||
custom_pass._uuid = "1"
|
||||
details1 = FxGraphHashDetails(None, [], {}, [])
|
||||
details2 = FxGraphHashDetails(None, [], {}, [])
|
||||
|
||||
custom_pass._uuid = "2"
|
||||
details3 = FxGraphHashDetails(None, [], {}, [])
|
||||
|
||||
gm = torch.fx.GraphModule({}, torch.fx.Graph())
|
||||
pickler = FxGraphCachePickler(gm)
|
||||
|
||||
self.assertEqual(
|
||||
pickler.dumps(details1),
|
||||
pickler.dumps(details2),
|
||||
)
|
||||
self.assertNotEqual(
|
||||
pickler.dumps(details1),
|
||||
pickler.dumps(details3),
|
||||
)
|
||||
|
||||
def test_bypass_unsupported(self):
|
||||
"""
|
||||
Test _reduce_unsupported
|
||||
|
@ -8,12 +8,17 @@ import torch._inductor.pattern_matcher as pattern_matcher
|
||||
import torch.fx as fx
|
||||
from torch._dynamo.utils import counters
|
||||
from torch._inductor import config
|
||||
from torch._inductor.custom_graph_pass import CustomGraphPass, get_hash_for_files
|
||||
from torch._inductor.codegen.common import get_custom_backend_pass_for_device
|
||||
from torch._inductor.custom_graph_pass import (
|
||||
CustomGraphModulePass,
|
||||
CustomGraphPass,
|
||||
get_hash_for_files,
|
||||
)
|
||||
from torch._inductor.lowering import lowerings as L
|
||||
from torch._inductor.pattern_matcher import Arg, CallFunction, PatternMatcherPass
|
||||
from torch._inductor.test_case import run_tests, TestCase
|
||||
from torch.testing._internal.common_utils import IS_LINUX
|
||||
from torch.testing._internal.inductor_utils import HAS_CPU
|
||||
from torch.testing._internal.inductor_utils import HAS_CPU, patch_inductor_backend
|
||||
|
||||
|
||||
@config.patch({"freezing": True})
|
||||
@ -264,6 +269,35 @@ class TestPostGradCustomPrePostPass(TestCustomPassBase):
|
||||
|
||||
inner_test()
|
||||
|
||||
def test_custom_backend_pass(self):
|
||||
class CustomBackendPass(CustomGraphModulePass):
|
||||
def __init__(self, existing_pass: CustomGraphModulePass = None):
|
||||
super().__init__()
|
||||
self.existing_pass = existing_pass
|
||||
|
||||
def __call__(self, gm: fx.GraphModule) -> None:
|
||||
if self.existing_pass:
|
||||
self.existing_pass(gm)
|
||||
|
||||
change_cos_pass(gm.graph)
|
||||
|
||||
def uuid(self) -> bytes:
|
||||
return get_hash_for_files((__file__,))
|
||||
|
||||
custom_backend_pass = CustomBackendPass(
|
||||
get_custom_backend_pass_for_device("cpu")
|
||||
)
|
||||
with patch_inductor_backend("cpu", custom_pass=custom_backend_pass):
|
||||
|
||||
def g(x):
|
||||
return x.sin().sin().sin()
|
||||
|
||||
def f(x):
|
||||
return x.cos().cos().cos()
|
||||
|
||||
x = torch.randn(8, dtype=torch.float32)
|
||||
torch.testing.assert_close(torch.compile(f)(x), g(x))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if IS_LINUX and HAS_CPU and torch.backends.mkldnn.is_available():
|
||||
|
@ -12,8 +12,6 @@ import torch
|
||||
import torch.library
|
||||
from torch._dynamo.testing import CompileCounterWithBackend, make_test_cls_with_patches
|
||||
from torch._inductor import metrics
|
||||
from torch._inductor.codegen.common import device_codegens, register_backend_for_device
|
||||
from torch._inductor.codegen.cpp import CppScheduling
|
||||
from torch._inductor.codegen.wrapper import PythonWrapperCodegen
|
||||
from torch._inductor.test_case import TestCase
|
||||
from torch._inductor.utils import run_and_get_code
|
||||
@ -34,7 +32,12 @@ from torch.testing._internal.common_utils import (
|
||||
TEST_WITH_ASAN,
|
||||
TEST_WITH_ROCM,
|
||||
)
|
||||
from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_CPU, HAS_GPU
|
||||
from torch.testing._internal.inductor_utils import (
|
||||
GPU_TYPE,
|
||||
HAS_CPU,
|
||||
HAS_GPU,
|
||||
patch_inductor_backend,
|
||||
)
|
||||
|
||||
|
||||
# Make the helper files in test/ importable
|
||||
@ -932,23 +935,13 @@ class TestInductorDynamic(TestCase):
|
||||
_test_wrapper_codegen_statically_known_int_or_none_in_context()
|
||||
return super().generate(is_inference, *args, **kwargs)
|
||||
|
||||
if "cpu" not in device_codegens:
|
||||
register_backend_for_device("cpu", CppScheduling, PythonWrapperCodegen)
|
||||
orig_cpu_codegens = device_codegens["cpu"]
|
||||
try:
|
||||
register_backend_for_device(
|
||||
"cpu", orig_cpu_codegens.scheduling, TestWrapperCodegen
|
||||
)
|
||||
with patch_inductor_backend("cpu", python_wrapper_codegen=TestWrapperCodegen):
|
||||
# Compile each of the functions above, with an example input
|
||||
# that has 5 in the first dimension, but is marked as dynamic
|
||||
|
||||
torch.compile(backend="inductor", dynamic=None)(fn_1)(_x)
|
||||
torch.compile(backend="inductor", dynamic=None)(fn_2)(_x)
|
||||
torch.compile(backend="inductor", dynamic=None)(fn_3)(_x)
|
||||
finally:
|
||||
register_backend_for_device(
|
||||
"cpu", orig_cpu_codegens.scheduling, orig_cpu_codegens.wrapper_codegen
|
||||
)
|
||||
|
||||
@torch._dynamo.config.patch(capture_scalar_outputs=True)
|
||||
def test_item_unbacked_stride_nobreak(self, device):
|
||||
|
@ -51,6 +51,7 @@ from torch import SymInt, Tensor
|
||||
from torch._dynamo.exc import SkipFrame
|
||||
from torch._dynamo.utils import CompileEventLogger, counters, dynamo_timed
|
||||
from torch._inductor import config, exc, metrics
|
||||
from torch._inductor.codegen.common import custom_backend_passes
|
||||
from torch._inductor.codegen.cuda import cuda_env
|
||||
from torch._inductor.codegen.rocm.compile_command import (
|
||||
rocm_compile_command,
|
||||
@ -72,7 +73,11 @@ from torch._inductor.cpp_builder import (
|
||||
normalize_path_separator,
|
||||
)
|
||||
from torch._inductor.cpu_vec_isa import pick_vec_isa
|
||||
from torch._inductor.custom_graph_pass import CustomGraphPass, CustomGraphPassType
|
||||
from torch._inductor.custom_graph_pass import (
|
||||
CustomGraphModulePass,
|
||||
CustomGraphPass,
|
||||
CustomGraphPassType,
|
||||
)
|
||||
from torch._inductor.freezing_utils import has_frozen_params, is_frozen_param
|
||||
from torch._inductor.runtime.compile_tasks import _reload_python_module
|
||||
from torch._inductor.runtime.runtime_utils import cache_dir, default_cache_dir
|
||||
@ -891,12 +896,16 @@ class FxGraphHashDetails:
|
||||
config.post_grad_custom_post_pass
|
||||
)
|
||||
|
||||
self.custom_backend_passes = tuple(
|
||||
map(self._get_custom_pass_detail, custom_backend_passes.values())
|
||||
)
|
||||
|
||||
def _get_custom_pass_detail(
|
||||
self, custom_pass: CustomGraphPassType
|
||||
self, custom_pass: Union[CustomGraphPassType, CustomGraphModulePass]
|
||||
) -> Optional[Any]:
|
||||
if not custom_pass:
|
||||
return None
|
||||
assert isinstance(custom_pass, CustomGraphPass)
|
||||
assert isinstance(custom_pass, (CustomGraphPass, CustomGraphModulePass))
|
||||
return custom_pass.uuid()
|
||||
|
||||
|
||||
|
@ -65,6 +65,7 @@ if TYPE_CHECKING:
|
||||
|
||||
from torch.fx import GraphModule
|
||||
|
||||
from ..custom_graph_pass import CustomGraphModulePass
|
||||
from ..ir import Buffer, ChoiceCaller, FixedLayout, IRNode
|
||||
from ..loop_body import LoopBody
|
||||
from ..scheduler import BaseScheduling, Scheduler, SchedulerNode
|
||||
@ -351,6 +352,7 @@ class DeviceOpOverrides:
|
||||
|
||||
|
||||
device_op_overrides_dict: dict[str, DeviceOpOverrides] = {}
|
||||
custom_backend_passes: dict[str, Optional[CustomGraphModulePass]] = {}
|
||||
|
||||
|
||||
# The code generated by Inductor consists of two main parts: kernel code and wrapper code.
|
||||
@ -379,10 +381,12 @@ def register_backend_for_device(
|
||||
device_scheduling: SchedulingConstructor,
|
||||
device_wrapper_codegen: WrapperConstructor,
|
||||
device_cpp_wrapper_codegen: Optional[WrapperConstructor] = None,
|
||||
device_custom_pass: Optional[CustomGraphModulePass] = None,
|
||||
) -> None:
|
||||
device_codegens[device] = DeviceCodegen(
|
||||
device_scheduling, device_wrapper_codegen, device_cpp_wrapper_codegen
|
||||
)
|
||||
custom_backend_passes[device] = device_custom_pass
|
||||
|
||||
|
||||
class BackendFeature(Enum):
|
||||
@ -441,6 +445,10 @@ def get_wrapper_codegen_for_device(
|
||||
return None
|
||||
|
||||
|
||||
def get_custom_backend_pass_for_device(device: str) -> Optional[CustomGraphModulePass]:
|
||||
return custom_backend_passes[device] if device in custom_backend_passes else None
|
||||
|
||||
|
||||
@functools.lru_cache(None)
|
||||
def init_backend_registration() -> None:
|
||||
from .cpp import CppScheduling
|
||||
|
@ -76,6 +76,7 @@ from torch._inductor.utils import (
|
||||
BoxedBool,
|
||||
count_tangents,
|
||||
fresh_inductor_cache,
|
||||
get_all_devices,
|
||||
InputType,
|
||||
is_gpu,
|
||||
should_assume_input_aligned,
|
||||
@ -1901,22 +1902,6 @@ def get_cpp_wrapper_config() -> dict[str, object]:
|
||||
}
|
||||
|
||||
|
||||
def get_all_devices(gm: torch.fx.GraphModule) -> OrderedSet[torch.device]:
|
||||
placeholder_nodes = gm.graph.find_nodes(op="placeholder")
|
||||
input_devices: OrderedSet[torch.device] = OrderedSet(
|
||||
node.meta["val"].device
|
||||
for node in placeholder_nodes
|
||||
if isinstance(node.meta.get("val"), torch.Tensor)
|
||||
)
|
||||
|
||||
out_devices: OrderedSet[torch.device] = OrderedSet(
|
||||
arg.meta["val"].device
|
||||
for arg in output_node(gm).args[0] # type: ignore[union-attr]
|
||||
if isinstance(arg, fx.Node) and isinstance(arg.meta.get("val"), torch.Tensor)
|
||||
)
|
||||
return input_devices | out_devices
|
||||
|
||||
|
||||
def get_cuda_device_context(gm: torch.fx.GraphModule) -> AbstractContextManager[None]:
|
||||
"""
|
||||
Returns a cuda device context manager if there is a single device in the graph
|
||||
|
@ -53,6 +53,38 @@ class CustomGraphPass(ABC):
|
||||
"""
|
||||
|
||||
|
||||
class CustomGraphModulePass(ABC):
|
||||
"""
|
||||
Implement this interface for custom Graph passes:
|
||||
|
||||
1) The __call__() method contains the implementation of the custom pass.
|
||||
|
||||
2) The uuid() method enables inductor to cache compiled graphs when your custom
|
||||
passes are applied. This method can return any identifier as long as it uniquely
|
||||
identifies your implementation (and can be pickled). The caching logic includes this
|
||||
identifier in its key calculation, i.e., any new value will effectively invalidate
|
||||
existing entries. We expect custom passes would typically depend purely on the
|
||||
textual reprensentation of the implementation. In that case, we recommend using the
|
||||
'get_hash_for_files' helper below to compute a unique hash from the contents of a
|
||||
static list of source files, i.e., the source(s) containing the custom pass
|
||||
implementation. That approach ensures that any change to the implementation will
|
||||
mean a new uuid.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def __call__(self, gm: torch.fx.GraphModule) -> None:
|
||||
"""
|
||||
Implementation of the custom pass.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def uuid(self) -> Optional[Any]:
|
||||
"""
|
||||
Return an ID to uniquely identify your custom pass implementation. Return None
|
||||
to skip inductor code caching entirely.
|
||||
"""
|
||||
|
||||
|
||||
CustomGraphPassType: TypeAlias = Optional[
|
||||
Union[CustomGraphPass, Callable[[torch.fx.graph.Graph], None]]
|
||||
]
|
||||
|
@ -22,6 +22,7 @@ from torch.fx.experimental.symbolic_shapes import statically_known_true, sym_eq
|
||||
from torch.utils._ordered_set import OrderedSet
|
||||
|
||||
from .. import config, ir, pattern_matcher
|
||||
from ..codegen.common import custom_backend_passes
|
||||
from ..comms import remove_fsdp2_unsharded_param_graph_input_usage
|
||||
from ..fx_utils import FakeTensorUpdater, get_fake_args_kwargs, get_node_storage
|
||||
from ..lowering import lowerings as L
|
||||
@ -48,6 +49,7 @@ from ..pattern_matcher import (
|
||||
)
|
||||
from ..utils import (
|
||||
decode_device,
|
||||
get_all_devices,
|
||||
get_gpu_type,
|
||||
is_gpu,
|
||||
is_pointwise_use,
|
||||
@ -182,6 +184,13 @@ def post_grad_passes(gm: torch.fx.GraphModule, is_inference: bool):
|
||||
|
||||
fake_tensor_updater.incremental_update()
|
||||
|
||||
for device, custom_backend_pass in custom_backend_passes.items():
|
||||
if custom_backend_pass is not None:
|
||||
gm_devices = [d.type for d in get_all_devices(gm)]
|
||||
if device in gm_devices:
|
||||
pass_name = "custom_backend_passes_" + device
|
||||
GraphTransformObserver(gm, pass_name).apply_gm_pass(custom_backend_pass)
|
||||
|
||||
# Keep these last, since they introduces mutation. Look at
|
||||
# ./fx_passes/README.md for a discussion of mutation invariants.
|
||||
GraphTransformObserver(gm, "reinplace_inplaceable_ops").apply_graph_pass(
|
||||
|
@ -994,6 +994,25 @@ def output_node(gm: torch.fx.GraphModule) -> Node:
|
||||
return last_node
|
||||
|
||||
|
||||
def get_all_devices(gm: torch.fx.GraphModule) -> OrderedSet[torch.device]:
|
||||
placeholder_nodes = gm.graph.find_nodes(op="placeholder")
|
||||
input_devices: OrderedSet[torch.device] = OrderedSet(
|
||||
node.meta["val"].device
|
||||
for node in placeholder_nodes
|
||||
if isinstance(node.meta.get("val"), torch.Tensor)
|
||||
)
|
||||
|
||||
out_arg = output_node(gm).args[0] # type: ignore[union-attr]
|
||||
out_args = out_arg if isinstance(out_arg, tuple) else (out_arg,)
|
||||
out_devices: OrderedSet[torch.device] = OrderedSet(
|
||||
arg.meta["val"].device
|
||||
for arg in out_args
|
||||
if isinstance(arg, torch.fx.Node)
|
||||
and isinstance(arg.meta.get("val"), torch.Tensor)
|
||||
)
|
||||
return input_devices | out_devices
|
||||
|
||||
|
||||
_registered_caches: list[Any] = []
|
||||
|
||||
|
||||
|
@ -14,6 +14,15 @@ from torch.fx.experimental.proxy_tensor import make_fx
|
||||
from torch._inductor.graph import GraphLowering
|
||||
from torch._inductor.compile_fx import shape_env_from_inputs
|
||||
from torch._inductor.codecache import CppCodeCache
|
||||
from torch._inductor.custom_graph_pass import CustomGraphModulePass
|
||||
from torch._inductor.codegen.common import (
|
||||
get_custom_backend_pass_for_device,
|
||||
get_scheduling_for_device,
|
||||
get_wrapper_codegen_for_device,
|
||||
init_backend_registration,
|
||||
register_backend_for_device
|
||||
)
|
||||
from torch._inductor.codegen.wrapper import PythonWrapperCodegen
|
||||
from torch._inductor.utils import get_gpu_shared_memory, is_big_gpu
|
||||
from torch._inductor.utils import GPU_TYPES, get_gpu_type, is_gpu
|
||||
from torch.utils._triton import has_triton
|
||||
@ -290,3 +299,41 @@ def _quantize_rowwise(x: Tensor, float8_dtype: torch.dtype):
|
||||
x_fp8 = _to_fp8_saturated(x * scale, float8_dtype)
|
||||
inverse_scale = scale.reciprocal()
|
||||
return x_fp8, inverse_scale
|
||||
|
||||
@contextlib.contextmanager
|
||||
def patch_inductor_backend(
|
||||
device: str,
|
||||
python_wrapper_codegen: PythonWrapperCodegen = None,
|
||||
custom_pass: CustomGraphModulePass = None
|
||||
):
|
||||
"""
|
||||
Patch the inductor backend for a specific device.
|
||||
"""
|
||||
# Make sure the backend is already registered
|
||||
init_backend_registration()
|
||||
|
||||
# Get the original registration parameters
|
||||
original_scheduling = get_scheduling_for_device(device)
|
||||
original_python_wrapper = get_wrapper_codegen_for_device(device, False)
|
||||
original_cpp_wrapper = get_wrapper_codegen_for_device(device, True)
|
||||
original_custom_pass = get_custom_backend_pass_for_device(device)
|
||||
|
||||
try:
|
||||
# Register modified backend for the device
|
||||
register_backend_for_device(
|
||||
device,
|
||||
original_scheduling,
|
||||
python_wrapper_codegen if python_wrapper_codegen is not None else original_python_wrapper,
|
||||
original_cpp_wrapper,
|
||||
custom_pass if custom_pass is not None else original_custom_pass
|
||||
)
|
||||
yield
|
||||
finally:
|
||||
# Restore the original backend
|
||||
register_backend_for_device(
|
||||
device,
|
||||
original_scheduling,
|
||||
original_python_wrapper,
|
||||
original_cpp_wrapper,
|
||||
original_custom_pass
|
||||
)
|
||||
|
Reference in New Issue
Block a user