mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add option to skip functional passes in the pattern matcher's replacement graph (#134364)
The pattern matcher runs DCE and remove_noop_ops on the replacement graph by default. Previously we had a switch for the DCE. This PR changes that switch to also control if we run remove_noop_ops. The context was that there is silent incorrectness with auto_functionalized. We use the Pattern matcher to decompose auto_functionalized into a mutable op + clones; remove_noop_ops were deleting the clones. Future: can try #134363 Test Plan: - new test. I wasn't able to produce a silently incorrect example so I settled for asserting that clones still exist in the post-grad graph. Pull Request resolved: https://github.com/pytorch/pytorch/pull/134364 Approved by: https://github.com/eellison ghstack dependencies: #133639
This commit is contained in:
@ -226,15 +226,25 @@ class Match:
|
||||
replacement_fn: ReplaceFn,
|
||||
args: Sequence[Any],
|
||||
trace_fn: Optional[TraceFn] = None,
|
||||
run_dce: bool = True,
|
||||
run_functional_passes: bool = True,
|
||||
) -> None:
|
||||
"""Replace with a graph generated by tracing the replacement_fn.
|
||||
|
||||
Args:
|
||||
run_functional_passes (bool). If we should run passes that
|
||||
assume functional IR (like DCE, remove_noop_ops), on the
|
||||
replacement graph.
|
||||
|
||||
"""
|
||||
from torch._inductor.virtualized import V
|
||||
|
||||
context = V.fake_mode if V.fake_mode is not None else contextlib.nullcontext
|
||||
|
||||
with context:
|
||||
if trace_fn is None:
|
||||
trace_fn = functools.partial(fwd_only, run_dce=run_dce)
|
||||
trace_fn = functools.partial(
|
||||
fwd_only, run_functional_passes=run_functional_passes
|
||||
)
|
||||
replacement = trace_fn(
|
||||
replacement_fn, torch.fx.map_arg(args, lambda arg: arg.meta["val"]) # type: ignore[arg-type]
|
||||
)
|
||||
@ -1810,7 +1820,7 @@ def fwd_only(
|
||||
fn: Callable[..., Any],
|
||||
args: Sequence[Any],
|
||||
*,
|
||||
run_dce: bool = True,
|
||||
run_functional_passes: bool = True,
|
||||
get_decomp_fn: Optional[Callable[..., Any]] = None,
|
||||
) -> torch.fx.GraphModule:
|
||||
"""Build a normalized inference graph, for use with fx_to_pattern"""
|
||||
@ -1823,10 +1833,10 @@ def fwd_only(
|
||||
|
||||
from .fx_passes.post_grad import remove_noop_ops
|
||||
|
||||
remove_noop_ops(gm.graph)
|
||||
|
||||
if run_dce:
|
||||
if run_functional_passes:
|
||||
remove_noop_ops(gm.graph)
|
||||
gm.graph.eliminate_dead_code()
|
||||
|
||||
gm.recompile()
|
||||
return gm
|
||||
|
||||
|
Reference in New Issue
Block a user