mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
documentation for pattern_matcher.py (#127459)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/127459 Approved by: https://github.com/oulgen ghstack dependencies: #127457, #127458
This commit is contained in:
committed by
PyTorch MergeBot
parent
7a60a75256
commit
97ea2b5d83
@ -1,4 +1,40 @@
|
||||
"""
|
||||
# Inductor Pattern Matcher
|
||||
|
||||
The pattern matcher enables search/replace within an FX graph.
|
||||
|
||||
The main entrypoint to the pattern matcher is register_replacement(). Given a
|
||||
search function and a replacement function this will register a replacement with
|
||||
a pass (such as torch._inductor.fx_passes.joint_graph.patterns).
|
||||
|
||||
Internally the pattern matcher represents patterns as a graph (a DAG). Creating
|
||||
new patterns manually as a graph is cumbersome and error-prone so the standard
|
||||
way to create patterns (using register_replacement()) is to provide a search
|
||||
function and a replacement function which is traced and converted into a graph.
|
||||
|
||||
Because the search functions are built somewhat generic (they tend to ignore
|
||||
tensor sizes, for example) register_replacement() allows you to specify an
|
||||
`extra_check` function which performs additional checks to verify that the
|
||||
matched pattern fully matches before returning it.
|
||||
|
||||
## Precompiled Patterns
|
||||
|
||||
New patterns are added using register_replacement(). Patterns added in this way
|
||||
can have a compile-time overhead because they need to be traced before
|
||||
use. Patterns can be precompiled and added using gen_register_replacement()
|
||||
instead. To do this you call gen_register_replacement() instead of
|
||||
register_replacement(). The arguments are the same except for an additional
|
||||
unique name which is used as a lookup key.
|
||||
|
||||
## Internals
|
||||
|
||||
The match DAG is represented by a graph of `PatternExpr` nodes. Each PatternExpr
|
||||
implements a `_match` method which returns either a `Match` object for a
|
||||
successful match or a `FailedMatch` object for a failure to match.
|
||||
"""
|
||||
|
||||
# mypy: disallow-untyped-defs
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
@ -104,6 +140,13 @@ MULTIPLE = Multiple()
|
||||
class Match:
|
||||
"""
|
||||
Represents a successfully matched pattern.
|
||||
|
||||
The `Match` object is returned to represent a successfully matched
|
||||
pattern. Included in the Match are the pattern that was matched, the graph
|
||||
nodes matched, and any args that were used during the matching.
|
||||
|
||||
The args and kwargs are specific to the type of pattern that was matched and
|
||||
provide hints about what was matched.
|
||||
"""
|
||||
|
||||
pattern: PatternExpr
|
||||
@ -202,6 +245,13 @@ class Match:
|
||||
|
||||
|
||||
class FailedMatch(RuntimeError):
|
||||
"""
|
||||
Represents a unsuccessful match.
|
||||
|
||||
The `FailedMatch` object is returned to represent a failure to match a
|
||||
pattern.
|
||||
"""
|
||||
|
||||
format_string: str
|
||||
|
||||
def __init__(self, format_string: str, *args: Any, **kwargs: Any) -> None:
|
||||
@ -235,7 +285,7 @@ def is_match(m: MatchResult) -> TypeGuard[Match]:
|
||||
|
||||
class MatchContext:
|
||||
"""
|
||||
State needed while running PatternExpr._match().
|
||||
Internal state needed while running PatternExpr._match().
|
||||
"""
|
||||
|
||||
outputs: List[Optional[PatternExpr]]
|
||||
@ -277,7 +327,7 @@ class MatchContext:
|
||||
|
||||
class PatternExpr(ABC):
|
||||
"""
|
||||
Base class for types of patterns
|
||||
Base class for types of patterns.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
|
Reference in New Issue
Block a user