mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
This is follow-up of #164653 to continue applying `UP035` fixes. The purpose is to finally enable this rule. Pull Request resolved: https://github.com/pytorch/pytorch/pull/165214 Approved by: https://github.com/ezyang
311 lines
10 KiB
Python
311 lines
10 KiB
Python
# mypy: allow-untyped-defs
|
|
import inspect
|
|
import logging
|
|
from collections.abc import Callable
|
|
from functools import wraps
|
|
from queue import Queue
|
|
|
|
import torch.nn as nn
|
|
from torch.fx._compatibility import compatibility
|
|
from torch.fx.graph_module import GraphModule
|
|
from torch.fx.passes.infra.pass_base import PassResult
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
logger.setLevel(logging.WARNING)
|
|
|
|
__all__ = ["pass_result_wrapper", "this_before_that_pass_constraint", "PassManager"]
|
|
|
|
|
|
@compatibility(is_backward_compatible=False)
|
|
def pass_result_wrapper(fn: Callable) -> Callable:
|
|
"""
|
|
Wrapper for passes which currently do not return a PassResult.
|
|
This wrapper makes them return a PassResult containing the modified object
|
|
and True for the "modified" flag.
|
|
|
|
Args:
|
|
fn (Callable[Module, Any])
|
|
|
|
Returns:
|
|
wrapped_fn (Callable[Module, PassResult])
|
|
"""
|
|
if fn is None:
|
|
# pyrefly: ignore # bad-return
|
|
return None
|
|
|
|
@wraps(fn)
|
|
def wrapped_fn(gm):
|
|
res = fn(gm)
|
|
if res is None:
|
|
return PassResult(gm, True)
|
|
if isinstance(res, PassResult):
|
|
return res
|
|
elif isinstance(res, nn.Module):
|
|
return PassResult(res, True)
|
|
|
|
if not inspect.isfunction(fn):
|
|
wrapped_fn.__name__ = type(fn).__name__
|
|
|
|
return wrapped_fn
|
|
|
|
|
|
def _validate_pass_schedule_constraint(
|
|
constraint: Callable[[Callable, Callable], bool], passes: list[Callable]
|
|
) -> None:
|
|
for i, a in enumerate(passes):
|
|
for j, b in enumerate(passes[i + 1 :]):
|
|
if constraint(a, b):
|
|
continue
|
|
raise RuntimeError(
|
|
f"pass schedule constraint violated. Expected {a} before {b}"
|
|
f" but found {a} at index {i} and {b} at index{j} in pass"
|
|
f" list."
|
|
)
|
|
|
|
|
|
def _topological_sort_passes(
|
|
passes: list[Callable], constraints: list[Callable]
|
|
) -> list[Callable]:
|
|
"""
|
|
Args
|
|
passes: Passes that we are ordering
|
|
constraints: Constraints applied on these passes
|
|
|
|
Returns
|
|
A sorted list of callables and a boolean of if a circular dependency
|
|
existed
|
|
"""
|
|
if len(constraints) == 0:
|
|
return passes
|
|
|
|
# Construct a graph mapping nodes to a list of their users
|
|
graph: dict[Callable, list[Callable]] = {p: [] for p in passes}
|
|
indegree_map: dict[Callable, int] = dict.fromkeys(passes, 0)
|
|
candidates: Queue = Queue()
|
|
for a in passes:
|
|
for b in passes:
|
|
if a == b:
|
|
continue
|
|
|
|
for constraint in constraints:
|
|
if not constraint(a, b):
|
|
graph[b].append(a)
|
|
indegree_map[a] += 1
|
|
|
|
if indegree_map[a] == 0:
|
|
candidates.put(a)
|
|
|
|
visited: dict[Callable, bool] = dict.fromkeys(passes, False)
|
|
sorted_passes: list[Callable] = []
|
|
|
|
while not candidates.empty():
|
|
p = candidates.get()
|
|
sorted_passes.append(p)
|
|
visited[p] = True
|
|
|
|
for n in graph[p]:
|
|
if not visited[n]:
|
|
indegree_map[n] -= 1
|
|
if indegree_map[n] == 0:
|
|
candidates.put(n)
|
|
|
|
# Check if there are unvisited nodes (aka cycles in the graph)
|
|
cycle_passes = list(filter(lambda p: indegree_map[p] != 0, indegree_map.keys()))
|
|
if len(cycle_passes) != 0:
|
|
error = (
|
|
f"Circular dependency detected within the following passes: {cycle_passes}"
|
|
)
|
|
raise RuntimeError(error)
|
|
|
|
return sorted_passes
|
|
|
|
|
|
@compatibility(is_backward_compatible=False)
|
|
def this_before_that_pass_constraint(this: Callable, that: Callable) -> Callable:
|
|
"""
|
|
Defines a partial order ('depends on' function) where `this` must occur
|
|
before `that`.
|
|
|
|
For example, the following pass list and constraint list would be invalid.
|
|
```
|
|
passes = [pass_b, pass_a]
|
|
|
|
constraints = [this_before_that_pass_constraint(pass_a, pass_b)]
|
|
```
|
|
|
|
Args:
|
|
this (Callable): pass which should occur first
|
|
that (Callable): pass which should occur later
|
|
|
|
Returns:
|
|
depends_on (Callable[[Object, Object], bool]
|
|
"""
|
|
|
|
def depends_on(a: Callable, b: Callable):
|
|
return a != that or b != this
|
|
|
|
return depends_on
|
|
|
|
|
|
@compatibility(is_backward_compatible=False)
|
|
class PassManager:
|
|
"""
|
|
Construct a PassManager.
|
|
|
|
Collects passes and constraints. This defines the pass schedule, manages
|
|
pass constraints and pass execution.
|
|
|
|
Args:
|
|
passes (Optional[List[Callable]]): List of passes. A pass is a
|
|
callable which modifies an object and returns a PassResult
|
|
constraint (Optional[List[Callable]]): List of constraints. A
|
|
constraint is a callable which takes two passes (A, B) and returns
|
|
True if A depends on B and False otherwise. See implementation of
|
|
`this_before_that_pass_constraint` for example.
|
|
steps (int): Max number of times we run the passes (default = 1).
|
|
run_checks_after_each_pass (bool): Whether to run checks and linting
|
|
after each pass
|
|
suppress_check_failures (bool): Whether to raise errors when running
|
|
checks
|
|
"""
|
|
|
|
passes: list[Callable[[nn.Module], PassResult]]
|
|
constraints: list[Callable[[Callable, Callable], bool]]
|
|
_validated: bool = False
|
|
steps: int = 1
|
|
|
|
def __init__(
|
|
self,
|
|
passes=None,
|
|
constraints=None,
|
|
steps=None,
|
|
run_checks_after_each_pass: bool = False,
|
|
suppress_check_failures: bool = False,
|
|
):
|
|
self.passes = passes or []
|
|
self.constraints = constraints or []
|
|
if steps:
|
|
self.steps = steps
|
|
|
|
self.run_checks_after_each_pass = run_checks_after_each_pass
|
|
self.suppress_check_failures = suppress_check_failures
|
|
|
|
def add_pass(self, _pass: Callable):
|
|
"""
|
|
Adds a pass into the current list of passes.
|
|
"""
|
|
self.passes.append(_pass)
|
|
self._validated = False
|
|
|
|
def add_constraint(self, constraint: Callable):
|
|
"""
|
|
Adds a constraint into the current list of constraints.
|
|
"""
|
|
self.constraints.append(constraint)
|
|
self._validated = False
|
|
|
|
def validate_constraints(self):
|
|
"""
|
|
Validates that current pass schedule defined by `self.passes` is valid
|
|
according to all constraints in `self.constraints`
|
|
"""
|
|
if self._validated:
|
|
return
|
|
for constraint in self.constraints:
|
|
_validate_pass_schedule_constraint(constraint, self.passes)
|
|
self._validated = True
|
|
|
|
def solve_constraints(self):
|
|
"""
|
|
Finds a valid traversal order based on the given constraints and orders
|
|
the passes based on this order.
|
|
|
|
If a circular dependency exists between the constraints and steps = 1,
|
|
then we will raise an error because if steps != 1 this means that we
|
|
will re-run the passes, allowing for circular dependencies.
|
|
"""
|
|
self.passes = _topological_sort_passes(self.passes, self.constraints)
|
|
self._validated = True
|
|
|
|
def add_checks(self, check: Callable) -> None:
|
|
"""
|
|
Adds a function which takes runs various checks on a given graph module.
|
|
This function is run before and after each pass if the
|
|
`run_checks_after_each_pass` flag is enabled.
|
|
"""
|
|
sig = inspect.signature(check)
|
|
|
|
if len(list(sig.parameters.values())) != 1:
|
|
raise TypeError(
|
|
"PassManager check function should only take in one variable, a module"
|
|
)
|
|
|
|
setattr(self, "check", check) # noqa: B010
|
|
|
|
def check(self, module: nn.Module) -> None:
|
|
pass
|
|
|
|
def __call__(self, module: nn.Module) -> PassResult:
|
|
"""
|
|
Runs a list of passes in the order based on `self.passes` on the given
|
|
graph module. Each time a pass is run, checks and linting will be run on
|
|
the graph module if `run_checks_after_each_pass` is set.
|
|
|
|
If the module is a graph module, we will run the list of passes until
|
|
the graph stops changing, or until `steps` number of times.
|
|
"""
|
|
# Order the passes based on the constraints
|
|
if not self._validated:
|
|
self.solve_constraints()
|
|
|
|
# Check graph invariants
|
|
self.check(module)
|
|
|
|
# Run the set of passes `steps` number of times or until the graph stops
|
|
# changing
|
|
overall_modified = False
|
|
for _ in range(self.steps):
|
|
modified = False
|
|
|
|
# Run the set of passes on the graph module
|
|
for i, fn in enumerate(self.passes):
|
|
fn_name = fn.__name__ if inspect.isfunction(fn) else type(fn).__name__
|
|
logger.debug("Running pass '%s'", fn_name)
|
|
|
|
try:
|
|
res = fn(module)
|
|
|
|
if not isinstance(res, PassResult) and not hasattr(
|
|
res, "graph_module"
|
|
):
|
|
raise TypeError(
|
|
f"The result of the pass {fn_name} should be type PassResult."
|
|
+ "Please wrap it with pass_result_wrapper()"
|
|
)
|
|
module = res.graph_module
|
|
modified = modified or res.modified
|
|
|
|
if isinstance(module, GraphModule):
|
|
logger.debug("Graph after pass '%s': %s", fn_name, module.graph)
|
|
module.recompile()
|
|
|
|
# Check graph invariants
|
|
if self.run_checks_after_each_pass:
|
|
self.check(module)
|
|
|
|
except Exception as e:
|
|
prev_pass_names = [
|
|
p.__name__ if inspect.isfunction(p) else type(p).__name__
|
|
for p in self.passes[:i]
|
|
]
|
|
msg = f"An error occurred when running the '{fn_name}' pass after the following passes: {prev_pass_names}"
|
|
raise Exception(msg) from e # noqa: TRY002
|
|
|
|
# If the graph no longer changes, then we can stop running these passes
|
|
overall_modified = overall_modified or modified
|
|
if not modified:
|
|
break
|
|
|
|
return PassResult(module, overall_modified)
|