Add option to skip functional passes in the pattern matcher's replacement graph (#134364)

The pattern matcher runs DCE and remove_noop_ops on the replacement
graph by default. Previously we had a switch for the DCE. This PR
changes that switch to also control if we run remove_noop_ops.

The context was that there is silent incorrectness with
auto_functionalized. We use the Pattern matcher to decompose
auto_functionalized into a mutable op + clones; remove_noop_ops were
deleting the clones.

Future: can try #134363

Test Plan:
- new test. I wasn't able to produce a silently incorrect example so I
  settled for asserting that clones still exist in the post-grad graph.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/134364
Approved by: https://github.com/eellison
ghstack dependencies: #133639
This commit is contained in:
rzou
2024-08-23 12:49:17 -07:00
committed by PyTorch MergeBot
parent 2c8fc3f4ce
commit a1d0b4d568
3 changed files with 47 additions and 7 deletions

View File

@ -226,15 +226,25 @@ class Match:
replacement_fn: ReplaceFn,
args: Sequence[Any],
trace_fn: Optional[TraceFn] = None,
run_dce: bool = True,
run_functional_passes: bool = True,
) -> None:
"""Replace with a graph generated by tracing the replacement_fn.
Args:
run_functional_passes (bool). If we should run passes that
assume functional IR (like DCE, remove_noop_ops), on the
replacement graph.
"""
from torch._inductor.virtualized import V
context = V.fake_mode if V.fake_mode is not None else contextlib.nullcontext
with context:
if trace_fn is None:
trace_fn = functools.partial(fwd_only, run_dce=run_dce)
trace_fn = functools.partial(
fwd_only, run_functional_passes=run_functional_passes
)
replacement = trace_fn(
replacement_fn, torch.fx.map_arg(args, lambda arg: arg.meta["val"]) # type: ignore[arg-type]
)
@ -1810,7 +1820,7 @@ def fwd_only(
fn: Callable[..., Any],
args: Sequence[Any],
*,
run_dce: bool = True,
run_functional_passes: bool = True,
get_decomp_fn: Optional[Callable[..., Any]] = None,
) -> torch.fx.GraphModule:
"""Build a normalized inference graph, for use with fx_to_pattern"""
@ -1823,10 +1833,10 @@ def fwd_only(
from .fx_passes.post_grad import remove_noop_ops
remove_noop_ops(gm.graph)
if run_dce:
if run_functional_passes:
remove_noop_ops(gm.graph)
gm.graph.eliminate_dead_code()
gm.recompile()
return gm