[dynamo] Allow guards to be dropped with custom filter functions. (#150936)

Summary: A follow up of https://github.com/pytorch/pytorch/pull/150689.

Test Plan: test_dynamo -k test_guard_filter_fn

Differential Revision: D72722322

Pull Request resolved: https://github.com/pytorch/pytorch/pull/150936
Approved by: https://github.com/jansel
This commit is contained in:
Zhengxu Chen
2025-04-11 03:06:34 +00:00
committed by PyTorch MergeBot
parent 4b0cf9fc00
commit 86370fd658
8 changed files with 202 additions and 60 deletions

View File

@ -2424,7 +2424,9 @@ def compile(
dynamic: _Optional[builtins.bool] = None,
backend: _Union[str, _Callable] = "inductor",
mode: _Union[str, None] = None,
options: _Optional[dict[str, _Union[str, builtins.int, builtins.bool]]] = None,
options: _Optional[
dict[str, _Union[str, builtins.int, builtins.bool, _Callable]]
] = None,
disable: builtins.bool = False,
) -> _Callable[_InputT, _RetT]: ...
@ -2437,7 +2439,9 @@ def compile(
dynamic: _Optional[builtins.bool] = None,
backend: _Union[str, _Callable] = "inductor",
mode: _Union[str, None] = None,
options: _Optional[dict[str, _Union[str, builtins.int, builtins.bool]]] = None,
options: _Optional[
dict[str, _Union[str, builtins.int, builtins.bool, _Callable]]
] = None,
disable: builtins.bool = False,
) -> _Callable[[_Callable[_InputT, _RetT]], _Callable[_InputT, _RetT]]: ...
@ -2449,7 +2453,9 @@ def compile(
dynamic: _Optional[builtins.bool] = None,
backend: _Union[str, _Callable] = "inductor",
mode: _Union[str, None] = None,
options: _Optional[dict[str, _Union[str, builtins.int, builtins.bool]]] = None,
options: _Optional[
dict[str, _Union[str, builtins.int, builtins.bool, _Callable]]
] = None,
disable: builtins.bool = False,
) -> _Union[
_Callable[[_Callable[_InputT, _RetT]], _Callable[_InputT, _RetT]],
@ -2585,6 +2591,10 @@ def compile(
if bisect_backend := CompilerBisector.get_backend():
backend = bisect_backend
guard_filter_fn = None
if options and isinstance(options, dict):
guard_filter_fn = options.pop("guard_filter_fn", None)
if backend == "inductor":
backend = _TorchCompileInductorWrapper(mode, options, dynamic)
else:
@ -2595,6 +2605,7 @@ def compile(
nopython=fullgraph,
dynamic=dynamic,
disable=disable,
guard_filter_fn=guard_filter_fn,
)(model) # type: ignore[return-value]