mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-11 22:34:53 +08:00
[Inductor] Allow customize decompositions for fwd_only trace function (#131329)
Summary: Inductor will aggressively try to decompose and lower ops into a smaller opset. However, sometimes it may not align with kernel coverage (or perf preference) on different backends. (eg. Inductor will decompose Gelu into primitive ops, but certain backends already has a Gelu op) Therefore, we need a mechanism to allow customization of decomp for trace function so that Inductor will simply pass this op through. Test Plan: Reviewers: @eellison Subscribers: Tasks: Tags: Fixes #ISSUE_NUMBER Pull Request resolved: https://github.com/pytorch/pytorch/pull/131329 Approved by: https://github.com/eellison
This commit is contained in:
committed by
PyTorch MergeBot
parent
eb146b10db
commit
027f35d9e6
@ -1802,12 +1802,19 @@ def fx_to_pattern(
|
||||
|
||||
@torch.no_grad()
|
||||
def fwd_only(
|
||||
fn: Callable[..., Any], args: Sequence[Any], *, run_dce: bool = True
|
||||
fn: Callable[..., Any],
|
||||
args: Sequence[Any],
|
||||
*,
|
||||
run_dce: bool = True,
|
||||
get_decomp_fn: Optional[Callable[..., Any]] = None,
|
||||
) -> torch.fx.GraphModule:
|
||||
"""Build a normalized inference graph, for use with fx_to_pattern"""
|
||||
# TODO - look into using aot autograd, asserting no mutating ops here
|
||||
with enable_python_dispatcher():
|
||||
gm = make_fx(fn, select_decomp_table(), tracing_mode="real")(*args)
|
||||
decompositions = (
|
||||
get_decomp_fn() if get_decomp_fn is not None else select_decomp_table()
|
||||
)
|
||||
gm = make_fx(fn, decompositions, tracing_mode="real")(*args)
|
||||
|
||||
from .fx_passes.post_grad import remove_noop_ops
|
||||
|
||||
|
||||
Reference in New Issue
Block a user