mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[dynamo] Allow guards to be dropped with custom filter functions. (#150936)
Summary: A follow up of https://github.com/pytorch/pytorch/pull/150689. Test Plan: test_dynamo -k test_guard_filter_fn Differential Revision: D72722322 Pull Request resolved: https://github.com/pytorch/pytorch/pull/150936 Approved by: https://github.com/jansel
This commit is contained in:
committed by
PyTorch MergeBot
parent
4b0cf9fc00
commit
86370fd658
@ -2424,7 +2424,9 @@ def compile(
|
||||
dynamic: _Optional[builtins.bool] = None,
|
||||
backend: _Union[str, _Callable] = "inductor",
|
||||
mode: _Union[str, None] = None,
|
||||
options: _Optional[dict[str, _Union[str, builtins.int, builtins.bool]]] = None,
|
||||
options: _Optional[
|
||||
dict[str, _Union[str, builtins.int, builtins.bool, _Callable]]
|
||||
] = None,
|
||||
disable: builtins.bool = False,
|
||||
) -> _Callable[_InputT, _RetT]: ...
|
||||
|
||||
@ -2437,7 +2439,9 @@ def compile(
|
||||
dynamic: _Optional[builtins.bool] = None,
|
||||
backend: _Union[str, _Callable] = "inductor",
|
||||
mode: _Union[str, None] = None,
|
||||
options: _Optional[dict[str, _Union[str, builtins.int, builtins.bool]]] = None,
|
||||
options: _Optional[
|
||||
dict[str, _Union[str, builtins.int, builtins.bool, _Callable]]
|
||||
] = None,
|
||||
disable: builtins.bool = False,
|
||||
) -> _Callable[[_Callable[_InputT, _RetT]], _Callable[_InputT, _RetT]]: ...
|
||||
|
||||
@ -2449,7 +2453,9 @@ def compile(
|
||||
dynamic: _Optional[builtins.bool] = None,
|
||||
backend: _Union[str, _Callable] = "inductor",
|
||||
mode: _Union[str, None] = None,
|
||||
options: _Optional[dict[str, _Union[str, builtins.int, builtins.bool]]] = None,
|
||||
options: _Optional[
|
||||
dict[str, _Union[str, builtins.int, builtins.bool, _Callable]]
|
||||
] = None,
|
||||
disable: builtins.bool = False,
|
||||
) -> _Union[
|
||||
_Callable[[_Callable[_InputT, _RetT]], _Callable[_InputT, _RetT]],
|
||||
@ -2585,6 +2591,10 @@ def compile(
|
||||
if bisect_backend := CompilerBisector.get_backend():
|
||||
backend = bisect_backend
|
||||
|
||||
guard_filter_fn = None
|
||||
if options and isinstance(options, dict):
|
||||
guard_filter_fn = options.pop("guard_filter_fn", None)
|
||||
|
||||
if backend == "inductor":
|
||||
backend = _TorchCompileInductorWrapper(mode, options, dynamic)
|
||||
else:
|
||||
@ -2595,6 +2605,7 @@ def compile(
|
||||
nopython=fullgraph,
|
||||
dynamic=dynamic,
|
||||
disable=disable,
|
||||
guard_filter_fn=guard_filter_fn,
|
||||
)(model) # type: ignore[return-value]
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user