documentation for pattern_matcher.py (#127459)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127459
Approved by: https://github.com/oulgen
ghstack dependencies: #127457, #127458
This commit is contained in:
Aaron Orenstein
2024-06-03 11:53:13 -07:00
committed by PyTorch MergeBot
parent 7a60a75256
commit 97ea2b5d83

View File

@ -1,4 +1,40 @@
"""
# Inductor Pattern Matcher
The pattern matcher enables search/replace within an FX graph.
The main entrypoint to the pattern matcher is register_replacement(). Given a
search function and a replacement function this will register a replacement with
a pass (such as torch._inductor.fx_passes.joint_graph.patterns).
Internally the pattern matcher represents patterns as a graph (a DAG). Creating
new patterns manually as a graph is cumbersome and error-prone so the standard
way to create patterns (using register_replacement()) is to provide a search
function and a replacement function which is traced and converted into a graph.
Because the search functions are built somewhat generic (they tend to ignore
tensor sizes, for example) register_replacement() allows you to specify an
`extra_check` function which performs additional checks to verify that the
matched pattern fully matches before returning it.
## Precompiled Patterns
New patterns are added using register_replacement(). Patterns added in this way
can have a compile-time overhead because they need to be traced before
use. Patterns can be precompiled and added using gen_register_replacement()
instead. To do this you call gen_register_replacement() instead of
register_replacement(). The arguments are the same except for an additional
unique name which is used as a lookup key.
## Internals
The match DAG is represented by a graph of `PatternExpr` nodes. Each PatternExpr
implements a `_match` method which returns either a `Match` object for a
successful match or a `FailedMatch` object for a failure to match.
"""
# mypy: disallow-untyped-defs
from __future__ import annotations
import contextlib
@ -104,6 +140,13 @@ MULTIPLE = Multiple()
class Match:
"""
Represents a successfully matched pattern.
The `Match` object is returned to represent a successfully matched
pattern. Included in the Match are the pattern that was matched, the graph
nodes matched, and any args that were used during the matching.
The args and kwargs are specific to the type of pattern that was matched and
provide hints about what was matched.
"""
pattern: PatternExpr
@ -202,6 +245,13 @@ class Match:
class FailedMatch(RuntimeError):
"""
Represents a unsuccessful match.
The `FailedMatch` object is returned to represent a failure to match a
pattern.
"""
format_string: str
def __init__(self, format_string: str, *args: Any, **kwargs: Any) -> None:
@ -235,7 +285,7 @@ def is_match(m: MatchResult) -> TypeGuard[Match]:
class MatchContext:
"""
State needed while running PatternExpr._match().
Internal state needed while running PatternExpr._match().
"""
outputs: List[Optional[PatternExpr]]
@ -277,7 +327,7 @@ class MatchContext:
class PatternExpr(ABC):
"""
Base class for types of patterns
Base class for types of patterns.
"""
@abstractmethod