mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
See #127836 for details. Pull Request resolved: https://github.com/pytorch/pytorch/pull/127843 Approved by: https://github.com/oulgen ghstack dependencies: #127842
239 lines
10 KiB
Python
239 lines
10 KiB
Python
# mypy: allow-untyped-defs
|
|
import functools
|
|
from contextlib import contextmanager
|
|
from dataclasses import dataclass, field
|
|
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Set, Tuple
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
|
|
@dataclass
|
|
class TracingConfig:
|
|
"""
|
|
This represents a symbolic tracing configuration.
|
|
|
|
Args:
|
|
tracer (torch.fx.Tracer): An instance of :class:`torch.fx.Tracer` to
|
|
use for symbolic tracing. The default value is the native
|
|
:class:`torch.fx.Tracer` constructed with default arguments.
|
|
However, the user may want to pass a different value such as the
|
|
``HFTracer`` for models in the HuggingFace Transformers_ library.
|
|
.. _Transformers: https://huggingface.co/docs/transformers/index
|
|
concrete_args (Optional[Dict[str, Any]]): Concrete arguments that
|
|
should not be treated as ``torch.fx.Proxy`` when tracing the
|
|
module ``forward()``. Passing ``concrete_args`` allows partially
|
|
specializing the forward, e.g. to remove control flow or data
|
|
structures. This ``concrete_args`` here is the same argument used
|
|
in :meth:`~torch.fx.Tracer.trace`.
|
|
"""
|
|
|
|
tracer: torch.fx.Tracer = field(default_factory=torch.fx.Tracer)
|
|
concrete_args: Optional[Dict[str, Any]] = None
|
|
|
|
|
|
class _ParamUsageInfo(NamedTuple):
|
|
"""
|
|
This is used for ``_ExecutionInfo.module_to_param_usage_infos`` to record
|
|
execution information. The ``dict`` maps modules to a list of these
|
|
``_ParamUsageInfo`` instances, where each instance represents a group of
|
|
parameters used together.
|
|
|
|
Specifically, for each module key in the ``dict``, each instance of this
|
|
class represents either:
|
|
(1) the module and some sublist of its ``named_parameters()`` used
|
|
together in execution (see ``_patched_create_proxy()``), or
|
|
(2) a submodule and all of ``submodule.named_parameters()`` (see
|
|
``_patched_call_module()``).
|
|
|
|
Type (1) corresponds to directly using parameters in ops without calling
|
|
``forward()``, and type (2) corresponds to calling ``forward()``. The
|
|
mapped-to lists in the ``dict`` follow the execution order.
|
|
"""
|
|
|
|
module: nn.Module
|
|
named_params: List[Tuple[str, nn.Parameter]]
|
|
|
|
|
|
class _ExecutionInfo:
|
|
"""
|
|
This represents the execution order information from the forward pass.
|
|
|
|
Attributes:
|
|
curr_module (nn.Module): Current module being traced.
|
|
module_forward_order (List[nn.Module]): The modules in (pre-)forward
|
|
order, i.e. the order in which their ``forward()`` methods are
|
|
called. Each call to a module's ``forward()`` corresponds to one
|
|
element in the list.
|
|
module_to_param_usage_infos (Dict[nn.Module, List[_ParamUsageInfo]]):
|
|
Maps a module to a list of module execution infos. See
|
|
:class:`_ParamUsageInfo` for details.
|
|
param_forward_order (List[nn.Parameter]): The parameters in forward
|
|
execution order, where only a parameter's first participation is
|
|
included.
|
|
visited_params (Set[nn.Parameter]): The parameters visited so far
|
|
during the trace. This is only used during tracing for fast
|
|
membership check. Invariant: The parameters in
|
|
``param_forward_order`` are exactly those in ``visited_params``.
|
|
"""
|
|
|
|
def __init__(self, root_module: nn.Module) -> None:
|
|
self.curr_module: nn.Module = root_module
|
|
self.module_forward_order: List[nn.Module] = [root_module]
|
|
self.module_to_param_usage_infos: Dict[nn.Module, List[_ParamUsageInfo]] = {
|
|
root_module: []
|
|
}
|
|
self.param_forward_order: List[nn.Parameter] = []
|
|
self.visited_params: Set[nn.Parameter] = set()
|
|
|
|
|
|
class _ExecOrderTracer:
|
|
def __init__(self) -> None:
|
|
self.exec_info: Optional[_ExecutionInfo] = None
|
|
|
|
@contextmanager
|
|
def patch_tracer(self, tracer: torch.fx.Tracer, root_module: nn.Module):
|
|
self.exec_info = _ExecutionInfo(root_module)
|
|
orig_call_module = tracer.call_module
|
|
orig_create_proxy = tracer.create_proxy
|
|
tracer.call_module = functools.partial(
|
|
self._patched_call_module, orig_call_module, self.exec_info
|
|
)
|
|
fqn_to_param = dict(root_module.named_parameters())
|
|
tracer.create_proxy = functools.partial(
|
|
self._patched_create_proxy,
|
|
orig_create_proxy,
|
|
self.exec_info,
|
|
fqn_to_param,
|
|
)
|
|
try:
|
|
yield
|
|
finally:
|
|
tracer.call_module = orig_call_module
|
|
tracer.create_proxy = orig_create_proxy
|
|
|
|
def _patched_call_module(
|
|
self,
|
|
call_module: Callable,
|
|
exec_info: _ExecutionInfo,
|
|
# Below are the expected arguments to `call_module()`
|
|
module: nn.Module,
|
|
forward: Callable,
|
|
args: Tuple[Any, ...],
|
|
kwargs: Dict[str, Any],
|
|
) -> Any:
|
|
"""
|
|
Overrides ``call_module`` to save execution information to
|
|
``exec_info``. Note that ``call_module`` is called during symbolic
|
|
tracing for each non-root module.
|
|
|
|
Args:
|
|
call_module (Callable): Original ``call_module`` to override.
|
|
exec_info (_ExecutionInfo): Used to record execution information.
|
|
module (nn.Module): Module corresponding to this ``call_module``.
|
|
forward (Callable): ``forward()`` method of ``module`` to be called
|
|
for this ``call_module``.
|
|
args (Tuple[Any, ...]): Positional arguments for ``forward``.
|
|
kwargs (Dict[str, Any]): Keyword arguments for ``forward``.
|
|
|
|
Returns:
|
|
Same return value as ``call_module``.
|
|
"""
|
|
exec_info.module_forward_order.append(module)
|
|
named_params = list(module.named_parameters())
|
|
curr_module = exec_info.curr_module
|
|
if named_params:
|
|
assert (
|
|
curr_module in exec_info.module_to_param_usage_infos
|
|
), "The current module should have already been processed by a patched `call_module`"
|
|
exec_info.module_to_param_usage_infos[exec_info.curr_module].append(
|
|
_ParamUsageInfo(module, named_params)
|
|
)
|
|
prev_curr_module = curr_module
|
|
exec_info.curr_module = module
|
|
exec_info.module_to_param_usage_infos[module] = []
|
|
output = call_module(module, forward, args, kwargs)
|
|
exec_info.curr_module = prev_curr_module
|
|
return output
|
|
|
|
def _patched_create_proxy(
|
|
self,
|
|
create_proxy: Callable,
|
|
exec_info: _ExecutionInfo,
|
|
fqn_to_param: Dict[str, nn.Parameter],
|
|
# Below are the expected arguments to `create_proxy()`
|
|
kind: str,
|
|
target: torch.fx.node.Target,
|
|
args: Tuple[Any, ...],
|
|
kwargs: Dict[str, Any],
|
|
name: Optional[str] = None,
|
|
type_expr: Optional[Any] = None,
|
|
proxy_factory_fn: Optional[Callable[[torch.fx.Node], torch.fx.Proxy]] = None,
|
|
) -> torch.fx.Proxy:
|
|
"""
|
|
Overrides ``create_proxy`` to save execution information to
|
|
``exec_info``. Note that ``create_proxy`` is called during symbolic
|
|
tracing for each leaf function/method/module.
|
|
|
|
Args:
|
|
create_proxy (Callable): Original ``create_proxy`` to override.
|
|
exec_info (_ExecutionInfo): Used to record execution information.
|
|
fqn_to_param (Dict[str, nn.Parameter]): ``dict`` version of the
|
|
root module's ``named_parameters()`` with FQN as key and
|
|
parameter as value.
|
|
kind (str): Kind of the target method ('call_function',
|
|
'call_method', 'get_attr', 'call_module', 'placeholder', or
|
|
'output'). See :class:`torch.fx.Graph` for details. This is
|
|
passed to ``create_proxy``.
|
|
target (torch.fx.node.Target): Contains the string name of the
|
|
function/method/module. This is passed to ``create_proxy``.
|
|
args (Tuple[Any, ...]): Positional arguments for the function/
|
|
method/module. This is passed to ``create_proxy``.
|
|
kwargs (Dict[str, Any]): Keyword arguments for the function/method/
|
|
module. This is passed to ``create_proxy``
|
|
name (Optional[str]): An optional string name for the ``Node``
|
|
created in ``create_proxy``. This is passed to
|
|
``create_proxy``.
|
|
type_expr (Optional[Any]): An optional type annotation representing
|
|
the Python type that the output of the node has. This is passed
|
|
to ``create_proxy``.
|
|
proxy_factory_fn (Callable[[torch.fx.Node], torch.fx.Proxy]):
|
|
An alternative proxy constructor used in ``create_proxy``. This
|
|
is passed to ``create_proxy``.
|
|
|
|
Returns:
|
|
torch.fx.Proxy: Created ``Node`` wrapped in a ``Proxy`` object.
|
|
"""
|
|
proxy = create_proxy(
|
|
kind, target, args, kwargs, name, type_expr, proxy_factory_fn
|
|
)
|
|
curr_module = exec_info.curr_module
|
|
if kind in ("call_function", "call_method"):
|
|
if args is not None:
|
|
named_params: List[Tuple[str, nn.Parameter]] = []
|
|
for arg in args:
|
|
if (
|
|
isinstance(arg, torch.fx.Proxy)
|
|
and arg.node.target in fqn_to_param
|
|
):
|
|
param = fqn_to_param[arg.node.target]
|
|
named_params.append((arg.node.target, param))
|
|
if param not in exec_info.visited_params:
|
|
exec_info.visited_params.add(param)
|
|
exec_info.param_forward_order.append(param)
|
|
if named_params:
|
|
exec_info.module_to_param_usage_infos[curr_module].append(
|
|
_ParamUsageInfo(curr_module, named_params)
|
|
)
|
|
elif kind == "call_module":
|
|
named_params = list(curr_module.named_parameters())
|
|
if named_params:
|
|
exec_info.module_to_param_usage_infos[curr_module].append(
|
|
_ParamUsageInfo(curr_module, named_params)
|
|
)
|
|
for _, param in named_params:
|
|
if param not in exec_info.visited_params:
|
|
exec_info.visited_params.add(param)
|
|
exec_info.param_forward_order.append(param)
|
|
return proxy
|