[Inductor] Allow customize decompositions for fwd_only trace function (#131329)

Summary:

Inductor will aggressively try to decompose and lower ops into a smaller opset. However, sometimes it may not align with kernel coverage (or perf preference) on different backends. (eg. Inductor will decompose Gelu into primitive ops, but certain backends already has a Gelu op) Therefore, we need a mechanism to allow customization of decomp for trace function so that Inductor will simply pass this op through.

Test Plan:

Reviewers:
@eellison
Subscribers:

Tasks:

Tags:

Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/131329
Approved by: https://github.com/eellison
This commit is contained in:
Xinran / Allan Rui
2024-07-23 13:10:48 +00:00
committed by PyTorch MergeBot
parent eb146b10db
commit 027f35d9e6

View File

@ -1802,12 +1802,19 @@ def fx_to_pattern(
@torch.no_grad()
def fwd_only(
fn: Callable[..., Any], args: Sequence[Any], *, run_dce: bool = True
fn: Callable[..., Any],
args: Sequence[Any],
*,
run_dce: bool = True,
get_decomp_fn: Optional[Callable[..., Any]] = None,
) -> torch.fx.GraphModule:
"""Build a normalized inference graph, for use with fx_to_pattern"""
# TODO - look into using aot autograd, asserting no mutating ops here
with enable_python_dispatcher():
gm = make_fx(fn, select_decomp_table(), tracing_mode="real")(*args)
decompositions = (
get_decomp_fn() if get_decomp_fn is not None else select_decomp_table()
)
gm = make_fx(fn, decompositions, tracing_mode="real")(*args)
from .fx_passes.post_grad import remove_noop_ops